summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs117
1 files changed, 5 insertions, 112 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 3dedec3..621aa3e 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -34,7 +34,6 @@ module CHAD (
import Data.Functor.Const
import Data.Some
-import Data.Type.Bool (If)
import Data.Type.Equality (type (==), testEquality)
import GHC.Stack (HasCallStack)
@@ -45,6 +44,7 @@ import AST.Count
import AST.Env
import AST.Sparse
import AST.Weaken.Auto
+import CHAD.Accum
import CHAD.EnvDescr
import CHAD.Types
import Data
@@ -348,28 +348,8 @@ opt2UnSparse = go . opt2
go t _ = error $ "Primitive operations that return " ++ show t ++ " are scary"
------------------------------------- MONOIDS -----------------------------------
-
-d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t))
-d2zeroInfo STNil _ = ENil ext
-d2zeroInfo (STPair a b) e =
- eunPair e $ \_ e1 e2 ->
- EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2)
-d2zeroInfo STEither{} _ = ENil ext
-d2zeroInfo STLEither{} _ = ENil ext
-d2zeroInfo STMaybe{} _ = ENil ext
-d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e
-d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext
-d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program"
-
-
----------------------------------- SPARSITY -----------------------------------
-subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env')
-subenvD1E SETop = SETop
-subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub)
-subenvD1E (SENo sub) = SENo (subenvD1E sub)
-
expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a)
expandSparse t sp _ e | Just Refl <- isDense (d2M t) sp = e
expandSparse t (SpSparse sp) epr e =
@@ -499,23 +479,6 @@ assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty"
--------------------------------- ACCUMULATORS ---------------------------------
-makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
-makeAccumulators _ SNil e = e
-makeAccumulators w (t `SCons` envpro) e =
- makeAccumulators (WPop w) envpro $
- EWith ext (d2M t) (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e
-
-uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list))
-uninvertTup SNil _ e = EPair ext e (ENil ext)
-uninvertTup (t `SCons` list) tcore e =
- ELet ext (uninvertTup list (STPair tcore t) e) $
- let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding
- in EPair ext
- (EFst ext (EFst ext (EVar ext recT IZ)))
- (EPair ext
- (ESnd ext (EVar ext recT IZ))
- (ESnd ext (EFst ext (EVar ext recT IZ))))
-
fromArrayValId :: Maybe (ValId t) -> Maybe Int
fromArrayValId (Just (VIArr i _)) = Just i
fromArrayValId _ = Nothing
@@ -788,8 +751,7 @@ drev des accumMap sd = \case
(EVar ext (d1 t) (conv1Idx i))
(subenvNone (d2e (select SMerge des)))
(let ty = applySparse sd (d2M t)
- in accumulateSparse SAI_D (d2M t) sd (EVar ext (fromSMTy ty) IZ) $ \w prj val idx ->
- EAccum ext (d2M t) prj idx val (EVar ext (STAccum (d2M t)) (w @> IS accI)))
+ in EAccum ext (d2M t) SAPHere (ENil ext) sd (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI)))
Idx2Me tupI ->
Ret BTop
@@ -1275,6 +1237,7 @@ drev des accumMap sd = \case
EWith{} -> err_accum
EZero{} -> err_monoid
+ EDeepZero{} -> err_monoid
EPlus{} -> err_monoid
EOneHot{} -> err_monoid
@@ -1392,76 +1355,6 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of
-- TODO: proper primal-only transform that doesn't depend on D1 = Id
drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t)
drevPrimal des e
- | Refl <- chadD1Id (typeOf e)
- , Refl <- chadD1EId (descrList des)
+ | Refl <- d1Identity (typeOf e)
+ , Refl <- d1eIdentity (descrList des)
= mapExt (const ext) e
- where
- chadD1Id :: STy a -> D1 a :~: a
- chadD1Id STNil = Refl
- chadD1Id (STPair a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl
- chadD1Id (STEither a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl
- chadD1Id (STLEither a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl
- chadD1Id (STMaybe a) | Refl <- chadD1Id a = Refl
- chadD1Id (STArr _ a) | Refl <- chadD1Id a = Refl
- chadD1Id (STScal _) = Refl
- chadD1Id STAccum{} = error "accumulators not allowed in source program"
-
- chadD1EId :: SList STy l -> D1E l :~: l
- chadD1EId SNil = Refl
- chadD1EId (SCons t l) | Refl <- chadD1Id t, Refl <- chadD1EId l = Refl
-
-accumulateSparse
- :: SStillDense dense -> SMTy t -> Sparse t t' -> Ex env t'
- -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' b -> Ex env' (AcIdx dense p t) -> Ex env' TNil)
- -> Ex env TNil
-accumulateSparse dense topty topsp arg accum = case (dense, topty, topsp) of
- (_, _, s) | Just Refl <- isDense topty s ->
- accum WId SAPHere arg (ENil ext)
- (_, SMTScal _, SpScal) ->
- accum WId SAPHere arg (ENil ext) -- should be handled by isDense already, but meh
- (_, _, SpSparse s) ->
- emaybe arg
- (ENil ext)
- (accumulateSparse dense topty s (evar IZ) (\w -> accum (WPop w)))
- (_, _, SpAbsent) ->
- ENil ext
- (SAI_D, SMTPair t1 t2, SpPair s1 s2) ->
- eunPair arg $ \w1 e1 e2 ->
- elet (accumulateSparse dense t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $
- accumulateSparse dense t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj))
- (SAI_S, SMTPair{}, SpPair{}) ->
- error "TODO: accumulating into pair inside coproduct unimplemented"
- -- There are two different ways this can be accomplished:
- -- 1. Ensure we have the requisite ZeroInfo here. This means that an
- -- accum-mode variable reference will (if its incoming cotangent is
- -- sparse enough) need to store some ZeroInfo fragments computed from
- -- the primal (not necessarily the entire primal). Doing this properly,
- -- i.e. not just storing a full D1 but only the required ZeroInfo
- -- fragments, is possible and not too inefficient but a bit of
- -- engineering again.
- -- 2. When creating an accumulator, don't initialise it with a generic
- -- EZero based on a ZeroInfo, but instead a special "deep zero" based on
- -- probably a full D1. This deep zero also initialises Left/Right/Just
- -- modelled after the primal. With this, an accumulation needs no zero
- -- info whatsoever (!) under the assumption that it receives a cotangent
- -- that is compatible with the primal it is propagated back to.
- (_, SMTLEither t1 t2, SpLEither s1 s2) ->
- elcase arg
- (ENil ext)
- (accumulateSparse SAI_S t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj)))
- (accumulateSparse SAI_S t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj)))
- (_, SMTMaybe t, SpMaybe s) ->
- emaybe arg
- (ENil ext)
- (accumulateSparse SAI_S t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj)))
- (SAI_D, SMTArr n t, SpArr s) ->
- let tn = tTup (sreplicate n tIx) in
- elet arg $
- elet (EBuild ext n (EShape ext (evar IZ)) $
- accumulateSparse dense t s
- (EIdx ext (evar (IS IZ)) (EVar ext tn IZ))
- (\w prj val idx -> accum (WPop (WPop w)) (SAPArrIdx prj) val (EPair ext (EVar ext tn (w @> IZ)) idx))) $
- ENil ext
- (SAI_S, SMTArr{}, SpArr{}) ->
- error "TODO: accumulating into array inside coproduct unimplemented"
- -- See the pair case above, same reasoning