diff options
Diffstat (limited to 'src/CHAD.hs')
| -rw-r--r-- | src/CHAD.hs | 346 |
1 files changed, 153 insertions, 193 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index 241825e..143376a 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -34,7 +34,6 @@ module CHAD ( import Data.Functor.Const import Data.Some -import Data.Type.Bool (If) import Data.Type.Equality (type (==), testEquality) import GHC.Stack (HasCallStack) @@ -45,6 +44,7 @@ import AST.Count import AST.Env import AST.Sparse import AST.Weaken.Auto +import CHAD.Accum import CHAD.EnvDescr import CHAD.Types import Data @@ -348,34 +348,8 @@ opt2UnSparse = go . opt2 go t _ = error $ "Primitive operations that return " ++ show t ++ " are scary" ------------------------------------- MONOIDS ----------------------------------- - -d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t)) -d2zeroInfo STNil _ = ENil ext -d2zeroInfo (STPair a b) e = - eunPair e $ \_ e1 e2 -> - EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2) -d2zeroInfo STEither{} _ = ENil ext -d2zeroInfo STLEither{} _ = ENil ext -d2zeroInfo STMaybe{} _ = ENil ext -d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e -d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext -d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program" - -zeroTup :: SList STy env0 -> D1E env0 :> env -> Ex env (Tup (D2E env0)) -zeroTup SNil _ = ENil ext -zeroTup (t `SCons` env) w = - EPair ext (zeroTup env (WPop w)) - (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) - - ----------------------------------- SPARSITY ----------------------------------- -subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env') -subenvD1E SETop = SETop -subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub) -subenvD1E (SENo sub) = SENo (subenvD1E sub) - expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a) expandSparse t sp _ e | Just Refl <- isDense (d2M t) sp = e expandSparse t (SpSparse sp) epr e = @@ -430,7 +404,8 @@ subenvPlus :: SBool req1 -> SBool req2 -> (forall e. Ex e (Tup env1) -> Ex e (Tup env2) -> Ex e (Tup env3)) -> r) -> r -subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\_ _ -> ENil ext) +-- don't destroy effects! +subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\a b -> use a $ use b $ ENil ext) subenvPlus req1 req2 (SCons _ env) (SENo sub1) (SENo sub2) k = subenvPlus req1 req2 env sub1 sub2 $ \sub3 s31 s32 pl -> @@ -448,18 +423,31 @@ subenvPlus req1 SF (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k = EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) (weakenExpr WSink e2)) (ESnd ext (EVar ext (typeOf e1) IZ))) -subenvPlus req1 ST (SCons t env) (SEYes sp1 sub1) (SENo sub2) k = - subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl -> - k (SEYes (SpSparse sp1) sub3) - (withInj minj13 $ \inj13 -> - \e1 -> eunPair e1 $ \_ e1a e1b -> - EPair ext (inj13 e1a) (EJust ext e1b)) - (Inj $ \e2 -> EPair ext (inj23 e2) (ENothing ext (applySparse sp1 (fromSMTy t)))) - (\e1 e2 -> - ELet ext e1 $ - EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) - (weakenExpr WSink e2)) - (EJust ext (ESnd ext (EVar ext (typeOf e1) IZ)))) +subenvPlus req1 ST (SCons t env) (SEYes sp1 sub1) (SENo sub2) k + | Just zero1 <- cheapZero (applySparse sp1 t) = + subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl -> + k (SEYes sp1 sub3) + (withInj minj13 $ \inj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) e1b) + (Inj $ \e2 -> EPair ext (inj23 e2) zero1) + (\e1 e2 -> + ELet ext e1 $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) + (weakenExpr WSink e2)) + (ESnd ext (EVar ext (typeOf e1) IZ))) + | otherwise = + subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl -> + k (SEYes (SpSparse sp1) sub3) + (withInj minj13 $ \inj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) (EJust ext e1b)) + (Inj $ \e2 -> EPair ext (inj23 e2) (ENothing ext (applySparse sp1 (fromSMTy t)))) + (\e1 e2 -> + ELet ext e1 $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) + (weakenExpr WSink e2)) + (EJust ext (ESnd ext (EVar ext (typeOf e1) IZ)))) subenvPlus req1 req2 (SCons t env) sub1@SENo{} sub2@SEYes{} k = subenvPlus req2 req1 (SCons t env) sub2 sub1 $ \sub3 minj23 minj13 pl -> @@ -505,23 +493,6 @@ assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty" --------------------------------- ACCUMULATORS --------------------------------- -makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) -makeAccumulators _ SNil e = e -makeAccumulators w (t `SCons` envpro) e = - makeAccumulators (WPop w) envpro $ - EWith ext (d2M t) (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e - -uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) -uninvertTup SNil _ e = EPair ext e (ENil ext) -uninvertTup (t `SCons` list) tcore e = - ELet ext (uninvertTup list (STPair tcore t) e) $ - let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding - in EPair ext - (EFst ext (EFst ext (EVar ext recT IZ))) - (EPair ext - (ESnd ext (EVar ext recT IZ)) - (ESnd ext (EFst ext (EVar ext recT IZ)))) - fromArrayValId :: Maybe (ValId t) -> Maybe Int fromArrayValId (Just (VIArr i _)) = Just i fromArrayValId _ = Nothing @@ -780,7 +751,7 @@ drev des accumMap (SpSparse sd) = subtape e1 sub' - (emaybe (evar IZ) + (emaybe (EVar ext (STMaybe (applySparse sd (d2 (typeOf e)))) IZ) (inj2 (ENil ext)) (inj1 (weakenExpr (WCopy WSink) e2))) } @@ -794,7 +765,7 @@ drev des accumMap sd = \case (EVar ext (d1 t) (conv1Idx i)) (subenvNone (d2e (select SMerge des))) (let ty = applySparse sd (d2M t) - in EAccum ext (d2M t) (_ sd) (ENil ext) (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI))) + in EAccum ext (d2M t) SAPHere (ENil ext) sd (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI))) Idx2Me tupI -> Ret BTop @@ -1094,42 +1065,39 @@ drev des accumMap sd = \case case lemAppendNil @e_binds of { Refl -> let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in let collectexpr = bindingsCollectTape (bindingsBinds e0) subtapeE in - Ret (BTop `BPush` (shty, drevPrimal des she) - `BPush` (STArr ndim (STPair (d1 eltty) tapety) - ,EBuild ext ndim - (EVar ext shty IZ) - (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil) - &. #sh (shty `SCons` SNil) - &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes)) - (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#ix :++: #sh :++: #d1env)) - e0)) $ - let w = autoWeak (#ix (shty `SCons` SNil) - &. #sh (shty `SCons` SNil) - &. #e0 (bindingsBinds e0) - &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes)) - (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#e0 :++: #ix :++: #sh :++: #d1env) - w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env')) - in EPair ext (weakenExpr w e1) (collectexpr w'))) - `BPush` (STArr ndim tapety, emap (ESnd ext (EVar ext (STPair (d1 eltty) tapety) IZ)) - (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ))) - (SEYesR (SENo (SEYesR SETop))) - (emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ)) - (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) + let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in + let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in + Ret (mergePrimalBindings + `BPush` (shty, weakenExpr (wSinks (d1e envPro)) (drevPrimal des she)) + `BPush` (STArr ndim (STPair (d1 eltty) tapety) + ,EBuild ext ndim + (EVar ext shty IZ) + (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil) + &. #sh (shty `SCons` SNil) + &. #propr (d1e envPro) + &. #d1env (desD1E des) + &. #d1env' (desD1E usedDes)) + (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#ix :++: #sh :++: #propr :++: #d1env)) + e0)) $ + let w = autoWeak (#ix (shty `SCons` SNil) + &. #sh (shty `SCons` SNil) + &. #e0 (bindingsBinds e0) + &. #propr (d1e envPro) + &. #d1env (desD1E des) + &. #d1env' (desD1E usedDes)) + (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#e0 :++: #ix :++: #sh :++: #propr :++: #d1env) + w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env')) + in EPair ext (weakenExpr w e1) (collectexpr w'))) + `BPush` (STArr ndim tapety, emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ))) + (SEYesR (SENo (SEYesR (subenvAll (d1e envPro))))) + (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) - (let sinkOverEnvPro = wSinks @(sd : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in + (let sinkOverEnvPro = wSinks @(sd : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace envPro) in ESnd ext $ uninvertTup (d2e envPro) (STArr ndim STNil) $ - -- TODO: what's happening here is that because of the sparsity - -- rewrite, makeAccumulators needs primals where it previously - -- didn't. The build derivative is currently not saving those - -- primals, so the hole below cannot currently be filled. The - -- appropriate primals (waves hands) need to be stored, so that a - -- weakening can be provided here. - makeAccumulators @_ @_ @(TArr ndim TNil) (_ (subenvCompose subMergeUsed proSub)) envPro $ + makeAccumulators @_ @_ @(TArr ndim TNil) (WSink .> WSink .> WSink .> wRaiseAbove (d1e envPro) (d2ace (select SAccum des))) envPro $ EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $ -- the cotangent for this element ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ)) @@ -1148,10 +1116,11 @@ drev des accumMap sd = \case &. #darr (auto1 @(TArr ndim sdElt)) &. #tapearr (auto1 @(TArr ndim (Tape e_tape))) &. #sh (auto1 @shty) + &. #propr (d1e envPro) &. #d2acUsed (d2ace (select SAccum usedDes)) &. #d2acEnv (d2ace (select SAccum des))) (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) - ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #d2acEnv) + ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #propr :++: #d2acEnv) .> wPro (subList (bindingsBinds e0) subtapeE)) e2) }}} @@ -1167,32 +1136,34 @@ drev des accumMap sd = \case weakenExpr (WCopy WSink) e2) EReplicate1Inner _ en e - -- We're allowed to ignore en2 here because the output of 'ei' is discrete. - | Rets binds subtape (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil) - <- retConcat des $ drev des accumMap en `SCons` drev des accumMap e `SCons` SNil + -- We're allowed to differentiate 'en' as primal-only here because its output is discrete. + | SpArr sdElt <- sd , let STArr ndim eltty = typeOf e -> - Ret binds - subtape - (EReplicate1Inner ext en1 e1) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EJust ext (EFold1Inner ext Commut - (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) - (ezeroD2 eltty) - (EVar ext (STArr (SS ndim) (d2 eltty)) IZ))) $ - weakenExpr (WCopy (WSink .> WSink)) e2) - (EVar ext (d2 (STArr (SS ndim) eltty)) IZ)) + -- This pessimistic sparsity union is because the array might have been empty, in which case we need to generate a zero. + sparsePlusS ST ST (d2M eltty) sdElt SpAbsent $ \sdElt' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sdElt') e of { Ret binds subtape e1 sub e2 -> + Ret binds + subtape + (EReplicate1Inner ext (weakenExpr (wSinks (bindingsBinds binds)) (drevPrimal des en)) e1) + sub + (ELet ext (EFold1Inner ext Commut + (sparsePlus (d2M eltty) sdElt' + (EVar ext (applySparse sdElt' (d2 eltty)) (IS IZ)) + (EVar ext (applySparse sdElt' (d2 eltty)) IZ)) + (inj2 (ENil ext)) + (emap (inj1 (evar IZ)) $ EVar ext (STArr (SS ndim) (applySparse sdElt (d2 eltty))) IZ)) $ + weakenExpr (WCopy WSink) e2) + } EIdx0 _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e + | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd) e , STArr _ t <- typeOf e -> Ret e0 subtape (EIdx0 ext e1) sub - (ELet ext (EJust ext (EUnit ext (EVar ext (d2 t) IZ))) $ - weakenExpr (WCopy WSink) e2) + (ELet ext (EUnit ext (EVar ext (applySparse sd (d2 t)) IZ)) $ + weakenExpr (WCopy WSink) e2) EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead" {- @@ -1214,57 +1185,58 @@ drev des accumMap sd = \case -} EIdx _ e ei - -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. - | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) - <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil - , STArr n eltty <- typeOf e + -- We're allowed to differentiate ei as primal because its output is discrete. + | STArr n eltty <- typeOf e , Refl <- indexTupD1Id n - , Refl <- lemZeroInfoD2 eltty - , let tIxN = tTup (sreplicate n tIx) -> - Ret (binds `BPush` (STArr n (d1 eltty), e1) - `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ)) - `BPush` (tIxN, weakenExpr (WSink .> WSink) ei1)) - (SEYesR (SEYesR (SENo subtape))) - (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) IZ)) - sub - (ELet ext (EOneHot ext (d2M (STArr n eltty)) (SAPJust (SAPArrIdx SAPHere)) - (EPair ext (EPair ext (EVar ext tIxN (IS IZ)) - (EBuild ext n (EVar ext tIxN (IS (IS IZ))) (ENil ext))) - (ENil ext)) - (EVar ext (d2 eltty) IZ)) $ - weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + , let tIxN = tTup (sreplicate n tIx) -> + sparsePlusS ST ST (d2M eltty) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sd') e of { Ret binds subtape e1 sub e2 -> + Ret (binds `BPush` (STArr n (d1 eltty), e1) + `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ)) + `BPush` (tIxN, weakenExpr (WSink .> WSink .> wSinks (bindingsBinds binds)) (drevPrimal des ei))) + (SEYesR (SEYesR (SENo subtape))) + (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) IZ)) + sub + (ELet ext + (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty))) + (SAPArrIdx SAPHere) + (EPair ext + (EPair ext (EVar ext tIxN (IS IZ)) + (EBuild ext n (EVar ext tIxN (IS (IS IZ))) $ + makeZeroInfo (applySparse sd' (d2M eltty)) (inj2 (ENil ext)))) + (ENil ext)) + (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $ + weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + } EShape _ e - -- Allowed to ignore e2 here because the output of EShape is discrete, - -- hence we'd be passing a zero cotangent to e2 anyway. - | Ret e0 subtape e1 _ _ <- drev des accumMap e - , STArr n _ <- typeOf e + -- Allowed to differentiate e as primal because the output of EShape is + -- discrete, hence we'd be passing a zero cotangent to e anyway. + | STArr n _ <- typeOf e , Refl <- indexTupD1Id n -> - Ret e0 - subtape - (EShape ext e1) - (subenvNone (select SMerge des)) + Ret BTop + SETop + (EShape ext (drevPrimal des e)) + (subenvNone (d2eM (select SMerge des))) (ENil ext) ESum1Inner _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e + | SpArr sd' <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e , STArr (SS n) t <- typeOf e -> Ret (e0 `BPush` (STArr (SS n) t, e1) `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ))) (SEYesR (SENo subtape)) (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ))) sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EJust ext (EReplicate1Inner ext - (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS (IS IZ)))) - (EVar ext (STArr n (d2 t)) IZ))) $ - weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) - (EVar ext (d2 (STArr n t)) IZ)) + (ELet ext (EReplicate1Inner ext + (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) + (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e2) - EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e - EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e + EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e + EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e -- These should be the next to be implemented, I think EFold1Inner{} -> err_unsupported "EFold1Inner" @@ -1279,6 +1251,7 @@ drev des accumMap sd = \case EWith{} -> err_accum EZero{} -> err_monoid + EDeepZero{} -> err_monoid EPlus{} -> err_monoid EOneHot{} -> err_monoid @@ -1287,35 +1260,35 @@ drev des accumMap sd = \case err_monoid = error "Monoid operations unsupported in the source program" err_unsupported s = error $ "CHAD: unsupported " ++ s - deriv_extremum :: ScalIsNumeric t' ~ True - => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t'))) - -> Sparse (TArr n (D2s t')) sd' - -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto sd' (TArr n (TScal t')) - deriv_extremum extremum e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , at@(STArr (SS n) t@(STScal st)) <- typeOf e - , let at' = STArr n t - , let tIxN = tTup (sreplicate (SS n) tIx) = - Ret (e0 `BPush` (at, e1) - `BPush` (at', extremum (EVar ext at IZ))) - (SEYesR (SEYesR subtape)) - (EVar ext at' IZ) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EJust ext - (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS (IS IZ))))) $ - eif (EOp ext (OEq st) (EPair ext - (EIdx ext (EVar ext at (IS (IS (IS (IS IZ))))) (EVar ext tIxN IZ)) - (EIdx ext (EVar ext at' (IS (IS (IS IZ)))) (EFst ext (EVar ext tIxN IZ))))) - (EIdx ext (EVar ext (STArr n (d2 t)) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) - (ezeroD2 t))) $ - weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2) - (EVar ext (d2 at') IZ)) - contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs) contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub)) +deriv_extremum :: (?config :: CHADConfig, ScalIsNumeric t ~ True) + => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) + -> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) + -> Sparse (D2s t) sd + -> Expr ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t)) +deriv_extremum extremum des accumMap sd e + | at@(STArr (SS n) t@(STScal st)) <- typeOf e + , let at' = STArr n t + , let tIxN = tTup (sreplicate (SS n) tIx) = + sparsePlusS ST ST (d2M t) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sd') e of { Ret e0 subtape e1 sub e2 -> + Ret (e0 `BPush` (at, e1) + `BPush` (at', extremum (EVar ext at IZ))) + (SEYesR (SEYesR subtape)) + (EVar ext at' IZ) + sub + (ELet ext + (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $ + eif (EOp ext (OEq st) (EPair ext + (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ)) + (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ))))) + (inj1 $ EIdx ext (EVar ext (STArr n (applySparse sd (d2 t))) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) + (inj2 (ENil ext))) $ + weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + } + data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s) data RetScoped env0 sto a s sd t = @@ -1351,7 +1324,8 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of SENo sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' SpAbsent (EPair ext e2 (ENil ext)) SAccum - | Just (VIArr i _) <- argids + | chcSmartWith ?config + , Just (VIArr i _) <- argids , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap , Just Refl <- testEquality foundTy (STAccum (d2M argty)) , Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) sd expr @@ -1380,9 +1354,9 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of &. #ac (auto1 @(TAccum (D2 a))) &. #tl (d2ace (select SAccum des)) in - RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub SpDense $ + RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub (spDense (d2M argty)) $ let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in - EWith ext (d2M argty) (EZero ext (d2M argty) (d2zeroInfo argty (EVar ext (d1 argty) primalIdx))) $ + EWith ext (d2M argty) (EDeepZero ext (d2M argty) (d2deepZeroInfo argty (EVar ext (d1 argty) primalIdx))) $ weakenExpr (autoWeak library (#d :++: #body :++: #ac :++: #tl) (#ac :++: #d :++: (#body :++: #p) :++: #tl)) @@ -1396,20 +1370,6 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of -- TODO: proper primal-only transform that doesn't depend on D1 = Id drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t) drevPrimal des e - | Refl <- chadD1Id (typeOf e) - , Refl <- chadD1EId (descrList des) + | Refl <- d1Identity (typeOf e) + , Refl <- d1eIdentity (descrList des) = mapExt (const ext) e - where - chadD1Id :: STy a -> D1 a :~: a - chadD1Id STNil = Refl - chadD1Id (STPair a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl - chadD1Id (STEither a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl - chadD1Id (STLEither a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl - chadD1Id (STMaybe a) | Refl <- chadD1Id a = Refl - chadD1Id (STArr _ a) | Refl <- chadD1Id a = Refl - chadD1Id (STScal _) = Refl - chadD1Id STAccum{} = error "accumulators not allowed in source program" - - chadD1EId :: SList STy l -> D1E l :~: l - chadD1EId SNil = Refl - chadD1EId (SCons t l) | Refl <- chadD1Id t, Refl <- chadD1EId l = Refl |
