From 0fefe4822c65bde81ec4c0da1b5b32a9b411ca79 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 24 Jun 2021 23:14:54 +0200 Subject: Initial --- AST.hs | 288 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 288 insertions(+) create mode 100644 AST.hs (limited to 'AST.hs') diff --git a/AST.hs b/AST.hs new file mode 100644 index 0000000..7e9c69c --- /dev/null +++ b/AST.hs @@ -0,0 +1,288 @@ +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeOperators #-} +module AST where + +import Data.GADT.Compare +import Data.Type.Equality +import qualified Data.Vector as V +import Data.Vector (Vector) + + +data Exp env a where + App :: Exp env (a -> b) -> Exp env a -> Exp env b + Lam :: Type t -> Exp (t ': env) a -> Exp env (t -> a) + Var :: Type a -> Idx env a -> Exp env a + Let :: Exp env t -> Exp (t ': env) a -> Exp env a + Lit :: Literal a -> Exp env a + Cond :: Exp env Bool -> Exp env a -> Exp env a -> Exp env a + Const :: Constant a -> Exp env a + Pair :: Exp env a -> Exp env b -> Exp env (a, b) + Fst :: Exp env (a, b) -> Exp env a + Snd :: Exp env (a, b) -> Exp env b + Build :: ShapeType sh -> Exp env sh -> Exp env (sh -> a) -> Exp env (Array sh a) + Ifold :: ShapeType sh -> Exp env ((s, sh) -> s) -> Exp env s -> Exp env sh -> Exp env s + Index :: Exp env (Array sh a) -> Exp env sh -> Exp env a + Shape :: Exp env (Array sh a) -> Exp env sh + +data Constant a where + CAddI :: Constant ((Int, Int) -> Int) + CSubI :: Constant ((Int, Int) -> Int) + CMulI :: Constant ((Int, Int) -> Int) + CDivI :: Constant ((Int, Int) -> Int) + CAddF :: Constant ((Double, Double) -> Double) + CSubF :: Constant ((Double, Double) -> Double) + CMulF :: Constant ((Double, Double) -> Double) + CDivF :: Constant ((Double, Double) -> Double) + CLog :: Constant (Double -> Double) + CExp :: Constant (Double -> Double) + CtoF :: Constant (Int -> Double) + CRound :: Constant (Double -> Int) + + CLtI :: Constant ((Int, Int) -> Bool) + CLtF :: Constant ((Double, Double) -> Bool) + CEq :: Type a -> Constant ((a, a) -> Bool) + CAnd :: Constant ((Bool, Bool) -> Bool) + COr :: Constant ((Bool, Bool) -> Bool) + CNot :: Constant (Bool -> Bool) + +data Type a where + TInt :: Type Int + TBool :: Type Bool + TDouble :: Type Double + TArray :: ShapeType sh -> Type a -> Type (Array sh a) + TNil :: Type () + TPair :: Type a -> Type b -> Type (a, b) + TFun :: Type a -> Type b -> Type (a -> b) + +data Idx env a where + Zero :: Idx (a ': env) a + Succ :: Idx env a -> Idx (t ': env) a + +data Literal a where + LInt :: Int -> Literal Int + LBool :: Bool -> Literal Bool + LDouble :: Double -> Literal Double + LArray :: Array sh a -> Literal (Array sh a) + LShape :: Shape sh -> Literal sh + LNil :: Literal () + LPair :: Literal a -> Literal b -> Literal (a, b) + +data Shape sh where + Z :: Shape () + (:.) :: Int -> Shape sh -> Shape (Int, sh) + +data ShapeType sh where + STZ :: ShapeType () + STC :: ShapeType sh -> ShapeType (Int, sh) + +data Array sh a where + Array :: Shape sh -> Type a -> Vector a -> Array sh a + +deriving instance Show (Exp env a) +deriving instance Show (Constant a) +deriving instance Show (Type a) +deriving instance Show (Idx env a) +deriving instance Show (Literal a) +deriving instance Show (Shape a) +deriving instance Show (ShapeType a) + +instance Show (Array sh a) where + showsPrec p (Array sh t v) = + showParen (p > 10) $ + showString "Array " + . showsPrec 11 sh + . showsPrec 11 t + . (case typeHasShow t of + Just Has -> showsPrec 11 v + Nothing -> showString ("[_ * " ++ show (V.length v) ++ "]")) + +deriving instance Eq (Type a) +deriving instance Eq (Shape sh) +deriving instance Eq (ShapeType sh) +deriving instance Eq a => Eq (Array sh a) + +instance GEq (Exp env) where + geq (App a b) (App a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl + geq App{} _ = Nothing + geq (Lam t e) (Lam t' e') | Just Refl <- geq t t', Just Refl <- geq e e' = Just Refl + geq Lam{} _ = Nothing + geq (Var t i) (Var t' i') | Just Refl <- geq t t', Just Refl <- geq i i' = Just Refl + geq Var{} _ = Nothing + geq (Let a b) (Let a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl + geq Let{} _ = Nothing + geq (Lit l) (Lit l') | Just Refl <- geq l l' = Just Refl + geq Lit{} _ = Nothing + geq (Cond a b c) (Cond a' b' c') | Just Refl <- geq a a', Just Refl <- geq b b', Just Refl <- geq c c' = Just Refl + geq Cond{} _ = Nothing + geq (Const c) (Const c') | Just Refl <- geq c c' = Just Refl + geq Const{} _ = Nothing + geq (Pair a b) (Pair a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl + geq Pair{} _ = Nothing + geq (Fst a) (Fst a') | Just Refl <- geq a a' = Just Refl + geq Fst{} _ = Nothing + geq (Snd a) (Snd a') | Just Refl <- geq a a' = Just Refl + geq Snd{} _ = Nothing + geq (Build t a b) (Build t' a' b') | Just Refl <- geq t t', Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl + geq Build{} _ = Nothing + geq (Ifold t a b c) (Ifold t' a' b' c') | Just Refl <- geq t t', Just Refl <- geq a a', Just Refl <- geq b b' , Just Refl <- geq c c' + = Just Refl + geq Ifold{} _ = Nothing + geq (Index a b) (Index a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl + geq Index{} _ = Nothing + geq (Shape a) (Shape a') | Just Refl <- geq a a' = Just Refl + geq Shape{} _ = Nothing + +instance GEq Constant where + geq CAddI CAddI = Just Refl ; geq CAddI _ = Nothing + geq CSubI CSubI = Just Refl ; geq CSubI _ = Nothing + geq CMulI CMulI = Just Refl ; geq CMulI _ = Nothing + geq CDivI CDivI = Just Refl ; geq CDivI _ = Nothing + geq CAddF CAddF = Just Refl ; geq CAddF _ = Nothing + geq CSubF CSubF = Just Refl ; geq CSubF _ = Nothing + geq CMulF CMulF = Just Refl ; geq CMulF _ = Nothing + geq CDivF CDivF = Just Refl ; geq CDivF _ = Nothing + geq CLog CLog = Just Refl ; geq CLog _ = Nothing + geq CExp CExp = Just Refl ; geq CExp _ = Nothing + geq CtoF CtoF = Just Refl ; geq CtoF _ = Nothing + geq CRound CRound = Just Refl ; geq CRound _ = Nothing + geq CLtI CLtI = Just Refl ; geq CLtI _ = Nothing + geq CLtF CLtF = Just Refl ; geq CLtF _ = Nothing + geq (CEq t) (CEq t') | Just Refl <- geq t t' = Just Refl ; geq CEq{} _ = Nothing + geq CAnd CAnd = Just Refl ; geq CAnd _ = Nothing + geq COr COr = Just Refl ; geq COr _ = Nothing + geq CNot CNot = Just Refl ; geq CNot _ = Nothing + +instance GEq Type where + geq TInt TInt = Just Refl ; geq TInt _ = Nothing + geq TBool TBool = Just Refl ; geq TBool _ = Nothing + geq TDouble TDouble = Just Refl ; geq TDouble _ = Nothing + geq (TArray sht t) (TArray sht' t') | Just Refl <- geq sht sht', Just Refl <- geq t t' = Just Refl ; geq TArray{} _ = Nothing + geq TNil TNil = Just Refl ; geq TNil _ = Nothing + geq (TPair a b) (TPair a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl ; geq TPair{} _ = Nothing + geq (TFun a b) (TFun a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl ; geq TFun{} _ = Nothing + +instance GEq (Idx env) where + geq Zero Zero = Just Refl + geq (Succ i) (Succ i') | Just Refl <- geq i i' = Just Refl + geq _ _ = Nothing + +instance GEq Literal where + geq (LInt a) (LInt a') | a == a' = Just Refl ; geq LInt{} _ = Nothing + geq (LBool a) (LBool a') | a == a' = Just Refl ; geq LBool{} _ = Nothing + geq (LDouble a) (LDouble a') | a == a' = Just Refl ; geq LDouble{} _ = Nothing + geq (LArray (Array sht t v)) (LArray (Array sht' t' v')) + | Just Refl <- geq sht sht' + , Just Refl <- geq t t' + = case typeHasEq t of + Just Has | v == v' -> Just Refl + | otherwise -> Nothing + Nothing -> error "GEq Literal: Literal array of incomparable values" + geq LArray{} _ = Nothing + geq (LShape a) (LShape a') | Just Refl <- geq a a' = Just Refl ; geq LShape{} _ = Nothing + geq LNil LNil = Just Refl ; geq LNil _ = Nothing + geq (LPair a b) (LPair a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl ; geq LPair{} _ = Nothing + +instance GEq Shape where + geq Z Z = Just Refl + geq (n :. sh) (n' :. sh') | n == n', Just Refl <- geq sh sh' = Just Refl + geq _ _ = Nothing + +instance GEq ShapeType where + geq STZ STZ = Just Refl + geq (STC sht) (STC sht') | Just Refl <- geq sht sht' = Just Refl + geq _ _ = Nothing + +shapeType :: Shape sh -> ShapeType sh +shapeType Z = STZ +shapeType (_ :. sh) = STC (shapeType sh) + +shapeType' :: Shape sh -> Type sh +shapeType' Z = TNil +shapeType' (_ :. sh) = TPair TInt (shapeType' sh) + +shapeTypeType :: ShapeType sh -> Type sh +shapeTypeType STZ = TNil +shapeTypeType (STC sht) = TPair TInt (shapeTypeType sht) + +literalType :: Literal a -> Type a +literalType LInt{} = TInt +literalType LBool{} = TBool +literalType LDouble{} = TDouble +literalType (LArray (Array sh t _)) = TArray (shapeType sh) t +literalType (LShape sh) = shapeType' sh +literalType LNil{} = TNil +literalType (LPair a b) = TPair (literalType a) (literalType b) + +constType :: Constant a -> Type a +constType CAddI = TFun (TPair TInt TInt) TInt +constType CSubI = TFun (TPair TInt TInt) TInt +constType CMulI = TFun (TPair TInt TInt) TInt +constType CDivI = TFun (TPair TInt TInt) TInt +constType CAddF = TFun (TPair TDouble TDouble) TDouble +constType CSubF = TFun (TPair TDouble TDouble) TDouble +constType CMulF = TFun (TPair TDouble TDouble) TDouble +constType CDivF = TFun (TPair TDouble TDouble) TDouble +constType CLog = TFun TDouble TDouble +constType CExp = TFun TDouble TDouble +constType CtoF = TFun TInt TDouble +constType CRound = TFun TDouble TInt +constType CLtI = TFun (TPair TInt TInt) TBool +constType CLtF = TFun (TPair TDouble TDouble) TBool +constType (CEq t) = TFun (TPair t t) TBool +constType CAnd = TFun (TPair TBool TBool) TBool +constType COr = TFun (TPair TBool TBool) TBool +constType CNot = TFun TBool TBool + +typeof :: Exp env a -> Type a +typeof (App e _) = let TFun _ t = typeof e in t +typeof (Lam t e) = TFun t (typeof e) +typeof (Var t _) = t +typeof (Let _ e) = typeof e +typeof (Lit l) = literalType l +typeof (Cond _ e _) = typeof e +typeof (Const c) = constType c +typeof (Pair e1 e2) = TPair (typeof e1) (typeof e2) +typeof (Fst e) = let TPair t _ = typeof e in t +typeof (Snd e) = let TPair _ t = typeof e in t +typeof (Build sht _ e) = let TFun _ t = typeof e in TArray sht t +typeof (Ifold _ _ e _) = typeof e +typeof (Index e _) = let TArray _ t = typeof e in t +typeof (Shape e) = let TArray sht _ = typeof e in shapeTypeType sht + +data Has c a where + Has :: c a => Has c a + +typeHasShow :: Type a -> Maybe (Has Show a) +typeHasShow TInt = Just Has +typeHasShow TBool = Just Has +typeHasShow TDouble = Just Has +typeHasShow TArray{} = Just Has +typeHasShow TNil = Just Has +typeHasShow (TPair a b) + | Just Has <- typeHasShow a + , Just Has <- typeHasShow b + = Just Has + | otherwise + = Nothing +typeHasShow TFun{} = Nothing + +typeHasEq :: Type a -> Maybe (Has Eq a) +typeHasEq TInt = Just Has +typeHasEq TBool = Just Has +typeHasEq TDouble = Just Has +typeHasEq (TArray _ t) + | Just Has <- typeHasEq t + = Just Has + | otherwise + = Nothing +typeHasEq TNil = Just Has +typeHasEq (TPair a b) + | Just Has <- typeHasEq a + , Just Has <- typeHasEq b + = Just Has + | otherwise + = Nothing +typeHasEq TFun{} = Nothing -- cgit v1.2.3-70-g09d2