diff options
Diffstat (limited to 'AST.hs')
-rw-r--r-- | AST.hs | 58 |
1 files changed, 49 insertions, 9 deletions
@@ -1,6 +1,8 @@ {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeOperators #-} module AST where @@ -26,6 +28,7 @@ data Exp env a where Ifold :: ShapeType sh -> Exp env ((s, sh) -> s) -> Exp env s -> Exp env sh -> Exp env s Index :: Exp env (Array sh a) -> Exp env sh -> Exp env a Shape :: Exp env (Array sh a) -> Exp env sh + Undef :: Type a -> Exp env a data Constant a where CAddI :: Constant ((Int, Int) -> Int) @@ -42,6 +45,7 @@ data Constant a where CRound :: Constant (Double -> Int) CLtI :: Constant ((Int, Int) -> Bool) + CLeI :: Constant ((Int, Int) -> Bool) CLtF :: Constant ((Double, Double) -> Bool) CEq :: Type a -> Constant ((a, a) -> Bool) CAnd :: Constant ((Bool, Bool) -> Bool) @@ -72,11 +76,11 @@ data Literal a where data Shape sh where Z :: Shape () - (:.) :: Int -> Shape sh -> Shape (Int, sh) + (:.) :: Shape sh -> Int -> Shape (sh, Int) data ShapeType sh where STZ :: ShapeType () - STC :: ShapeType sh -> ShapeType (Int, sh) + STC :: ShapeType sh -> ShapeType (sh, Int) data Array sh a where Array :: Shape sh -> Type a -> Vector a -> Array sh a @@ -86,15 +90,19 @@ deriving instance Show (Constant a) deriving instance Show (Type a) deriving instance Show (Idx env a) deriving instance Show (Literal a) -deriving instance Show (Shape a) deriving instance Show (ShapeType a) +instance Show (Shape sh) where + showsPrec _ Z = showString "Z" + showsPrec p (sh :. n) = showParen (p > 0) $ + showsPrec 10 sh . showString " :. " . shows n + instance Show (Array sh a) where showsPrec p (Array sh t v) = showParen (p > 10) $ showString "Array " - . showsPrec 11 sh - . showsPrec 11 t + . showsPrec 11 sh . showString " " + . showsPrec 11 t . showString " " . (case typeHasShow t of Just Has -> showsPrec 11 v Nothing -> showString ("[_ * " ++ show (V.length v) ++ "]")) @@ -134,6 +142,8 @@ instance GEq (Exp env) where geq Index{} _ = Nothing geq (Shape a) (Shape a') | Just Refl <- geq a a' = Just Refl geq Shape{} _ = Nothing + geq (Undef a) (Undef a') | Just Refl <- geq a a' = Just Refl + geq Undef{} _ = Nothing instance GEq Constant where geq CAddI CAddI = Just Refl ; geq CAddI _ = Nothing @@ -149,6 +159,7 @@ instance GEq Constant where geq CtoF CtoF = Just Refl ; geq CtoF _ = Nothing geq CRound CRound = Just Refl ; geq CRound _ = Nothing geq CLtI CLtI = Just Refl ; geq CLtI _ = Nothing + geq CLeI CLeI = Just Refl ; geq CLeI _ = Nothing geq CLtF CLtF = Just Refl ; geq CLtF _ = Nothing geq (CEq t) (CEq t') | Just Refl <- geq t t' = Just Refl ; geq CEq{} _ = Nothing geq CAnd CAnd = Just Refl ; geq CAnd _ = Nothing @@ -187,7 +198,7 @@ instance GEq Literal where instance GEq Shape where geq Z Z = Just Refl - geq (n :. sh) (n' :. sh') | n == n', Just Refl <- geq sh sh' = Just Refl + geq (sh :. n) (sh' :. n') | n == n', Just Refl <- geq sh sh' = Just Refl geq _ _ = Nothing instance GEq ShapeType where @@ -195,17 +206,36 @@ instance GEq ShapeType where geq (STC sht) (STC sht') | Just Refl <- geq sht sht' = Just Refl geq _ _ = Nothing +-- Requires that the given term is neither 'Lam' nor 'Let'. +recurseMon :: Monoid s => (forall t. Exp env t -> s) -> Exp env a -> s +recurseMon f = \case + App a b -> f a <> f b + Lam _ _ -> error "recurseMon: Given Lam" + Var _ _ -> mempty + Let _ _ -> error "recurseMon: Given Let" + Lit _ -> mempty + Cond a b c -> f a <> f b <> f c + Const _ -> mempty + Pair a b -> f a <> f b + Fst a -> f a + Snd a -> f a + Build _ a b -> f a <> f b + Ifold _ a b c -> f a <> f b <> f c + Index a b -> f a <> f b + Shape a -> f a + Undef _ -> mempty + shapeType :: Shape sh -> ShapeType sh shapeType Z = STZ -shapeType (_ :. sh) = STC (shapeType sh) +shapeType (sh :. _) = STC (shapeType sh) shapeType' :: Shape sh -> Type sh shapeType' Z = TNil -shapeType' (_ :. sh) = TPair TInt (shapeType' sh) +shapeType' (sh :. _) = TPair (shapeType' sh) TInt shapeTypeType :: ShapeType sh -> Type sh shapeTypeType STZ = TNil -shapeTypeType (STC sht) = TPair TInt (shapeTypeType sht) +shapeTypeType (STC sht) = TPair (shapeTypeType sht) TInt literalType :: Literal a -> Type a literalType LInt{} = TInt @@ -230,6 +260,7 @@ constType CExp = TFun TDouble TDouble constType CtoF = TFun TInt TDouble constType CRound = TFun TDouble TInt constType CLtI = TFun (TPair TInt TInt) TBool +constType CLeI = TFun (TPair TInt TInt) TBool constType CLtF = TFun (TPair TDouble TDouble) TBool constType (CEq t) = TFun (TPair t t) TBool constType CAnd = TFun (TPair TBool TBool) TBool @@ -251,6 +282,15 @@ typeof (Build sht _ e) = let TFun _ t = typeof e in TArray sht t typeof (Ifold _ _ e _) = typeof e typeof (Index e _) = let TArray _ t = typeof e in t typeof (Shape e) = let TArray sht _ = typeof e in shapeTypeType sht +typeof (Undef t) = t + +idxToInt :: Idx env a -> Int +idxToInt Zero = 0 +idxToInt (Succ i) = idxToInt i + 1 + +shtToInt :: ShapeType sh -> Int +shtToInt STZ = 0 +shtToInt (STC sht) = shtToInt sht + 1 data Has c a where Has :: c a => Has c a |