{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeOperators #-} module AST where import Data.GADT.Compare import Data.Type.Equality import qualified Data.Vector as V import Data.Vector (Vector) data Exp env a where App :: Exp env (a -> b) -> Exp env a -> Exp env b Lam :: Type t -> Exp (t ': env) a -> Exp env (t -> a) Var :: Type a -> Idx env a -> Exp env a Let :: Exp env t -> Exp (t ': env) a -> Exp env a Lit :: Literal a -> Exp env a Cond :: Exp env Bool -> Exp env a -> Exp env a -> Exp env a Const :: Constant a -> Exp env a Pair :: Exp env a -> Exp env b -> Exp env (a, b) Fst :: Exp env (a, b) -> Exp env a Snd :: Exp env (a, b) -> Exp env b Build :: ShapeType sh -> Exp env sh -> Exp env (sh -> a) -> Exp env (Array sh a) Ifold :: ShapeType sh -> Exp env ((s, sh) -> s) -> Exp env s -> Exp env sh -> Exp env s Index :: Exp env (Array sh a) -> Exp env sh -> Exp env a Shape :: Exp env (Array sh a) -> Exp env sh Undef :: Type a -> Exp env a data Constant a where CAddI :: Constant ((Int, Int) -> Int) CSubI :: Constant ((Int, Int) -> Int) CMulI :: Constant ((Int, Int) -> Int) CDivI :: Constant ((Int, Int) -> Int) CAddF :: Constant ((Double, Double) -> Double) CSubF :: Constant ((Double, Double) -> Double) CMulF :: Constant ((Double, Double) -> Double) CDivF :: Constant ((Double, Double) -> Double) CLog :: Constant (Double -> Double) CExp :: Constant (Double -> Double) CtoF :: Constant (Int -> Double) CRound :: Constant (Double -> Int) CLtI :: Constant ((Int, Int) -> Bool) CLeI :: Constant ((Int, Int) -> Bool) CLtF :: Constant ((Double, Double) -> Bool) CEq :: Type a -> Constant ((a, a) -> Bool) CAnd :: Constant ((Bool, Bool) -> Bool) COr :: Constant ((Bool, Bool) -> Bool) CNot :: Constant (Bool -> Bool) data Type a where TInt :: Type Int TBool :: Type Bool TDouble :: Type Double TArray :: ShapeType sh -> Type a -> Type (Array sh a) TNil :: Type () TPair :: Type a -> Type b -> Type (a, b) TFun :: Type a -> Type b -> Type (a -> b) data Idx env a where Zero :: Idx (a ': env) a Succ :: Idx env a -> Idx (t ': env) a data Literal a where LInt :: Int -> Literal Int LBool :: Bool -> Literal Bool LDouble :: Double -> Literal Double LArray :: Array sh a -> Literal (Array sh a) LShape :: Shape sh -> Literal sh LNil :: Literal () LPair :: Literal a -> Literal b -> Literal (a, b) data Shape sh where Z :: Shape () (:.) :: Shape sh -> Int -> Shape (sh, Int) data ShapeType sh where STZ :: ShapeType () STC :: ShapeType sh -> ShapeType (sh, Int) data Array sh a where Array :: Shape sh -> Type a -> Vector a -> Array sh a deriving instance Show (Exp env a) deriving instance Show (Constant a) deriving instance Show (Type a) deriving instance Show (Idx env a) deriving instance Show (Literal a) deriving instance Show (ShapeType a) instance Show (Shape sh) where showsPrec _ Z = showString "Z" showsPrec p (sh :. n) = showParen (p > 0) $ showsPrec 10 sh . showString " :. " . shows n instance Show (Array sh a) where showsPrec p (Array sh t v) = showParen (p > 10) $ showString "Array " . showsPrec 11 sh . showString " " . showsPrec 11 t . showString " " . (case typeHasShow t of Just Has -> showsPrec 11 v Nothing -> showString ("[_ * " ++ show (V.length v) ++ "]")) deriving instance Eq (Type a) deriving instance Eq (Shape sh) deriving instance Eq (ShapeType sh) deriving instance Eq a => Eq (Array sh a) instance GEq (Exp env) where geq (App a b) (App a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl geq App{} _ = Nothing geq (Lam t e) (Lam t' e') | Just Refl <- geq t t', Just Refl <- geq e e' = Just Refl geq Lam{} _ = Nothing geq (Var t i) (Var t' i') | Just Refl <- geq t t', Just Refl <- geq i i' = Just Refl geq Var{} _ = Nothing geq (Let a b) (Let a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl geq Let{} _ = Nothing geq (Lit l) (Lit l') | Just Refl <- geq l l' = Just Refl geq Lit{} _ = Nothing geq (Cond a b c) (Cond a' b' c') | Just Refl <- geq a a', Just Refl <- geq b b', Just Refl <- geq c c' = Just Refl geq Cond{} _ = Nothing geq (Const c) (Const c') | Just Refl <- geq c c' = Just Refl geq Const{} _ = Nothing geq (Pair a b) (Pair a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl geq Pair{} _ = Nothing geq (Fst a) (Fst a') | Just Refl <- geq a a' = Just Refl geq Fst{} _ = Nothing geq (Snd a) (Snd a') | Just Refl <- geq a a' = Just Refl geq Snd{} _ = Nothing geq (Build t a b) (Build t' a' b') | Just Refl <- geq t t', Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl geq Build{} _ = Nothing geq (Ifold t a b c) (Ifold t' a' b' c') | Just Refl <- geq t t', Just Refl <- geq a a', Just Refl <- geq b b' , Just Refl <- geq c c' = Just Refl geq Ifold{} _ = Nothing geq (Index a b) (Index a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl geq Index{} _ = Nothing geq (Shape a) (Shape a') | Just Refl <- geq a a' = Just Refl geq Shape{} _ = Nothing geq (Undef a) (Undef a') | Just Refl <- geq a a' = Just Refl geq Undef{} _ = Nothing instance GEq Constant where geq CAddI CAddI = Just Refl ; geq CAddI _ = Nothing geq CSubI CSubI = Just Refl ; geq CSubI _ = Nothing geq CMulI CMulI = Just Refl ; geq CMulI _ = Nothing geq CDivI CDivI = Just Refl ; geq CDivI _ = Nothing geq CAddF CAddF = Just Refl ; geq CAddF _ = Nothing geq CSubF CSubF = Just Refl ; geq CSubF _ = Nothing geq CMulF CMulF = Just Refl ; geq CMulF _ = Nothing geq CDivF CDivF = Just Refl ; geq CDivF _ = Nothing geq CLog CLog = Just Refl ; geq CLog _ = Nothing geq CExp CExp = Just Refl ; geq CExp _ = Nothing geq CtoF CtoF = Just Refl ; geq CtoF _ = Nothing geq CRound CRound = Just Refl ; geq CRound _ = Nothing geq CLtI CLtI = Just Refl ; geq CLtI _ = Nothing geq CLeI CLeI = Just Refl ; geq CLeI _ = Nothing geq CLtF CLtF = Just Refl ; geq CLtF _ = Nothing geq (CEq t) (CEq t') | Just Refl <- geq t t' = Just Refl ; geq CEq{} _ = Nothing geq CAnd CAnd = Just Refl ; geq CAnd _ = Nothing geq COr COr = Just Refl ; geq COr _ = Nothing geq CNot CNot = Just Refl ; geq CNot _ = Nothing instance GEq Type where geq TInt TInt = Just Refl ; geq TInt _ = Nothing geq TBool TBool = Just Refl ; geq TBool _ = Nothing geq TDouble TDouble = Just Refl ; geq TDouble _ = Nothing geq (TArray sht t) (TArray sht' t') | Just Refl <- geq sht sht', Just Refl <- geq t t' = Just Refl ; geq TArray{} _ = Nothing geq TNil TNil = Just Refl ; geq TNil _ = Nothing geq (TPair a b) (TPair a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl ; geq TPair{} _ = Nothing geq (TFun a b) (TFun a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl ; geq TFun{} _ = Nothing instance GEq (Idx env) where geq Zero Zero = Just Refl geq (Succ i) (Succ i') | Just Refl <- geq i i' = Just Refl geq _ _ = Nothing instance GEq Literal where geq (LInt a) (LInt a') | a == a' = Just Refl ; geq LInt{} _ = Nothing geq (LBool a) (LBool a') | a == a' = Just Refl ; geq LBool{} _ = Nothing geq (LDouble a) (LDouble a') | a == a' = Just Refl ; geq LDouble{} _ = Nothing geq (LArray (Array sht t v)) (LArray (Array sht' t' v')) | Just Refl <- geq sht sht' , Just Refl <- geq t t' = case typeHasEq t of Just Has | v == v' -> Just Refl | otherwise -> Nothing Nothing -> error "GEq Literal: Literal array of incomparable values" geq LArray{} _ = Nothing geq (LShape a) (LShape a') | Just Refl <- geq a a' = Just Refl ; geq LShape{} _ = Nothing geq LNil LNil = Just Refl ; geq LNil _ = Nothing geq (LPair a b) (LPair a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl ; geq LPair{} _ = Nothing instance GEq Shape where geq Z Z = Just Refl geq (sh :. n) (sh' :. n') | n == n', Just Refl <- geq sh sh' = Just Refl geq _ _ = Nothing instance GEq ShapeType where geq STZ STZ = Just Refl geq (STC sht) (STC sht') | Just Refl <- geq sht sht' = Just Refl geq _ _ = Nothing -- Requires that the given term is neither 'Lam' nor 'Let'. recurseMon :: Monoid s => (forall t. Exp env t -> s) -> Exp env a -> s recurseMon f = \case App a b -> f a <> f b Lam _ _ -> error "recurseMon: Given Lam" Var _ _ -> mempty Let _ _ -> error "recurseMon: Given Let" Lit _ -> mempty Cond a b c -> f a <> f b <> f c Const _ -> mempty Pair a b -> f a <> f b Fst a -> f a Snd a -> f a Build _ a b -> f a <> f b Ifold _ a b c -> f a <> f b <> f c Index a b -> f a <> f b Shape a -> f a Undef _ -> mempty shapeType :: Shape sh -> ShapeType sh shapeType Z = STZ shapeType (sh :. _) = STC (shapeType sh) shapeType' :: Shape sh -> Type sh shapeType' Z = TNil shapeType' (sh :. _) = TPair (shapeType' sh) TInt shapeTypeType :: ShapeType sh -> Type sh shapeTypeType STZ = TNil shapeTypeType (STC sht) = TPair (shapeTypeType sht) TInt literalType :: Literal a -> Type a literalType LInt{} = TInt literalType LBool{} = TBool literalType LDouble{} = TDouble literalType (LArray (Array sh t _)) = TArray (shapeType sh) t literalType (LShape sh) = shapeType' sh literalType LNil{} = TNil literalType (LPair a b) = TPair (literalType a) (literalType b) constType :: Constant a -> Type a constType CAddI = TFun (TPair TInt TInt) TInt constType CSubI = TFun (TPair TInt TInt) TInt constType CMulI = TFun (TPair TInt TInt) TInt constType CDivI = TFun (TPair TInt TInt) TInt constType CAddF = TFun (TPair TDouble TDouble) TDouble constType CSubF = TFun (TPair TDouble TDouble) TDouble constType CMulF = TFun (TPair TDouble TDouble) TDouble constType CDivF = TFun (TPair TDouble TDouble) TDouble constType CLog = TFun TDouble TDouble constType CExp = TFun TDouble TDouble constType CtoF = TFun TInt TDouble constType CRound = TFun TDouble TInt constType CLtI = TFun (TPair TInt TInt) TBool constType CLeI = TFun (TPair TInt TInt) TBool constType CLtF = TFun (TPair TDouble TDouble) TBool constType (CEq t) = TFun (TPair t t) TBool constType CAnd = TFun (TPair TBool TBool) TBool constType COr = TFun (TPair TBool TBool) TBool constType CNot = TFun TBool TBool typeof :: Exp env a -> Type a typeof (App e _) = let TFun _ t = typeof e in t typeof (Lam t e) = TFun t (typeof e) typeof (Var t _) = t typeof (Let _ e) = typeof e typeof (Lit l) = literalType l typeof (Cond _ e _) = typeof e typeof (Const c) = constType c typeof (Pair e1 e2) = TPair (typeof e1) (typeof e2) typeof (Fst e) = let TPair t _ = typeof e in t typeof (Snd e) = let TPair _ t = typeof e in t typeof (Build sht _ e) = let TFun _ t = typeof e in TArray sht t typeof (Ifold _ _ e _) = typeof e typeof (Index e _) = let TArray _ t = typeof e in t typeof (Shape e) = let TArray sht _ = typeof e in shapeTypeType sht typeof (Undef t) = t idxToInt :: Idx env a -> Int idxToInt Zero = 0 idxToInt (Succ i) = idxToInt i + 1 shtToInt :: ShapeType sh -> Int shtToInt STZ = 0 shtToInt (STC sht) = shtToInt sht + 1 data Has c a where Has :: c a => Has c a typeHasShow :: Type a -> Maybe (Has Show a) typeHasShow TInt = Just Has typeHasShow TBool = Just Has typeHasShow TDouble = Just Has typeHasShow TArray{} = Just Has typeHasShow TNil = Just Has typeHasShow (TPair a b) | Just Has <- typeHasShow a , Just Has <- typeHasShow b = Just Has | otherwise = Nothing typeHasShow TFun{} = Nothing typeHasEq :: Type a -> Maybe (Has Eq a) typeHasEq TInt = Just Has typeHasEq TBool = Just Has typeHasEq TDouble = Just Has typeHasEq (TArray _ t) | Just Has <- typeHasEq t = Just Has | otherwise = Nothing typeHasEq TNil = Just Has typeHasEq (TPair a b) | Just Has <- typeHasEq a , Just Has <- typeHasEq b = Just Has | otherwise = Nothing typeHasEq TFun{} = Nothing