diff options
Diffstat (limited to 'src/AST.hs')
-rw-r--r-- | src/AST.hs | 300 |
1 files changed, 219 insertions, 81 deletions
@@ -16,13 +16,16 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE FlexibleInstances #-} module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where import Data.Functor.Const +import Data.Functor.Identity import Data.Kind (Type) import Array import AST.Accum +import AST.Sparse.Types import AST.Types import AST.Weaken import CHAD.Types @@ -33,11 +36,9 @@ import Data -- inner variable / inner array dimension. In pretty printing, the inner -- variable / inner dimension is printed on the _right_. -- --- Note that the 'EZero' and 'EPlus' constructs have typing that depend on the --- type transformation of CHAD. Indeed, these constructors are created _by_ --- CHAD, and are intended to be eliminated after simplification, so that the --- input program as well as the output program do not contain these --- constructors. +-- All the monoid operations are unsupposed as the input to CHAD, and are +-- intended to be eliminated after simplification, so that the input program as +-- well as the output program do not contain these constructors. -- TODO: ensure this by a "stage" type parameter. type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type data Expr x env t where @@ -60,7 +61,7 @@ data Expr x env t where -- array operations EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) - EFold1Inner :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) + EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) @@ -87,15 +88,28 @@ data Expr x env t where -> Expr x env a -> Expr x env b -> Expr x env t + -- fake halfway checkpointing + ERecompute :: x t -> Expr x env t -> Expr x env t + -- accumulation effect on monoids - EWith :: x (TPair a (D2 t)) -> STy t -> Expr x env (D2 t) -> Expr x (TAccum t : env) a -> Expr x env (TPair a (D2 t)) - -- TODO: let this contain a OneHotTerm that is shared with EOneHot for uniformity in Simplify - EAccum :: x TNil -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (TAccum t) -> Expr x env TNil + -- | The initialiser for an accumulator __MUST__ be deep! If it is zero, it + -- must be EDeepZero, not just EZero. This is to ensure that EAccum does not + -- need to create any zeros. + EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) + -- The 'Sparse' here is eliminated to dense by UnMonoid. + EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Sparse a b -> Expr x env b -> Expr x env (TAccum t) -> Expr x env TNil -- monoidal operations (to be desugared to regular operations after simplification) - EZero :: x (D2 t) -> STy t -> Expr x env (D2 t) - EPlus :: x (D2 t) -> STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t) - EOneHot :: x (D2 t) -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (D2 t) + EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t + EDeepZero :: x t -> SMTy t -> Expr x env (DeepZeroInfo t) -> Expr x env t + EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t + EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t + + -- interface of abstract monoidal types + ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b) + ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b) + ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b) + ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c -- partiality EError :: x a -> STy a -> String -> Expr x env a @@ -106,6 +120,9 @@ type Ex = Expr (Const ()) ext :: Const () a ext = Const () +data Commutative = Commut | Noncommut + deriving (Show) + type SOp :: Ty -> Ty -> Type data SOp a t where OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) @@ -124,6 +141,7 @@ data SOp a t where OExp :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) OLog :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) OIDiv :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) + OMod :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) deriving instance Show (SOp a t) opt1 :: SOp a t -> STy a @@ -144,6 +162,7 @@ opt1 = \case OExp t -> STScal t OLog t -> STScal t OIDiv t -> STPair (STScal t) (STScal t) + OMod t -> STPair (STScal t) (STScal t) opt2 :: SOp a t -> STy t opt2 = \case @@ -163,6 +182,7 @@ opt2 = \case OExp t -> STScal t OLog t -> STScal t OIDiv t -> STScal t + OMod t -> STScal t typeOf :: Expr x env t -> STy t typeOf = \case @@ -179,10 +199,14 @@ typeOf = \case ENothing _ t -> STMaybe t EJust _ e -> STMaybe (typeOf e) EMaybe _ e _ _ -> typeOf e + ELNil _ t1 t2 -> STLEither t1 t2 + ELInl _ t2 e -> STLEither (typeOf e) t2 + ELInr _ t1 e -> STLEither t1 (typeOf e) + ELCase _ _ a _ _ -> typeOf a EConstArr _ n t _ -> STArr n (STScal t) EBuild _ n _ e -> STArr n (typeOf e) - EFold1Inner _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t + EFold1Inner _ _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t EUnit _ e -> STArr SZ (typeOf e) EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t @@ -197,13 +221,15 @@ typeOf = \case EOp _ op _ -> opt2 op ECustom _ _ _ _ e _ _ _ _ -> typeOf e + ERecompute _ e -> typeOf e EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1) - EAccum _ _ _ _ _ _ -> STNil + EAccum _ _ _ _ _ _ _ -> STNil - EZero _ t -> d2 t - EPlus _ t _ _ -> d2 t - EOneHot _ t _ _ _ -> d2 t + EZero _ t _ -> fromSMTy t + EDeepZero _ t _ -> fromSMTy t + EPlus _ t _ _ -> fromSMTy t + EOneHot _ t _ _ _ -> fromSMTy t EError _ t _ -> t @@ -221,9 +247,13 @@ extOf = \case ENothing x _ -> x EJust x _ -> x EMaybe x _ _ _ -> x + ELNil x _ _ -> x + ELInl x _ _ -> x + ELInr x _ _ -> x + ELCase x _ _ _ _ -> x EConstArr x _ _ _ -> x EBuild x _ _ _ -> x - EFold1Inner x _ _ _ -> x + EFold1Inner x _ _ _ _ -> x ESum1Inner x _ -> x EUnit x _ -> x EReplicate1Inner x _ _ -> x @@ -236,38 +266,70 @@ extOf = \case EShape x _ -> x EOp x _ _ -> x ECustom x _ _ _ _ _ _ _ _ -> x + ERecompute x _ -> x EWith x _ _ _ -> x - EAccum x _ _ _ _ _ -> x - EZero x _ -> x + EAccum x _ _ _ _ _ _ -> x + EZero x _ _ -> x + EDeepZero x _ _ -> x EPlus x _ _ _ -> x EOneHot x _ _ _ _ -> x EError x _ _ -> x -unSTy :: STy t -> Ty -unSTy = \case - STNil -> TNil - STPair a b -> TPair (unSTy a) (unSTy b) - STEither a b -> TEither (unSTy a) (unSTy b) - STMaybe t -> TMaybe (unSTy t) - STArr n t -> TArr (unSNat n) (unSTy t) - STScal t -> TScal (unSScalTy t) - STAccum t -> TAccum (unSTy t) - -unSEnv :: SList STy env -> [Ty] -unSEnv SNil = [] -unSEnv (SCons t l) = unSTy t : unSEnv l - -unSScalTy :: SScalTy t -> ScalTy -unSScalTy = \case - STI32 -> TI32 - STI64 -> TI64 - STF32 -> TF32 - STF64 -> TF64 - STBool -> TBool - -subst1 :: Expr x env a -> Expr x (a : env) t -> Expr x env t -subst1 repl = subst $ \x t -> \case IZ -> repl - IS i -> EVar x t i +mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t +mapExt f = runIdentity . travExt (Identity . f) + +{-# SPECIALIZE travExt :: (forall a. x a -> Identity (x' a)) -> Expr x env t -> Identity (Expr x' env t) #-} +travExt :: Applicative f => (forall a. x a -> f (x' a)) -> Expr x env t -> f (Expr x' env t) +travExt f = \case + EVar x t i -> EVar <$> f x <*> pure t <*> pure i + ELet x rhs body -> ELet <$> f x <*> travExt f rhs <*> travExt f body + EPair x a b -> EPair <$> f x <*> travExt f a <*> travExt f b + EFst x e -> EFst <$> f x <*> travExt f e + ESnd x e -> ESnd <$> f x <*> travExt f e + ENil x -> ENil <$> f x + EInl x t e -> EInl <$> f x <*> pure t <*> travExt f e + EInr x t e -> EInr <$> f x <*> pure t <*> travExt f e + ECase x e a b -> ECase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b + ENothing x t -> ENothing <$> f x <*> pure t + EJust x e -> EJust <$> f x <*> travExt f e + EMaybe x a b e -> EMaybe <$> f x <*> travExt f a <*> travExt f b <*> travExt f e + ELNil x t1 t2 -> ELNil <$> f x <*> pure t1 <*> pure t2 + ELInl x t e -> ELInl <$> f x <*> pure t <*> travExt f e + ELInr x t e -> ELInr <$> f x <*> pure t <*> travExt f e + ELCase x e a b c -> ELCase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b <*> travExt f c + EConstArr x n t a -> EConstArr <$> f x <*> pure n <*> pure t <*> pure a + EBuild x n a b -> EBuild <$> f x <*> pure n <*> travExt f a <*> travExt f b + EFold1Inner x cm a b c -> EFold1Inner <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c + ESum1Inner x e -> ESum1Inner <$> f x <*> travExt f e + EUnit x e -> EUnit <$> f x <*> travExt f e + EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b + EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e + EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e + EConst x t v -> EConst <$> f x <*> pure t <*> pure v + EIdx0 x e -> EIdx0 <$> f x <*> travExt f e + EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b + EIdx x e es -> EIdx <$> f x <*> travExt f e <*> travExt f es + EShape x e -> EShape <$> f x <*> travExt f e + EOp x op e -> EOp <$> f x <*> pure op <*> travExt f e + ECustom x s t p a b c e1 e2 -> ECustom <$> f x <*> pure s <*> pure t <*> pure p <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f e1 <*> travExt f e2 + ERecompute x e -> ERecompute <$> f x <*> travExt f e + EWith x t e1 e2 -> EWith <$> f x <*> pure t <*> travExt f e1 <*> travExt f e2 + EAccum x t p e1 sp e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> pure sp <*> travExt f e2 <*> travExt f e3 + EZero x t e -> EZero <$> f x <*> pure t <*> travExt f e + EDeepZero x t e -> EDeepZero <$> f x <*> pure t <*> travExt f e + EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b + EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b + EError x t s -> EError <$> f x <*> pure t <*> pure s + +substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t +substInline repl = + subst $ \x t -> \case IZ -> repl + IS i -> EVar x t i + +subst0 :: Ex (b : env) a -> Ex (a : env) t -> Ex (b : env) t +subst0 repl = + subst $ \_ t -> \case IZ -> repl + IS i -> EVar ext t (IS i) subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a) -> Expr x env t -> Expr x env' t @@ -290,9 +352,13 @@ subst' f w = \case ENothing x t -> ENothing x t EJust x e -> EJust x (subst' f w e) EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e) + ELNil x t1 t2 -> ELNil x t1 t2 + ELInl x t e -> ELInl x t (subst' f w e) + ELInr x t e -> ELInr x t (subst' f w e) + ELCase x e a b c -> ELCase x (subst' f w e) (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' (sinkF f) (WCopy w) c) EConstArr x n t a -> EConstArr x n t a EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) - EFold1Inner x a b c -> EFold1Inner x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) + EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) ESum1Inner x e -> ESum1Inner x (subst' f w e) EUnit x e -> EUnit x (subst' f w e) EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b) @@ -305,11 +371,13 @@ subst' f w = \case EShape x e -> EShape x (subst' f w e) EOp x op e -> EOp x op (subst' f w e) ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2) + ERecompute x e -> ERecompute x (subst' f w e) EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) - EAccum x t p e1 e2 e3 -> EAccum x t p (subst' f w e1) (subst' f w e2) (subst' f w e3) - EZero x t -> EZero x t + EAccum x t p e1 sp e2 e3 -> EAccum x t p (subst' f w e1) sp (subst' f w e2) (subst' f w e3) + EZero x t e -> EZero x t (subst' f w e) + EDeepZero x t e -> EDeepZero x t (subst' f w e) EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b) - EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) + EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) EError x t s -> EError x t s where sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) @@ -321,15 +389,6 @@ subst' f w = \case weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i)) -slistIdx :: SList f list -> Idx list t -> f t -slistIdx (SCons x _) IZ = x -slistIdx (SCons _ list) (IS i) = slistIdx list i -slistIdx SNil i = case i of {} - -idx2int :: Idx env t -> Int -idx2int IZ = 0 -idx2int (IS n) = 1 + idx2int n - class KnownScalTy t where knownScalTy :: SScalTy t instance KnownScalTy TI32 where knownScalTy = STI32 instance KnownScalTy TI64 where knownScalTy = STI64 @@ -341,10 +400,19 @@ class KnownTy t where knownTy :: STy t instance KnownTy TNil where knownTy = STNil instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy +instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy -instance KnownTy t => KnownTy (TAccum t) where knownTy = STAccum knownTy +instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy + +class KnownMTy t where knownMTy :: SMTy t +instance KnownMTy TNil where knownMTy = SMTNil +instance (KnownMTy s, KnownMTy t) => KnownMTy (TPair s t) where knownMTy = SMTPair knownMTy knownMTy +instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy +instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy +instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy +instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy class KnownEnv env where knownEnv :: SList STy env instance KnownEnv '[] where knownEnv = SNil @@ -354,10 +422,19 @@ styKnown :: STy t -> Dict (KnownTy t) styKnown STNil = Dict styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict +styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict styKnown (STMaybe t) | Dict <- styKnown t = Dict styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict styKnown (STScal t) | Dict <- sscaltyKnown t = Dict -styKnown (STAccum t) | Dict <- styKnown t = Dict +styKnown (STAccum t) | Dict <- smtyKnown t = Dict + +smtyKnown :: SMTy t -> Dict (KnownMTy t) +smtyKnown SMTNil = Dict +smtyKnown (SMTPair a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict +smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict +smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict +smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict +smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t) sscaltyKnown STI32 = Dict @@ -394,27 +471,30 @@ eidxEq (SS n) a b (eidxEq n (EFst ext (EVar ext ty (IS IZ))) (EFst ext (EVar ext ty IZ))) -emap :: Ex (a : env) b -> Ex env (TArr n a) -> Ex env (TArr n b) -emap f arr = - let STArr n t = typeOf arr - in ELet ext arr $ - EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ - ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ)) - (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - weakenExpr (WCopy (WSink .> WSink)) f - -ezipWith :: Ex (b : a : env) c -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c) -ezipWith f arr1 arr2 = - let STArr n t1 = typeOf arr1 - STArr _ t2 = typeOf arr2 - in ELet ext arr1 $ - ELet ext (weakenExpr WSink arr2) $ - EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ - ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $ - weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f +emap :: (KnownTy a => Ex (a : env) b) -> Ex env (TArr n a) -> Ex env (TArr n b) +emap f arr + | STArr n t <- typeOf arr + , Dict <- styKnown t + = ELet ext arr $ + EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ + ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ)) + (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) f + +ezipWith :: ((KnownTy a, KnownTy b) => Ex (b : a : env) c) -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c) +ezipWith f arr1 arr2 + | STArr n t1 <- typeOf arr1 + , STArr _ t2 <- typeOf arr2 + , Dict <- styKnown t1 + , Dict <- styKnown t2 + = ELet ext arr1 $ + ELet ext (weakenExpr WSink arr2) $ + EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ + ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $ + weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) ezip arr1 arr2 = @@ -435,3 +515,61 @@ eshapeEmpty (SS n) e = (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)) (EConst ext STI64 0))) (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)))) + +-- ezeroD2 :: STy t -> Ex env (ZeroInfo (D2 t)) -> Ex env (D2 t) +-- ezeroD2 t ezi = EZero ext (d2M t) ezi + +-- eaccumD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (TAccum (D2 t)) -> Ex env TNil +-- eaccumD2 t p ei ev ea | Refl <- lemZeroInfoD2 t = EAccum ext (d2M t) (ENil ext) p ei ev ea + +-- eonehotD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (D2 t) +-- eonehotD2 t p ei ev | Refl <- lemZeroInfoD2 t = EOneHot ext (d2M t) (ENil ext) p ei ev + +eunPair :: Ex env (TPair a b) -> (forall env'. env :> env' -> Ex env' a -> Ex env' b -> Ex env' r) -> Ex env r +eunPair (EPair _ e1 e2) k = k WId e1 e2 +eunPair e k = + elet e $ + k WSink + (EFst ext (evar IZ)) + (ESnd ext (evar IZ)) + +efst :: Ex env (TPair a b) -> Ex env a +efst (EPair _ e1 _) = e1 +efst e = EFst ext e + +esnd :: Ex env (TPair a b) -> Ex env b +esnd (EPair _ _ e2) = e2 +esnd e = ESnd ext e + +elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b +elet rhs body + | Dict <- styKnown (typeOf rhs) + = ELet ext rhs body + +emaybe :: Ex env (TMaybe a) -> Ex env b -> (KnownTy a => Ex (a : env) b) -> Ex env b +emaybe e a b + | STMaybe t <- typeOf e + , Dict <- styKnown t + = EMaybe ext a b e + +elcase :: Ex env (TLEither a b) -> Ex env c -> (KnownTy a => Ex (a : env) c) -> (KnownTy b => Ex (b : env) c) -> Ex env c +elcase e a b c + | STLEither t1 t2 <- typeOf e + , Dict <- styKnown t1 + , Dict <- styKnown t2 + = ELCase ext e a b c + +evar :: KnownTy a => Idx env a -> Ex env a +evar = EVar ext knownTy + +makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) +makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ) + where + -- invariant: expression argument is duplicable + go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) + go SMTNil _ = ENil ext + go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e)) + go SMTLEither{} _ = ENil ext + go SMTMaybe{} _ = ENil ext + go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e + go SMTScal{} _ = ENil ext |