aboutsummaryrefslogtreecommitdiff
path: root/AST.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2021-06-27 18:34:35 +0200
committerTom Smeding <tom@tomsmeding.com>2021-06-27 18:34:35 +0200
commitd4abcc3b2dfefbbcb7cd4a182eec64f1da42d951 (patch)
tree1ab301617043ac6df228ef617afa22633a01a671 /AST.hs
parent0fefe4822c65bde81ec4c0da1b5b32a9b411ca79 (diff)
Diffstat (limited to 'AST.hs')
-rw-r--r--AST.hs58
1 files changed, 49 insertions, 9 deletions
diff --git a/AST.hs b/AST.hs
index 7e9c69c..3e1d2f6 100644
--- a/AST.hs
+++ b/AST.hs
@@ -1,6 +1,8 @@
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
module AST where
@@ -26,6 +28,7 @@ data Exp env a where
Ifold :: ShapeType sh -> Exp env ((s, sh) -> s) -> Exp env s -> Exp env sh -> Exp env s
Index :: Exp env (Array sh a) -> Exp env sh -> Exp env a
Shape :: Exp env (Array sh a) -> Exp env sh
+ Undef :: Type a -> Exp env a
data Constant a where
CAddI :: Constant ((Int, Int) -> Int)
@@ -42,6 +45,7 @@ data Constant a where
CRound :: Constant (Double -> Int)
CLtI :: Constant ((Int, Int) -> Bool)
+ CLeI :: Constant ((Int, Int) -> Bool)
CLtF :: Constant ((Double, Double) -> Bool)
CEq :: Type a -> Constant ((a, a) -> Bool)
CAnd :: Constant ((Bool, Bool) -> Bool)
@@ -72,11 +76,11 @@ data Literal a where
data Shape sh where
Z :: Shape ()
- (:.) :: Int -> Shape sh -> Shape (Int, sh)
+ (:.) :: Shape sh -> Int -> Shape (sh, Int)
data ShapeType sh where
STZ :: ShapeType ()
- STC :: ShapeType sh -> ShapeType (Int, sh)
+ STC :: ShapeType sh -> ShapeType (sh, Int)
data Array sh a where
Array :: Shape sh -> Type a -> Vector a -> Array sh a
@@ -86,15 +90,19 @@ deriving instance Show (Constant a)
deriving instance Show (Type a)
deriving instance Show (Idx env a)
deriving instance Show (Literal a)
-deriving instance Show (Shape a)
deriving instance Show (ShapeType a)
+instance Show (Shape sh) where
+ showsPrec _ Z = showString "Z"
+ showsPrec p (sh :. n) = showParen (p > 0) $
+ showsPrec 10 sh . showString " :. " . shows n
+
instance Show (Array sh a) where
showsPrec p (Array sh t v) =
showParen (p > 10) $
showString "Array "
- . showsPrec 11 sh
- . showsPrec 11 t
+ . showsPrec 11 sh . showString " "
+ . showsPrec 11 t . showString " "
. (case typeHasShow t of
Just Has -> showsPrec 11 v
Nothing -> showString ("[_ * " ++ show (V.length v) ++ "]"))
@@ -134,6 +142,8 @@ instance GEq (Exp env) where
geq Index{} _ = Nothing
geq (Shape a) (Shape a') | Just Refl <- geq a a' = Just Refl
geq Shape{} _ = Nothing
+ geq (Undef a) (Undef a') | Just Refl <- geq a a' = Just Refl
+ geq Undef{} _ = Nothing
instance GEq Constant where
geq CAddI CAddI = Just Refl ; geq CAddI _ = Nothing
@@ -149,6 +159,7 @@ instance GEq Constant where
geq CtoF CtoF = Just Refl ; geq CtoF _ = Nothing
geq CRound CRound = Just Refl ; geq CRound _ = Nothing
geq CLtI CLtI = Just Refl ; geq CLtI _ = Nothing
+ geq CLeI CLeI = Just Refl ; geq CLeI _ = Nothing
geq CLtF CLtF = Just Refl ; geq CLtF _ = Nothing
geq (CEq t) (CEq t') | Just Refl <- geq t t' = Just Refl ; geq CEq{} _ = Nothing
geq CAnd CAnd = Just Refl ; geq CAnd _ = Nothing
@@ -187,7 +198,7 @@ instance GEq Literal where
instance GEq Shape where
geq Z Z = Just Refl
- geq (n :. sh) (n' :. sh') | n == n', Just Refl <- geq sh sh' = Just Refl
+ geq (sh :. n) (sh' :. n') | n == n', Just Refl <- geq sh sh' = Just Refl
geq _ _ = Nothing
instance GEq ShapeType where
@@ -195,17 +206,36 @@ instance GEq ShapeType where
geq (STC sht) (STC sht') | Just Refl <- geq sht sht' = Just Refl
geq _ _ = Nothing
+-- Requires that the given term is neither 'Lam' nor 'Let'.
+recurseMon :: Monoid s => (forall t. Exp env t -> s) -> Exp env a -> s
+recurseMon f = \case
+ App a b -> f a <> f b
+ Lam _ _ -> error "recurseMon: Given Lam"
+ Var _ _ -> mempty
+ Let _ _ -> error "recurseMon: Given Let"
+ Lit _ -> mempty
+ Cond a b c -> f a <> f b <> f c
+ Const _ -> mempty
+ Pair a b -> f a <> f b
+ Fst a -> f a
+ Snd a -> f a
+ Build _ a b -> f a <> f b
+ Ifold _ a b c -> f a <> f b <> f c
+ Index a b -> f a <> f b
+ Shape a -> f a
+ Undef _ -> mempty
+
shapeType :: Shape sh -> ShapeType sh
shapeType Z = STZ
-shapeType (_ :. sh) = STC (shapeType sh)
+shapeType (sh :. _) = STC (shapeType sh)
shapeType' :: Shape sh -> Type sh
shapeType' Z = TNil
-shapeType' (_ :. sh) = TPair TInt (shapeType' sh)
+shapeType' (sh :. _) = TPair (shapeType' sh) TInt
shapeTypeType :: ShapeType sh -> Type sh
shapeTypeType STZ = TNil
-shapeTypeType (STC sht) = TPair TInt (shapeTypeType sht)
+shapeTypeType (STC sht) = TPair (shapeTypeType sht) TInt
literalType :: Literal a -> Type a
literalType LInt{} = TInt
@@ -230,6 +260,7 @@ constType CExp = TFun TDouble TDouble
constType CtoF = TFun TInt TDouble
constType CRound = TFun TDouble TInt
constType CLtI = TFun (TPair TInt TInt) TBool
+constType CLeI = TFun (TPair TInt TInt) TBool
constType CLtF = TFun (TPair TDouble TDouble) TBool
constType (CEq t) = TFun (TPair t t) TBool
constType CAnd = TFun (TPair TBool TBool) TBool
@@ -251,6 +282,15 @@ typeof (Build sht _ e) = let TFun _ t = typeof e in TArray sht t
typeof (Ifold _ _ e _) = typeof e
typeof (Index e _) = let TArray _ t = typeof e in t
typeof (Shape e) = let TArray sht _ = typeof e in shapeTypeType sht
+typeof (Undef t) = t
+
+idxToInt :: Idx env a -> Int
+idxToInt Zero = 0
+idxToInt (Succ i) = idxToInt i + 1
+
+shtToInt :: ShapeType sh -> Int
+shtToInt STZ = 0
+shtToInt (STC sht) = shtToInt sht + 1
data Has c a where
Has :: c a => Has c a