summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs298
1 files changed, 210 insertions, 88 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index df792ce..b5a9af0 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -42,8 +42,8 @@ import AST
import AST.Bindings
import AST.Count
import AST.Env
+import AST.Sparse
import AST.Weaken.Auto
-import CHAD.Accum
import CHAD.EnvDescr
import CHAD.Types
import Data
@@ -65,7 +65,7 @@ tapeTy (SCons t ts) = STPair t (tapeTy ts)
bindingsCollectTape :: Bindings f env binds -> Subenv binds tapebinds
-> Append binds env :> env2 -> Ex env2 (Tape tapebinds)
bindingsCollectTape BTop SETop _ = ENil ext
-bindingsCollectTape (BPush binds (t, _)) (SEYes sub) w =
+bindingsCollectTape (BPush binds (t, _)) (SEYesR sub) w =
EPair ext (EVar ext t (w @> IZ))
(bindingsCollectTape binds sub (w .> WSink))
bindingsCollectTape (BPush binds _) (SENo sub) w =
@@ -227,26 +227,37 @@ data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
d2op :: SOp a t -> D2Op a t
d2op op = case op of
- OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EJust ext (EPair ext d d)
+ OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EPair ext d d
OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d ->
- EJust ext (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
- (EOp ext (OMul t) (EPair ext (EFst ext e) d)))
+ EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
+ (EOp ext (OMul t) (EPair ext (EFst ext e) d))
ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d
- OLt t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t)))
- OLe t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t)))
- OEq t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t)))
+ OLt t -> Linear $ \_ -> pairZero t
+ OLe t -> Linear $ \_ -> pairZero t
+ OEq t -> Linear $ \_ -> pairZero t
ONot -> Linear $ \_ -> ENil ext
- OAnd -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
- OOr -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
+ OAnd -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ OOr -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
OIf -> Linear $ \_ -> ENil ext
- ORound64 -> Linear $ \_ -> EConst ext STF64 0.0
+ ORound64 -> Linear $ \_ -> EZero ext (SMTScal STF64) (ENil ext)
OToFl64 -> Linear $ \_ -> ENil ext
ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d)
OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d)
OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d)
- OIDiv t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil)
- OMod t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil)
+ OIDiv t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ OMod t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
where
+ pairZero :: SScalTy a -> Ex env (D2 (TPair (TScal a) (TScal a)))
+ pairZero t = ziNil t $ EPair ext (EZero ext (d2M (STScal t)) (ENil ext))
+ (EZero ext (d2M (STScal t)) (ENil ext))
+ where
+ ziNil :: SScalTy a -> (ZeroInfo (D2s a) ~ TNil => r) -> r
+ ziNil STI32 k = k
+ ziNil STI64 k = k
+ ziNil STF32 k = k
+ ziNil STF64 k = k
+ ziNil STBool k = k
+
d2opUnArrangeInt :: SScalTy a
-> (D2s a ~ TScal a => D2Op (TScal a) t)
-> D2Op (TScal a) t
@@ -261,11 +272,11 @@ d2op op = case op of
-> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t)
-> D2Op (TPair (TScal a) (TScal a)) t
d2opBinArrangeInt ty float = case ty of
- STI32 -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
- STI64 -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
+ STI32 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
+ STI64 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
STF32 -> float
STF64 -> float
- STBool -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
+ STBool -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext)
floatingD2 :: ScalIsFloating a ~ True
=> SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r
@@ -293,7 +304,7 @@ conv1Idx (IS i) = IS (conv1Idx i)
data Idx2 env sto t
= Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t)))
- | Idx2Me (Idx (Select env sto "merge") t)
+ | Idx2Me (Idx (D2E (Select env sto "merge")) (D2 t))
| Idx2Di (Idx (Select env sto "discr") t)
conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t
@@ -317,44 +328,127 @@ conv2Idx DTop i = case i of {}
------------------------------------ MONOIDS -----------------------------------
-zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0))
-zeroTup SNil = ENil ext
-zeroTup (t `SCons` env) = EPair ext (zeroTup env) (ezeroD2 t)
-
-
------------------------------------- SUBENVS -----------------------------------
+d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t))
+d2zeroInfo STNil _ = ENil ext
+d2zeroInfo (STPair a b) e =
+ eunPair e $ \_ e1 e2 ->
+ EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2)
+d2zeroInfo STEither{} _ = ENil ext
+d2zeroInfo STLEither{} _ = ENil ext
+d2zeroInfo STMaybe{} _ = ENil ext
+d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e
+d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext
+d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program"
+
+zeroTup :: SList STy env0 -> D1E env0 :> env -> Ex env (Tup (D2E env0))
+zeroTup SNil _ = ENil ext
+zeroTup (t `SCons` env) w =
+ EPair ext (zeroTup env (WPop w))
+ (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ))))
+
+
+----------------------------------- SPARSITY -----------------------------------
+
+subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env')
+subenvD1E SETop = SETop
+subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub)
+subenvD1E (SENo sub) = SENo (subenvD1E sub)
+
+expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a)
+expandSparse _ SpDense _ e = e
+expandSparse t (SpSparse sp) epr e =
+ EMaybe ext
+ (EZero ext (d2M t) (d2zeroInfo t epr))
+ (expandSparse t sp (weakenExpr WSink epr) (EVar ext (applySparse sp (d2 t)) IZ))
+ e
+expandSparse t SpAbsent epr _ = EZero ext (d2M t) (d2zeroInfo t epr)
+expandSparse (STPair t1 t2) (SpPair s1 s2) epr e =
+ eunPair epr $ \w1 epr1 epr2 ->
+ eunPair (weakenExpr w1 e) $ \w2 e1 e2 ->
+ EPair ext (expandSparse t1 s1 (weakenExpr w2 epr1) e1)
+ (expandSparse t2 s2 (weakenExpr w2 epr2) e2)
+expandSparse (STEither t1 t2) (SpLEither s1 s2) epr e =
+ ELCase ext e
+ (EZero ext (d2M (STEither t1 t2)) (ENil ext))
+ (ECase ext (weakenExpr WSink epr)
+ (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ))))
+ (EError ext (d2 (STEither t1 t2)) "expspa r<-dl"))
+ (ECase ext (weakenExpr WSink epr)
+ (EError ext (d2 (STEither t1 t2)) "expspa l<-dr")
+ (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ)))))
+expandSparse (STEither t1 t2) (SpLeft s) epr e =
+ let epr' = ECase ext epr (EVar ext (d1 t1) IZ) (EError ext (d1 t1) "expspa r<-dL")
+ in ELInl ext (d2 t2) (expandSparse t1 s epr' e)
+expandSparse (STEither t1 t2) (SpRight s) epr e =
+ let epr' = ECase ext epr (EError ext (d1 t2) "expspa l<-dR") (EVar ext (d1 t2) IZ)
+ in ELInr ext (d2 t1) (expandSparse t2 s epr' e)
+expandSparse (STLEither t1 t2) (SpLEither s1 s2) epr e =
+ ELCase ext e
+ (EZero ext (d2M (STEither t1 t2)) (ENil ext))
+ (ELCase ext (weakenExpr WSink epr)
+ (EError ext (d2 (STEither t1 t2)) "expspa ln<-dl")
+ (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ))))
+ (EError ext (d2 (STEither t1 t2)) "expspa lr<-dl"))
+ (ELCase ext (weakenExpr WSink epr)
+ (EError ext (d2 (STEither t1 t2)) "expspa ln<-dr")
+ (EError ext (d2 (STEither t1 t2)) "expspa ll<-dr")
+ (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ)))))
+expandSparse (STLEither t1 t2) (SpLeft s) epr e =
+ let epr' = ELCase ext epr (EError ext (d1 t1) "expspa ln<-dL") (EVar ext (d1 t1) IZ) (EError ext (d1 t1) "expspa r<-dL")
+ in ELInl ext (d2 t2) (expandSparse t1 s epr' e)
+expandSparse (STLEither t1 t2) (SpRight s) epr e =
+ let epr' = ELCase ext epr (EError ext (d1 t2) "expspa ln<-dR") (EError ext (d1 t2) "expspa l<-dR") (EVar ext (d1 t2) IZ)
+ in ELInr ext (d2 t1) (expandSparse t2 s epr' e)
+expandSparse (STMaybe t) (SpMaybe s) epr e =
+ EMaybe ext
+ (ENothing ext (d2 t))
+ (let epr' = EMaybe ext (EError ext (d1 t) "expspa n<-dj") (EVar ext (d1 t) IZ) epr
+ in EJust ext (expandSparse t s (weakenExpr WSink epr') (EVar ext (applySparse s (d2 t)) IZ)))
+ e
+expandSparse (STArr _ t) (SpArr s) epr e =
+ ezipWith (expandSparse t s (EVar ext (d1 t) (IS IZ)) (EVar ext (applySparse s (d2 t)) IZ)) epr e
+expandSparse (STScal sty) _ _ _ = case sty of {} -- SpDense and SpSparse handled already
+expandSparse (STAccum{}) _ _ _ = error "accumulators not allowed in source program"
+
+sparsePlus
+ :: SMTy t -> Sparse t t1 -> Sparse t t2
+ -> (forall t3. Sparse t t3
+ -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3)
+ -> r)
+ -> r
+sparsePlus t sp1 sp2 k = sparsePlusS SF SF t sp1 sp2 $ \sp3 _ _ plus -> k sp3 plus
subenvPlus :: SList STy env
- -> Subenv env env1 -> Subenv env env2
- -> (forall env3. Subenv env env3
- -> Subenv env3 env1
- -> Subenv env3 env2
- -> (Ex exenv (Tup (D2E env1))
- -> Ex exenv (Tup (D2E env2))
- -> Ex exenv (Tup (D2E env3)))
+ -> SubenvS (D2E env) env1 -> SubenvS (D2E env) env2
+ -> (forall env3. SubenvS (D2E env) env3
+ -> SubenvS env3 env1
+ -> SubenvS env3 env2
+ -> (Ex exenv (Tup env1)
+ -> Ex exenv (Tup env2)
+ -> Ex exenv (Tup env3))
-> r)
-> r
subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext)
subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k =
subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
k (SENo sub3) s31 s32 pl
-subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k =
+subenvPlus (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k =
subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 ->
+ k (SEYes sp1 sub3) (SEYes SpDense s31) (SENo s32) $ \e1 e2 ->
ELet ext e1 $
EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
(weakenExpr WSink e2))
(ESnd ext (EVar ext (typeOf e1) IZ))
-subenvPlus (SCons _ env) (SENo sub1) (SEYes sub2) k =
+subenvPlus (SCons _ env) (SENo sub1) (SEYes sp2 sub2) k =
subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SEYes sub3) (SENo s31) (SEYes s32) $ \e1 e2 ->
+ k (SEYes sp2 sub3) (SENo s31) (SEYes SpDense s32) $ \e1 e2 ->
ELet ext e2 $
EPair ext (pl (weakenExpr WSink e1)
(EFst ext (EVar ext (typeOf e2) IZ)))
(ESnd ext (EVar ext (typeOf e2) IZ))
-subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k =
+subenvPlus (SCons t env) (SEYes sp1 sub1) (SEYes sp2 sub2) k =
subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
- k (SEYes sub3) (SEYes s31) (SEYes s32) $ \e1 e2 ->
+ k (SEYesR sub3) (SEYesR s31) (SEYesR s32) $ \e1 e2 ->
ELet ext e1 $
ELet ext (weakenExpr WSink e2) $
EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ)))
@@ -363,22 +457,44 @@ subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k =
(ESnd ext (EVar ext (typeOf e1) (IS IZ)))
(ESnd ext (EVar ext (typeOf e2) IZ)))
-expandSubenvZeros :: SList STy env0 -> Subenv env0 env0Merge -> Ex env (Tup (D2E env0Merge)) -> Ex env (Tup (D2E env0))
-expandSubenvZeros _ SETop _ = ENil ext
-expandSubenvZeros (SCons t ts) (SEYes sub) e =
- ELet ext e $
- let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ
- in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var)
-expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (ezeroD2 t)
+expandSubenvZeros :: D1E env0 :> env -> SList STy env0 -> SubenvS (D2E env0) contribs
+ -> Ex env (Tup contribs) -> Ex env (Tup (D2E env0))
+expandSubenvZeros _ SNil SETop _ = ENil ext
+expandSubenvZeros w (SCons t ts) (SEYes sp sub) e =
+ eunPair e $ \w1 e1 e2 ->
+ EPair ext
+ (expandSubenvZeros (w1 .> WPop w) ts sub e1)
+ (expandSparse t sp (EVar ext (d1 t) (w1 .> w @> IZ)) e2)
+expandSubenvZeros w (SCons t ts) (SENo sub) e =
+ EPair ext
+ (expandSubenvZeros (WPop w) ts sub e)
+ (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ))))
assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[]
assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl
assertSubenvEmpty SETop = Refl
-assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty"
+assertSubenvEmpty SEYesR{} = error "assertSubenvEmpty: not empty"
--------------------------------- ACCUMULATORS ---------------------------------
+makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
+makeAccumulators _ SNil e = e
+makeAccumulators w (t `SCons` envpro) e =
+ makeAccumulators (WPop w) envpro $
+ EWith ext (d2M t) (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e
+
+uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list))
+uninvertTup SNil _ e = EPair ext e (ENil ext)
+uninvertTup (t `SCons` list) tcore e =
+ ELet ext (uninvertTup list (STPair tcore t) e) $
+ let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding
+ in EPair ext
+ (EFst ext (EFst ext (EVar ext recT IZ)))
+ (EPair ext
+ (ESnd ext (EVar ext recT IZ))
+ (ESnd ext (EFst ext (EVar ext recT IZ))))
+
fromArrayValId :: Maybe (ValId t) -> Maybe Int
fromArrayValId (Just (VIArr i _)) = Just i
fromArrayValId _ = Nothing
@@ -422,7 +538,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
k (storepl `DPush` (t, vid, SAccum))
envpro
prosub
- (SEYes accrevsub)
+ (SEYesR accrevsub)
(VarMap.sink1 accumMap)
(\shbinds ->
autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr)))
@@ -449,7 +565,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf ->
k (storepl `DPush` (t, vid, SAccum))
(t `SCons` envpro)
- (SEYes prosub)
+ (SEYesR prosub)
(SENo accrevsub)
(let accumMap' = VarMap.sink1 accumMap
in case fromArrayValId vid of
@@ -499,19 +615,21 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
---------------------------- RETURN TRIPLE FROM CHAD ---------------------------
data Ret env0 sto t =
- forall shbinds tapebinds env0Merge.
+ forall shbinds tapebinds contribs.
Ret (Bindings Ex (D1E env0) shbinds) -- shared binds
(Subenv shbinds tapebinds)
(Ex (Append shbinds (D1E env0)) (D1 t))
- (Subenv (Select env0 sto "merge") env0Merge)
- (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge)))
+ (SubenvS (D2E (Select env0 sto "merge")) contribs)
+ (forall sd. Sparse (D2 t) sd
+ -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs))
deriving instance Show (Ret env0 sto t)
data RetPair env0 sto env shbinds tapebinds t =
- forall env0Merge.
+ forall contribs.
RetPair (Ex (Append shbinds env) (D1 t))
- (Subenv (Select env0 sto "merge") env0Merge)
- (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge)))
+ (SubenvS (D2E (Select env0 sto "merge")) contribs)
+ (forall sd. Sparse (D2 t) sd
+ -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs))
deriving instance Show (RetPair env0 sto env shbinds tapebinds t)
data Rets env0 sto env list =
@@ -569,18 +687,24 @@ freezeRet :: Descr env sto
freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) =
let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0
e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2
+ tContribs = tTup (subList (d2e (select SMerge descr)) sub)
+ library = #d (auto1 @(D2 t))
+ &. #tape (subList (bindingsBinds e0) subtape)
+ &. #shbinds (bindingsBinds e0)
+ &. #d2ace (d2ace (select SAccum descr))
+ &. #tl (desD1E descr)
+ &. #contribs (SCons tContribs SNil)
in letBinds e0' $
EPair ext
(weakenExpr wInsertD2Ac e1)
- (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t))
- &. #tape (subList (bindingsBinds e0) subtape)
- &. #shbinds (bindingsBinds e0)
- &. #d2ace (d2ace (select SAccum descr))
- &. #tl (desD1E descr))
+ (ELet ext (weakenExpr (autoWeak library
(#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl)
(#shbinds :++: #d :++: #d2ace :++: #tl))
e2') $
- expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ))
+ expandSubenvZeros
+ (autoWeak library #tl (#contribs :++: #shbinds :++: #d :++: #d2ace :++: #tl)
+ .> wUndoSubenv (subenvD1E (selectSub SMerge descr)))
+ (select SMerge descr) sub (EVar ext tContribs IZ))
---------------------------- THE CHAD TRANSFORMATION ---------------------------
@@ -596,21 +720,21 @@ drev des accumMap = \case
Ret BTop
SETop
(EVar ext (d1 t) (conv1Idx i))
- (subenvNone (select SMerge des))
+ (subenvNone (d2e (select SMerge des)))
(EAccum ext (d2M t) SAPHere (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2M t)) (IS accI)))
Idx2Me tupI ->
Ret BTop
SETop
(EVar ext (d1 t) (conv1Idx i))
- (subenvOnehot (select SMerge des) tupI)
+ (subenvOnehot (d2e (select SMerge des)) tupI)
(EPair ext (ENil ext) (EVar ext (d2 t) IZ))
Idx2Di _ ->
Ret BTop
SETop
(EVar ext (d1 t) (conv1Idx i))
- (subenvNone (select SMerge des))
+ (subenvNone (d2e (select SMerge des)))
(ENil ext)
ELet _ (rhs :: Expr _ _ a) body
@@ -621,7 +745,7 @@ drev des accumMap = \case
, Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env)
, Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) ->
subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body ->
- let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody))) (d2 (typeOf rhs)) in
+ let bodyResType = STPair (tTup (subList (d2e (select SMerge des)) subBody)) (d2 (typeOf rhs)) in
Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0')
(subenvConcat (SENo @d1_a subtapeRHS) subtapeBody)
(weakenExpr wbody0' body1)
@@ -637,7 +761,7 @@ drev des accumMap = \case
(ELet ext (ESnd ext (EVar ext bodyResType IZ)) $
weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $
plus_RHS_Body
- (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ)
+ (EVar ext (tTup (subList (d2e (select SMerge des)) subRHS)) IZ)
(EFst ext (EVar ext bodyResType (IS IZ))))
EPair _ a b
@@ -649,16 +773,13 @@ drev des accumMap = \case
subtape
(EPair ext a1 b1)
subBoth
- (EMaybe ext
- (zeroTup (subList (select SMerge des) subBoth))
- (ELet ext (ELet ext (EFst ext (EVar ext dt IZ))
- (weakenExpr (WCopy (wSinks' @[_,_])) a2)) $
- ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ)))
- (weakenExpr (WCopy (wSinks' @[_,_,_])) b2)) $
- plus_A_B
- (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ))
- (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ))
- (EVar ext (STMaybe (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ))
+ (ELet ext (ELet ext (EFst ext (EVar ext dt IZ))
+ (weakenExpr (WCopy WSink) a2)) $
+ ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ)))
+ (weakenExpr (WCopy (WSink .> WSink)) b2)) $
+ plus_A_B
+ (EVar ext (tTup (subList (d2e (select SMerge des)) subA)) (IS IZ))
+ (EVar ext (tTup (subList (d2e (select SMerge des)) subB)) IZ))
EFst _ e
| Ret e0 subtape e1 sub e2 <- drev des accumMap e
@@ -732,7 +853,7 @@ drev des accumMap = \case
ECase ext e1
(letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0'))))
(letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0'))))))
- (SEYes subtapeE)
+ (SEYesR subtapeE)
(EFst ext (EVar ext tPrimal IZ))
subOut
(ELet ext
@@ -801,7 +922,7 @@ drev des accumMap = \case
(weakenExpr (WCopy WSink) e2))
Nonlinear d2opfun ->
Ret (e0 `BPush` (d1 (typeOf e), e1))
- (SEYes subtape)
+ (SEYesR subtape)
(d1op op $ EVar ext (d1 (typeOf e)) IZ)
sub
(ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ))
@@ -816,7 +937,7 @@ drev des accumMap = \case
`BPush` (typeOf b1, weakenExpr WSink b1)
`BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr))
`BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ)))
- (SEYes (SENo (SENo (SENo subtape))))
+ (SEYesR (SENo (SENo (SENo subtape))))
(EFst ext (EVar ext (typeOf pr) (IS IZ)))
bsub
(ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $
@@ -865,7 +986,7 @@ drev des accumMap = \case
, shty :: STy shty <- tTup (sreplicate ndim tIx)
, Refl <- indexTupD1Id ndim ->
deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') ->
- let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in
+ let e = unsafeWeakenWithSubenv (SEYesR usedSub) orige in
subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro ->
let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in
@@ -894,7 +1015,7 @@ drev des accumMap = \case
in EPair ext (weakenExpr w e1) (collectexpr w)))
`BPush` (STArr ndim tapety, emap (ESnd ext (EVar ext (STPair (d1 eltty) tapety) IZ))
(EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ)))
- (SEYes (SENo (SEYes SETop)))
+ (SEYesR (SENo (SEYesR SETop)))
(emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ))
(EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ)))
(subenvCompose subMergeUsed proSub)
@@ -981,7 +1102,7 @@ drev des accumMap = \case
, STArr (SS n) eltty <- typeOf e ->
Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1)
`BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ)))
- (SEYes (SENo subtape))
+ (SEYesR (SENo subtape))
(EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ))
(weakenExpr (WSink .> WSink) ei1))
sub
@@ -1002,7 +1123,7 @@ drev des accumMap = \case
Ret (binds `BPush` (STArr n (d1 eltty), e1)
`BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ))
`BPush` (tIxN, weakenExpr (WSink .> WSink) ei1))
- (SEYes (SEYes (SENo subtape)))
+ (SEYesR (SEYesR (SENo subtape)))
(EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ)))
(EVar ext (tTup (sreplicate n tIx)) IZ))
sub
@@ -1030,7 +1151,7 @@ drev des accumMap = \case
, STArr (SS n) t <- typeOf e ->
Ret (e0 `BPush` (STArr (SS n) t, e1)
`BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ)))
- (SEYes (SENo subtape))
+ (SEYesR (SENo subtape))
(ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ)))
sub
(EMaybe ext
@@ -1076,7 +1197,7 @@ drev des accumMap = \case
, let tIxN = tTup (sreplicate (SS n) tIx) =
Ret (e0 `BPush` (at, e1)
`BPush` (at', extremum (EVar ext at IZ)))
- (SEYes (SEYes subtape))
+ (SEYesR (SEYesR subtape))
(EVar ext at' IZ)
sub
(EMaybe ext
@@ -1094,16 +1215,17 @@ drev des accumMap = \case
data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s)
data RetScoped env0 sto a s t =
- forall shbinds tapebinds env0Merge.
+ forall shbinds tapebinds contribs.
RetScoped
(Bindings Ex (D1E (a : env0)) shbinds) -- shared binds
(Subenv shbinds tapebinds)
(Ex (Append shbinds (D1E (a : env0))) (D1 t))
- (Subenv (Select env0 sto "merge") env0Merge)
+ (SubenvS (D2E (Select env0 sto "merge")) contribs)
-- ^ merge contributions to the _enclosing_ merge environment
- (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum")))
- (If (s == "discr") (Tup (D2E env0Merge))
- (TPair (Tup (D2E env0Merge)) (D2 a))))
+ (forall sd. Sparse (D2 t) sd
+ -> Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum")))
+ (If (s == "discr") (Tup contribs)
+ (TPair (Tup contribs) (D2 a))))
-- ^ the merge contributions, plus the cotangent to the argument
-- (if there is any)
deriving instance Show (RetScoped env0 sto a s t)
@@ -1118,7 +1240,7 @@ drevScoped des accumMap argty argsto argids expr = case argsto of
SMerge
| Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr ->
case sub of
- SEYes sub' -> RetScoped e0 subtape e1 sub' e2
+ SEYesR sub' -> RetScoped e0 subtape e1 sub' e2
SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (ezeroD2 argty))
SAccum