diff options
Diffstat (limited to 'src/CHAD.hs')
| -rw-r--r-- | src/CHAD.hs | 210 |
1 files changed, 146 insertions, 64 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index 7594a0f..72ce36d 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1048,73 +1048,46 @@ drev des accumMap sd = \case , let eltty = typeOf orige , shty :: STy shty <- tTup (sreplicate ndim tIx) , Refl <- indexTupD1Id ndim -> - deleteUnused (descrList des) (occEnvPopSome (occCountAll orige)) $ \(usedSub :: Subenv env env') -> - let e = unsafeWeakenWithSubenv (SEYesR usedSub) orige in - subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> - accumPromote sdElt usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> - let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in - case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro sdElt e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 SETop e2 -> - case lemAppendNil @e_binds of { Refl -> - let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in - let collectexpr = bindingsCollectTape (bindingsBinds e0) subtapeE in - let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in - let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in - Ret (mergePrimalBindings - `bpush` weakenExpr (wSinks (d1e envPro)) (drevPrimal des she) + drevLambda des accumMap (shty, SDiscr) sdElt orige $ \(provars :: SList _ envPro) esub proPrimalBinds e0 e1 (e1tape :: Ex _ e_tape) _ wrapAccum e2 -> + let library = #ix (shty `SCons` SNil) + &. #e0 (bindingsBinds e0) + &. #propr (d1e provars) + &. #d1env (desD1E des) + &. #d (auto1 @sdElt) + &. #tape (auto1 @e_tape) + &. #pro (d2ace provars) + &. #d2acEnv (d2ace (select SAccum des)) + &. #darr (auto1 @(TArr ndim sdElt)) + &. #tapearr (auto1 @(TArr ndim e_tape)) in + Ret (proPrimalBinds `bpush` EBuild ext ndim - (EVar ext shty IZ) - (letBinds (fst (weakenBindingsE (autoWeak (#ix (shty `SCons` SNil) - &. #sh (shty `SCons` SNil) - &. #propr (d1e envPro) - &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes)) - (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#ix :++: #sh :++: #propr :++: #d1env)) + (weakenExpr (wSinks (d1e provars)) (drevPrimal des she)) + (letBinds (fst (weakenBindingsE (autoWeak library + (#ix :++: #d1env) + (#ix :++: #propr :++: #d1env)) e0)) $ - let w = autoWeak (#ix (shty `SCons` SNil) - &. #sh (shty `SCons` SNil) - &. #e0 (bindingsBinds e0) - &. #propr (d1e envPro) - &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes)) - (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#e0 :++: #ix :++: #sh :++: #propr :++: #d1env) - w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env')) - in EPair ext (weakenExpr w e1) (collectexpr w')) - `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ)) - (SEYesR (SENo (SEYesR (subenvAll (d1e envPro))))) - (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) - (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) - (let sinkOverEnvPro = wSinks @(sd : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace envPro) in + weakenExpr (autoWeak library (#e0 :++: #ix :++: #d1env) + (#e0 :++: #ix :++: #propr :++: #d1env)) + (EPair ext e1 e1tape)) + `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) IZ)) + (SEYesR (SENo (subenvAll (d1e provars)))) + (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) (IS IZ))) + (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) esub) + (let sinkOverEnvPro = wSinks @(sd : TArr ndim e_tape : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace provars) in ESnd ext $ - uninvertTup (d2e envPro) (STArr ndim STNil) $ - makeAccumulators @_ @_ @(TArr ndim TNil) (WSink .> WSink .> WSink .> wRaiseAbove (d1e envPro) (d2ace (select SAccum des))) envPro $ - EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $ - -- the cotangent for this element - ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ)) - (EVar ext shty IZ)) $ - -- the tape for this element - ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) - (EVar ext shty (IS IZ))) $ - let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) - in letBinds (rebinds IZ) $ - weakenExpr (autoWeak (#d (auto1 @sdElt) - &. #pro (d2ace envPro) - &. #etape (subList (bindingsBinds e0) subtapeE) - &. #prerebinds prerebinds - &. #tape (auto1 @(Tape e_tape)) - &. #ix (auto1 @shty) - &. #darr (auto1 @(TArr ndim sdElt)) - &. #tapearr (auto1 @(TArr ndim (Tape e_tape))) - &. #sh (auto1 @shty) - &. #propr (d1e envPro) - &. #d2acUsed (d2ace (select SAccum usedDes)) - &. #d2acEnv (d2ace (select SAccum des))) - (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) - ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #propr :++: #d2acEnv) - .> wPro (subList (bindingsBinds e0) subtapeE)) - e2) - }} + wrapAccum (WSink .> WSink .> wRaiseAbove (d1e provars) (d2ace (select SAccum des))) $ + EBuild ext ndim (EShape ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (sinkOverEnvPro @> IZ))) $ + -- the tape for this element + ELet ext (EIdx ext (EVar ext (STArr ndim (typeOf e1tape)) (WSink .> sinkOverEnvPro @> IS IZ)) + (EVar ext shty IZ)) $ + -- the cotangent for this element + ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> WSink .> sinkOverEnvPro @> IZ)) + (EVar ext shty (IS IZ))) $ + weakenExpr (autoWeak library (#d :++: #tape :++: #pro :++: #d2acEnv) + (#d :++: #tape :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv)) + e2) + + EMap{} -> undefined EFold1Inner _ commut origef ex₀ earr | SpArr @_ @sdElt sdElt <- sd @@ -1346,6 +1319,33 @@ drev des accumMap sd = \case (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $ weakenExpr (WCopy (WSink .> WSink)) e2) + EZip _ a b + | SpArr sd' <- sd + , STArr n t1 <- typeOf a + , STArr _ t2 <- typeOf b -> + splitSparsePair (STPair (d2 t1) (d2 t2)) sd' $ \sd1 sd2 pairSplitE -> + case retConcat des (toSingleRet (drev des accumMap (SpArr sd1) a) `SCons` + toSingleRet (drev des accumMap (SpArr sd2) b) `SCons` SNil) of + { Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) -> + subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B -> + Ret binds + subtape + (EZip ext a1 b1) + subBoth + (case pairSplitE of + Left Refl -> + let t' = STArr n (STPair (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 t2))) in + plus_A_B + (elet (emap (EFst ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) a2) + (elet (emap (ESnd ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) b2) + Right f -> f IZ $ \wrapPair pick1 pick2 -> + elet (emap (wrapPair (EPair ext pick1 pick2)) + (EVar ext (applySparse (SpArr sd') (STArr n (STPair (d2 t1) (d2 t2)))) IZ)) $ + plus_A_B + (elet (emap (EFst ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) a2) + (elet (emap (ESnd ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) b2)) + } + ENothing{} -> err_unsupported "ENothing" EJust{} -> err_unsupported "EJust" EMaybe{} -> err_unsupported "EMaybe" @@ -1476,6 +1476,88 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of , Refl <- lemAppendNil @tapebinds -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2 +drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False) + => Descr env sto + -> VarMap Int (D2AcE (Select env sto "accum")) + -> (STy a, Storage s) + -> Sparse (D2 t) dt + -> Expr ValId (a : env) t + -> (forall provars shbinds tape d2a'. + SList STy provars + -> Subenv (D2E (Select env sto "merge")) (D2E provars) + -> Bindings Ex (D1E env) (D1E provars) -- accum-promoted free variables of which we need a primal in the reverse pass (to initialise the accumulator) + -> Bindings Ex (D1 a : D1E env) shbinds + -> Ex (Append shbinds (D1 a : D1E env)) (D1 t) + -> Ex (Append shbinds (D1 a : D1E env)) tape + -> Sparse (D2 a) d2a' + -> (forall env' b. + D1E provars :> env' + -> Ex (Append (D2AcE provars) env') b + -> Ex ( env') (TPair b (Tup (D2E provars)))) + -> Ex (dt : tape : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a' + -> r) + -> r +drevLambda des accumMap (argty, argsto) sd origef k = + let t = typeOf origef in + deleteUnused (descrList des) (occEnvPopSome (occCountAll origef)) $ \(usedSub :: Subenv env env') -> + let ef = unsafeWeakenWithSubenv (SEYesR usedSub) origef in + subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> + accumPromote (applySparse sd (d2 t)) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> + let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in + let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in + let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in + case prf1 prodes argty argsto of { Refl -> + case drev (prodes `DPush` (argty, Nothing, argsto)) accumMapPro sd ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 -> + let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in + extractContrib prodes argty argsto subEf $ \argSp getSparseArg -> + let library = #fbinds (bindingsBinds ef0) + &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf) + &. #ftape (auto1 @(Tape e_tape)) + &. #arg (d1 argty `SCons` SNil) + &. #d (applySparse sd (d2 t) `SCons` SNil) + &. #d1env (desD1E des) + &. #d1env' (desD1E usedDes) + &. #propr (d1e envPro) + &. #d2acUsed (d2ace (select SAccum usedDes)) + &. #d2acEnv (d2ace (select SAccum des)) + &. #d2acPro (d2ace envPro) + &. #efPrerebinds efPrerebinds in + k envPro + (subenvD2E (subenvCompose subMergeUsed proSub)) + mergePrimalBindings + (fst (weakenBindingsE (WCopy (wUndoSubenv subD1eUsed)) ef0)) + (weakenExpr (autoWeak library (#fbinds :++: #arg :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#fbinds :++: #arg :++: #d1env)) + ef1) + (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: #arg :++: #d1env))) + argSp + (\wpro1 body -> + uninvertTup (d2e envPro) (typeOf body) $ + makeAccumulators wpro1 envPro $ + body) + (letBinds (efRebinds (IS IZ)) $ + weakenExpr + (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) + ((#ftapebinds :++: #efPrerebinds) :++: #d :++: #ftape :++: #d2acPro :++: #d2acEnv) + .> wPro (subList (bindingsBinds ef0) subtapeEf)) + (getSparseArg ef2)) + }} + where + extractContrib :: (Select env sto "merge" ~ '[], (s == "accum") ~ False) + => proxy env sto -> proxy2 a -> Storage s + -- if s == "merge", this simplifies to SubenvS '[D2 a] t' + -- if s == "discr", this simplifies to SubenvS '[] t' + -> SubenvS (D2E (Select (a : env) (s : sto) "merge")) t' + -> (forall d'. Sparse (D2 a) d' -> (forall env'. Ex env' (Tup t') -> Ex env' d') -> r) -> r + extractContrib _ _ SMerge (SENo SETop) k' = k' SpAbsent id + extractContrib _ _ SMerge (SEYes s SETop) k' = k' s (ESnd ext) + extractContrib _ _ SDiscr SETop k' = k' SpAbsent id + + prf1 :: (s == "accum") ~ False => proxy env sto -> proxy2 a -> Storage s + -> Select (a : env) (s : sto) "accum" :~: Select env sto "accum" + prf1 _ _ SMerge = Refl + prf1 _ _ SDiscr = Refl + -- TODO: proper primal-only transform that doesn't depend on D1 = Id drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t) drevPrimal des e |
