From 955af83f664639701fdbee54718186e07b31d42f Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Tue, 28 Oct 2025 11:56:40 +0100 Subject: Better fold D{1,2} primitives --- src/CHAD.hs | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) (limited to 'src/CHAD.hs') diff --git a/src/CHAD.hs b/src/CHAD.hs index 25d26a6..93fabf9 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1133,9 +1133,11 @@ drev des accumMap sd = \case let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in let bogTy = STArr (SS ndim) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf))) primalTy = STPair (STArr ndim (d1 eltty)) bogTy + zipPrimalTy = STPair (d1 eltty) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf))) library = #xy (d1 eltty `SCons` d1 eltty `SCons` SNil) &. #parr (auto1 @(TArr (S n) (D1 elt))) &. #px₀ (auto1 @(D1 elt)) + &. #px (auto1 @(D1 elt)) &. #pzi (auto1 @(ZeroInfo (D2 elt))) &. #primal (primalTy `SCons` SNil) &. #darr (auto1 @(TArr n sdElt)) @@ -1145,6 +1147,7 @@ drev des accumMap sd = \case &. #x₀atapebinds (subList (bindingsBinds bindsx₀a) subtapex₀a) &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf) &. #ftape (auto1 @(Tape e_tape)) + &. #primalzip (zipPrimalTy `SCons` SNil) &. #efPrerebinds efPrerebinds &. #propr (d1e envPro) &. #d1env (desD1E des) @@ -1166,11 +1169,14 @@ drev des accumMap sd = \case (#xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) layout) ef0)) $ - EPair ext - (weakenExpr (autoWeak library (#fbinds :++: #xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#fbinds :++: layout)) - ef1) - (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: layout)))) + elet (weakenExpr (autoWeak library (#fbinds :++: #xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#fbinds :++: layout)) + ef1) $ + EPair ext + (evar IZ) + (EPair ext + (evar IZ) + (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#px :++: #fbinds :++: layout))))) (EVar ext (d1 eltty) (IS (IS IZ))) (EVar ext (STArr (SS ndim) (d1 eltty)) IZ)) (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e envPro))))))) @@ -1181,19 +1187,24 @@ drev des accumMap sd = \case (uninvertTup (d2e envPro) (STPair (d2 eltty) (STArr (SS ndim) (d2 eltty))) $ makeAccumulators (autoWeak library #propr layout1) envPro $ let layout2 = #d2acPro :++: layout1 in - EFold1InnerD2 ext commut (d2M eltty) - (letBinds (efRebinds (IS (IS (IS IZ)))) $ - let layout3 = (#ftapebinds :++: #efPrerebinds) :++: #d :++: #xy :++: #ftape :++: layout2 in + EFold1InnerD2 ext commut + (elet (ESnd ext (ESnd ext (EVar ext zipPrimalTy (IS IZ)))) $ + elet (EFst ext (ESnd ext (EVar ext zipPrimalTy (IS (IS IZ))))) $ + elet (EFst ext (EVar ext zipPrimalTy (IS (IS (IS IZ))))) $ + letBinds (efRebinds (IS (IS IZ))) $ + let layout3 = (#ftapebinds :++: #efPrerebinds) :++: #xy :++: #ftape :++: #d :++: #primalzip :++: layout2 in elet (expandSubenvZeros (autoWeak library #xy layout3) (eltty `SCons` eltty `SCons` SNil) subEf $ weakenExpr (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) layout3 .> wPro (subList (bindingsBinds ef0) subtapeEf)) ef2) $ EPair ext (ESnd ext (EFst ext (evar IZ))) (ESnd ext (evar IZ))) - (EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ)) - (EVar ext (tZeroInfo (d2M eltty)) (autoWeak library #pzi layout2 @> IZ)) - (ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)) + (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (autoWeak library #pzi layout2 @> IZ))) + (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) + (ezip + (EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ)) + (ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ))) (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ))) - (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ)) + (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ)) (EFst ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))) $ plus_x₀a_f (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $ -- cgit v1.2.3-70-g09d2