diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-04 23:09:21 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-04 23:09:21 +0100 |
| commit | 57779d4303f377004705c8da06a5ac46177950b2 (patch) | |
| tree | 0407089403d3d5c2de778c1aab7aed8adf2d01c0 /src/CHAD.hs | |
| parent | 351667a3ff14c96a8dfe3a2f1dd76b6e1a996542 (diff) | |
Diffstat (limited to 'src/CHAD.hs')
| -rw-r--r-- | src/CHAD.hs | 112 |
1 files changed, 42 insertions, 70 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index 72ce36d..9da5395 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1077,37 +1077,29 @@ drev des accumMap sd = \case ESnd ext $ wrapAccum (WSink .> WSink .> wRaiseAbove (d1e provars) (d2ace (select SAccum des))) $ EBuild ext ndim (EShape ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (sinkOverEnvPro @> IZ))) $ - -- the tape for this element - ELet ext (EIdx ext (EVar ext (STArr ndim (typeOf e1tape)) (WSink .> sinkOverEnvPro @> IS IZ)) - (EVar ext shty IZ)) $ -- the cotangent for this element - ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> WSink .> sinkOverEnvPro @> IZ)) + ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ)) + (EVar ext shty IZ)) $ + -- the tape for this element + ELet ext (EIdx ext (EVar ext (STArr ndim (typeOf e1tape)) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) (EVar ext shty (IS IZ))) $ - weakenExpr (autoWeak library (#d :++: #tape :++: #pro :++: #d2acEnv) - (#d :++: #tape :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv)) + weakenExpr (autoWeak library (#tape :++: #d :++: #pro :++: #d2acEnv) + (#tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv)) e2) - EMap{} -> undefined + EMap{} -> error "TODO: CHAD EMap" EFold1Inner _ commut origef ex₀ earr | SpArr @_ @sdElt sdElt <- sd , STArr (SS ndim) eltty :: STy (TArr (S n) elt) <- typeOf earr , Rets bindsx₀a subtapex₀a (RetPair ex₀1 subx₀ ex₀2 `SCons` RetPair ea1 suba ea2 `SCons` SNil) <- retConcat des $ toSingleRet (drev des accumMap (spDense (d2M eltty)) ex₀) `SCons` toSingleRet (drev des accumMap (spDense (SMTArr (SS ndim) (d2M eltty))) earr) `SCons` SNil -> - deleteUnused (descrList des) (occEnvPopSome (occEnvPopSome (occCountAll origef))) $ \(usedSub :: Subenv env env') -> - let ef = unsafeWeakenWithSubenv (SEYesR (SEYesR usedSub)) origef in - subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> - accumPromote (d2 eltty) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> - let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in - let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in - let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in - let (mergePrimalBindings', _) = weakenBindingsE (sinkWithBindings bindsx₀a) mergePrimalBindings in - case drev (prodes `DPush` (eltty, Nothing, SMerge) `DPush` (eltty, Nothing, SMerge)) accumMapPro (spDense (d2M eltty)) ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 -> - let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in - let bogTy = STArr (SS ndim) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf))) + drevLambda des accumMap (STPair eltty eltty, SMerge) (spDense (d2M eltty)) origef $ \(provars :: SList _ envPro) efsub proPrimalBinds ef0 ef1 (ef1tape :: Ex _ ef_tape) subEf wrapAccum ef2 -> + let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings bindsx₀a) proPrimalBinds in + let bogEltTy = STPair (STPair (d1 eltty) (d1 eltty)) (typeOf ef1tape) + bogTy = STArr (SS ndim) bogEltTy primalTy = STPair (STArr ndim (d1 eltty)) bogTy - zipPrimalTy = STPair (d1 eltty) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf))) - library = #xy (d1 eltty `SCons` d1 eltty `SCons` SNil) + library = #xy (STPair (d1 eltty) (d1 eltty) `SCons` SNil) &. #parr (auto1 @(TArr (S n) (D1 elt))) &. #px₀ (auto1 @(D1 elt)) &. #px (auto1 @(D1 elt)) @@ -1118,70 +1110,52 @@ drev des accumMap sd = \case &. #x₀abinds (bindingsBinds bindsx₀a) &. #fbinds (bindingsBinds ef0) &. #x₀atapebinds (subList (bindingsBinds bindsx₀a) subtapex₀a) - &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf) - &. #ftape (auto1 @(Tape e_tape)) - &. #primalzip (zipPrimalTy `SCons` SNil) - &. #efPrerebinds efPrerebinds - &. #propr (d1e envPro) + &. #ftape (auto1 @ef_tape) + &. #bogelt (bogEltTy `SCons` SNil) + &. #propr (d1e provars) &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes) - &. #d2acUsed (d2ace (select SAccum usedDes)) &. #d2acEnv (d2ace (select SAccum des)) - &. #d2acPro (d2ace envPro) + &. #d2acPro (d2ace provars) &. #foldd2res (auto1 @(TPair (TPair (D2 elt) (TArr (S n) (D2 elt))) (Tup (D2E envPro)))) wOverPrimalBindings = autoWeak library (#x₀abinds :++: #d1env) ((#propr :++: #x₀abinds) :++: #d1env) in subenvPlus SF SF (d2eM (select SMerge des)) subx₀ suba $ \subx₀a _ _ plus_x₀_a -> - subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) $ \subx₀af _ _ plus_x₀a_f -> - Ret (bconcat bindsx₀a mergePrimalBindings' + subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) $ \subx₀af _ _ plus_x₀a_f -> + Ret (bconcat bindsx₀a proPrimalBinds' `bpush` weakenExpr wOverPrimalBindings ex₀1 `bpush` d2zeroInfo eltty (EVar ext (d1 eltty) IZ) `bpush` weakenExpr (WSink .> WSink .> wOverPrimalBindings) ea1 `bpush` EFold1InnerD1 ext commut (let layout = #xy :++: #parr :++: #pzi :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env in - letBinds (fst (weakenBindingsE (autoWeak library - (#xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - layout) - ef0)) $ - elet (weakenExpr (autoWeak library (#fbinds :++: #xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#fbinds :++: layout)) - ef1) $ - EPair ext - (evar IZ) + letBinds (fst (weakenBindingsE (autoWeak library (#xy :++: #d1env) layout) ef0)) $ + EPair ext -- (out, ((in1, in2), tape)); the "additional stores" are ((in1, in2), tape) + (weakenExpr (autoWeak library (#fbinds :++: #xy :++: #d1env) (#fbinds :++: layout)) ef1) (EPair ext - (evar IZ) - (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#px :++: #fbinds :++: layout))))) + (EVar ext (STPair (d1 eltty) (d1 eltty)) (autoWeak library #xy (#fbinds :++: layout) @> IZ)) + (weakenExpr (autoWeak library (#fbinds :++: #xy :++: #d1env) (#fbinds :++: layout)) ef1tape))) (EVar ext (d1 eltty) (IS (IS IZ))) (EVar ext (STArr (SS ndim) (d1 eltty)) IZ)) - (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e envPro))))))) + (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e provars))))))) (EFst ext (EVar ext primalTy IZ)) subx₀af (let layout1 = #darr :++: #primal :++: #parr :++: #pzi :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in elet - (uninvertTup (d2e envPro) (STPair (STArr ndim (d2 eltty)) (STArr (SS ndim) (d2 eltty))) $ - makeAccumulators (autoWeak library #propr layout1) envPro $ - let layout2 = #d2acPro :++: layout1 in - EFold1InnerD2 ext commut - (elet (ESnd ext (ESnd ext (EVar ext zipPrimalTy (IS IZ)))) $ - elet (EFst ext (ESnd ext (EVar ext zipPrimalTy (IS (IS IZ))))) $ - elet (EFst ext (EVar ext zipPrimalTy (IS (IS (IS IZ))))) $ - letBinds (efRebinds (IS (IS IZ))) $ - let layout3 = (#ftapebinds :++: #efPrerebinds) :++: #xy :++: #ftape :++: #d :++: #primalzip :++: layout2 in - elet (expandSubenvZeros (autoWeak library #xy layout3) (eltty `SCons` eltty `SCons` SNil) subEf $ - weakenExpr (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) layout3 - .> wPro (subList (bindingsBinds ef0) subtapeEf)) - ef2) $ - EPair ext (ESnd ext (EFst ext (evar IZ))) (ESnd ext (evar IZ))) - (ezip - (EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ)) - (ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ))) - (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ))) - (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ)) - (EFst ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))) $ + (wrapAccum (autoWeak library #propr layout1) $ + let layout2 = #d2acPro :++: layout1 in + EFold1InnerD2 ext commut + (elet (ESnd ext (EVar ext bogEltTy (IS IZ))) $ + let layout3 = #ftape :++: #d :++: #bogelt :++: layout2 in + expandSparse (STPair eltty eltty) subEf (EFst ext (EVar ext bogEltTy (IS (IS IZ)))) $ + weakenExpr (autoWeak library (#ftape :++: #d :++: #d2acPro :++: #d2acEnv) layout3) ef2) + (ESnd ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ))) + (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ))) + (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ)) + (EFst ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ))))) $ plus_x₀a_f (plus_x₀_a (elet (EIdx0 ext (EFold1Inner ext Commut - (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) + (let t = STPair (d2 eltty) (d2 eltty) + in EPlus ext (d2M eltty) (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ))) (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (WSink .> autoWeak library #pzi layout1 @> IZ))) (eflatten (EFst ext (EFst ext (evar IZ)))))) $ weakenExpr (WCopy (WSink .> autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) @@ -1189,7 +1163,6 @@ drev des accumMap sd = \case (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $ subst0 (ESnd ext (EFst ext (evar IZ))) ea2)) (ESnd ext (evar IZ))) - } EUnit _ e | SpArr sdElt <- sd @@ -1213,9 +1186,8 @@ drev des accumMap sd = \case (EReplicate1Inner ext (weakenExpr (wSinks (bindingsBinds binds)) (drevPrimal des en)) e1) sub (ELet ext (EFold1Inner ext Commut - (sparsePlus (d2M eltty) sdElt' - (EVar ext (applySparse sdElt' (d2 eltty)) (IS IZ)) - (EVar ext (applySparse sdElt' (d2 eltty)) IZ)) + (let t = STPair (applySparse sdElt' (d2 eltty)) (applySparse sdElt' (d2 eltty)) + in sparsePlus (d2M eltty) sdElt' (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ))) (inj2 (ENil ext)) (emap (inj1 (evar IZ)) $ EVar ext (STArr (SS ndim) (applySparse sdElt (d2 eltty))) IZ)) $ weakenExpr (WCopy WSink) e2) @@ -1494,7 +1466,7 @@ drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False) D1E provars :> env' -> Ex (Append (D2AcE provars) env') b -> Ex ( env') (TPair b (Tup (D2E provars)))) - -> Ex (dt : tape : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a' + -> Ex (tape : dt : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a' -> r) -> r drevLambda des accumMap (argty, argsto) sd origef k = @@ -1535,10 +1507,10 @@ drevLambda des accumMap (argty, argsto) sd origef k = uninvertTup (d2e envPro) (typeOf body) $ makeAccumulators wpro1 envPro $ body) - (letBinds (efRebinds (IS IZ)) $ + (letBinds (efRebinds IZ) $ weakenExpr (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) - ((#ftapebinds :++: #efPrerebinds) :++: #d :++: #ftape :++: #d2acPro :++: #d2acEnv) + ((#ftapebinds :++: #efPrerebinds) :++: #ftape :++: #d :++: #d2acPro :++: #d2acEnv) .> wPro (subList (bindingsBinds ef0) subtapeEf)) (getSparseArg ef2)) }} |
