summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs167
1 files changed, 110 insertions, 57 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 7cd4c26..3dedec3 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -362,12 +362,6 @@ d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e
d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext
d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program"
-zeroTup :: SList STy env0 -> D1E env0 :> env -> Ex env (Tup (D2E env0))
-zeroTup SNil _ = ENil ext
-zeroTup (t `SCons` env) w =
- EPair ext (zeroTup env (WPop w))
- (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ))))
-
----------------------------------- SPARSITY -----------------------------------
@@ -780,7 +774,7 @@ drev des accumMap (SpSparse sd) =
subtape
e1
sub'
- (emaybe (evar IZ)
+ (emaybe (EVar ext (STMaybe (applySparse sd (d2 (typeOf e)))) IZ)
(inj2 (ENil ext))
(inj1 (weakenExpr (WCopy WSink) e2)))
}
@@ -794,7 +788,8 @@ drev des accumMap sd = \case
(EVar ext (d1 t) (conv1Idx i))
(subenvNone (d2e (select SMerge des)))
(let ty = applySparse sd (d2M t)
- in EAccum ext (d2M t) (_ sd) (ENil ext) (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI)))
+ in accumulateSparse SAI_D (d2M t) sd (EVar ext (fromSMTy ty) IZ) $ \w prj val idx ->
+ EAccum ext (d2M t) prj idx val (EVar ext (STAccum (d2M t)) (w @> IS accI)))
Idx2Me tupI ->
Ret BTop
@@ -1227,43 +1222,45 @@ drev des accumMap sd = \case
(EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ)))
(EVar ext (tTup (sreplicate n tIx)) IZ))
sub
- (ELet ext (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty))) (SAPArrIdx SAPHere)
- (EPair ext (EPair ext (EVar ext tIxN (IS IZ))
- (EBuild ext n (EVar ext tIxN (IS (IS IZ))) (ENil ext))) (ENil ext))
- (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $
+ (ELet ext
+ (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty)))
+ (SAPArrIdx SAPHere)
+ (EPair ext
+ (EPair ext (EVar ext tIxN (IS IZ))
+ (EBuild ext n (EVar ext tIxN (IS (IS IZ))) $
+ makeZeroInfo (applySparse sd' (d2M eltty)) (inj2 (ENil ext))))
+ (ENil ext))
+ (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $
weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
}
EShape _ e
- -- Allowed to ignore e2 here because the output of EShape is discrete,
- -- hence we'd be passing a zero cotangent to e2 anyway.
- | Ret e0 subtape e1 _ _ <- drev des accumMap e
- , STArr n _ <- typeOf e
+ -- Allowed to differentiate e as primal because the output of EShape is
+ -- discrete, hence we'd be passing a zero cotangent to e anyway.
+ | STArr n _ <- typeOf e
, Refl <- indexTupD1Id n ->
- Ret e0
- subtape
- (EShape ext e1)
- (subenvNone (select SMerge des))
+ Ret BTop
+ SETop
+ (EShape ext (drevPrimal des e))
+ (subenvNone (d2eM (select SMerge des)))
(ENil ext)
ESum1Inner _ e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e
+ | SpArr sd' <- sd
+ , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e
, STArr (SS n) t <- typeOf e ->
Ret (e0 `BPush` (STArr (SS n) t, e1)
`BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ)))
(SEYesR (SENo subtape))
(ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ)))
sub
- (EMaybe ext
- (zeroTup (subList (select SMerge des) sub))
- (ELet ext (EJust ext (EReplicate1Inner ext
- (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS (IS IZ))))
- (EVar ext (STArr n (d2 t)) IZ))) $
- weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
- (EVar ext (d2 (STArr n t)) IZ))
+ (ELet ext (EReplicate1Inner ext
+ (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ)))
+ (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink)) e2)
- EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e
- EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e
+ EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e
+ EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e
-- These should be the next to be implemented, I think
EFold1Inner{} -> err_unsupported "EFold1Inner"
@@ -1286,35 +1283,35 @@ drev des accumMap sd = \case
err_monoid = error "Monoid operations unsupported in the source program"
err_unsupported s = error $ "CHAD: unsupported " ++ s
- deriv_extremum :: ScalIsNumeric t' ~ True
- => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t')))
- -> Sparse (TArr n (D2s t')) sd'
- -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto sd' (TArr n (TScal t'))
- deriv_extremum extremum e
- | Ret e0 subtape e1 sub e2 <- drev des accumMap e
- , at@(STArr (SS n) t@(STScal st)) <- typeOf e
- , let at' = STArr n t
- , let tIxN = tTup (sreplicate (SS n) tIx) =
- Ret (e0 `BPush` (at, e1)
- `BPush` (at', extremum (EVar ext at IZ)))
- (SEYesR (SEYesR subtape))
- (EVar ext at' IZ)
- sub
- (EMaybe ext
- (zeroTup (subList (select SMerge des) sub))
- (ELet ext (EJust ext
- (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS (IS IZ))))) $
- eif (EOp ext (OEq st) (EPair ext
- (EIdx ext (EVar ext at (IS (IS (IS (IS IZ))))) (EVar ext tIxN IZ))
- (EIdx ext (EVar ext at' (IS (IS (IS IZ)))) (EFst ext (EVar ext tIxN IZ)))))
- (EIdx ext (EVar ext (STArr n (d2 t)) (IS IZ)) (EFst ext (EVar ext tIxN IZ)))
- (ezeroD2 t))) $
- weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2)
- (EVar ext (d2 at') IZ))
-
contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs)
contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub))
+deriv_extremum :: (?config :: CHADConfig, ScalIsNumeric t ~ True)
+ => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t)))
+ -> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
+ -> Sparse (D2s t) sd
+ -> Expr ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t))
+deriv_extremum extremum des accumMap sd e
+ | at@(STArr (SS n) t@(STScal st)) <- typeOf e
+ , let at' = STArr n t
+ , let tIxN = tTup (sreplicate (SS n) tIx) =
+ sparsePlusS ST ST (d2M t) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ ->
+ case drev des accumMap (SpArr sd') e of { Ret e0 subtape e1 sub e2 ->
+ Ret (e0 `BPush` (at, e1)
+ `BPush` (at', extremum (EVar ext at IZ)))
+ (SEYesR (SEYesR subtape))
+ (EVar ext at' IZ)
+ sub
+ (ELet ext
+ (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $
+ eif (EOp ext (OEq st) (EPair ext
+ (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ))
+ (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ)))))
+ (inj1 $ EIdx ext (EVar ext (STArr n (applySparse sd (d2 t))) (IS IZ)) (EFst ext (EVar ext tIxN IZ)))
+ (inj2 (ENil ext))) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+ }
+
data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s)
data RetScoped env0 sto a s sd t =
@@ -1379,7 +1376,7 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of
&. #ac (auto1 @(TAccum (D2 a)))
&. #tl (d2ace (select SAccum des))
in
- RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub SpDense $
+ RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub (spDense (d2M argty)) $
let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in
EWith ext (d2M argty) (EZero ext (d2M argty) (d2zeroInfo argty (EVar ext (d1 argty) primalIdx))) $
weakenExpr (autoWeak library
@@ -1412,3 +1409,59 @@ drevPrimal des e
chadD1EId :: SList STy l -> D1E l :~: l
chadD1EId SNil = Refl
chadD1EId (SCons t l) | Refl <- chadD1Id t, Refl <- chadD1EId l = Refl
+
+accumulateSparse
+ :: SStillDense dense -> SMTy t -> Sparse t t' -> Ex env t'
+ -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' b -> Ex env' (AcIdx dense p t) -> Ex env' TNil)
+ -> Ex env TNil
+accumulateSparse dense topty topsp arg accum = case (dense, topty, topsp) of
+ (_, _, s) | Just Refl <- isDense topty s ->
+ accum WId SAPHere arg (ENil ext)
+ (_, SMTScal _, SpScal) ->
+ accum WId SAPHere arg (ENil ext) -- should be handled by isDense already, but meh
+ (_, _, SpSparse s) ->
+ emaybe arg
+ (ENil ext)
+ (accumulateSparse dense topty s (evar IZ) (\w -> accum (WPop w)))
+ (_, _, SpAbsent) ->
+ ENil ext
+ (SAI_D, SMTPair t1 t2, SpPair s1 s2) ->
+ eunPair arg $ \w1 e1 e2 ->
+ elet (accumulateSparse dense t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $
+ accumulateSparse dense t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj))
+ (SAI_S, SMTPair{}, SpPair{}) ->
+ error "TODO: accumulating into pair inside coproduct unimplemented"
+ -- There are two different ways this can be accomplished:
+ -- 1. Ensure we have the requisite ZeroInfo here. This means that an
+ -- accum-mode variable reference will (if its incoming cotangent is
+ -- sparse enough) need to store some ZeroInfo fragments computed from
+ -- the primal (not necessarily the entire primal). Doing this properly,
+ -- i.e. not just storing a full D1 but only the required ZeroInfo
+ -- fragments, is possible and not too inefficient but a bit of
+ -- engineering again.
+ -- 2. When creating an accumulator, don't initialise it with a generic
+ -- EZero based on a ZeroInfo, but instead a special "deep zero" based on
+ -- probably a full D1. This deep zero also initialises Left/Right/Just
+ -- modelled after the primal. With this, an accumulation needs no zero
+ -- info whatsoever (!) under the assumption that it receives a cotangent
+ -- that is compatible with the primal it is propagated back to.
+ (_, SMTLEither t1 t2, SpLEither s1 s2) ->
+ elcase arg
+ (ENil ext)
+ (accumulateSparse SAI_S t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj)))
+ (accumulateSparse SAI_S t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj)))
+ (_, SMTMaybe t, SpMaybe s) ->
+ emaybe arg
+ (ENil ext)
+ (accumulateSparse SAI_S t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj)))
+ (SAI_D, SMTArr n t, SpArr s) ->
+ let tn = tTup (sreplicate n tIx) in
+ elet arg $
+ elet (EBuild ext n (EShape ext (evar IZ)) $
+ accumulateSparse dense t s
+ (EIdx ext (evar (IS IZ)) (EVar ext tn IZ))
+ (\w prj val idx -> accum (WPop (WPop w)) (SAPArrIdx prj) val (EPair ext (EVar ext tn (w @> IZ)) idx))) $
+ ENil ext
+ (SAI_S, SMTArr{}, SpArr{}) ->
+ error "TODO: accumulating into array inside coproduct unimplemented"
+ -- See the pair case above, same reasoning