summaryrefslogtreecommitdiff
path: root/src/AST.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST.hs')
-rw-r--r--src/AST.hs40
1 files changed, 40 insertions, 0 deletions
diff --git a/src/AST.hs b/src/AST.hs
index f389467..785e34a 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -18,6 +18,7 @@ module AST (module AST, module AST.Weaken) where
import Data.Functor.Const
import Data.Kind (Type)
import Data.Int
+import Data.Type.Equality
import AST.Env
import AST.Weaken
@@ -46,6 +47,15 @@ data STy t where
STAccum :: STy t -> STy (TAccum t)
deriving instance Show (STy t)
+instance TestEquality STy where
+ testEquality STNil STNil = Just Refl
+ testEquality (STPair a b) (STPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
+ testEquality (STEither a b) (STEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
+ testEquality (STArr a b) (STArr a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
+ testEquality (STScal a) (STScal a') | Just Refl <- testEquality a a' = Just Refl
+ testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl
+ testEquality _ _ = Nothing
+
data SScalTy t where
STI32 :: SScalTy TI32
STI64 :: SScalTy TI64
@@ -54,6 +64,21 @@ data SScalTy t where
STBool :: SScalTy TBool
deriving instance Show (SScalTy t)
+instance TestEquality SScalTy where
+ testEquality STI32 STI32 = Just Refl
+ testEquality STI64 STI64 = Just Refl
+ testEquality STF32 STF32 = Just Refl
+ testEquality STF64 STF64 = Just Refl
+ testEquality STBool STBool = Just Refl
+ testEquality _ _ = Nothing
+
+scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t))
+scalRepIsShow STI32 = Dict
+scalRepIsShow STI64 = Dict
+scalRepIsShow STF32 = Dict
+scalRepIsShow STF64 = Dict
+scalRepIsShow STBool = Dict
+
type TIx = TScal TI64
tIx :: STy TIx
@@ -305,6 +330,21 @@ class KnownEnv env where knownEnv :: SList STy env
instance KnownEnv '[] where knownEnv = SNil
instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv
+styKnown :: STy t -> Dict (KnownTy t)
+styKnown STNil = Dict
+styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
+styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
+styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict
+styKnown (STScal t) | Dict <- sscaltyKnown t = Dict
+styKnown (STAccum t) | Dict <- styKnown t = Dict
+
+sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t)
+sscaltyKnown STI32 = Dict
+sscaltyKnown STI64 = Dict
+sscaltyKnown STF32 = Dict
+sscaltyKnown STF64 = Dict
+sscaltyKnown STBool = Dict
+
ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t)
ebuildUp1 n sh size f =
EBuild ext (SS n) (EPair ext sh size) $