aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-04 23:09:21 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-04 23:09:21 +0100
commit57779d4303f377004705c8da06a5ac46177950b2 (patch)
tree0407089403d3d5c2de778c1aab7aed8adf2d01c0 /src/CHAD.hs
parent351667a3ff14c96a8dfe3a2f1dd76b6e1a996542 (diff)
drevLambda works, TODO D[map]HEADmaster
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs112
1 files changed, 42 insertions, 70 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 72ce36d..9da5395 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1077,37 +1077,29 @@ drev des accumMap sd = \case
ESnd ext $
wrapAccum (WSink .> WSink .> wRaiseAbove (d1e provars) (d2ace (select SAccum des))) $
EBuild ext ndim (EShape ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (sinkOverEnvPro @> IZ))) $
- -- the tape for this element
- ELet ext (EIdx ext (EVar ext (STArr ndim (typeOf e1tape)) (WSink .> sinkOverEnvPro @> IS IZ))
- (EVar ext shty IZ)) $
-- the cotangent for this element
- ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> WSink .> sinkOverEnvPro @> IZ))
+ ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ))
+ (EVar ext shty IZ)) $
+ -- the tape for this element
+ ELet ext (EIdx ext (EVar ext (STArr ndim (typeOf e1tape)) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))
(EVar ext shty (IS IZ))) $
- weakenExpr (autoWeak library (#d :++: #tape :++: #pro :++: #d2acEnv)
- (#d :++: #tape :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv))
+ weakenExpr (autoWeak library (#tape :++: #d :++: #pro :++: #d2acEnv)
+ (#tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv))
e2)
- EMap{} -> undefined
+ EMap{} -> error "TODO: CHAD EMap"
EFold1Inner _ commut origef ex₀ earr
| SpArr @_ @sdElt sdElt <- sd
, STArr (SS ndim) eltty :: STy (TArr (S n) elt) <- typeOf earr
, Rets bindsx₀a subtapex₀a (RetPair ex₀1 subx₀ ex₀2 `SCons` RetPair ea1 suba ea2 `SCons` SNil)
<- retConcat des $ toSingleRet (drev des accumMap (spDense (d2M eltty)) ex₀) `SCons` toSingleRet (drev des accumMap (spDense (SMTArr (SS ndim) (d2M eltty))) earr) `SCons` SNil ->
- deleteUnused (descrList des) (occEnvPopSome (occEnvPopSome (occCountAll origef))) $ \(usedSub :: Subenv env env') ->
- let ef = unsafeWeakenWithSubenv (SEYesR (SEYesR usedSub)) origef in
- subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed ->
- accumPromote (d2 eltty) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro ->
- let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in
- let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in
- let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in
- let (mergePrimalBindings', _) = weakenBindingsE (sinkWithBindings bindsx₀a) mergePrimalBindings in
- case drev (prodes `DPush` (eltty, Nothing, SMerge) `DPush` (eltty, Nothing, SMerge)) accumMapPro (spDense (d2M eltty)) ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 ->
- let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in
- let bogTy = STArr (SS ndim) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf)))
+ drevLambda des accumMap (STPair eltty eltty, SMerge) (spDense (d2M eltty)) origef $ \(provars :: SList _ envPro) efsub proPrimalBinds ef0 ef1 (ef1tape :: Ex _ ef_tape) subEf wrapAccum ef2 ->
+ let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings bindsx₀a) proPrimalBinds in
+ let bogEltTy = STPair (STPair (d1 eltty) (d1 eltty)) (typeOf ef1tape)
+ bogTy = STArr (SS ndim) bogEltTy
primalTy = STPair (STArr ndim (d1 eltty)) bogTy
- zipPrimalTy = STPair (d1 eltty) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf)))
- library = #xy (d1 eltty `SCons` d1 eltty `SCons` SNil)
+ library = #xy (STPair (d1 eltty) (d1 eltty) `SCons` SNil)
&. #parr (auto1 @(TArr (S n) (D1 elt)))
&. #px₀ (auto1 @(D1 elt))
&. #px (auto1 @(D1 elt))
@@ -1118,70 +1110,52 @@ drev des accumMap sd = \case
&. #x₀abinds (bindingsBinds bindsx₀a)
&. #fbinds (bindingsBinds ef0)
&. #x₀atapebinds (subList (bindingsBinds bindsx₀a) subtapex₀a)
- &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf)
- &. #ftape (auto1 @(Tape e_tape))
- &. #primalzip (zipPrimalTy `SCons` SNil)
- &. #efPrerebinds efPrerebinds
- &. #propr (d1e envPro)
+ &. #ftape (auto1 @ef_tape)
+ &. #bogelt (bogEltTy `SCons` SNil)
+ &. #propr (d1e provars)
&. #d1env (desD1E des)
- &. #d1env' (desD1E usedDes)
- &. #d2acUsed (d2ace (select SAccum usedDes))
&. #d2acEnv (d2ace (select SAccum des))
- &. #d2acPro (d2ace envPro)
+ &. #d2acPro (d2ace provars)
&. #foldd2res (auto1 @(TPair (TPair (D2 elt) (TArr (S n) (D2 elt))) (Tup (D2E envPro))))
wOverPrimalBindings = autoWeak library (#x₀abinds :++: #d1env) ((#propr :++: #x₀abinds) :++: #d1env) in
subenvPlus SF SF (d2eM (select SMerge des)) subx₀ suba $ \subx₀a _ _ plus_x₀_a ->
- subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) $ \subx₀af _ _ plus_x₀a_f ->
- Ret (bconcat bindsx₀a mergePrimalBindings'
+ subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) $ \subx₀af _ _ plus_x₀a_f ->
+ Ret (bconcat bindsx₀a proPrimalBinds'
`bpush` weakenExpr wOverPrimalBindings ex₀1
`bpush` d2zeroInfo eltty (EVar ext (d1 eltty) IZ)
`bpush` weakenExpr (WSink .> WSink .> wOverPrimalBindings) ea1
`bpush` EFold1InnerD1 ext commut
(let layout = #xy :++: #parr :++: #pzi :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env in
- letBinds (fst (weakenBindingsE (autoWeak library
- (#xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
- layout)
- ef0)) $
- elet (weakenExpr (autoWeak library (#fbinds :++: #xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
- (#fbinds :++: layout))
- ef1) $
- EPair ext
- (evar IZ)
+ letBinds (fst (weakenBindingsE (autoWeak library (#xy :++: #d1env) layout) ef0)) $
+ EPair ext -- (out, ((in1, in2), tape)); the "additional stores" are ((in1, in2), tape)
+ (weakenExpr (autoWeak library (#fbinds :++: #xy :++: #d1env) (#fbinds :++: layout)) ef1)
(EPair ext
- (evar IZ)
- (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#px :++: #fbinds :++: layout)))))
+ (EVar ext (STPair (d1 eltty) (d1 eltty)) (autoWeak library #xy (#fbinds :++: layout) @> IZ))
+ (weakenExpr (autoWeak library (#fbinds :++: #xy :++: #d1env) (#fbinds :++: layout)) ef1tape)))
(EVar ext (d1 eltty) (IS (IS IZ)))
(EVar ext (STArr (SS ndim) (d1 eltty)) IZ))
- (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e envPro)))))))
+ (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e provars)))))))
(EFst ext (EVar ext primalTy IZ))
subx₀af
(let layout1 = #darr :++: #primal :++: #parr :++: #pzi :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in
elet
- (uninvertTup (d2e envPro) (STPair (STArr ndim (d2 eltty)) (STArr (SS ndim) (d2 eltty))) $
- makeAccumulators (autoWeak library #propr layout1) envPro $
- let layout2 = #d2acPro :++: layout1 in
- EFold1InnerD2 ext commut
- (elet (ESnd ext (ESnd ext (EVar ext zipPrimalTy (IS IZ)))) $
- elet (EFst ext (ESnd ext (EVar ext zipPrimalTy (IS (IS IZ))))) $
- elet (EFst ext (EVar ext zipPrimalTy (IS (IS (IS IZ))))) $
- letBinds (efRebinds (IS (IS IZ))) $
- let layout3 = (#ftapebinds :++: #efPrerebinds) :++: #xy :++: #ftape :++: #d :++: #primalzip :++: layout2 in
- elet (expandSubenvZeros (autoWeak library #xy layout3) (eltty `SCons` eltty `SCons` SNil) subEf $
- weakenExpr (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) layout3
- .> wPro (subList (bindingsBinds ef0) subtapeEf))
- ef2) $
- EPair ext (ESnd ext (EFst ext (evar IZ))) (ESnd ext (evar IZ)))
- (ezip
- (EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ))
- (ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))
- (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ)))
- (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ))
- (EFst ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))) $
+ (wrapAccum (autoWeak library #propr layout1) $
+ let layout2 = #d2acPro :++: layout1 in
+ EFold1InnerD2 ext commut
+ (elet (ESnd ext (EVar ext bogEltTy (IS IZ))) $
+ let layout3 = #ftape :++: #d :++: #bogelt :++: layout2 in
+ expandSparse (STPair eltty eltty) subEf (EFst ext (EVar ext bogEltTy (IS (IS IZ)))) $
+ weakenExpr (autoWeak library (#ftape :++: #d :++: #d2acPro :++: #d2acEnv) layout3) ef2)
+ (ESnd ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))
+ (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ)))
+ (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ))
+ (EFst ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ))))) $
plus_x₀a_f
(plus_x₀_a
(elet (EIdx0 ext
(EFold1Inner ext Commut
- (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))
+ (let t = STPair (d2 eltty) (d2 eltty)
+ in EPlus ext (d2M eltty) (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ)))
(EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (WSink .> autoWeak library #pzi layout1 @> IZ)))
(eflatten (EFst ext (EFst ext (evar IZ)))))) $
weakenExpr (WCopy (WSink .> autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1))
@@ -1189,7 +1163,6 @@ drev des accumMap sd = \case
(weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $
subst0 (ESnd ext (EFst ext (evar IZ))) ea2))
(ESnd ext (evar IZ)))
- }
EUnit _ e
| SpArr sdElt <- sd
@@ -1213,9 +1186,8 @@ drev des accumMap sd = \case
(EReplicate1Inner ext (weakenExpr (wSinks (bindingsBinds binds)) (drevPrimal des en)) e1)
sub
(ELet ext (EFold1Inner ext Commut
- (sparsePlus (d2M eltty) sdElt'
- (EVar ext (applySparse sdElt' (d2 eltty)) (IS IZ))
- (EVar ext (applySparse sdElt' (d2 eltty)) IZ))
+ (let t = STPair (applySparse sdElt' (d2 eltty)) (applySparse sdElt' (d2 eltty))
+ in sparsePlus (d2M eltty) sdElt' (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ)))
(inj2 (ENil ext))
(emap (inj1 (evar IZ)) $ EVar ext (STArr (SS ndim) (applySparse sdElt (d2 eltty))) IZ)) $
weakenExpr (WCopy WSink) e2)
@@ -1494,7 +1466,7 @@ drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False)
D1E provars :> env'
-> Ex (Append (D2AcE provars) env') b
-> Ex ( env') (TPair b (Tup (D2E provars))))
- -> Ex (dt : tape : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a'
+ -> Ex (tape : dt : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a'
-> r)
-> r
drevLambda des accumMap (argty, argsto) sd origef k =
@@ -1535,10 +1507,10 @@ drevLambda des accumMap (argty, argsto) sd origef k =
uninvertTup (d2e envPro) (typeOf body) $
makeAccumulators wpro1 envPro $
body)
- (letBinds (efRebinds (IS IZ)) $
+ (letBinds (efRebinds IZ) $
weakenExpr
(autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
- ((#ftapebinds :++: #efPrerebinds) :++: #d :++: #ftape :++: #d2acPro :++: #d2acEnv)
+ ((#ftapebinds :++: #efPrerebinds) :++: #ftape :++: #d :++: #d2acPro :++: #d2acEnv)
.> wPro (subList (bindingsBinds ef0) subtapeEf))
(getSparseArg ef2))
}}