aboutsummaryrefslogtreecommitdiff
path: root/AST.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2021-06-24 23:14:54 +0200
committerTom Smeding <tom@tomsmeding.com>2021-06-24 23:14:54 +0200
commit0fefe4822c65bde81ec4c0da1b5b32a9b411ca79 (patch)
tree0efeffb8b1b6d6126bc806209a2f5a64fb32c96f /AST.hs
Initial
Diffstat (limited to 'AST.hs')
-rw-r--r--AST.hs288
1 files changed, 288 insertions, 0 deletions
diff --git a/AST.hs b/AST.hs
new file mode 100644
index 0000000..7e9c69c
--- /dev/null
+++ b/AST.hs
@@ -0,0 +1,288 @@
+{-# LANGUAGE ConstraintKinds #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeOperators #-}
+module AST where
+
+import Data.GADT.Compare
+import Data.Type.Equality
+import qualified Data.Vector as V
+import Data.Vector (Vector)
+
+
+data Exp env a where
+ App :: Exp env (a -> b) -> Exp env a -> Exp env b
+ Lam :: Type t -> Exp (t ': env) a -> Exp env (t -> a)
+ Var :: Type a -> Idx env a -> Exp env a
+ Let :: Exp env t -> Exp (t ': env) a -> Exp env a
+ Lit :: Literal a -> Exp env a
+ Cond :: Exp env Bool -> Exp env a -> Exp env a -> Exp env a
+ Const :: Constant a -> Exp env a
+ Pair :: Exp env a -> Exp env b -> Exp env (a, b)
+ Fst :: Exp env (a, b) -> Exp env a
+ Snd :: Exp env (a, b) -> Exp env b
+ Build :: ShapeType sh -> Exp env sh -> Exp env (sh -> a) -> Exp env (Array sh a)
+ Ifold :: ShapeType sh -> Exp env ((s, sh) -> s) -> Exp env s -> Exp env sh -> Exp env s
+ Index :: Exp env (Array sh a) -> Exp env sh -> Exp env a
+ Shape :: Exp env (Array sh a) -> Exp env sh
+
+data Constant a where
+ CAddI :: Constant ((Int, Int) -> Int)
+ CSubI :: Constant ((Int, Int) -> Int)
+ CMulI :: Constant ((Int, Int) -> Int)
+ CDivI :: Constant ((Int, Int) -> Int)
+ CAddF :: Constant ((Double, Double) -> Double)
+ CSubF :: Constant ((Double, Double) -> Double)
+ CMulF :: Constant ((Double, Double) -> Double)
+ CDivF :: Constant ((Double, Double) -> Double)
+ CLog :: Constant (Double -> Double)
+ CExp :: Constant (Double -> Double)
+ CtoF :: Constant (Int -> Double)
+ CRound :: Constant (Double -> Int)
+
+ CLtI :: Constant ((Int, Int) -> Bool)
+ CLtF :: Constant ((Double, Double) -> Bool)
+ CEq :: Type a -> Constant ((a, a) -> Bool)
+ CAnd :: Constant ((Bool, Bool) -> Bool)
+ COr :: Constant ((Bool, Bool) -> Bool)
+ CNot :: Constant (Bool -> Bool)
+
+data Type a where
+ TInt :: Type Int
+ TBool :: Type Bool
+ TDouble :: Type Double
+ TArray :: ShapeType sh -> Type a -> Type (Array sh a)
+ TNil :: Type ()
+ TPair :: Type a -> Type b -> Type (a, b)
+ TFun :: Type a -> Type b -> Type (a -> b)
+
+data Idx env a where
+ Zero :: Idx (a ': env) a
+ Succ :: Idx env a -> Idx (t ': env) a
+
+data Literal a where
+ LInt :: Int -> Literal Int
+ LBool :: Bool -> Literal Bool
+ LDouble :: Double -> Literal Double
+ LArray :: Array sh a -> Literal (Array sh a)
+ LShape :: Shape sh -> Literal sh
+ LNil :: Literal ()
+ LPair :: Literal a -> Literal b -> Literal (a, b)
+
+data Shape sh where
+ Z :: Shape ()
+ (:.) :: Int -> Shape sh -> Shape (Int, sh)
+
+data ShapeType sh where
+ STZ :: ShapeType ()
+ STC :: ShapeType sh -> ShapeType (Int, sh)
+
+data Array sh a where
+ Array :: Shape sh -> Type a -> Vector a -> Array sh a
+
+deriving instance Show (Exp env a)
+deriving instance Show (Constant a)
+deriving instance Show (Type a)
+deriving instance Show (Idx env a)
+deriving instance Show (Literal a)
+deriving instance Show (Shape a)
+deriving instance Show (ShapeType a)
+
+instance Show (Array sh a) where
+ showsPrec p (Array sh t v) =
+ showParen (p > 10) $
+ showString "Array "
+ . showsPrec 11 sh
+ . showsPrec 11 t
+ . (case typeHasShow t of
+ Just Has -> showsPrec 11 v
+ Nothing -> showString ("[_ * " ++ show (V.length v) ++ "]"))
+
+deriving instance Eq (Type a)
+deriving instance Eq (Shape sh)
+deriving instance Eq (ShapeType sh)
+deriving instance Eq a => Eq (Array sh a)
+
+instance GEq (Exp env) where
+ geq (App a b) (App a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl
+ geq App{} _ = Nothing
+ geq (Lam t e) (Lam t' e') | Just Refl <- geq t t', Just Refl <- geq e e' = Just Refl
+ geq Lam{} _ = Nothing
+ geq (Var t i) (Var t' i') | Just Refl <- geq t t', Just Refl <- geq i i' = Just Refl
+ geq Var{} _ = Nothing
+ geq (Let a b) (Let a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl
+ geq Let{} _ = Nothing
+ geq (Lit l) (Lit l') | Just Refl <- geq l l' = Just Refl
+ geq Lit{} _ = Nothing
+ geq (Cond a b c) (Cond a' b' c') | Just Refl <- geq a a', Just Refl <- geq b b', Just Refl <- geq c c' = Just Refl
+ geq Cond{} _ = Nothing
+ geq (Const c) (Const c') | Just Refl <- geq c c' = Just Refl
+ geq Const{} _ = Nothing
+ geq (Pair a b) (Pair a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl
+ geq Pair{} _ = Nothing
+ geq (Fst a) (Fst a') | Just Refl <- geq a a' = Just Refl
+ geq Fst{} _ = Nothing
+ geq (Snd a) (Snd a') | Just Refl <- geq a a' = Just Refl
+ geq Snd{} _ = Nothing
+ geq (Build t a b) (Build t' a' b') | Just Refl <- geq t t', Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl
+ geq Build{} _ = Nothing
+ geq (Ifold t a b c) (Ifold t' a' b' c') | Just Refl <- geq t t', Just Refl <- geq a a', Just Refl <- geq b b' , Just Refl <- geq c c'
+ = Just Refl
+ geq Ifold{} _ = Nothing
+ geq (Index a b) (Index a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl
+ geq Index{} _ = Nothing
+ geq (Shape a) (Shape a') | Just Refl <- geq a a' = Just Refl
+ geq Shape{} _ = Nothing
+
+instance GEq Constant where
+ geq CAddI CAddI = Just Refl ; geq CAddI _ = Nothing
+ geq CSubI CSubI = Just Refl ; geq CSubI _ = Nothing
+ geq CMulI CMulI = Just Refl ; geq CMulI _ = Nothing
+ geq CDivI CDivI = Just Refl ; geq CDivI _ = Nothing
+ geq CAddF CAddF = Just Refl ; geq CAddF _ = Nothing
+ geq CSubF CSubF = Just Refl ; geq CSubF _ = Nothing
+ geq CMulF CMulF = Just Refl ; geq CMulF _ = Nothing
+ geq CDivF CDivF = Just Refl ; geq CDivF _ = Nothing
+ geq CLog CLog = Just Refl ; geq CLog _ = Nothing
+ geq CExp CExp = Just Refl ; geq CExp _ = Nothing
+ geq CtoF CtoF = Just Refl ; geq CtoF _ = Nothing
+ geq CRound CRound = Just Refl ; geq CRound _ = Nothing
+ geq CLtI CLtI = Just Refl ; geq CLtI _ = Nothing
+ geq CLtF CLtF = Just Refl ; geq CLtF _ = Nothing
+ geq (CEq t) (CEq t') | Just Refl <- geq t t' = Just Refl ; geq CEq{} _ = Nothing
+ geq CAnd CAnd = Just Refl ; geq CAnd _ = Nothing
+ geq COr COr = Just Refl ; geq COr _ = Nothing
+ geq CNot CNot = Just Refl ; geq CNot _ = Nothing
+
+instance GEq Type where
+ geq TInt TInt = Just Refl ; geq TInt _ = Nothing
+ geq TBool TBool = Just Refl ; geq TBool _ = Nothing
+ geq TDouble TDouble = Just Refl ; geq TDouble _ = Nothing
+ geq (TArray sht t) (TArray sht' t') | Just Refl <- geq sht sht', Just Refl <- geq t t' = Just Refl ; geq TArray{} _ = Nothing
+ geq TNil TNil = Just Refl ; geq TNil _ = Nothing
+ geq (TPair a b) (TPair a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl ; geq TPair{} _ = Nothing
+ geq (TFun a b) (TFun a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl ; geq TFun{} _ = Nothing
+
+instance GEq (Idx env) where
+ geq Zero Zero = Just Refl
+ geq (Succ i) (Succ i') | Just Refl <- geq i i' = Just Refl
+ geq _ _ = Nothing
+
+instance GEq Literal where
+ geq (LInt a) (LInt a') | a == a' = Just Refl ; geq LInt{} _ = Nothing
+ geq (LBool a) (LBool a') | a == a' = Just Refl ; geq LBool{} _ = Nothing
+ geq (LDouble a) (LDouble a') | a == a' = Just Refl ; geq LDouble{} _ = Nothing
+ geq (LArray (Array sht t v)) (LArray (Array sht' t' v'))
+ | Just Refl <- geq sht sht'
+ , Just Refl <- geq t t'
+ = case typeHasEq t of
+ Just Has | v == v' -> Just Refl
+ | otherwise -> Nothing
+ Nothing -> error "GEq Literal: Literal array of incomparable values"
+ geq LArray{} _ = Nothing
+ geq (LShape a) (LShape a') | Just Refl <- geq a a' = Just Refl ; geq LShape{} _ = Nothing
+ geq LNil LNil = Just Refl ; geq LNil _ = Nothing
+ geq (LPair a b) (LPair a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl ; geq LPair{} _ = Nothing
+
+instance GEq Shape where
+ geq Z Z = Just Refl
+ geq (n :. sh) (n' :. sh') | n == n', Just Refl <- geq sh sh' = Just Refl
+ geq _ _ = Nothing
+
+instance GEq ShapeType where
+ geq STZ STZ = Just Refl
+ geq (STC sht) (STC sht') | Just Refl <- geq sht sht' = Just Refl
+ geq _ _ = Nothing
+
+shapeType :: Shape sh -> ShapeType sh
+shapeType Z = STZ
+shapeType (_ :. sh) = STC (shapeType sh)
+
+shapeType' :: Shape sh -> Type sh
+shapeType' Z = TNil
+shapeType' (_ :. sh) = TPair TInt (shapeType' sh)
+
+shapeTypeType :: ShapeType sh -> Type sh
+shapeTypeType STZ = TNil
+shapeTypeType (STC sht) = TPair TInt (shapeTypeType sht)
+
+literalType :: Literal a -> Type a
+literalType LInt{} = TInt
+literalType LBool{} = TBool
+literalType LDouble{} = TDouble
+literalType (LArray (Array sh t _)) = TArray (shapeType sh) t
+literalType (LShape sh) = shapeType' sh
+literalType LNil{} = TNil
+literalType (LPair a b) = TPair (literalType a) (literalType b)
+
+constType :: Constant a -> Type a
+constType CAddI = TFun (TPair TInt TInt) TInt
+constType CSubI = TFun (TPair TInt TInt) TInt
+constType CMulI = TFun (TPair TInt TInt) TInt
+constType CDivI = TFun (TPair TInt TInt) TInt
+constType CAddF = TFun (TPair TDouble TDouble) TDouble
+constType CSubF = TFun (TPair TDouble TDouble) TDouble
+constType CMulF = TFun (TPair TDouble TDouble) TDouble
+constType CDivF = TFun (TPair TDouble TDouble) TDouble
+constType CLog = TFun TDouble TDouble
+constType CExp = TFun TDouble TDouble
+constType CtoF = TFun TInt TDouble
+constType CRound = TFun TDouble TInt
+constType CLtI = TFun (TPair TInt TInt) TBool
+constType CLtF = TFun (TPair TDouble TDouble) TBool
+constType (CEq t) = TFun (TPair t t) TBool
+constType CAnd = TFun (TPair TBool TBool) TBool
+constType COr = TFun (TPair TBool TBool) TBool
+constType CNot = TFun TBool TBool
+
+typeof :: Exp env a -> Type a
+typeof (App e _) = let TFun _ t = typeof e in t
+typeof (Lam t e) = TFun t (typeof e)
+typeof (Var t _) = t
+typeof (Let _ e) = typeof e
+typeof (Lit l) = literalType l
+typeof (Cond _ e _) = typeof e
+typeof (Const c) = constType c
+typeof (Pair e1 e2) = TPair (typeof e1) (typeof e2)
+typeof (Fst e) = let TPair t _ = typeof e in t
+typeof (Snd e) = let TPair _ t = typeof e in t
+typeof (Build sht _ e) = let TFun _ t = typeof e in TArray sht t
+typeof (Ifold _ _ e _) = typeof e
+typeof (Index e _) = let TArray _ t = typeof e in t
+typeof (Shape e) = let TArray sht _ = typeof e in shapeTypeType sht
+
+data Has c a where
+ Has :: c a => Has c a
+
+typeHasShow :: Type a -> Maybe (Has Show a)
+typeHasShow TInt = Just Has
+typeHasShow TBool = Just Has
+typeHasShow TDouble = Just Has
+typeHasShow TArray{} = Just Has
+typeHasShow TNil = Just Has
+typeHasShow (TPair a b)
+ | Just Has <- typeHasShow a
+ , Just Has <- typeHasShow b
+ = Just Has
+ | otherwise
+ = Nothing
+typeHasShow TFun{} = Nothing
+
+typeHasEq :: Type a -> Maybe (Has Eq a)
+typeHasEq TInt = Just Has
+typeHasEq TBool = Just Has
+typeHasEq TDouble = Just Has
+typeHasEq (TArray _ t)
+ | Just Has <- typeHasEq t
+ = Just Has
+ | otherwise
+ = Nothing
+typeHasEq TNil = Just Has
+typeHasEq (TPair a b)
+ | Just Has <- typeHasEq a
+ , Just Has <- typeHasEq b
+ = Just Has
+ | otherwise
+ = Nothing
+typeHasEq TFun{} = Nothing