diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-10-23 23:53:37 +0200 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-10-23 23:53:37 +0200 |
| commit | 2542f5ef42452967fec1d2376927c1f36bf263f4 (patch) | |
| tree | 717d97be4d21c4ac0355270ac81df33296b8b852 /src/CHAD.hs | |
| parent | f805440cf8833d238f848dd07f89b8ed5bc69e90 (diff) | |
WIP fold: Implement D[fold1i]
Still need to handle the new primitives in the rest of the library
Diffstat (limited to 'src/CHAD.hs')
| -rw-r--r-- | src/CHAD.hs | 90 |
1 files changed, 87 insertions, 3 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index cfae98d..ec719e8 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1116,6 +1116,89 @@ drev des accumMap sd = \case e2) }} + EFold1Inner _ commut origef ex₀ earr + | SpArr @_ @sdElt sdElt <- sd + , STArr (SS ndim) eltty :: STy (TArr (S n) elt) <- typeOf earr + , Rets bindsx₀a subtapex₀a (RetPair ex₀1 subx₀ ex₀2 `SCons` RetPair ea1 suba ea2 `SCons` SNil) + <- retConcat des $ toSingleRet (drev des accumMap (spDense (d2M eltty)) ex₀) `SCons` toSingleRet (drev des accumMap (spDense (SMTArr (SS ndim) (d2M eltty))) earr) `SCons` SNil -> + deleteUnused (descrList des) (occEnvPopSome (occEnvPopSome (occCountAll origef))) $ \(usedSub :: Subenv env env') -> + let ef = unsafeWeakenWithSubenv (SEYesR (SEYesR usedSub)) origef in + subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> + accumPromote (d2 eltty) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> + let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in + let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in + let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in + let (mergePrimalBindings', _) = weakenBindingsE (sinkWithBindings bindsx₀a) mergePrimalBindings in + case drev (prodes `DPush` (eltty, Nothing, SMerge) `DPush` (eltty, Nothing, SMerge)) accumMapPro (spDense (d2M eltty)) ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 -> + let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in + let bogTy = STArr (SS ndim) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf))) + primalTy = STPair (STArr ndim (d1 eltty)) bogTy + library = #xy (d1 eltty `SCons` d1 eltty `SCons` SNil) + &. #parr (auto1 @(TArr (S n) (D1 elt))) + &. #px₀ (auto1 @(D1 elt)) + &. #primal (primalTy `SCons` SNil) + &. #darr (auto1 @(TArr n sdElt)) + &. #d (auto1 @(D2 elt)) + &. #x₀abinds (bindingsBinds bindsx₀a) + &. #fbinds (bindingsBinds ef0) + &. #x₀atapebinds (subList (bindingsBinds bindsx₀a) subtapex₀a) + &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf) + &. #ftape (auto1 @(Tape e_tape)) + &. #efPrerebinds efPrerebinds + &. #propr (d1e envPro) + &. #d1env (desD1E des) + &. #d1env' (desD1E usedDes) + &. #d2acUsed (d2ace (select SAccum usedDes)) + &. #d2acEnv (d2ace (select SAccum des)) + &. #d2acPro (d2ace envPro) + &. #foldd2res (auto1 @(TPair (TPair (D2 elt) (TArr (S n) (D2 elt))) (Tup (D2E envPro)))) + wOverPrimalBindings = autoWeak library (#x₀abinds :++: #d1env) ((#propr :++: #x₀abinds) :++: #d1env) in + subenvPlus SF SF (d2eM (select SMerge des)) subx₀ suba $ \subx₀a _ _ plus_x₀_a -> + subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) $ \subx₀af _ _ plus_x₀a_f -> + Ret (bconcat bindsx₀a mergePrimalBindings' + `bpush` weakenExpr wOverPrimalBindings ex₀1 + `bpush` weakenExpr (WSink .> wOverPrimalBindings) ea1 + `bpush` EFold1InnerD1 ext commut + (letBinds (fst (weakenBindingsE (autoWeak library + (#xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#xy :++: #parr :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env)) + ef0)) $ + EPair ext + (weakenExpr (autoWeak library (#fbinds :++: #xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#fbinds :++: #xy :++: #parr :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env)) + ef1) + (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: #xy :++: #parr :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env)))) + (EVar ext (d1 eltty) (IS IZ)) + (EVar ext (STArr (SS ndim) (d1 eltty)) IZ)) + (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e envPro)))))) + (EFst ext (EVar ext primalTy IZ)) + subx₀af + (let layout1 = #darr :++: #primal :++: #parr :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in + elet + (uninvertTup (d2e envPro) (STPair (d2 eltty) (STArr (SS ndim) (d2 eltty))) $ + makeAccumulators (autoWeak library #propr layout1) envPro $ + let layout2 = #d2acPro :++: layout1 in + EFold1InnerD2 ext commut (d2M eltty) + (letBinds (efRebinds (IS (IS (IS IZ)))) $ + let layout3 = (#ftapebinds :++: #efPrerebinds) :++: #d :++: #xy :++: #ftape :++: layout2 in + elet (expandSubenvZeros (autoWeak library #xy layout3) (eltty `SCons` eltty `SCons` SNil) subEf $ + weakenExpr (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) layout3 + .> wPro (subList (bindingsBinds ef0) subtapeEf)) + ef2) $ + EPair ext (ESnd ext (EFst ext (evar IZ))) (ESnd ext (evar IZ))) + (EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ)) + (ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)) + (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ))) + (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ)) + (EFst ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))) $ + plus_x₀a_f + (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $ + plus_x₀_a + (subst0 (EFst ext (EFst ext (evar IZ))) ex₀2) + (subst0 (ESnd ext (EFst ext (evar IZ))) ea2)) + (ESnd ext (evar IZ))) + } + EUnit _ e | SpArr sdElt <- sd , Ret e0 subtape e1 sub e2 <- drev des accumMap sdElt e -> @@ -1229,9 +1312,6 @@ drev des accumMap sd = \case EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e - -- These should be the next to be implemented, I think - EFold1Inner{} -> err_unsupported "EFold1Inner" - ENothing{} -> err_unsupported "ENothing" EJust{} -> err_unsupported "EJust" EMaybe{} -> err_unsupported "EMaybe" @@ -1246,10 +1326,14 @@ drev des accumMap sd = \case EPlus{} -> err_monoid EOneHot{} -> err_monoid + EFold1InnerD1{} -> err_targetlang "EFold1InnerD1" + EFold1InnerD2{} -> err_targetlang "EFold1InnerD2" + where err_accum = error "Accumulator operations unsupported in the source program" err_monoid = error "Monoid operations unsupported in the source program" err_unsupported s = error $ "CHAD: unsupported " ++ s + err_targetlang s = error $ "CHAD: Target language operation " ++ s ++ " not supported in source program" contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs) contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub)) |
