aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs18
1 files changed, 11 insertions, 7 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 04c4231..7594a0f 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1184,7 +1184,7 @@ drev des accumMap sd = \case
subx₀af
(let layout1 = #darr :++: #primal :++: #parr :++: #pzi :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in
elet
- (uninvertTup (d2e envPro) (STPair (d2 eltty) (STArr (SS ndim) (d2 eltty))) $
+ (uninvertTup (d2e envPro) (STPair (STArr ndim (d2 eltty)) (STArr (SS ndim) (d2 eltty))) $
makeAccumulators (autoWeak library #propr layout1) envPro $
let layout2 = #d2acPro :++: layout1 in
EFold1InnerD2 ext commut
@@ -1198,8 +1198,6 @@ drev des accumMap sd = \case
.> wPro (subList (bindingsBinds ef0) subtapeEf))
ef2) $
EPair ext (ESnd ext (EFst ext (evar IZ))) (ESnd ext (evar IZ)))
- (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (autoWeak library #pzi layout2 @> IZ)))
- (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))
(ezip
(EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ))
(ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))
@@ -1207,10 +1205,16 @@ drev des accumMap sd = \case
(EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ))
(EFst ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))) $
plus_x₀a_f
- (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $
- plus_x₀_a
- (subst0 (EFst ext (EFst ext (evar IZ))) ex₀2)
- (subst0 (ESnd ext (EFst ext (evar IZ))) ea2))
+ (plus_x₀_a
+ (elet (EIdx0 ext
+ (EFold1Inner ext Commut
+ (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))
+ (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (WSink .> autoWeak library #pzi layout1 @> IZ)))
+ (eflatten (EFst ext (EFst ext (evar IZ)))))) $
+ weakenExpr (WCopy (WSink .> autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1))
+ ex₀2)
+ (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $
+ subst0 (ESnd ext (EFst ext (evar IZ))) ea2))
(ESnd ext (evar IZ)))
}