summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs346
1 files changed, 153 insertions, 193 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 241825e..143376a 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -34,7 +34,6 @@ module CHAD (
import Data.Functor.Const
import Data.Some
-import Data.Type.Bool (If)
import Data.Type.Equality (type (==), testEquality)
import GHC.Stack (HasCallStack)
@@ -45,6 +44,7 @@ import AST.Count
import AST.Env
import AST.Sparse
import AST.Weaken.Auto
+import CHAD.Accum
import CHAD.EnvDescr
import CHAD.Types
import Data
@@ -348,34 +348,8 @@ opt2UnSparse = go . opt2
go t _ = error $ "Primitive operations that return " ++ show t ++ " are scary"
------------------------------------- MONOIDS -----------------------------------
-
-d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t))
-d2zeroInfo STNil _ = ENil ext
-d2zeroInfo (STPair a b) e =
- eunPair e $ \_ e1 e2 ->
- EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2)
-d2zeroInfo STEither{} _ = ENil ext
-d2zeroInfo STLEither{} _ = ENil ext
-d2zeroInfo STMaybe{} _ = ENil ext
-d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e
-d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext
-d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program"
-
-zeroTup :: SList STy env0 -> D1E env0 :> env -> Ex env (Tup (D2E env0))
-zeroTup SNil _ = ENil ext
-zeroTup (t `SCons` env) w =
- EPair ext (zeroTup env (WPop w))
- (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ))))
-
-
----------------------------------- SPARSITY -----------------------------------
-subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env')
-subenvD1E SETop = SETop
-subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub)
-subenvD1E (SENo sub) = SENo (subenvD1E sub)
-
expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a)
expandSparse t sp _ e | Just Refl <- isDense (d2M t) sp = e
expandSparse t (SpSparse sp) epr e =
@@ -430,7 +404,8 @@ subenvPlus :: SBool req1 -> SBool req2
-> (forall e. Ex e (Tup env1) -> Ex e (Tup env2) -> Ex e (Tup env3))
-> r)
-> r
-subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\_ _ -> ENil ext)
+-- don't destroy effects!
+subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\a b -> use a $ use b $ ENil ext)
subenvPlus req1 req2 (SCons _ env) (SENo sub1) (SENo sub2) k =
subenvPlus req1 req2 env sub1 sub2 $ \sub3 s31 s32 pl ->
@@ -448,18 +423,31 @@ subenvPlus req1 SF (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k =
EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
(weakenExpr WSink e2))
(ESnd ext (EVar ext (typeOf e1) IZ)))
-subenvPlus req1 ST (SCons t env) (SEYes sp1 sub1) (SENo sub2) k =
- subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl ->
- k (SEYes (SpSparse sp1) sub3)
- (withInj minj13 $ \inj13 ->
- \e1 -> eunPair e1 $ \_ e1a e1b ->
- EPair ext (inj13 e1a) (EJust ext e1b))
- (Inj $ \e2 -> EPair ext (inj23 e2) (ENothing ext (applySparse sp1 (fromSMTy t))))
- (\e1 e2 ->
- ELet ext e1 $
- EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
- (weakenExpr WSink e2))
- (EJust ext (ESnd ext (EVar ext (typeOf e1) IZ))))
+subenvPlus req1 ST (SCons t env) (SEYes sp1 sub1) (SENo sub2) k
+ | Just zero1 <- cheapZero (applySparse sp1 t) =
+ subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl ->
+ k (SEYes sp1 sub3)
+ (withInj minj13 $ \inj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) e1b)
+ (Inj $ \e2 -> EPair ext (inj23 e2) zero1)
+ (\e1 e2 ->
+ ELet ext e1 $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
+ (weakenExpr WSink e2))
+ (ESnd ext (EVar ext (typeOf e1) IZ)))
+ | otherwise =
+ subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl ->
+ k (SEYes (SpSparse sp1) sub3)
+ (withInj minj13 $ \inj13 ->
+ \e1 -> eunPair e1 $ \_ e1a e1b ->
+ EPair ext (inj13 e1a) (EJust ext e1b))
+ (Inj $ \e2 -> EPair ext (inj23 e2) (ENothing ext (applySparse sp1 (fromSMTy t))))
+ (\e1 e2 ->
+ ELet ext e1 $
+ EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ))
+ (weakenExpr WSink e2))
+ (EJust ext (ESnd ext (EVar ext (typeOf e1) IZ))))
subenvPlus req1 req2 (SCons t env) sub1@SENo{} sub2@SEYes{} k =
subenvPlus req2 req1 (SCons t env) sub2 sub1 $ \sub3 minj23 minj13 pl ->
@@ -505,23 +493,6 @@ assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty"
--------------------------------- ACCUMULATORS ---------------------------------
-makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
-makeAccumulators _ SNil e = e
-makeAccumulators w (t `SCons` envpro) e =
- makeAccumulators (WPop w) envpro $
- EWith ext (d2M t) (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e
-
-uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list))
-uninvertTup SNil _ e = EPair ext e (ENil ext)
-uninvertTup (t `SCons` list) tcore e =
- ELet ext (uninvertTup list (STPair tcore t) e) $
- let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding
- in EPair ext
- (EFst ext (EFst ext (EVar ext recT IZ)))
- (EPair ext
- (ESnd ext (EVar ext recT IZ))
- (ESnd ext (EFst ext (EVar ext recT IZ))))
-
fromArrayValId :: Maybe (ValId t) -> Maybe Int
fromArrayValId (Just (VIArr i _)) = Just i
fromArrayValId _ = Nothing
@@ -780,7 +751,7 @@ drev des accumMap (SpSparse sd) =
subtape
e1
sub'
- (emaybe (evar IZ)
+ (emaybe (EVar ext (STMaybe (applySparse sd (d2 (typeOf e)))) IZ)
(inj2 (ENil ext))
(inj1 (weakenExpr (WCopy WSink) e2)))
}
@@ -794,7 +765,7 @@ drev des accumMap sd = \case
(EVar ext (d1 t) (conv1Idx i))
(subenvNone (d2e (select SMerge des)))
(let ty = applySparse sd (d2M t)
- in EAccum ext (d2M t) (_ sd) (ENil ext) (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI)))
+ in EAccum ext (d2M t) SAPHere (ENil ext) sd (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI)))
Idx2Me tupI ->
Ret BTop
@@ -1094,42 +1065,39 @@ drev des accumMap sd = \case
case lemAppendNil @e_binds of { Refl ->
let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in
let collectexpr = bindingsCollectTape (bindingsBinds e0) subtapeE in
- Ret (BTop `BPush` (shty, drevPrimal des she)
- `BPush` (STArr ndim (STPair (d1 eltty) tapety)
- ,EBuild ext ndim
- (EVar ext shty IZ)
- (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil)
- &. #sh (shty `SCons` SNil)
- &. #d1env (desD1E des)
- &. #d1env' (desD1E usedDes))
- (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
- (#ix :++: #sh :++: #d1env))
- e0)) $
- let w = autoWeak (#ix (shty `SCons` SNil)
- &. #sh (shty `SCons` SNil)
- &. #e0 (bindingsBinds e0)
- &. #d1env (desD1E des)
- &. #d1env' (desD1E usedDes))
- (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
- (#e0 :++: #ix :++: #sh :++: #d1env)
- w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env'))
- in EPair ext (weakenExpr w e1) (collectexpr w')))
- `BPush` (STArr ndim tapety, emap (ESnd ext (EVar ext (STPair (d1 eltty) tapety) IZ))
- (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ)))
- (SEYesR (SENo (SEYesR SETop)))
- (emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ))
- (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ)))
+ let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in
+ let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in
+ Ret (mergePrimalBindings
+ `BPush` (shty, weakenExpr (wSinks (d1e envPro)) (drevPrimal des she))
+ `BPush` (STArr ndim (STPair (d1 eltty) tapety)
+ ,EBuild ext ndim
+ (EVar ext shty IZ)
+ (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil)
+ &. #sh (shty `SCons` SNil)
+ &. #propr (d1e envPro)
+ &. #d1env (desD1E des)
+ &. #d1env' (desD1E usedDes))
+ (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
+ (#ix :++: #sh :++: #propr :++: #d1env))
+ e0)) $
+ let w = autoWeak (#ix (shty `SCons` SNil)
+ &. #sh (shty `SCons` SNil)
+ &. #e0 (bindingsBinds e0)
+ &. #propr (d1e envPro)
+ &. #d1env (desD1E des)
+ &. #d1env' (desD1E usedDes))
+ (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
+ (#e0 :++: #ix :++: #sh :++: #propr :++: #d1env)
+ w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env'))
+ in EPair ext (weakenExpr w e1) (collectexpr w')))
+ `BPush` (STArr ndim tapety, emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ)))
+ (SEYesR (SENo (SEYesR (subenvAll (d1e envPro)))))
+ (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ)))
(subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub)))
- (let sinkOverEnvPro = wSinks @(sd : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in
+ (let sinkOverEnvPro = wSinks @(sd : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace envPro) in
ESnd ext $
uninvertTup (d2e envPro) (STArr ndim STNil) $
- -- TODO: what's happening here is that because of the sparsity
- -- rewrite, makeAccumulators needs primals where it previously
- -- didn't. The build derivative is currently not saving those
- -- primals, so the hole below cannot currently be filled. The
- -- appropriate primals (waves hands) need to be stored, so that a
- -- weakening can be provided here.
- makeAccumulators @_ @_ @(TArr ndim TNil) (_ (subenvCompose subMergeUsed proSub)) envPro $
+ makeAccumulators @_ @_ @(TArr ndim TNil) (WSink .> WSink .> WSink .> wRaiseAbove (d1e envPro) (d2ace (select SAccum des))) envPro $
EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $
-- the cotangent for this element
ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ))
@@ -1148,10 +1116,11 @@ drev des accumMap sd = \case
&. #darr (auto1 @(TArr ndim sdElt))
&. #tapearr (auto1 @(TArr ndim (Tape e_tape)))
&. #sh (auto1 @shty)
+ &. #propr (d1e envPro)
&. #d2acUsed (d2ace (select SAccum usedDes))
&. #d2acEnv (d2ace (select SAccum des)))
(#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
- ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #d2acEnv)
+ ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #propr :++: #d2acEnv)
.> wPro (subList (bindingsBinds e0) subtapeE))
e2)
}}}
@@ -1167,32 +1136,34 @@ drev des accumMap sd = \case
weakenExpr (WCopy WSink) e2)
EReplicate1Inner _ en e
- -- We're allowed to ignore en2 here because the output of 'ei' is discrete.
- | Rets binds subtape (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil)
- <- retConcat des $ drev des accumMap en `SCons` drev des accumMap e `SCons` SNil
+ -- We're allowed to differentiate 'en' as primal-only here because its output is discrete.
+ | SpArr sdElt <- sd
, let STArr ndim eltty = typeOf e ->
- Ret binds
- subtape
- (EReplicate1Inner ext en1 e1)
- sub
- (EMaybe ext
- (zeroTup (subList (select SMerge des) sub))
- (ELet ext (EJust ext (EFold1Inner ext Commut
- (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))
- (ezeroD2 eltty)
- (EVar ext (STArr (SS ndim) (d2 eltty)) IZ))) $
- weakenExpr (WCopy (WSink .> WSink)) e2)
- (EVar ext (d2 (STArr (SS ndim) eltty)) IZ))
+ -- This pessimistic sparsity union is because the array might have been empty, in which case we need to generate a zero.
+ sparsePlusS ST ST (d2M eltty) sdElt SpAbsent $ \sdElt' (Inj inj1) (Inj inj2) _ ->
+ case drev des accumMap (SpArr sdElt') e of { Ret binds subtape e1 sub e2 ->
+ Ret binds
+ subtape
+ (EReplicate1Inner ext (weakenExpr (wSinks (bindingsBinds binds)) (drevPrimal des en)) e1)
+ sub
+ (ELet ext (EFold1Inner ext Commut
+ (sparsePlus (d2M eltty) sdElt'
+ (EVar ext (applySparse sdElt' (d2 eltty)) (IS IZ))
+ (EVar ext (applySparse sdElt' (d2 eltty)) IZ))
+ (inj2 (ENil ext))
+ (emap (inj1 (evar IZ)) $ EVar ext (STArr (SS ndim) (applySparse sdElt (d2 eltty))) IZ)) $
+ weakenExpr (WCopy WSink) e2)
+ }
EIdx0 _ e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd) e
, STArr _ t <- typeOf e ->
Ret e0
subtape
(EIdx0 ext e1)
sub
- (ELet ext (EJust ext (EUnit ext (EVar ext (d2 t) IZ))) $
- weakenExpr (WCopy WSink) e2)
+ (ELet ext (EUnit ext (EVar ext (applySparse sd (d2 t)) IZ)) $
+ weakenExpr (WCopy WSink) e2)
EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead"
{-
@@ -1214,57 +1185,58 @@ drev des accumMap sd = \case
-}
EIdx _ e ei
- -- We're allowed to ignore ei2 here because the output of 'ei' is discrete.
- | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)
- <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil
- , STArr n eltty <- typeOf e
+ -- We're allowed to differentiate ei as primal because its output is discrete.
+ | STArr n eltty <- typeOf e
, Refl <- indexTupD1Id n
- , Refl <- lemZeroInfoD2 eltty
- , let tIxN = tTup (sreplicate n tIx) ->
- Ret (binds `BPush` (STArr n (d1 eltty), e1)
- `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ))
- `BPush` (tIxN, weakenExpr (WSink .> WSink) ei1))
- (SEYesR (SEYesR (SENo subtape)))
- (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ)))
- (EVar ext (tTup (sreplicate n tIx)) IZ))
- sub
- (ELet ext (EOneHot ext (d2M (STArr n eltty)) (SAPJust (SAPArrIdx SAPHere))
- (EPair ext (EPair ext (EVar ext tIxN (IS IZ))
- (EBuild ext n (EVar ext tIxN (IS (IS IZ))) (ENil ext)))
- (ENil ext))
- (EVar ext (d2 eltty) IZ)) $
- weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+ , let tIxN = tTup (sreplicate n tIx) ->
+ sparsePlusS ST ST (d2M eltty) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ ->
+ case drev des accumMap (SpArr sd') e of { Ret binds subtape e1 sub e2 ->
+ Ret (binds `BPush` (STArr n (d1 eltty), e1)
+ `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ))
+ `BPush` (tIxN, weakenExpr (WSink .> WSink .> wSinks (bindingsBinds binds)) (drevPrimal des ei)))
+ (SEYesR (SEYesR (SENo subtape)))
+ (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ)))
+ (EVar ext (tTup (sreplicate n tIx)) IZ))
+ sub
+ (ELet ext
+ (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty)))
+ (SAPArrIdx SAPHere)
+ (EPair ext
+ (EPair ext (EVar ext tIxN (IS IZ))
+ (EBuild ext n (EVar ext tIxN (IS (IS IZ))) $
+ makeZeroInfo (applySparse sd' (d2M eltty)) (inj2 (ENil ext))))
+ (ENil ext))
+ (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+ }
EShape _ e
- -- Allowed to ignore e2 here because the output of EShape is discrete,
- -- hence we'd be passing a zero cotangent to e2 anyway.
- | Ret e0 subtape e1 _ _ <- drev des accumMap e
- , STArr n _ <- typeOf e
+ -- Allowed to differentiate e as primal because the output of EShape is
+ -- discrete, hence we'd be passing a zero cotangent to e anyway.
+ | STArr n _ <- typeOf e
, Refl <- indexTupD1Id n ->
- Ret e0
- subtape
- (EShape ext e1)
- (subenvNone (select SMerge des))
+ Ret BTop
+ SETop
+ (EShape ext (drevPrimal des e))
+ (subenvNone (d2eM (select SMerge des)))
(ENil ext)
ESum1Inner _ e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e
+ | SpArr sd' <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e
, STArr (SS n) t <- typeOf e ->
Ret (e0 `BPush` (STArr (SS n) t, e1)
`BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ)))
(SEYesR (SENo subtape))
(ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ)))
sub
- (EMaybe ext
- (zeroTup (subList (select SMerge des) sub))
- (ELet ext (EJust ext (EReplicate1Inner ext
- (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS (IS IZ))))
- (EVar ext (STArr n (d2 t)) IZ))) $
- weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
- (EVar ext (d2 (STArr n t)) IZ))
+ (ELet ext (EReplicate1Inner ext
+ (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ)))
+ (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink)) e2)
- EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e
- EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e
+ EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e
+ EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e
-- These should be the next to be implemented, I think
EFold1Inner{} -> err_unsupported "EFold1Inner"
@@ -1279,6 +1251,7 @@ drev des accumMap sd = \case
EWith{} -> err_accum
EZero{} -> err_monoid
+ EDeepZero{} -> err_monoid
EPlus{} -> err_monoid
EOneHot{} -> err_monoid
@@ -1287,35 +1260,35 @@ drev des accumMap sd = \case
err_monoid = error "Monoid operations unsupported in the source program"
err_unsupported s = error $ "CHAD: unsupported " ++ s
- deriv_extremum :: ScalIsNumeric t' ~ True
- => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t')))
- -> Sparse (TArr n (D2s t')) sd'
- -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto sd' (TArr n (TScal t'))
- deriv_extremum extremum e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e
- , at@(STArr (SS n) t@(STScal st)) <- typeOf e
- , let at' = STArr n t
- , let tIxN = tTup (sreplicate (SS n) tIx) =
- Ret (e0 `BPush` (at, e1)
- `BPush` (at', extremum (EVar ext at IZ)))
- (SEYesR (SEYesR subtape))
- (EVar ext at' IZ)
- sub
- (EMaybe ext
- (zeroTup (subList (select SMerge des) sub))
- (ELet ext (EJust ext
- (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS (IS IZ))))) $
- eif (EOp ext (OEq st) (EPair ext
- (EIdx ext (EVar ext at (IS (IS (IS (IS IZ))))) (EVar ext tIxN IZ))
- (EIdx ext (EVar ext at' (IS (IS (IS IZ)))) (EFst ext (EVar ext tIxN IZ)))))
- (EIdx ext (EVar ext (STArr n (d2 t)) (IS IZ)) (EFst ext (EVar ext tIxN IZ)))
- (ezeroD2 t))) $
- weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2)
- (EVar ext (d2 at') IZ))
-
contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs)
contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub))
+deriv_extremum :: (?config :: CHADConfig, ScalIsNumeric t ~ True)
+ => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t)))
+ -> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
+ -> Sparse (D2s t) sd
+ -> Expr ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t))
+deriv_extremum extremum des accumMap sd e
+ | at@(STArr (SS n) t@(STScal st)) <- typeOf e
+ , let at' = STArr n t
+ , let tIxN = tTup (sreplicate (SS n) tIx) =
+ sparsePlusS ST ST (d2M t) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ ->
+ case drev des accumMap (SpArr sd') e of { Ret e0 subtape e1 sub e2 ->
+ Ret (e0 `BPush` (at, e1)
+ `BPush` (at', extremum (EVar ext at IZ)))
+ (SEYesR (SEYesR subtape))
+ (EVar ext at' IZ)
+ sub
+ (ELet ext
+ (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $
+ eif (EOp ext (OEq st) (EPair ext
+ (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ))
+ (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ)))))
+ (inj1 $ EIdx ext (EVar ext (STArr n (applySparse sd (d2 t))) (IS IZ)) (EFst ext (EVar ext tIxN IZ)))
+ (inj2 (ENil ext))) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+ }
+
data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s)
data RetScoped env0 sto a s sd t =
@@ -1351,7 +1324,8 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of
SENo sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' SpAbsent (EPair ext e2 (ENil ext))
SAccum
- | Just (VIArr i _) <- argids
+ | chcSmartWith ?config
+ , Just (VIArr i _) <- argids
, Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap
, Just Refl <- testEquality foundTy (STAccum (d2M argty))
, Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) sd expr
@@ -1380,9 +1354,9 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of
&. #ac (auto1 @(TAccum (D2 a)))
&. #tl (d2ace (select SAccum des))
in
- RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub SpDense $
+ RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub (spDense (d2M argty)) $
let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in
- EWith ext (d2M argty) (EZero ext (d2M argty) (d2zeroInfo argty (EVar ext (d1 argty) primalIdx))) $
+ EWith ext (d2M argty) (EDeepZero ext (d2M argty) (d2deepZeroInfo argty (EVar ext (d1 argty) primalIdx))) $
weakenExpr (autoWeak library
(#d :++: #body :++: #ac :++: #tl)
(#ac :++: #d :++: (#body :++: #p) :++: #tl))
@@ -1396,20 +1370,6 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of
-- TODO: proper primal-only transform that doesn't depend on D1 = Id
drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t)
drevPrimal des e
- | Refl <- chadD1Id (typeOf e)
- , Refl <- chadD1EId (descrList des)
+ | Refl <- d1Identity (typeOf e)
+ , Refl <- d1eIdentity (descrList des)
= mapExt (const ext) e
- where
- chadD1Id :: STy a -> D1 a :~: a
- chadD1Id STNil = Refl
- chadD1Id (STPair a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl
- chadD1Id (STEither a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl
- chadD1Id (STLEither a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl
- chadD1Id (STMaybe a) | Refl <- chadD1Id a = Refl
- chadD1Id (STArr _ a) | Refl <- chadD1Id a = Refl
- chadD1Id (STScal _) = Refl
- chadD1Id STAccum{} = error "accumulators not allowed in source program"
-
- chadD1EId :: SList STy l -> D1E l :~: l
- chadD1EId SNil = Refl
- chadD1EId (SCons t l) | Refl <- chadD1Id t, Refl <- chadD1EId l = Refl