aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs35
1 files changed, 23 insertions, 12 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 25d26a6..93fabf9 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1133,9 +1133,11 @@ drev des accumMap sd = \case
let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in
let bogTy = STArr (SS ndim) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf)))
primalTy = STPair (STArr ndim (d1 eltty)) bogTy
+ zipPrimalTy = STPair (d1 eltty) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf)))
library = #xy (d1 eltty `SCons` d1 eltty `SCons` SNil)
&. #parr (auto1 @(TArr (S n) (D1 elt)))
&. #px₀ (auto1 @(D1 elt))
+ &. #px (auto1 @(D1 elt))
&. #pzi (auto1 @(ZeroInfo (D2 elt)))
&. #primal (primalTy `SCons` SNil)
&. #darr (auto1 @(TArr n sdElt))
@@ -1145,6 +1147,7 @@ drev des accumMap sd = \case
&. #x₀atapebinds (subList (bindingsBinds bindsx₀a) subtapex₀a)
&. #ftapebinds (subList (bindingsBinds ef0) subtapeEf)
&. #ftape (auto1 @(Tape e_tape))
+ &. #primalzip (zipPrimalTy `SCons` SNil)
&. #efPrerebinds efPrerebinds
&. #propr (d1e envPro)
&. #d1env (desD1E des)
@@ -1166,11 +1169,14 @@ drev des accumMap sd = \case
(#xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
layout)
ef0)) $
- EPair ext
- (weakenExpr (autoWeak library (#fbinds :++: #xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
- (#fbinds :++: layout))
- ef1)
- (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: layout))))
+ elet (weakenExpr (autoWeak library (#fbinds :++: #xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
+ (#fbinds :++: layout))
+ ef1) $
+ EPair ext
+ (evar IZ)
+ (EPair ext
+ (evar IZ)
+ (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#px :++: #fbinds :++: layout)))))
(EVar ext (d1 eltty) (IS (IS IZ)))
(EVar ext (STArr (SS ndim) (d1 eltty)) IZ))
(SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e envPro)))))))
@@ -1181,19 +1187,24 @@ drev des accumMap sd = \case
(uninvertTup (d2e envPro) (STPair (d2 eltty) (STArr (SS ndim) (d2 eltty))) $
makeAccumulators (autoWeak library #propr layout1) envPro $
let layout2 = #d2acPro :++: layout1 in
- EFold1InnerD2 ext commut (d2M eltty)
- (letBinds (efRebinds (IS (IS (IS IZ)))) $
- let layout3 = (#ftapebinds :++: #efPrerebinds) :++: #d :++: #xy :++: #ftape :++: layout2 in
+ EFold1InnerD2 ext commut
+ (elet (ESnd ext (ESnd ext (EVar ext zipPrimalTy (IS IZ)))) $
+ elet (EFst ext (ESnd ext (EVar ext zipPrimalTy (IS (IS IZ))))) $
+ elet (EFst ext (EVar ext zipPrimalTy (IS (IS (IS IZ))))) $
+ letBinds (efRebinds (IS (IS IZ))) $
+ let layout3 = (#ftapebinds :++: #efPrerebinds) :++: #xy :++: #ftape :++: #d :++: #primalzip :++: layout2 in
elet (expandSubenvZeros (autoWeak library #xy layout3) (eltty `SCons` eltty `SCons` SNil) subEf $
weakenExpr (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) layout3
.> wPro (subList (bindingsBinds ef0) subtapeEf))
ef2) $
EPair ext (ESnd ext (EFst ext (evar IZ))) (ESnd ext (evar IZ)))
- (EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ))
- (EVar ext (tZeroInfo (d2M eltty)) (autoWeak library #pzi layout2 @> IZ))
- (ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ))
+ (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (autoWeak library #pzi layout2 @> IZ)))
+ (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))
+ (ezip
+ (EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ))
+ (ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))
(ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ)))
- (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ))
+ (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ))
(EFst ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))) $
plus_x₀a_f
(weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $