aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-10-24 23:34:30 +0200
committerTom Smeding <tom@tomsmeding.com>2025-10-24 23:34:30 +0200
commit42176d4a8a0fe7954f17da5c0506721695aa477f (patch)
tree8a29e847faa613e9becf1bccdcaad010187e639b /src/CHAD.hs
parent7729c45a325fe653421d654ed4c28b040585fce9 (diff)
WIP fold: everything but Compile (slow, but should be sound)
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs20
1 files changed, 12 insertions, 8 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index ec719e8..25d26a6 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1136,6 +1136,7 @@ drev des accumMap sd = \case
library = #xy (d1 eltty `SCons` d1 eltty `SCons` SNil)
&. #parr (auto1 @(TArr (S n) (D1 elt)))
&. #px₀ (auto1 @(D1 elt))
+ &. #pzi (auto1 @(ZeroInfo (D2 elt)))
&. #primal (primalTy `SCons` SNil)
&. #darr (auto1 @(TArr n sdElt))
&. #d (auto1 @(D2 elt))
@@ -1157,23 +1158,25 @@ drev des accumMap sd = \case
subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) $ \subx₀af _ _ plus_x₀a_f ->
Ret (bconcat bindsx₀a mergePrimalBindings'
`bpush` weakenExpr wOverPrimalBindings ex₀1
- `bpush` weakenExpr (WSink .> wOverPrimalBindings) ea1
+ `bpush` d2zeroInfo eltty (EVar ext (d1 eltty) IZ)
+ `bpush` weakenExpr (WSink .> WSink .> wOverPrimalBindings) ea1
`bpush` EFold1InnerD1 ext commut
- (letBinds (fst (weakenBindingsE (autoWeak library
+ (let layout = #xy :++: #parr :++: #pzi :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env in
+ letBinds (fst (weakenBindingsE (autoWeak library
(#xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
- (#xy :++: #parr :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env))
+ layout)
ef0)) $
EPair ext
(weakenExpr (autoWeak library (#fbinds :++: #xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
- (#fbinds :++: #xy :++: #parr :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env))
+ (#fbinds :++: layout))
ef1)
- (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: #xy :++: #parr :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env))))
- (EVar ext (d1 eltty) (IS IZ))
+ (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: layout))))
+ (EVar ext (d1 eltty) (IS (IS IZ)))
(EVar ext (STArr (SS ndim) (d1 eltty)) IZ))
- (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e envPro))))))
+ (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e envPro)))))))
(EFst ext (EVar ext primalTy IZ))
subx₀af
- (let layout1 = #darr :++: #primal :++: #parr :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in
+ (let layout1 = #darr :++: #primal :++: #parr :++: #pzi :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in
elet
(uninvertTup (d2e envPro) (STPair (d2 eltty) (STArr (SS ndim) (d2 eltty))) $
makeAccumulators (autoWeak library #propr layout1) envPro $
@@ -1187,6 +1190,7 @@ drev des accumMap sd = \case
ef2) $
EPair ext (ESnd ext (EFst ext (evar IZ))) (ESnd ext (evar IZ)))
(EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ))
+ (EVar ext (tZeroInfo (d2M eltty)) (autoWeak library #pzi layout2 @> IZ))
(ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ))
(ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ)))
(EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ))