diff options
| -rw-r--r-- | src/AST.hs | 14 | ||||
| -rw-r--r-- | src/CHAD.hs | 90 | ||||
| -rw-r--r-- | src/CHAD/Accum.hs | 5 |
3 files changed, 106 insertions, 3 deletions
@@ -62,6 +62,7 @@ data Expr x env t where -- array operations EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) + -- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right) EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) @@ -69,6 +70,19 @@ data Expr x env t where EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) + EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) t1)) -> Commutative -> Expr x (t1 : t1 : env) (TPair t1 tape) -> Expr x env t1 -> Expr x env (TArr (S n) t1) + -> Expr x env (TPair (TArr n t1) -- normal primal fold output + (TArr (S n) (TPair t1 tape))) -- bag-of-goodies: zip (prescanl) (the tape stores) + -- TODO: as-is, the primal input array is mostly unused; it is used only if the combination function returns sparse cotangents that need to be expanded, and nowhere else. That's wasteful storage. + EFold1InnerD2 :: x (TArr (S n) t2) -> Commutative + -> SMTy t2 -- t2 must be a monoid in order to be able to add all inner-vector contributions to the single x0 + -- TODO: `fold1i (*)` should have zero tape stores since primals are directly made available here, but this is not yet true + -> Expr x (t2 : t1 : t1 : tape : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation) + -> Expr x env (TArr (S n) t1) -- primal input array + -> Expr x env (TArr (S n) (TPair t1 tape)) -- bag-of-goodies from EFold1InnerD1 + -> Expr x env (TArr n t2) -- incoming cotangent + -> Expr x env (TPair t2 (TArr (S n) t2)) -- outgoing cotangents to x0 and input array + -- expression operations EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t diff --git a/src/CHAD.hs b/src/CHAD.hs index cfae98d..ec719e8 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1116,6 +1116,89 @@ drev des accumMap sd = \case e2) }} + EFold1Inner _ commut origef ex₀ earr + | SpArr @_ @sdElt sdElt <- sd + , STArr (SS ndim) eltty :: STy (TArr (S n) elt) <- typeOf earr + , Rets bindsx₀a subtapex₀a (RetPair ex₀1 subx₀ ex₀2 `SCons` RetPair ea1 suba ea2 `SCons` SNil) + <- retConcat des $ toSingleRet (drev des accumMap (spDense (d2M eltty)) ex₀) `SCons` toSingleRet (drev des accumMap (spDense (SMTArr (SS ndim) (d2M eltty))) earr) `SCons` SNil -> + deleteUnused (descrList des) (occEnvPopSome (occEnvPopSome (occCountAll origef))) $ \(usedSub :: Subenv env env') -> + let ef = unsafeWeakenWithSubenv (SEYesR (SEYesR usedSub)) origef in + subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> + accumPromote (d2 eltty) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> + let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in + let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in + let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in + let (mergePrimalBindings', _) = weakenBindingsE (sinkWithBindings bindsx₀a) mergePrimalBindings in + case drev (prodes `DPush` (eltty, Nothing, SMerge) `DPush` (eltty, Nothing, SMerge)) accumMapPro (spDense (d2M eltty)) ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 -> + let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in + let bogTy = STArr (SS ndim) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf))) + primalTy = STPair (STArr ndim (d1 eltty)) bogTy + library = #xy (d1 eltty `SCons` d1 eltty `SCons` SNil) + &. #parr (auto1 @(TArr (S n) (D1 elt))) + &. #px₀ (auto1 @(D1 elt)) + &. #primal (primalTy `SCons` SNil) + &. #darr (auto1 @(TArr n sdElt)) + &. #d (auto1 @(D2 elt)) + &. #x₀abinds (bindingsBinds bindsx₀a) + &. #fbinds (bindingsBinds ef0) + &. #x₀atapebinds (subList (bindingsBinds bindsx₀a) subtapex₀a) + &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf) + &. #ftape (auto1 @(Tape e_tape)) + &. #efPrerebinds efPrerebinds + &. #propr (d1e envPro) + &. #d1env (desD1E des) + &. #d1env' (desD1E usedDes) + &. #d2acUsed (d2ace (select SAccum usedDes)) + &. #d2acEnv (d2ace (select SAccum des)) + &. #d2acPro (d2ace envPro) + &. #foldd2res (auto1 @(TPair (TPair (D2 elt) (TArr (S n) (D2 elt))) (Tup (D2E envPro)))) + wOverPrimalBindings = autoWeak library (#x₀abinds :++: #d1env) ((#propr :++: #x₀abinds) :++: #d1env) in + subenvPlus SF SF (d2eM (select SMerge des)) subx₀ suba $ \subx₀a _ _ plus_x₀_a -> + subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) $ \subx₀af _ _ plus_x₀a_f -> + Ret (bconcat bindsx₀a mergePrimalBindings' + `bpush` weakenExpr wOverPrimalBindings ex₀1 + `bpush` weakenExpr (WSink .> wOverPrimalBindings) ea1 + `bpush` EFold1InnerD1 ext commut + (letBinds (fst (weakenBindingsE (autoWeak library + (#xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#xy :++: #parr :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env)) + ef0)) $ + EPair ext + (weakenExpr (autoWeak library (#fbinds :++: #xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#fbinds :++: #xy :++: #parr :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env)) + ef1) + (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: #xy :++: #parr :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env)))) + (EVar ext (d1 eltty) (IS IZ)) + (EVar ext (STArr (SS ndim) (d1 eltty)) IZ)) + (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e envPro)))))) + (EFst ext (EVar ext primalTy IZ)) + subx₀af + (let layout1 = #darr :++: #primal :++: #parr :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in + elet + (uninvertTup (d2e envPro) (STPair (d2 eltty) (STArr (SS ndim) (d2 eltty))) $ + makeAccumulators (autoWeak library #propr layout1) envPro $ + let layout2 = #d2acPro :++: layout1 in + EFold1InnerD2 ext commut (d2M eltty) + (letBinds (efRebinds (IS (IS (IS IZ)))) $ + let layout3 = (#ftapebinds :++: #efPrerebinds) :++: #d :++: #xy :++: #ftape :++: layout2 in + elet (expandSubenvZeros (autoWeak library #xy layout3) (eltty `SCons` eltty `SCons` SNil) subEf $ + weakenExpr (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) layout3 + .> wPro (subList (bindingsBinds ef0) subtapeEf)) + ef2) $ + EPair ext (ESnd ext (EFst ext (evar IZ))) (ESnd ext (evar IZ))) + (EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ)) + (ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)) + (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ))) + (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ)) + (EFst ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))) $ + plus_x₀a_f + (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $ + plus_x₀_a + (subst0 (EFst ext (EFst ext (evar IZ))) ex₀2) + (subst0 (ESnd ext (EFst ext (evar IZ))) ea2)) + (ESnd ext (evar IZ))) + } + EUnit _ e | SpArr sdElt <- sd , Ret e0 subtape e1 sub e2 <- drev des accumMap sdElt e -> @@ -1229,9 +1312,6 @@ drev des accumMap sd = \case EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e - -- These should be the next to be implemented, I think - EFold1Inner{} -> err_unsupported "EFold1Inner" - ENothing{} -> err_unsupported "ENothing" EJust{} -> err_unsupported "EJust" EMaybe{} -> err_unsupported "EMaybe" @@ -1246,10 +1326,14 @@ drev des accumMap sd = \case EPlus{} -> err_monoid EOneHot{} -> err_monoid + EFold1InnerD1{} -> err_targetlang "EFold1InnerD1" + EFold1InnerD2{} -> err_targetlang "EFold1InnerD2" + where err_accum = error "Accumulator operations unsupported in the source program" err_monoid = error "Monoid operations unsupported in the source program" err_unsupported s = error $ "CHAD: unsupported " ++ s + err_targetlang s = error $ "CHAD: Target language operation " ++ s ++ " not supported in source program" contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs) contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub)) diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs index 7212232..a7bc53f 100644 --- a/src/CHAD/Accum.hs +++ b/src/CHAD/Accum.hs @@ -44,6 +44,11 @@ d2deepZeroInfo (STArr _ t) e = emap (d2deepZeroInfo t (EVar ext (d1 t) IZ)) e d2deepZeroInfo (STScal t) _ | Refl <- lemDeepZeroInfoScal t = ENil ext d2deepZeroInfo STAccum{} _ = error "accumulators not allowed in source program" +-- The weakening is necessary because we need to initialise the created +-- accumulators with zeros. Those zeros are deep and need full primals. This +-- means, in the end, that primals corresponding to environment entries +-- promoted to an accumulator with accumPromote in CHAD need to be stored for +-- the dual. makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) makeAccumulators _ SNil e = e makeAccumulators w (t `SCons` envpro) e = |
