summaryrefslogtreecommitdiff
path: root/src/AST.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST.hs')
-rw-r--r--src/AST.hs94
1 files changed, 49 insertions, 45 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 802ee2a..f389467 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -30,7 +30,7 @@ data Ty
| TEither Ty Ty
| TArr Nat Ty -- ^ rank, element type
| TScal ScalTy
- | TAccum Nat Ty -- ^ rank and element type of the array being accumulated to
+ | TAccum Ty
deriving (Show, Eq, Ord)
data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
@@ -43,7 +43,7 @@ data STy t where
STEither :: STy a -> STy b -> STy (TEither a b)
STArr :: SNat n -> STy t -> STy (TArr n t)
STScal :: SScalTy t -> STy (TScal t)
- STAccum :: SNat n -> STy t -> STy (TAccum n t)
+ STAccum :: STy t -> STy (TAccum t)
deriving instance Show (STy t)
data SScalTy t where
@@ -66,10 +66,23 @@ type family ScalRep t where
ScalRep TF64 = Double
ScalRep TBool = Bool
-type ConsN :: Nat -> a -> [a] -> [a]
-type family ConsN n x l where
- ConsN Z x l = l
- ConsN (S n) x l = x : ConsN n x l
+-- | This index is flipped around from the usual direction: the smallest index
+-- is at the _heart_ of the nesting, not at the outside. The outermost layer
+-- indexes into the _outer_ dimension of the type @t@. This makes indices into
+-- compound structures work properly with coproducts.
+type family AcIdx t i where
+ AcIdx t Z = TNil
+ AcIdx (TPair a b) (S i) = TEither (AcIdx a i) (AcIdx b i)
+ AcIdx (TEither a b) (S i) = TEither (AcIdx a i) (AcIdx b i)
+ AcIdx (TArr Z t) (S i) = AcIdx t i
+ AcIdx (TArr (S n) t) (S i) = TPair TIx (AcIdx (TArr n t) i)
+
+type family AcVal t i where
+ AcVal t Z = t
+ AcVal (TPair a b) (S i) = TEither (AcVal a i) (AcVal b i)
+ AcVal (TEither a b) (S i) = TEither (AcVal a i) (AcVal b i)
+ AcVal (TArr Z t) (S i) = AcVal t i
+ AcVal (TArr (S n) t) (S i) = AcVal (TArr n t) i
-- General assumption: head of the list (whatever way it is associated) is the
-- inner variable / inner array dimension. In pretty printing, the inner
@@ -91,22 +104,23 @@ data Expr x env t where
-- array operations
EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t)
- EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (ConsN n TIx env) t -> Expr x env (TArr n t)
+ EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t)
EFold1 :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t)
- EReplicate :: x (TArr (S n) t) -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) -- TODO: unused
+ -- EReplicate :: x (TArr (S n) t) -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) -- TODO: unused
-- expression operations
EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t)
EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t
EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t)
- EIdx :: x t -> Expr x env (TArr n t) -> Vec n (Expr x env TIx) -> Expr x env t
+ EIdx :: x t -> SNat n -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t
EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx))
EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t
-- accumulation effect
- EWith :: Expr x env (TArr n t) -> Expr x (TAccum n t : env) a -> Expr x env (TPair a (TArr n t))
- EAccum1 :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum (S Z) t) -> Expr x env TNil
+ EWith :: Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t)
+ EAccum :: SNat i -> Expr x env (AcIdx t i) -> Expr x env (AcVal t i) -> Expr x env (TAccum t) -> Expr x env TNil
+ -- EAccum1 :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum (S Z) t) -> Expr x env TNil
-- partiality
EError :: STy a -> String -> Expr x env a
@@ -117,10 +131,6 @@ type Ex = Expr (Const ())
ext :: Const () a
ext = Const ()
-type family Replicate n x where
- Replicate Z x = '[]
- Replicate (S n) x = x : Replicate n x
-
type family Tup env where
Tup '[] = TNil
Tup (t : ts) = TPair (Tup ts) t
@@ -129,6 +139,14 @@ tTup :: SList STy env -> STy (Tup env)
tTup SNil = STNil
tTup (SCons t ts) = STPair (tTup ts) t
+eTup :: SList (Ex env) list -> Ex env (Tup list)
+eTup SNil = ENil ext
+eTup (e `SCons` es) = EPair ext (eTup es) e
+
+type family InvTup core env where
+ InvTup core '[] = core
+ InvTup core (t : ts) = InvTup (TPair core t) ts
+
type SOp :: Ty -> Ty -> Type
data SOp a t where
OAdd :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
@@ -169,17 +187,17 @@ typeOf = \case
EBuild _ n _ e -> STArr n (typeOf e)
EFold1 _ _ e | STArr (SS n) t <- typeOf e -> STArr n t
EUnit _ e -> STArr SZ (typeOf e)
- EReplicate _ e | STArr n t <- typeOf e -> STArr (SS n) t
+ -- EReplicate _ e | STArr n t <- typeOf e -> STArr (SS n) t
EConst _ t _ -> STScal t
EIdx0 _ e | STArr _ t <- typeOf e -> t
EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t
- EIdx _ e _ | STArr _ t <- typeOf e -> t
- -- EShape _ e | STArr n _ <- typeOf e -> _
+ EIdx _ _ e _ | STArr _ t <- typeOf e -> t
+ EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx)
EOp _ op _ -> opt2 op
EWith e1 e2 -> STPair (typeOf e2) (typeOf e1)
- EAccum1 _ _ _ -> STNil
+ EAccum _ _ _ _ -> STNil
EError t _ -> t
@@ -194,7 +212,7 @@ unSTy = \case
STEither a b -> TEither (unSTy a) (unSTy b)
STArr n t -> TArr (unSNat n) (unSTy t)
STScal t -> TScal (unSScalTy t)
- STAccum n t -> TAccum (unSNat n) (unSTy t)
+ STAccum t -> TAccum (unSTy t)
unSList :: SList STy env -> [Ty]
unSList SNil = []
@@ -231,17 +249,18 @@ subst' f w = \case
EInr x t e -> EInr x t (subst' f w e)
ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b)
EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b)
- EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkFN n f) (wcopyN n w) b)
+ EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b)
EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b)
EUnit x e -> EUnit x (subst' f w e)
- EReplicate x e -> EReplicate x (subst' f w e)
+ -- EReplicate x e -> EReplicate x (subst' f w e)
EConst x t v -> EConst x t v
EIdx0 x e -> EIdx0 x (subst' f w e)
EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
- EIdx x e es -> EIdx x (subst' f w e) (fmap (subst' f w) es)
+ EIdx x n e es -> EIdx x n (subst' f w e) (subst' f w es)
+ EShape x e -> EShape x (subst' f w e)
EOp x op e -> EOp x op (subst' f w e)
EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
- EAccum1 e1 e2 e3 -> EAccum1 (subst' f w e1) (subst' f w e2) (subst' f w e3)
+ EAccum i e1 e2 e3 -> EAccum i (subst' f w e1) (subst' f w e2) (subst' f w e3)
EError t s -> EError t s
where
sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
@@ -250,28 +269,9 @@ subst' f w = \case
IZ -> EVar x' t (w' @> IZ)
IS i -> f' x' t (WPop w') i
- sinkFN :: SNat n
- -> (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
- -> x t -> STy t -> (ConsN n TIx env' :> env2) -> Idx (ConsN n TIx env) t -> Expr x env2 t
- sinkFN SZ f' x t w' i = f' x t w' i
- sinkFN (SS _) _ x t w' IZ = EVar x t (w' @> IZ)
- sinkFN (SS n) f' x t w' (IS i) = sinkFN n f' x t (WPop w') i
-
weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t
weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i))
-wsinkN :: SNat n -> env :> ConsN n TIx env
-wsinkN SZ = WId
-wsinkN (SS n) = WSink .> wsinkN n
-
-wcopyN :: SNat n -> env :> env' -> ConsN n TIx env :> ConsN n TIx env'
-wcopyN SZ w = w
-wcopyN (SS n) w = WCopy (wcopyN n w)
-
-wpopN :: SNat n -> ConsN n TIx env :> env' -> env :> env'
-wpopN SZ w = w
-wpopN (SS n) w = wpopN n (WPop w)
-
wUndoSubenv :: Subenv env env' -> env' :> env
wUndoSubenv SETop = WId
wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub)
@@ -299,11 +299,15 @@ instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair kn
instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy
instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy
instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy
-instance (KnownNat n, KnownTy t) => KnownTy (TAccum n t) where knownTy = STAccum knownNat knownTy
+instance KnownTy t => KnownTy (TAccum t) where knownTy = STAccum knownTy
class KnownEnv env where knownEnv :: SList STy env
instance KnownEnv '[] where knownEnv = SNil
instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv
ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t)
-ebuildUp1 n sh size f = EBuild ext (SS n) (EPair ext sh size) (error "TODO" f)
+ebuildUp1 n sh size f =
+ EBuild ext (SS n) (EPair ext sh size) $
+ let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ
+ in EIdx ext n (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f))
+ (EFst ext arg)