summaryrefslogtreecommitdiff
path: root/src/AST.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST.hs')
-rw-r--r--src/AST.hs49
1 files changed, 33 insertions, 16 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 785e34a..2132bc6 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -20,6 +20,7 @@ import Data.Kind (Type)
import Data.Int
import Data.Type.Equality
+import Array
import AST.Env
import AST.Weaken
import Data
@@ -91,6 +92,13 @@ type family ScalRep t where
ScalRep TF64 = Double
ScalRep TBool = Bool
+type family ScalIsNumeric t where
+ ScalIsNumeric TI32 = True
+ ScalIsNumeric TI64 = True
+ ScalIsNumeric TF32 = True
+ ScalIsNumeric TF64 = True
+ ScalIsNumeric TBool = False
+
-- | This index is flipped around from the usual direction: the smallest index
-- is at the _heart_ of the nesting, not at the outside. The outermost layer
-- indexes into the _outer_ dimension of the type @t@. This makes indices into
@@ -128,11 +136,13 @@ data Expr x env t where
ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c
-- array operations
+ EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t))
EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t)
EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t)
- EFold1 :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
+ EFold1Inner :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
+ ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t)
- -- EReplicate :: x (TArr (S n) t) -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) -- TODO: unused
+ EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t)
-- expression operations
EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t)
@@ -160,13 +170,16 @@ type family Tup env where
Tup '[] = TNil
Tup (t : ts) = TPair (Tup ts) t
+mkTup :: f TNil -> (forall a b. f a -> f b -> f (TPair a b))
+ -> SList f list -> f (Tup list)
+mkTup nil _ SNil = nil
+mkTup nil pair (e `SCons` es) = pair (mkTup nil pair es) e
+
tTup :: SList STy env -> STy (Tup env)
-tTup SNil = STNil
-tTup (SCons t ts) = STPair (tTup ts) t
+tTup = mkTup STNil STPair
eTup :: SList (Ex env) list -> Ex env (Tup list)
-eTup SNil = ENil ext
-eTup (e `SCons` es) = EPair ext (eTup es) e
+eTup = mkTup (ENil ext) (EPair ext)
type family InvTup core env where
InvTup core '[] = core
@@ -174,12 +187,12 @@ type family InvTup core env where
type SOp :: Ty -> Ty -> Type
data SOp a t where
- OAdd :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
- OMul :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
- ONeg :: SScalTy a -> SOp (TScal a) (TScal a)
- OLt :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
- OLe :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
- OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
+ OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
+ OMul :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
+ ONeg :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
+ OLt :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
+ OLe :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
+ OEq :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
ONot :: SOp (TScal TBool) (TScal TBool)
OIf :: SOp (TScal TBool) (TEither TNil TNil)
deriving instance Show (SOp a t)
@@ -208,11 +221,13 @@ typeOf = \case
EInr _ t1 e -> STEither t1 (typeOf e)
ECase _ _ a _ -> typeOf a
+ EConstArr _ n t _ -> STArr n (STScal t)
EBuild1 _ _ e -> STArr (SS SZ) (typeOf e)
EBuild _ n _ e -> STArr n (typeOf e)
- EFold1 _ _ e | STArr (SS n) t <- typeOf e -> STArr n t
+ EFold1Inner _ _ e | STArr (SS n) t <- typeOf e -> STArr n t
+ ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
EUnit _ e -> STArr SZ (typeOf e)
- -- EReplicate _ e | STArr n t <- typeOf e -> STArr (SS n) t
+ EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t
EConst _ t _ -> STScal t
EIdx0 _ e | STArr _ t <- typeOf e -> t
@@ -273,11 +288,13 @@ subst' f w = \case
EInl x t e -> EInl x t (subst' f w e)
EInr x t e -> EInr x t (subst' f w e)
ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b)
+ EConstArr x n t a -> EConstArr x n t a
EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b)
EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b)
- EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b)
+ EFold1Inner x a b -> EFold1Inner x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b)
+ ESum1Inner x e -> ESum1Inner x (subst' f w e)
EUnit x e -> EUnit x (subst' f w e)
- -- EReplicate x e -> EReplicate x (subst' f w e)
+ EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b)
EConst x t v -> EConst x t v
EIdx0 x e -> EIdx0 x (subst' f w e)
EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)