summaryrefslogtreecommitdiff
path: root/src/AST
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST')
-rw-r--r--src/AST/Count.hs3
-rw-r--r--src/AST/Pretty.hs7
-rw-r--r--src/AST/Types.hs95
3 files changed, 105 insertions, 0 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index 6a00e83..40a46f6 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -46,6 +46,7 @@ Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2)
-- | This code is executed many times
scaleMany :: Occ -> Occ
+scaleMany (Occ l Zero) = Occ l Zero
scaleMany (Occ l _) = Occ l Many
occCount :: Idx env a -> Expr x env t -> Occ
@@ -124,6 +125,8 @@ occCountGeneral onehot unpush alter many = go WId
EOp _ _ e -> re e
EWith a b -> re a <> re1 b
EAccum _ a b e -> re a <> re b <> re e
+ EZero _ -> mempty
+ EPlus _ a b -> re a <> re b
EError{} -> mempty
where
re :: Monoid (r env') => Expr x env' t'' -> r env'
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 2ce883b..f5e681a 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -181,6 +181,13 @@ ppExpr' d val = \case
return $ showParen (d > 10) $
showString ("accum " ++ show (unSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3'
+ EZero _ -> return $ showString "zero"
+
+ EPlus _ a b -> do
+ a' <- ppExpr' 11 val a
+ b' <- ppExpr' 11 val b
+ return $ showParen (d > 10) $ showString "plus " . a' . showString " " . b'
+
EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s)
ppExprLet :: Int -> SVal env -> Expr x env t -> M ShowS
diff --git a/src/AST/Types.hs b/src/AST/Types.hs
new file mode 100644
index 0000000..a3e5080
--- /dev/null
+++ b/src/AST/Types.hs
@@ -0,0 +1,95 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeFamilies #-}
+module AST.Types where
+
+import Data.Int (Int32, Int64)
+import Data.Kind (Type)
+import Data.Type.Equality
+
+import Data
+
+
+data Ty
+ = TNil
+ | TPair Ty Ty
+ | TEither Ty Ty
+ | TMaybe Ty
+ | TArr Nat Ty -- ^ rank, element type
+ | TScal ScalTy
+ | TAccum Ty
+ deriving (Show, Eq, Ord)
+
+data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
+ deriving (Show, Eq, Ord)
+
+type STy :: Ty -> Type
+data STy t where
+ STNil :: STy TNil
+ STPair :: STy a -> STy b -> STy (TPair a b)
+ STEither :: STy a -> STy b -> STy (TEither a b)
+ STMaybe :: STy a -> STy (TMaybe a)
+ STArr :: SNat n -> STy t -> STy (TArr n t)
+ STScal :: SScalTy t -> STy (TScal t)
+ STAccum :: STy t -> STy (TAccum t)
+deriving instance Show (STy t)
+
+instance TestEquality STy where
+ testEquality STNil STNil = Just Refl
+ testEquality STNil _ = Nothing
+ testEquality (STPair a b) (STPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
+ testEquality STPair{} _ = Nothing
+ testEquality (STEither a b) (STEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
+ testEquality STEither{} _ = Nothing
+ testEquality (STMaybe a) (STMaybe a') | Just Refl <- testEquality a a' = Just Refl
+ testEquality STMaybe{} _ = Nothing
+ testEquality (STArr a b) (STArr a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
+ testEquality STArr{} _ = Nothing
+ testEquality (STScal a) (STScal a') | Just Refl <- testEquality a a' = Just Refl
+ testEquality STScal{} _ = Nothing
+ testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl
+ testEquality STAccum{} _ = Nothing
+
+data SScalTy t where
+ STI32 :: SScalTy TI32
+ STI64 :: SScalTy TI64
+ STF32 :: SScalTy TF32
+ STF64 :: SScalTy TF64
+ STBool :: SScalTy TBool
+deriving instance Show (SScalTy t)
+
+instance TestEquality SScalTy where
+ testEquality STI32 STI32 = Just Refl
+ testEquality STI64 STI64 = Just Refl
+ testEquality STF32 STF32 = Just Refl
+ testEquality STF64 STF64 = Just Refl
+ testEquality STBool STBool = Just Refl
+ testEquality _ _ = Nothing
+
+scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t))
+scalRepIsShow STI32 = Dict
+scalRepIsShow STI64 = Dict
+scalRepIsShow STF32 = Dict
+scalRepIsShow STF64 = Dict
+scalRepIsShow STBool = Dict
+
+type TIx = TScal TI64
+
+tIx :: STy TIx
+tIx = STScal STI64
+
+type family ScalRep t where
+ ScalRep TI32 = Int32
+ ScalRep TI64 = Int64
+ ScalRep TF32 = Float
+ ScalRep TF64 = Double
+ ScalRep TBool = Bool
+
+type family ScalIsNumeric t where
+ ScalIsNumeric TI32 = True
+ ScalIsNumeric TI64 = True
+ ScalIsNumeric TF32 = True
+ ScalIsNumeric TF64 = True
+ ScalIsNumeric TBool = False