From 4c9ae47dd5bbd27b1acb6dc5d4a55657ac1f026f Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 30 Oct 2025 15:58:08 +0100 Subject: Simplify foldD2 to not sum x0 contributions --- src/CHAD.hs | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) (limited to 'src/CHAD.hs') diff --git a/src/CHAD.hs b/src/CHAD.hs index 04c4231..7594a0f 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1184,7 +1184,7 @@ drev des accumMap sd = \case subx₀af (let layout1 = #darr :++: #primal :++: #parr :++: #pzi :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in elet - (uninvertTup (d2e envPro) (STPair (d2 eltty) (STArr (SS ndim) (d2 eltty))) $ + (uninvertTup (d2e envPro) (STPair (STArr ndim (d2 eltty)) (STArr (SS ndim) (d2 eltty))) $ makeAccumulators (autoWeak library #propr layout1) envPro $ let layout2 = #d2acPro :++: layout1 in EFold1InnerD2 ext commut @@ -1198,8 +1198,6 @@ drev des accumMap sd = \case .> wPro (subList (bindingsBinds ef0) subtapeEf)) ef2) $ EPair ext (ESnd ext (EFst ext (evar IZ))) (ESnd ext (evar IZ))) - (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (autoWeak library #pzi layout2 @> IZ))) - (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) (ezip (EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ)) (ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ))) @@ -1207,10 +1205,16 @@ drev des accumMap sd = \case (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ)) (EFst ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))) $ plus_x₀a_f - (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $ - plus_x₀_a - (subst0 (EFst ext (EFst ext (evar IZ))) ex₀2) - (subst0 (ESnd ext (EFst ext (evar IZ))) ea2)) + (plus_x₀_a + (elet (EIdx0 ext + (EFold1Inner ext Commut + (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) + (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (WSink .> autoWeak library #pzi layout1 @> IZ))) + (eflatten (EFst ext (EFst ext (evar IZ)))))) $ + weakenExpr (WCopy (WSink .> autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) + ex₀2) + (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $ + subst0 (ESnd ext (EFst ext (evar IZ))) ea2)) (ESnd ext (evar IZ))) } -- cgit v1.2.3-70-g09d2