aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-10-23 23:53:37 +0200
committerTom Smeding <tom@tomsmeding.com>2025-10-23 23:53:37 +0200
commit2542f5ef42452967fec1d2376927c1f36bf263f4 (patch)
tree717d97be4d21c4ac0355270ac81df33296b8b852
parentf805440cf8833d238f848dd07f89b8ed5bc69e90 (diff)
WIP fold: Implement D[fold1i]
Still need to handle the new primitives in the rest of the library
-rw-r--r--src/AST.hs14
-rw-r--r--src/CHAD.hs90
-rw-r--r--src/CHAD/Accum.hs5
3 files changed, 106 insertions, 3 deletions
diff --git a/src/AST.hs b/src/AST.hs
index a10f1ae..275abcd 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -62,6 +62,7 @@ data Expr x env t where
-- array operations
EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t))
EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t)
+ -- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right)
EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t)
@@ -69,6 +70,19 @@ data Expr x env t where
EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
+ EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) t1)) -> Commutative -> Expr x (t1 : t1 : env) (TPair t1 tape) -> Expr x env t1 -> Expr x env (TArr (S n) t1)
+ -> Expr x env (TPair (TArr n t1) -- normal primal fold output
+ (TArr (S n) (TPair t1 tape))) -- bag-of-goodies: zip (prescanl) (the tape stores)
+ -- TODO: as-is, the primal input array is mostly unused; it is used only if the combination function returns sparse cotangents that need to be expanded, and nowhere else. That's wasteful storage.
+ EFold1InnerD2 :: x (TArr (S n) t2) -> Commutative
+ -> SMTy t2 -- t2 must be a monoid in order to be able to add all inner-vector contributions to the single x0
+ -- TODO: `fold1i (*)` should have zero tape stores since primals are directly made available here, but this is not yet true
+ -> Expr x (t2 : t1 : t1 : tape : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation)
+ -> Expr x env (TArr (S n) t1) -- primal input array
+ -> Expr x env (TArr (S n) (TPair t1 tape)) -- bag-of-goodies from EFold1InnerD1
+ -> Expr x env (TArr n t2) -- incoming cotangent
+ -> Expr x env (TPair t2 (TArr (S n) t2)) -- outgoing cotangents to x0 and input array
+
-- expression operations
EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t)
EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t
diff --git a/src/CHAD.hs b/src/CHAD.hs
index cfae98d..ec719e8 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1116,6 +1116,89 @@ drev des accumMap sd = \case
e2)
}}
+ EFold1Inner _ commut origef ex₀ earr
+ | SpArr @_ @sdElt sdElt <- sd
+ , STArr (SS ndim) eltty :: STy (TArr (S n) elt) <- typeOf earr
+ , Rets bindsx₀a subtapex₀a (RetPair ex₀1 subx₀ ex₀2 `SCons` RetPair ea1 suba ea2 `SCons` SNil)
+ <- retConcat des $ toSingleRet (drev des accumMap (spDense (d2M eltty)) ex₀) `SCons` toSingleRet (drev des accumMap (spDense (SMTArr (SS ndim) (d2M eltty))) earr) `SCons` SNil ->
+ deleteUnused (descrList des) (occEnvPopSome (occEnvPopSome (occCountAll origef))) $ \(usedSub :: Subenv env env') ->
+ let ef = unsafeWeakenWithSubenv (SEYesR (SEYesR usedSub)) origef in
+ subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed ->
+ accumPromote (d2 eltty) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro ->
+ let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in
+ let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in
+ let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in
+ let (mergePrimalBindings', _) = weakenBindingsE (sinkWithBindings bindsx₀a) mergePrimalBindings in
+ case drev (prodes `DPush` (eltty, Nothing, SMerge) `DPush` (eltty, Nothing, SMerge)) accumMapPro (spDense (d2M eltty)) ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 ->
+ let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in
+ let bogTy = STArr (SS ndim) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf)))
+ primalTy = STPair (STArr ndim (d1 eltty)) bogTy
+ library = #xy (d1 eltty `SCons` d1 eltty `SCons` SNil)
+ &. #parr (auto1 @(TArr (S n) (D1 elt)))
+ &. #px₀ (auto1 @(D1 elt))
+ &. #primal (primalTy `SCons` SNil)
+ &. #darr (auto1 @(TArr n sdElt))
+ &. #d (auto1 @(D2 elt))
+ &. #x₀abinds (bindingsBinds bindsx₀a)
+ &. #fbinds (bindingsBinds ef0)
+ &. #x₀atapebinds (subList (bindingsBinds bindsx₀a) subtapex₀a)
+ &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf)
+ &. #ftape (auto1 @(Tape e_tape))
+ &. #efPrerebinds efPrerebinds
+ &. #propr (d1e envPro)
+ &. #d1env (desD1E des)
+ &. #d1env' (desD1E usedDes)
+ &. #d2acUsed (d2ace (select SAccum usedDes))
+ &. #d2acEnv (d2ace (select SAccum des))
+ &. #d2acPro (d2ace envPro)
+ &. #foldd2res (auto1 @(TPair (TPair (D2 elt) (TArr (S n) (D2 elt))) (Tup (D2E envPro))))
+ wOverPrimalBindings = autoWeak library (#x₀abinds :++: #d1env) ((#propr :++: #x₀abinds) :++: #d1env) in
+ subenvPlus SF SF (d2eM (select SMerge des)) subx₀ suba $ \subx₀a _ _ plus_x₀_a ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) $ \subx₀af _ _ plus_x₀a_f ->
+ Ret (bconcat bindsx₀a mergePrimalBindings'
+ `bpush` weakenExpr wOverPrimalBindings ex₀1
+ `bpush` weakenExpr (WSink .> wOverPrimalBindings) ea1
+ `bpush` EFold1InnerD1 ext commut
+ (letBinds (fst (weakenBindingsE (autoWeak library
+ (#xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
+ (#xy :++: #parr :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env))
+ ef0)) $
+ EPair ext
+ (weakenExpr (autoWeak library (#fbinds :++: #xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
+ (#fbinds :++: #xy :++: #parr :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env))
+ ef1)
+ (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: #xy :++: #parr :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env))))
+ (EVar ext (d1 eltty) (IS IZ))
+ (EVar ext (STArr (SS ndim) (d1 eltty)) IZ))
+ (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e envPro))))))
+ (EFst ext (EVar ext primalTy IZ))
+ subx₀af
+ (let layout1 = #darr :++: #primal :++: #parr :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in
+ elet
+ (uninvertTup (d2e envPro) (STPair (d2 eltty) (STArr (SS ndim) (d2 eltty))) $
+ makeAccumulators (autoWeak library #propr layout1) envPro $
+ let layout2 = #d2acPro :++: layout1 in
+ EFold1InnerD2 ext commut (d2M eltty)
+ (letBinds (efRebinds (IS (IS (IS IZ)))) $
+ let layout3 = (#ftapebinds :++: #efPrerebinds) :++: #d :++: #xy :++: #ftape :++: layout2 in
+ elet (expandSubenvZeros (autoWeak library #xy layout3) (eltty `SCons` eltty `SCons` SNil) subEf $
+ weakenExpr (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) layout3
+ .> wPro (subList (bindingsBinds ef0) subtapeEf))
+ ef2) $
+ EPair ext (ESnd ext (EFst ext (evar IZ))) (ESnd ext (evar IZ)))
+ (EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ))
+ (ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ))
+ (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ)))
+ (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ))
+ (EFst ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))) $
+ plus_x₀a_f
+ (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $
+ plus_x₀_a
+ (subst0 (EFst ext (EFst ext (evar IZ))) ex₀2)
+ (subst0 (ESnd ext (EFst ext (evar IZ))) ea2))
+ (ESnd ext (evar IZ)))
+ }
+
EUnit _ e
| SpArr sdElt <- sd
, Ret e0 subtape e1 sub e2 <- drev des accumMap sdElt e ->
@@ -1229,9 +1312,6 @@ drev des accumMap sd = \case
EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e
EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e
- -- These should be the next to be implemented, I think
- EFold1Inner{} -> err_unsupported "EFold1Inner"
-
ENothing{} -> err_unsupported "ENothing"
EJust{} -> err_unsupported "EJust"
EMaybe{} -> err_unsupported "EMaybe"
@@ -1246,10 +1326,14 @@ drev des accumMap sd = \case
EPlus{} -> err_monoid
EOneHot{} -> err_monoid
+ EFold1InnerD1{} -> err_targetlang "EFold1InnerD1"
+ EFold1InnerD2{} -> err_targetlang "EFold1InnerD2"
+
where
err_accum = error "Accumulator operations unsupported in the source program"
err_monoid = error "Monoid operations unsupported in the source program"
err_unsupported s = error $ "CHAD: unsupported " ++ s
+ err_targetlang s = error $ "CHAD: Target language operation " ++ s ++ " not supported in source program"
contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs)
contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub))
diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs
index 7212232..a7bc53f 100644
--- a/src/CHAD/Accum.hs
+++ b/src/CHAD/Accum.hs
@@ -44,6 +44,11 @@ d2deepZeroInfo (STArr _ t) e = emap (d2deepZeroInfo t (EVar ext (d1 t) IZ)) e
d2deepZeroInfo (STScal t) _ | Refl <- lemDeepZeroInfoScal t = ENil ext
d2deepZeroInfo STAccum{} _ = error "accumulators not allowed in source program"
+-- The weakening is necessary because we need to initialise the created
+-- accumulators with zeros. Those zeros are deep and need full primals. This
+-- means, in the end, that primals corresponding to environment entries
+-- promoted to an accumulator with accumPromote in CHAD need to be stored for
+-- the dual.
makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
makeAccumulators _ SNil e = e
makeAccumulators w (t `SCons` envpro) e =