summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs718
1 files changed, 426 insertions, 292 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index b5a9af0..241825e 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -11,6 +11,7 @@
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
@@ -62,15 +63,21 @@ tapeTy :: SList STy binds -> STy (Tape binds)
tapeTy SNil = STNil
tapeTy (SCons t ts) = STPair t (tapeTy ts)
-bindingsCollectTape :: Bindings f env binds -> Subenv binds tapebinds
- -> Append binds env :> env2 -> Ex env2 (Tape tapebinds)
-bindingsCollectTape BTop SETop _ = ENil ext
-bindingsCollectTape (BPush binds (t, _)) (SEYesR sub) w =
+bindingsCollectTape :: SList STy binds -> Subenv binds tapebinds
+ -> binds :> env2 -> Ex env2 (Tape tapebinds)
+bindingsCollectTape SNil SETop _ = ENil ext
+bindingsCollectTape (t `SCons` binds) (SEYesR sub) w =
EPair ext (EVar ext t (w @> IZ))
(bindingsCollectTape binds sub (w .> WSink))
-bindingsCollectTape (BPush binds _) (SENo sub) w =
+bindingsCollectTape (_ `SCons` binds) (SENo sub) w =
bindingsCollectTape binds sub (w .> WSink)
+-- bindingsCollectTape' :: forall f env binds tapebinds env2. Bindings f env binds -> Subenv binds tapebinds
+-- -> Append binds env :> env2 -> Ex env2 (Tape tapebinds)
+-- bindingsCollectTape' binds sub w
+-- | Refl <- lemAppendNil @binds
+-- = bindingsCollectTape (bindingsBinds binds) sub (w .> wCopies @_ @_ @'[] (bindingsBinds binds) (WClosed @env))
+
-- In order from large to small: i.e. in reverse order from what we want,
-- because in a Bindings, the head of the list is the bottom-most entry.
type family TapeUnfoldings binds where
@@ -325,6 +332,21 @@ conv2Idx (DPush des (_, _, SDiscr)) (IS i) =
Idx2Di j -> Idx2Di (IS j)
conv2Idx DTop i = case i of {}
+opt2UnSparse :: SOp a b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b)
+opt2UnSparse = go . opt2
+ where
+ go :: STy b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b)
+ go (STScal STI32) SpAbsent = \_ -> ENil ext
+ go (STScal STI64) SpAbsent = \_ -> ENil ext
+ go (STScal STF32) SpAbsent = \_ -> EZero ext (SMTScal STF32) (ENil ext)
+ go (STScal STF64) SpAbsent = \_ -> EZero ext (SMTScal STF64) (ENil ext)
+ go (STScal STBool) SpAbsent = \_ -> ENil ext
+ go (STScal STF32) SpScal = id
+ go (STScal STF64) SpScal = id
+ go STNil _ = \_ -> ENil ext
+ go (STPair t1 t2) (SpPair s1 s2) = \e -> eunPair e $ \_ e1 e2 -> EPair ext (go t1 s1 e1) (go t2 s2 e2)
+ go t _ = error $ "Primitive operations that return " ++ show t ++ " are scary"
+
------------------------------------ MONOIDS -----------------------------------
@@ -355,7 +377,7 @@ subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub)
subenvD1E (SENo sub) = SENo (subenvD1E sub)
expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a)
-expandSparse _ SpDense _ e = e
+expandSparse t sp _ e | Just Refl <- isDense (d2M t) sp = e
expandSparse t (SpSparse sp) epr e =
EMaybe ext
(EZero ext (d2M t) (d2zeroInfo t epr))
@@ -376,12 +398,6 @@ expandSparse (STEither t1 t2) (SpLEither s1 s2) epr e =
(ECase ext (weakenExpr WSink epr)
(EError ext (d2 (STEither t1 t2)) "expspa l<-dr")
(ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ)))))
-expandSparse (STEither t1 t2) (SpLeft s) epr e =
- let epr' = ECase ext epr (EVar ext (d1 t1) IZ) (EError ext (d1 t1) "expspa r<-dL")
- in ELInl ext (d2 t2) (expandSparse t1 s epr' e)
-expandSparse (STEither t1 t2) (SpRight s) epr e =
- let epr' = ECase ext epr (EError ext (d1 t2) "expspa l<-dR") (EVar ext (d1 t2) IZ)
- in ELInr ext (d2 t1) (expandSparse t2 s epr' e)
expandSparse (STLEither t1 t2) (SpLEither s1 s2) epr e =
ELCase ext e
(EZero ext (d2M (STEither t1 t2)) (ENil ext))
@@ -393,12 +409,6 @@ expandSparse (STLEither t1 t2) (SpLEither s1 s2) epr e =
(EError ext (d2 (STEither t1 t2)) "expspa ln<-dr")
(EError ext (d2 (STEither t1 t2)) "expspa ll<-dr")
(ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ)))))
-expandSparse (STLEither t1 t2) (SpLeft s) epr e =
- let epr' = ELCase ext epr (EError ext (d1 t1) "expspa ln<-dL") (EVar ext (d1 t1) IZ) (EError ext (d1 t1) "expspa r<-dL")
- in ELInl ext (d2 t2) (expandSparse t1 s epr' e)
-expandSparse (STLEither t1 t2) (SpRight s) epr e =
- let epr' = ELCase ext epr (EError ext (d1 t2) "expspa ln<-dR") (EError ext (d1 t2) "expspa l<-dR") (EVar ext (d1 t2) IZ)
- in ELInr ext (d2 t1) (expandSparse t2 s epr' e)
expandSparse (STMaybe t) (SpMaybe s) epr e =
EMaybe ext
(ENothing ext (d2 t))
@@ -407,55 +417,72 @@ expandSparse (STMaybe t) (SpMaybe s) epr e =
e
expandSparse (STArr _ t) (SpArr s) epr e =
ezipWith (expandSparse t s (EVar ext (d1 t) (IS IZ)) (EVar ext (applySparse s (d2 t)) IZ)) epr e
-expandSparse (STScal sty) _ _ _ = case sty of {} -- SpDense and SpSparse handled already
+expandSparse (STScal STF32) SpScal _ e = e
+expandSparse (STScal STF64) SpScal _ e = e
expandSparse (STAccum{}) _ _ _ = error "accumulators not allowed in source program"
-sparsePlus
- :: SMTy t -> Sparse t t1 -> Sparse t t2
- -> (forall t3. Sparse t t3
- -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3)
- -> r)
- -> r
-sparsePlus t sp1 sp2 k = sparsePlusS SF SF t sp1 sp2 $ \sp3 _ _ plus -> k sp3 plus
-
-subenvPlus :: SList STy env
- -> SubenvS (D2E env) env1 -> SubenvS (D2E env) env2
- -> (forall env3. SubenvS (D2E env) env3
- -> SubenvS env3 env1
- -> SubenvS env3 env2
- -> (Ex exenv (Tup env1)
- -> Ex exenv (Tup env2)
- -> Ex exenv (Tup env3))
+subenvPlus :: SBool req1 -> SBool req2
+ -> SList SMTy env
+ -> SubenvS env env1 -> SubenvS env env2
+ -> (forall env3. SubenvS env env3
+ -> Injection req1 (Tup env1) (Tup env3)
+ -> Injection req2 (Tup env2) (Tup env3)
+ -> (forall e. Ex e (Tup env1) -> Ex e (Tup env2) -> Ex e (Tup env3))
-> r)
-> r
-subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext)
-subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
+subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\_ _ -> ENil ext)
+
+subenvPlus req1 req2 (SCons _ env) (SENo sub1) (SENo sub2) k =
+ subenvPlus req1 req2 env sub1 sub2 $ \sub3 s31 s32 pl ->
k (SENo sub3) s31 s32 pl
-subenvPlus (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SEYes sp1 sub3) (SEYes SpDense s31) (SENo s32) $ \e1 e2 ->
- ELet ext e1 $
- EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
- (weakenExpr WSink e2))
- (ESnd ext (EVar ext (typeOf e1) IZ))
-subenvPlus (SCons _ env) (SENo sub1) (SEYes sp2 sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SEYes sp2 sub3) (SENo s31) (SEYes SpDense s32) $ \e1 e2 ->
- ELet ext e2 $
- EPair ext (pl (weakenExpr WSink e1)
- (EFst ext (EVar ext (typeOf e2) IZ)))
- (ESnd ext (EVar ext (typeOf e2) IZ))
-subenvPlus (SCons t env) (SEYes sp1 sub1) (SEYes sp2 sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SEYesR sub3) (SEYesR s31) (SEYesR s32) $ \e1 e2 ->
- ELet ext e1 $
- ELet ext (weakenExpr WSink e2) $
- EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ)))
- (EFst ext (EVar ext (typeOf e2) IZ)))
- (EPlus ext (d2M t)
- (ESnd ext (EVar ext (typeOf e1) (IS IZ)))
- (ESnd ext (EVar ext (typeOf e2) IZ)))
+
+subenvPlus req1 SF (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k =
+ subenvPlus req1 SF env sub1 sub2 $ \sub3 minj13 _ pl ->
+ k (SEYes sp1 sub3)
+ (withInj minj13 $ \inj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) e1b)
+ Noinj
+ (\e1 e2 ->
+ ELet ext e1 $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
+ (weakenExpr WSink e2))
+ (ESnd ext (EVar ext (typeOf e1) IZ)))
+subenvPlus req1 ST (SCons t env) (SEYes sp1 sub1) (SENo sub2) k =
+ subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl ->
+ k (SEYes (SpSparse sp1) sub3)
+ (withInj minj13 $ \inj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) (EJust ext e1b))
+ (Inj $ \e2 -> EPair ext (inj23 e2) (ENothing ext (applySparse sp1 (fromSMTy t))))
+ (\e1 e2 ->
+ ELet ext e1 $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
+ (weakenExpr WSink e2))
+ (EJust ext (ESnd ext (EVar ext (typeOf e1) IZ))))
+
+subenvPlus req1 req2 (SCons t env) sub1@SENo{} sub2@SEYes{} k =
+ subenvPlus req2 req1 (SCons t env) sub2 sub1 $ \sub3 minj23 minj13 pl ->
+ k sub3 minj13 minj23 (flip pl)
+
+subenvPlus req1 req2 (SCons t env) (SEYes sp1 sub1) (SEYes sp2 sub2) k =
+ subenvPlus req1 req2 env sub1 sub2 $ \sub3 minj13 minj23 pl ->
+ sparsePlusS req1 req2 t sp1 sp2 $ \sp3 mTinj13 mTinj23 plus ->
+ k (SEYes sp3 sub3)
+ (withInj2 minj13 mTinj13 $ \inj13 tinj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) (tinj13 e1b))
+ (withInj2 minj23 mTinj23 $ \inj23 tinj23 ->
+ \e2 -> eunPair e2 $ \_ e2a e2b ->
+ EPair ext (inj23 e2a) (tinj23 e2b))
+ (\e1 e2 ->
+ ELet ext e1 $
+ ELet ext (weakenExpr WSink e2) $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ)))
+ (EFst ext (EVar ext (typeOf e2) IZ)))
+ (plus
+ (ESnd ext (EVar ext (typeOf e1) (IS IZ)))
+ (ESnd ext (EVar ext (typeOf e2) IZ))))
expandSubenvZeros :: D1E env0 :> env -> SList STy env0 -> SubenvS (D2E env0) contribs
-> Ex env (Tup contribs) -> Ex env (Tup (D2E env0))
@@ -470,10 +497,10 @@ expandSubenvZeros w (SCons t ts) (SENo sub) e =
(expandSubenvZeros (WPop w) ts sub e)
(EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ))))
-assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[]
+assertSubenvEmpty :: HasCallStack => Subenv' s env env' -> env' :~: '[]
assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl
assertSubenvEmpty SETop = Refl
-assertSubenvEmpty SEYesR{} = error "assertSubenvEmpty: not empty"
+assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty"
--------------------------------- ACCUMULATORS ---------------------------------
@@ -523,8 +550,8 @@ accumPromote :: forall dt env sto proxy r.
-- accumulators.
-> (forall shbinds.
SList STy shbinds
- -> (D2 dt : Append shbinds (D2AcE (Select env stoRepl "accum")))
- :> Append (D2AcE envPro) (D2 dt : Append shbinds (D2AcE (Select env sto "accum"))))
+ -> (dt : Append shbinds (D2AcE (Select env stoRepl "accum")))
+ :> Append (D2AcE envPro) (dt : Append shbinds (D2AcE (Select env sto "accum"))))
-- ^ A weakening that converts a computation in the
-- revised environment to one in the original environment
-- extended with some accumulators.
@@ -541,11 +568,11 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
(SEYesR accrevsub)
(VarMap.sink1 accumMap)
(\shbinds ->
- autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr)))
+ autoWeak (#pro (d2ace envpro) &. #d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr)))
(#acc :++: (#pro :++: #d :++: #shb :++: #tl))
(#pro :++: #d :++: #shb :++: #acc :++: #tl)
.> WCopy (wf shbinds)
- .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl)))
+ .> autoWeak (#d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl)))
(#d :++: #shb :++: #acc :++: #tl)
(#acc :++: (#d :++: #shb :++: #tl)))
@@ -582,7 +609,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
-- goal: | ARE EQUAL ||
-- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
WCopy (wf shbinds)
- .> WPick @(TAccum (D2 t)) @(D2 dt : shbinds) (Const () `SCons` shbindsC)
+ .> WPick @(TAccum (D2 t)) @(dt : shbinds) (Const () `SCons` shbindsC)
(WId @(D2AcE (Select env1 stoRepl "accum"))))
-- Discrete values are left as-is, nothing to do
@@ -614,23 +641,41 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
---------------------------- RETURN TRIPLE FROM CHAD ---------------------------
-data Ret env0 sto t =
+data Ret env0 sto sd t =
forall shbinds tapebinds contribs.
Ret (Bindings Ex (D1E env0) shbinds) -- shared binds
(Subenv shbinds tapebinds)
(Ex (Append shbinds (D1E env0)) (D1 t))
(SubenvS (D2E (Select env0 sto "merge")) contribs)
- (forall sd. Sparse (D2 t) sd
- -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs))
-deriving instance Show (Ret env0 sto t)
-
-data RetPair env0 sto env shbinds tapebinds t =
- forall contribs.
- RetPair (Ex (Append shbinds env) (D1 t))
- (SubenvS (D2E (Select env0 sto "merge")) contribs)
- (forall sd. Sparse (D2 t) sd
- -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs))
-deriving instance Show (RetPair env0 sto env shbinds tapebinds t)
+ (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs))
+deriving instance Show (Ret env0 sto sd t)
+
+type data TyTyPair = MkTyTyPair Ty Ty
+
+data SingleRet env0 sto (pair :: TyTyPair) =
+ forall shbinds tapebinds.
+ SingleRet
+ (Bindings Ex (D1E env0) shbinds) -- shared binds
+ (Subenv shbinds tapebinds)
+ (RetPair env0 sto (D1E env0) shbinds tapebinds pair)
+
+-- pattern Ret1 :: forall env0 sto Bindings Ex (D1E env0) shbinds
+-- -> Subenv shbinds tapebinds
+-- -> Ex (Append shbinds (D1E env0)) (D1 t)
+-- -> SubenvS (D2E (Select env0 sto "merge")) contribs
+-- -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)
+-- -> SingleRet env0 sto (MkTyTyPair sd t)
+-- pattern Ret1 e0 subtape e1 sub e2 = SingleRet e0 subtape (RetPair e1 sub e2)
+-- {-# COMPLETE Ret1 #-}
+
+data RetPair env0 sto env shbinds tapebinds (pair :: TyTyPair) where
+ RetPair :: forall sd t contribs -- existentials
+ env0 sto env shbinds tapebinds. -- universals
+ Ex (Append shbinds env) (D1 t)
+ -> SubenvS (D2E (Select env0 sto "merge")) contribs
+ -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)
+ -> RetPair env0 sto env shbinds tapebinds (MkTyTyPair sd t)
+deriving instance Show (RetPair env0 sto env shbinds tapebinds pair)
data Rets env0 sto env list =
forall shbinds tapebinds.
@@ -639,8 +684,11 @@ data Rets env0 sto env list =
(SList (RetPair env0 sto env shbinds tapebinds) list)
deriving instance Show (Rets env0 sto env list)
+toSingleRet :: Ret env0 sto sd t -> SingleRet env0 sto (MkTyTyPair sd t)
+toSingleRet (Ret e0 subtape e1 sub e2) = SingleRet e0 subtape (RetPair e1 sub e2)
+
weakenRetPair :: SList STy shbinds -> env :> env'
- -> RetPair env0 sto env shbinds tapebinds t -> RetPair env0 sto env' shbinds tapebinds t
+ -> RetPair env0 sto env shbinds tapebinds pair -> RetPair env0 sto env' shbinds tapebinds pair
weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2
weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list
@@ -648,46 +696,47 @@ weakenRets w (Rets binds tapesub list) =
let (binds', _) = weakenBindings weakenExpr w binds
in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list)
-rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto t f.
+rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto pair f.
Descr env0 sto
-> SList f b1 -> SList f b2
-> Subenv b1 tapebinds1 -> Subenv b2 tapebinds2
- -> RetPair env0 sto (Append b1 env) b2 tapebinds2 t
- -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) t
-rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair p sub d)
+ -> RetPair env0 sto (Append b1 env) b2 tapebinds2 pair
+ -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) pair
+rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair @sd e1 sub e2)
| Refl <- lemAppendAssoc @b2 @b1 @env =
- RetPair p sub (weakenExpr (autoWeak
- (#d (auto1 @(D2 t))
- &. #t2 (subList b2 subtape2)
- &. #t1 (subList b1 subtape1)
- &. #tl (d2ace (select SAccum descr)))
- (#d :++: (#t2 :++: #tl))
- (#d :++: ((#t2 :++: #t1) :++: #tl)))
- d)
-
-retConcat :: forall env0 sto list. Descr env0 sto -> SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list
+ RetPair e1 sub
+ (weakenExpr (autoWeak
+ (#d (auto1 @sd)
+ &. #t2 (subList b2 subtape2)
+ &. #t1 (subList b1 subtape1)
+ &. #tl (d2ace (select SAccum descr)))
+ (#d :++: (#t2 :++: #tl))
+ (#d :++: ((#t2 :++: #t1) :++: #tl)))
+ e2)
+
+retConcat :: forall env0 sto list. Descr env0 sto -> SList (SingleRet env0 sto) list -> Rets env0 sto (D1E env0) list
retConcat _ SNil = Rets BTop SETop SNil
-retConcat descr (SCons (Ret (b :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) p sub d) list)
+retConcat descr (SCons (SingleRet (e0 :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) (RetPair e1 sub e2)) list)
| Rets (binds :: Bindings _ _ shbinds2) (subtape2 :: Subenv _ tapebinds2) pairs
- <- weakenRets (sinkWithBindings b) (retConcat descr list)
+ <- weakenRets (sinkWithBindings e0) (retConcat descr list)
, Refl <- lemAppendAssoc @shbinds2 @shbinds1 @(D1E env0)
, Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(D2AcE (Select env0 sto "accum"))
- = Rets (bconcat b binds)
+ = Rets (bconcat e0 binds)
(subenvConcat subtape subtape2)
- (SCons (RetPair (weakenExpr (sinkWithBindings binds) p)
+ (SCons (RetPair (weakenExpr (sinkWithBindings binds) e1)
sub
- (weakenExpr (WCopy (sinkWithSubenv subtape2)) d))
- (slistMap (rebaseRetPair descr (bindingsBinds b) (bindingsBinds binds)
+ (weakenExpr (WCopy (sinkWithSubenv subtape2)) e2))
+ (slistMap (rebaseRetPair descr (bindingsBinds e0) (bindingsBinds binds)
subtape subtape2)
pairs))
freezeRet :: Descr env sto
- -> Ret env sto t
+ -> Ret env sto (D2 t) t
-> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge"))))
-freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) =
+freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ _ t) =
let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0
e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2
- tContribs = tTup (subList (d2e (select SMerge descr)) sub)
+ tContribs = tTup (slistMap fromSMTy (subList (d2eM (select SMerge descr)) sub))
library = #d (auto1 @(D2 t))
&. #tape (subList (bindingsBinds e0) subtape)
&. #shbinds (bindingsBinds e0)
@@ -709,11 +758,34 @@ freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) =
---------------------------- THE CHAD TRANSFORMATION ---------------------------
-drev :: forall env sto t.
+drev :: forall env sto sd t.
(?config :: CHADConfig)
=> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
- -> Expr ValId env t -> Ret env sto t
-drev des accumMap = \case
+ -> Sparse (D2 t) sd
+ -> Expr ValId env t -> Ret env sto sd t
+drev des _ sd | isAbsent sd =
+ \e ->
+ Ret BTop
+ SETop
+ (drevPrimal des e)
+ (subenvNone (d2e (select SMerge des)))
+ (ENil ext)
+drev _ _ SpAbsent = error "Absent should be isAbsent"
+
+drev des accumMap (SpSparse sd) =
+ \e ->
+ case drev des accumMap sd e of { Ret e0 subtape e1 sub e2 ->
+ subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ ->
+ Ret e0
+ subtape
+ e1
+ sub'
+ (emaybe (evar IZ)
+ (inj2 (ENil ext))
+ (inj1 (weakenExpr (WCopy WSink) e2)))
+ }
+
+drev des accumMap sd = \case
EVar _ t i ->
case conv2Idx des i of
Idx2Ac accI ->
@@ -721,14 +793,15 @@ drev des accumMap = \case
SETop
(EVar ext (d1 t) (conv1Idx i))
(subenvNone (d2e (select SMerge des)))
- (EAccum ext (d2M t) SAPHere (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2M t)) (IS accI)))
+ (let ty = applySparse sd (d2M t)
+ in EAccum ext (d2M t) (_ sd) (ENil ext) (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI)))
Idx2Me tupI ->
Ret BTop
SETop
(EVar ext (d1 t) (conv1Idx i))
- (subenvOnehot (d2e (select SMerge des)) tupI)
- (EPair ext (ENil ext) (EVar ext (d2 t) IZ))
+ (subenvOnehot (d2e (select SMerge des)) tupI sd)
+ (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t)) IZ))
Idx2Di _ ->
Ret BTop
@@ -738,20 +811,22 @@ drev des accumMap = \case
(ENil ext)
ELet _ (rhs :: Expr _ _ a) body
- | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des accumMap rhs
- , ChosenStorage storage <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge
- , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) body
+ | ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge
+ , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody sdBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) sd body
+ , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) rhs1 subRHS rhs2 <- drev des accumMap sdBody rhs
, let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0
- , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env)
- , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) ->
- subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body ->
- let bodyResType = STPair (tTup (subList (d2e (select SMerge des)) subBody)) (d2 (typeOf rhs)) in
+ , Refl <- lemAppendAssoc @body_shbinds @'[D1 a] @rhs_shbinds
+ , Refl <- lemAppendAssoc @body_shbinds @(D1 a : rhs_shbinds) @(D1E env)
+ , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum"))
+ ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subRHS subBody $ \subBoth _ _ plus_RHS_Body ->
+ let bodyResType = STPair (contribTupTy des subBody) (applySparse sdBody (d2 (typeOf rhs))) in
Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0')
- (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody)
+ (subenvConcat subtapeRHS subtapeBody)
(weakenExpr wbody0' body1)
subBoth
- (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t))
- &. #body (subList (bindingsBinds body0) subtapeBody)
+ (ELet ext (weakenExpr (autoWeak (#d (auto1 @sd)
+ &. #body (subList (bindingsBinds body0 `sappend` SCons (d1 (typeOf rhs)) SNil) subtapeBody)
&. #rhs (subList (bindingsBinds rhs0) subtapeRHS)
&. #tl (d2ace (select SAccum des)))
(#d :++: #body :++: #tl)
@@ -761,14 +836,15 @@ drev des accumMap = \case
(ELet ext (ESnd ext (EVar ext bodyResType IZ)) $
weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $
plus_RHS_Body
- (EVar ext (tTup (subList (d2e (select SMerge des)) subRHS)) IZ)
+ (EVar ext (contribTupTy des subRHS) IZ)
(EFst ext (EVar ext bodyResType (IS IZ))))
EPair _ a b
- | Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil)
- <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil
- , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) ->
- subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B ->
+ | SpPair sd1 sd2 <- sd
+ , Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil)
+ <- retConcat des $ toSingleRet (drev des accumMap sd1 a) `SCons` toSingleRet (drev des accumMap sd2 b) `SCons` SNil
+ , let dt = STPair (applySparse sd1 (d2 (typeOf a))) (applySparse sd2 (d2 (typeOf b))) ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B ->
Ret binds
subtape
(EPair ext a1 b1)
@@ -778,147 +854,155 @@ drev des accumMap = \case
ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ)))
(weakenExpr (WCopy (WSink .> WSink)) b2)) $
plus_A_B
- (EVar ext (tTup (subList (d2e (select SMerge des)) subA)) (IS IZ))
- (EVar ext (tTup (subList (d2e (select SMerge des)) subB)) IZ))
+ (EVar ext (contribTupTy des subA) (IS IZ))
+ (EVar ext (contribTupTy des subB) IZ))
EFst _ e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e
- , STPair t1 t2 <- typeOf e ->
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair sd SpAbsent) e
+ , STPair t1 _ <- typeOf e ->
Ret e0
subtape
(EFst ext e1)
sub
- (ELet ext (EJust ext (EPair ext (EVar ext (d2 t1) IZ) (ezeroD2 t2))) $
+ (ELet ext (EPair ext (EVar ext (applySparse sd (d2 t1)) IZ) (ENil ext)) $
weakenExpr (WCopy WSink) e2)
ESnd _ e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e
- , STPair t1 t2 <- typeOf e ->
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair SpAbsent sd) e
+ , STPair _ t2 <- typeOf e ->
Ret e0
subtape
(ESnd ext e1)
sub
- (ELet ext (EJust ext (EPair ext (ezeroD2 t1) (EVar ext (d2 t2) IZ))) $
+ (ELet ext (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t2)) IZ)) $
weakenExpr (WCopy WSink) e2)
- ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext)
+ -- Don't need to handle ENil, because its cotangent is always absent!
+ -- ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (d2e (select SMerge des))) (ENil ext)
EInl _ t2 e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e ->
+ | SpLEither sd1 sd2 <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap sd1 e ->
+ subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ ->
Ret e0
subtape
(EInl ext (d1 t2) e1)
- sub
+ sub'
(ELCase ext
- (EVar ext (STLEither (d2 (typeOf e)) (d2 t2)) IZ)
- (zeroTup (subList (select SMerge des) sub))
- (weakenExpr (WCopy WSink) e2)
- (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr"))
+ (EVar ext (STLEither (applySparse sd1 (d2 (typeOf e))) (applySparse sd2 (d2 t2))) IZ)
+ (inj2 $ ENil ext)
+ (inj1 $ weakenExpr (WCopy WSink) e2)
+ (EError ext (contribTupTy des sub') "inl<-dinr"))
EInr _ t1 e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e ->
+ | SpLEither sd1 sd2 <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap sd2 e ->
+ subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ ->
Ret e0
subtape
(EInr ext (d1 t1) e1)
- sub
+ sub'
(ELCase ext
- (EVar ext (STLEither (d2 t1) (d2 (typeOf e))) IZ)
- (zeroTup (subList (select SMerge des) sub))
- (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl")
- (weakenExpr (WCopy WSink) e2))
+ (EVar ext (STLEither (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 (typeOf e)))) IZ)
+ (inj2 $ ENil ext)
+ (EError ext (contribTupTy des sub') "inr<-dinl")
+ (inj1 $ weakenExpr (WCopy WSink) e2))
ECase _ e (a :: Expr _ _ t) b
- | STEither t1 t2 <- typeOf e
- , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap e
+ | STEither (t1 :: STy a) (t2 :: STy b) <- typeOf e
, ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge
, ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge
, let (bindids1, bindids2) = validSplitEither (extOf e)
- , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des accumMap t1 storage1 bindids1 a
- , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des accumMap t2 storage2 bindids2 b
+ , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA sd1 a2
+ <- drevScoped des accumMap t1 storage1 bindids1 sd a
+ , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB sd2 b2
+ <- drevScoped des accumMap t2 storage2 bindids2 sd b
+ , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap (SpLEither sd1 sd2) e
, Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
, Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
- , let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA)
- , let tapeB = tapeTy (subList (bindingsBinds b0) subtapeB)
- , let collectA = bindingsCollectTape a0 subtapeA
- , let collectB = bindingsCollectTape b0 subtapeB
+ , let subtapeListA = subList (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA
+ , let subtapeListB = subList (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB
+ , let tapeA = tapeTy subtapeListA
+ , let tapeB = tapeTy subtapeListB
+ , let collectA = bindingsCollectTape @_ @_ @(Append rhs_a_binds (D1 a : Append e_binds (D1E env)))
+ (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA
+ , let collectB = bindingsCollectTape @_ @_ @(Append rhs_b_binds (D1 b : Append e_binds (D1E env)))
+ (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB
, (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB)
, let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0
, let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0
+ , Refl <- lemAppendNil @(Append rhs_a_binds '[D1 a])
+ , Refl <- lemAppendNil @(Append rhs_b_binds '[D1 b])
+ , Refl <- lemAppendAssoc @rhs_a_binds @'[D1 a] @(D1E env)
+ , Refl <- lemAppendAssoc @rhs_b_binds @'[D1 b] @(D1E env)
+ , let wa0'' = wa0' .> wCopies (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) (WClosed @(D1E env))
+ , let wb0'' = wb0' .> wCopies (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) (WClosed @(D1E env))
->
- subenvPlus (select SMerge des) subA subB $ \subAB sAB_A sAB_B _ ->
- subenvPlus (select SMerge des) subAB subE $ \subOut _ _ plus_AB_E ->
- let tCaseRet = STPair (tTup (d2e (subList (select SMerge des) subAB))) (STLEither (d2 t1) (d2 t2)) in
+ subenvPlus ST ST (d2eM (select SMerge des)) subA subB $ \subAB (Inj sAB_A) (Inj sAB_B) _ ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subAB subE $ \subOut _ _ plus_AB_E ->
Ret (e0 `BPush`
(tPrimal,
ECase ext e1
- (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0'))))
- (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0'))))))
+ (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0''))))
+ (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0''))))))
(SEYesR subtapeE)
(EFst ext (EVar ext tPrimal IZ))
subOut
- (ELet ext
+ (elet
(ECase ext (ESnd ext (EVar ext tPrimal (IS IZ)))
- (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds a0) subtapeA) IZ
+ (let (rebinds, prerebinds) = reconstructBindings subtapeListA IZ
in letBinds rebinds $
ELet ext
- (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_tape : D2 t : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend (subList (bindingsBinds a0) subtapeA) prerebinds) @> IS IZ)) $
- ELet ext
- (weakenExpr (autoWeak (#d (auto1 @(D2 t))
- &. #ta0 (subList (bindingsBinds a0) subtapeA)
+ (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_a_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListA prerebinds) @> IS IZ)) $
+ elet
+ (weakenExpr (autoWeak (#d (auto1 @sd)
+ &. #ta0 subtapeListA
&. #prea0 prerebinds
- &. #recon (tapeA `SCons` d2 (typeOf a) `SCons` SNil)
+ &. #recon (tapeA `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil)
&. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE)
&. #tl (d2ace (select SAccum des)))
(#d :++: #ta0 :++: #tl)
(#d :++: (#ta0 :++: #prea0) :++: #recon :++: #binds :++: #tl))
a2) $
- EPair ext
- (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $
- EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ))
- (ELInl ext (d2 t2)
- (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ))))
- (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds b0) subtapeB) IZ
+ EPair ext (sAB_A $ EFst ext (evar IZ))
+ (ELInl ext (applySparse sd2 (d2 t2)) (ESnd ext (evar IZ))))
+ (let (rebinds, prerebinds) = reconstructBindings subtapeListB IZ
in letBinds rebinds $
ELet ext
- (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_tape : D2 t : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend (subList (bindingsBinds b0) subtapeB) prerebinds) @> IS IZ)) $
- ELet ext
- (weakenExpr (autoWeak (#d (auto1 @(D2 t))
- &. #tb0 (subList (bindingsBinds b0) subtapeB)
+ (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_b_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListB prerebinds) @> IS IZ)) $
+ elet
+ (weakenExpr (autoWeak (#d (auto1 @sd)
+ &. #tb0 subtapeListB
&. #preb0 prerebinds
- &. #recon (tapeB `SCons` d2 (typeOf a) `SCons` SNil)
+ &. #recon (tapeB `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil)
&. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE)
&. #tl (d2ace (select SAccum des)))
(#d :++: #tb0 :++: #tl)
(#d :++: (#tb0 :++: #preb0) :++: #recon :++: #binds :++: #tl))
b2) $
- EPair ext
- (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $
- EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))
- (ELInr ext (d2 t1)
- (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))))) $
- ELet ext
- (ELet ext (ESnd ext (EVar ext tCaseRet IZ)) $
- weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $
+ EPair ext (sAB_B $ EFst ext (evar IZ))
+ (ELInr ext (applySparse sd1 (d2 t1)) (ESnd ext (evar IZ))))) $
plus_AB_E
- (EFst ext (EVar ext tCaseRet (IS IZ)))
- (EVar ext (tTup (d2e (subList (select SMerge des) subE))) IZ))
+ (EFst ext (evar IZ))
+ (ELet ext (ESnd ext (evar IZ)) $
+ weakenExpr (WCopy (wSinks' @[_,_,_])) e2))
EConst _ t val ->
Ret BTop
SETop
(EConst ext t val)
- (subenvNone (select SMerge des))
+ (subenvNone (d2e (select SMerge des)))
(ENil ext)
EOp _ op e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e ->
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap (spDense (d2M (opt1 op))) e ->
case d2op op of
Linear d2opfun ->
Ret e0
subtape
(d1op op e1)
sub
- (ELet ext (d2opfun (EVar ext (d2 (opt2 op)) IZ))
+ (ELet ext (d2opfun (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ)))
(weakenExpr (WCopy WSink) e2))
Nonlinear d2opfun ->
Ret (e0 `BPush` (d1 (typeOf e), e1))
@@ -926,36 +1010,51 @@ drev des accumMap = \case
(d1op op $ EVar ext (d1 (typeOf e)) IZ)
sub
(ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ))
- (EVar ext (d2 (opt2 op)) IZ))
+ (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ)))
(weakenExpr (WCopy (wSinks' @[_,_])) e2))
- ECustom _ _ _ storety _ pr du a b
+ ECustom _ _ tb storety srce pr du a b
-- allowed to ignore a2 because 'a' is the part of the input that is inactive
- | Rets binds subtape (RetPair a1 _ _ `SCons` RetPair b1 bsub b2 `SCons` SNil)
- <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil ->
- Ret (binds `BPush` (typeOf a1, a1)
- `BPush` (typeOf b1, weakenExpr WSink b1)
- `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr))
- `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ)))
- (SEYesR (SENo (SENo (SENo subtape))))
- (EFst ext (EVar ext (typeOf pr) (IS IZ)))
- bsub
- (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $
- weakenExpr (WCopy (WSink .> WSink)) b2)
-
- -- TODO: compute primal in direct form here instead of taking the redundantly inefficient CHAD primal
+ | Ret b0 bsubtape b1 bsub b2 <- drev des accumMap (spDense (d2M tb)) b ->
+ case isDense (d2M (typeOf srce)) sd of
+ Just Refl ->
+ Ret (b0 `BPush` (d1 (typeOf a), weakenExpr (sinkWithBindings b0) (drevPrimal des a))
+ `BPush` (typeOf b1, weakenExpr WSink b1)
+ `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr))
+ `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ)))
+ (SEYesR (SENo (SENo (SENo bsubtape))))
+ (EFst ext (EVar ext (typeOf pr) (IS IZ)))
+ bsub
+ (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $
+ weakenExpr (WCopy (WSink .> WSink)) b2)
+
+ Nothing ->
+ Ret (b0 `BPush` (d1 (typeOf a), weakenExpr (sinkWithBindings b0) (drevPrimal des a))
+ `BPush` (typeOf b1, weakenExpr WSink b1)
+ `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)))
+ (SEYesR (SENo (SENo bsubtape)))
+ (EFst ext (EVar ext (typeOf pr) IZ))
+ bsub
+ (ELet ext (ESnd ext (EVar ext (typeOf pr) (IS IZ))) $ -- tape
+ ELet ext (expandSparse (typeOf srce) sd -- expanded incoming cotangent
+ (EFst ext (EVar ext (typeOf pr) (IS (IS IZ))))
+ (EVar ext (applySparse sd (d2 (typeOf srce))) (IS IZ))) $
+ ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) b2)
+
ERecompute _ e ->
deleteUnused (descrList des) (occCountAll e) $ \usedSub ->
let smallE = unsafeWeakenWithSubenv usedSub e in
subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
- case drev usedDes (VarMap.subMap subAccumUsed accumMap) smallE of { Ret e0 subtape e1 sub e2 ->
+ case drev usedDes (VarMap.subMap subAccumUsed accumMap) sd smallE of { Ret e0 subtape _ sub e2 ->
+ let subMergeUsed' = subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E subMergeUsed) in
Ret (collectBindings (desD1E des) subD1eUsed)
(subenvAll (desD1E usedDes))
- (weakenExpr (wRaiseAbove (desD1E usedDes) (desD1E des)) $ letBinds e0 e1)
- (subenvCompose subMergeUsed sub)
+ (weakenExpr (wSinks (desD1E usedDes)) $ drevPrimal des e)
+ (subenvCompose subMergeUsed' sub)
(letBinds (fst (weakenBindings weakenExpr (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $
weakenExpr
- (autoWeak (#d (auto1 @(D2 t))
+ (autoWeak (#d (auto1 @sd)
&. #shbinds (bindingsBinds e0)
&. #tape (subList (bindingsBinds e0) subtape)
&. #d1env (desD1E usedDes)
@@ -970,31 +1069,32 @@ drev des accumMap = \case
Ret BTop
SETop
(EError ext (d1 t) s)
- (subenvNone (select SMerge des))
+ (subenvNone (d2e (select SMerge des)))
(ENil ext)
EConstArr _ n t val ->
Ret BTop
SETop
(EConstArr ext n t val)
- (subenvNone (select SMerge des))
+ (subenvNone (d2e (select SMerge des)))
(ENil ext)
EBuild _ (ndim :: SNat ndim) she (orige :: Expr _ _ eltty)
- | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des accumMap she -- allowed to ignore she2 here because she has a discrete result
+ | SpArr @_ @sdElt sdElt <- sd
, let eltty = typeOf orige
, shty :: STy shty <- tTup (sreplicate ndim tIx)
, Refl <- indexTupD1Id ndim ->
deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') ->
let e = unsafeWeakenWithSubenv (SEYesR usedSub) orige in
- subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
- accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro ->
+ subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed ->
+ accumPromote sdElt usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro ->
let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in
- case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 ->
+ case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro sdElt e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 ->
case assertSubenvEmpty sub of { Refl ->
+ case lemAppendNil @e_binds of { Refl ->
let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in
- let collectexpr = bindingsCollectTape e0 subtapeE in
- Ret (BTop `BPush` (shty, letBinds she0 she1)
+ let collectexpr = bindingsCollectTape (bindingsBinds e0) subtapeE in
+ Ret (BTop `BPush` (shty, drevPrimal des she)
`BPush` (STArr ndim (STPair (d1 eltty) tapety)
,EBuild ext ndim
(EVar ext shty IZ)
@@ -1012,58 +1112,59 @@ drev des accumMap = \case
&. #d1env' (desD1E usedDes))
(#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
(#e0 :++: #ix :++: #sh :++: #d1env)
- in EPair ext (weakenExpr w e1) (collectexpr w)))
+ w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env'))
+ in EPair ext (weakenExpr w e1) (collectexpr w')))
`BPush` (STArr ndim tapety, emap (ESnd ext (EVar ext (STPair (d1 eltty) tapety) IZ))
(EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ)))
(SEYesR (SENo (SEYesR SETop)))
(emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ))
(EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ)))
- (subenvCompose subMergeUsed proSub)
- (let sinkOverEnvPro = wSinks @(TArr ndim (D2 eltty) : D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in
- EMaybe ext
- (zeroTup envPro)
- (ESnd ext $
- uninvertTup (d2e envPro) (STArr ndim STNil) $
- makeAccumulators @_ @_ @(TArr ndim TNil) envPro $
- EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS (IS IZ)))) $
- -- the cotangent for this element
- ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ))
- (EVar ext shty IZ)) $
- -- the tape for this element
- ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS (IS IZ)))
- (EVar ext shty (IS IZ))) $
- let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ
- in letBinds rebinds $
- weakenExpr (autoWeak (#d (auto1 @(D2 eltty))
- &. #pro (d2ace envPro)
- &. #etape (subList (bindingsBinds e0) subtapeE)
- &. #prerebinds prerebinds
- &. #tape (auto1 @(Tape e_tape))
- &. #ix (auto1 @shty)
- &. #darr (auto1 @(TArr ndim (D2 eltty)))
- &. #mdarr (auto1 @(TMaybe (TArr ndim (D2 eltty))))
- &. #tapearr (auto1 @(TArr ndim (Tape e_tape)))
- &. #sh (auto1 @shty)
- &. #d2acUsed (d2ace (select SAccum usedDes))
- &. #d2acEnv (d2ace (select SAccum des)))
- (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
- ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #mdarr :++: #tapearr :++: #sh :++: #d2acEnv)
- .> wPro (subList (bindingsBinds e0) subtapeE))
- e2)
- (EVar ext (d2 (STArr ndim eltty)) IZ))
- }}
+ (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub)))
+ (let sinkOverEnvPro = wSinks @(sd : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in
+ ESnd ext $
+ uninvertTup (d2e envPro) (STArr ndim STNil) $
+ -- TODO: what's happening here is that because of the sparsity
+ -- rewrite, makeAccumulators needs primals where it previously
+ -- didn't. The build derivative is currently not saving those
+ -- primals, so the hole below cannot currently be filled. The
+ -- appropriate primals (waves hands) need to be stored, so that a
+ -- weakening can be provided here.
+ makeAccumulators @_ @_ @(TArr ndim TNil) (_ (subenvCompose subMergeUsed proSub)) envPro $
+ EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $
+ -- the cotangent for this element
+ ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ))
+ (EVar ext shty IZ)) $
+ -- the tape for this element
+ ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))
+ (EVar ext shty (IS IZ))) $
+ let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ
+ in letBinds rebinds $
+ weakenExpr (autoWeak (#d (auto1 @sdElt)
+ &. #pro (d2ace envPro)
+ &. #etape (subList (bindingsBinds e0) subtapeE)
+ &. #prerebinds prerebinds
+ &. #tape (auto1 @(Tape e_tape))
+ &. #ix (auto1 @shty)
+ &. #darr (auto1 @(TArr ndim sdElt))
+ &. #tapearr (auto1 @(TArr ndim (Tape e_tape)))
+ &. #sh (auto1 @shty)
+ &. #d2acUsed (d2ace (select SAccum usedDes))
+ &. #d2acEnv (d2ace (select SAccum des)))
+ (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
+ ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #d2acEnv)
+ .> wPro (subList (bindingsBinds e0) subtapeE))
+ e2)
+ }}}
EUnit _ e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e ->
+ | SpArr sdElt <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap sdElt e ->
Ret e0
subtape
(EUnit ext e1)
sub
- (EMaybe ext
- (zeroTup (subList (select SMerge des) sub))
- (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $
- weakenExpr (WCopy (WSink .> WSink)) e2)
- (EVar ext (STMaybe (STArr SZ (d2 (typeOf e)))) IZ))
+ (ELet ext (EIdx0 ext (EVar ext (STArr SZ (applySparse sdElt (d2 (typeOf e)))) IZ)) $
+ weakenExpr (WCopy WSink) e2)
EReplicate1Inner _ en e
-- We're allowed to ignore en2 here because the output of 'ei' is discrete.
@@ -1177,7 +1278,6 @@ drev des accumMap = \case
ELCase{} -> err_unsupported "ELCase"
EWith{} -> err_accum
- EAccum{} -> err_accum
EZero{} -> err_monoid
EPlus{} -> err_monoid
EOneHot{} -> err_monoid
@@ -1189,7 +1289,8 @@ drev des accumMap = \case
deriv_extremum :: ScalIsNumeric t' ~ True
=> (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t')))
- -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto (TArr n (TScal t'))
+ -> Sparse (TArr n (D2s t')) sd'
+ -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto sd' (TArr n (TScal t'))
deriv_extremum extremum e
| Ret e0 subtape e1 sub e2 <- drev des accumMap e
, at@(STArr (SS n) t@(STScal st)) <- typeOf e
@@ -1212,70 +1313,103 @@ drev des accumMap = \case
weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2)
(EVar ext (d2 at') IZ))
+ contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs)
+ contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub))
+
data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s)
-data RetScoped env0 sto a s t =
- forall shbinds tapebinds contribs.
+data RetScoped env0 sto a s sd t =
+ forall shbinds tapebinds contribs sa.
RetScoped
(Bindings Ex (D1E (a : env0)) shbinds) -- shared binds
- (Subenv shbinds tapebinds)
+ (Subenv (Append shbinds '[D1 a]) tapebinds)
(Ex (Append shbinds (D1E (a : env0))) (D1 t))
(SubenvS (D2E (Select env0 sto "merge")) contribs)
-- ^ merge contributions to the _enclosing_ merge environment
- (forall sd. Sparse (D2 t) sd
- -> Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum")))
- (If (s == "discr") (Tup contribs)
- (TPair (Tup contribs) (D2 a))))
+ (Sparse (D2 a) sa)
+ -- ^ contribution to the argument
+ (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum")))
+ (If (s == "discr") (Tup contribs)
+ (TPair (Tup contribs) sa)))
-- ^ the merge contributions, plus the cotangent to the argument
-- (if there is any)
-deriving instance Show (RetScoped env0 sto a s t)
+deriving instance Show (RetScoped env0 sto a s sd t)
-drevScoped :: forall a s env sto t.
+drevScoped :: forall a s env sto sd t.
(?config :: CHADConfig)
=> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
-> STy a -> Storage s -> Maybe (ValId a)
+ -> Sparse (D2 t) sd
-> Expr ValId (a : env) t
- -> RetScoped env sto a s t
-drevScoped des accumMap argty argsto argids expr = case argsto of
+ -> RetScoped env sto a s sd t
+drevScoped des accumMap argty argsto argids sd expr = case argsto of
SMerge
- | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr ->
+ | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr
+ , Refl <- lemAppendNil @tapebinds ->
case sub of
- SEYesR sub' -> RetScoped e0 subtape e1 sub' e2
- SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (ezeroD2 argty))
+ SEYes sp sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' sp e2
+ SENo sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' SpAbsent (EPair ext e2 (ENil ext))
SAccum
| Just (VIArr i _) <- argids
, Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap
, Just Refl <- testEquality foundTy (STAccum (d2M argty))
- , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) expr ->
- RetScoped e0 subtape e1 sub $
+ , Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) sd expr
+ , Refl <- lemAppendNil @tapebinds ->
+ -- Our contribution to the binding's cotangent _here_ is zero (absent),
+ -- because we're contributing to an earlier binding of the same value
+ -- instead.
+ RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent $
let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in
ELet ext (EVar ext (STAccum (d2M argty)) (WSink .> wtapebinds @> idx)) $
- weakenExpr (autoWeak (#d (auto1 @(D2 t))
+ weakenExpr (autoWeak (#d (auto1 @sd)
&. #body (subList (bindingsBinds e0) subtape)
&. #ac (auto1 @(TAccum (D2 a)))
&. #tl (d2ace (select SAccum des)))
(#d :++: #body :++: #ac :++: #tl)
(#ac :++: #d :++: #body :++: #tl))
- -- Our contribution to the binding's cotangent _here_ is
- -- zero, because we're contributing to an earlier binding
- -- of the same value instead.
- (EPair ext e2 (ezeroD2 argty))
+ (EPair ext e2 (ENil ext))
| let accumMap' = case argids of
Just (VIArr i _) -> VarMap.insert i (STAccum (d2M argty)) IZ (VarMap.sink1 accumMap)
_ -> VarMap.sink1 accumMap
- , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' expr ->
- RetScoped e0 subtape e1 sub $
- EWith ext (d2M argty) (ezeroD2 argty) $
- weakenExpr (autoWeak (#d (auto1 @(D2 t))
- &. #body (subList (bindingsBinds e0) subtape)
- &. #ac (auto1 @(TAccum (D2 a)))
- &. #tl (d2ace (select SAccum des)))
+ , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' sd expr ->
+ let library = #d (auto1 @sd)
+ &. #p (auto1 @(D1 a))
+ &. #body (subList (bindingsBinds e0) subtape)
+ &. #ac (auto1 @(TAccum (D2 a)))
+ &. #tl (d2ace (select SAccum des))
+ in
+ RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub SpDense $
+ let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in
+ EWith ext (d2M argty) (EZero ext (d2M argty) (d2zeroInfo argty (EVar ext (d1 argty) primalIdx))) $
+ weakenExpr (autoWeak library
(#d :++: #body :++: #ac :++: #tl)
- (#ac :++: #d :++: #body :++: #tl))
+ (#ac :++: #d :++: (#body :++: #p) :++: #tl))
e2
SDiscr
- | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr ->
- RetScoped e0 subtape e1 sub e2
+ | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr
+ , Refl <- lemAppendNil @tapebinds ->
+ RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2
+
+-- TODO: proper primal-only transform that doesn't depend on D1 = Id
+drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t)
+drevPrimal des e
+ | Refl <- chadD1Id (typeOf e)
+ , Refl <- chadD1EId (descrList des)
+ = mapExt (const ext) e
+ where
+ chadD1Id :: STy a -> D1 a :~: a
+ chadD1Id STNil = Refl
+ chadD1Id (STPair a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl
+ chadD1Id (STEither a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl
+ chadD1Id (STLEither a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl
+ chadD1Id (STMaybe a) | Refl <- chadD1Id a = Refl
+ chadD1Id (STArr _ a) | Refl <- chadD1Id a = Refl
+ chadD1Id (STScal _) = Refl
+ chadD1Id STAccum{} = error "accumulators not allowed in source program"
+
+ chadD1EId :: SList STy l -> D1E l :~: l
+ chadD1EId SNil = Refl
+ chadD1EId (SCons t l) | Refl <- chadD1Id t, Refl <- chadD1EId l = Refl