diff options
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 167 |
1 files changed, 110 insertions, 57 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index 7cd4c26..3dedec3 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -362,12 +362,6 @@ d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program" -zeroTup :: SList STy env0 -> D1E env0 :> env -> Ex env (Tup (D2E env0)) -zeroTup SNil _ = ENil ext -zeroTup (t `SCons` env) w = - EPair ext (zeroTup env (WPop w)) - (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) - ----------------------------------- SPARSITY ----------------------------------- @@ -780,7 +774,7 @@ drev des accumMap (SpSparse sd) = subtape e1 sub' - (emaybe (evar IZ) + (emaybe (EVar ext (STMaybe (applySparse sd (d2 (typeOf e)))) IZ) (inj2 (ENil ext)) (inj1 (weakenExpr (WCopy WSink) e2))) } @@ -794,7 +788,8 @@ drev des accumMap sd = \case (EVar ext (d1 t) (conv1Idx i)) (subenvNone (d2e (select SMerge des))) (let ty = applySparse sd (d2M t) - in EAccum ext (d2M t) (_ sd) (ENil ext) (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI))) + in accumulateSparse SAI_D (d2M t) sd (EVar ext (fromSMTy ty) IZ) $ \w prj val idx -> + EAccum ext (d2M t) prj idx val (EVar ext (STAccum (d2M t)) (w @> IS accI))) Idx2Me tupI -> Ret BTop @@ -1227,43 +1222,45 @@ drev des accumMap sd = \case (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) sub - (ELet ext (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty))) (SAPArrIdx SAPHere) - (EPair ext (EPair ext (EVar ext tIxN (IS IZ)) - (EBuild ext n (EVar ext tIxN (IS (IS IZ))) (ENil ext))) (ENil ext)) - (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $ + (ELet ext + (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty))) + (SAPArrIdx SAPHere) + (EPair ext + (EPair ext (EVar ext tIxN (IS IZ)) + (EBuild ext n (EVar ext tIxN (IS (IS IZ))) $ + makeZeroInfo (applySparse sd' (d2M eltty)) (inj2 (ENil ext)))) + (ENil ext)) + (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) } EShape _ e - -- Allowed to ignore e2 here because the output of EShape is discrete, - -- hence we'd be passing a zero cotangent to e2 anyway. - | Ret e0 subtape e1 _ _ <- drev des accumMap e - , STArr n _ <- typeOf e + -- Allowed to differentiate e as primal because the output of EShape is + -- discrete, hence we'd be passing a zero cotangent to e anyway. + | STArr n _ <- typeOf e , Refl <- indexTupD1Id n -> - Ret e0 - subtape - (EShape ext e1) - (subenvNone (select SMerge des)) + Ret BTop + SETop + (EShape ext (drevPrimal des e)) + (subenvNone (d2eM (select SMerge des))) (ENil ext) ESum1Inner _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e + | SpArr sd' <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e , STArr (SS n) t <- typeOf e -> Ret (e0 `BPush` (STArr (SS n) t, e1) `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ))) (SEYesR (SENo subtape)) (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ))) sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EJust ext (EReplicate1Inner ext - (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS (IS IZ)))) - (EVar ext (STArr n (d2 t)) IZ))) $ - weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) - (EVar ext (d2 (STArr n t)) IZ)) + (ELet ext (EReplicate1Inner ext + (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) + (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e2) - EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e - EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e + EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e + EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e -- These should be the next to be implemented, I think EFold1Inner{} -> err_unsupported "EFold1Inner" @@ -1286,35 +1283,35 @@ drev des accumMap sd = \case err_monoid = error "Monoid operations unsupported in the source program" err_unsupported s = error $ "CHAD: unsupported " ++ s - deriv_extremum :: ScalIsNumeric t' ~ True - => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t'))) - -> Sparse (TArr n (D2s t')) sd' - -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto sd' (TArr n (TScal t')) - deriv_extremum extremum e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , at@(STArr (SS n) t@(STScal st)) <- typeOf e - , let at' = STArr n t - , let tIxN = tTup (sreplicate (SS n) tIx) = - Ret (e0 `BPush` (at, e1) - `BPush` (at', extremum (EVar ext at IZ))) - (SEYesR (SEYesR subtape)) - (EVar ext at' IZ) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EJust ext - (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS (IS IZ))))) $ - eif (EOp ext (OEq st) (EPair ext - (EIdx ext (EVar ext at (IS (IS (IS (IS IZ))))) (EVar ext tIxN IZ)) - (EIdx ext (EVar ext at' (IS (IS (IS IZ)))) (EFst ext (EVar ext tIxN IZ))))) - (EIdx ext (EVar ext (STArr n (d2 t)) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) - (ezeroD2 t))) $ - weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2) - (EVar ext (d2 at') IZ)) - contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs) contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub)) +deriv_extremum :: (?config :: CHADConfig, ScalIsNumeric t ~ True) + => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) + -> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) + -> Sparse (D2s t) sd + -> Expr ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t)) +deriv_extremum extremum des accumMap sd e + | at@(STArr (SS n) t@(STScal st)) <- typeOf e + , let at' = STArr n t + , let tIxN = tTup (sreplicate (SS n) tIx) = + sparsePlusS ST ST (d2M t) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sd') e of { Ret e0 subtape e1 sub e2 -> + Ret (e0 `BPush` (at, e1) + `BPush` (at', extremum (EVar ext at IZ))) + (SEYesR (SEYesR subtape)) + (EVar ext at' IZ) + sub + (ELet ext + (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $ + eif (EOp ext (OEq st) (EPair ext + (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ)) + (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ))))) + (inj1 $ EIdx ext (EVar ext (STArr n (applySparse sd (d2 t))) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) + (inj2 (ENil ext))) $ + weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + } + data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s) data RetScoped env0 sto a s sd t = @@ -1379,7 +1376,7 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of &. #ac (auto1 @(TAccum (D2 a))) &. #tl (d2ace (select SAccum des)) in - RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub SpDense $ + RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub (spDense (d2M argty)) $ let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in EWith ext (d2M argty) (EZero ext (d2M argty) (d2zeroInfo argty (EVar ext (d1 argty) primalIdx))) $ weakenExpr (autoWeak library @@ -1412,3 +1409,59 @@ drevPrimal des e chadD1EId :: SList STy l -> D1E l :~: l chadD1EId SNil = Refl chadD1EId (SCons t l) | Refl <- chadD1Id t, Refl <- chadD1EId l = Refl + +accumulateSparse + :: SStillDense dense -> SMTy t -> Sparse t t' -> Ex env t' + -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' b -> Ex env' (AcIdx dense p t) -> Ex env' TNil) + -> Ex env TNil +accumulateSparse dense topty topsp arg accum = case (dense, topty, topsp) of + (_, _, s) | Just Refl <- isDense topty s -> + accum WId SAPHere arg (ENil ext) + (_, SMTScal _, SpScal) -> + accum WId SAPHere arg (ENil ext) -- should be handled by isDense already, but meh + (_, _, SpSparse s) -> + emaybe arg + (ENil ext) + (accumulateSparse dense topty s (evar IZ) (\w -> accum (WPop w))) + (_, _, SpAbsent) -> + ENil ext + (SAI_D, SMTPair t1 t2, SpPair s1 s2) -> + eunPair arg $ \w1 e1 e2 -> + elet (accumulateSparse dense t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $ + accumulateSparse dense t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj)) + (SAI_S, SMTPair{}, SpPair{}) -> + error "TODO: accumulating into pair inside coproduct unimplemented" + -- There are two different ways this can be accomplished: + -- 1. Ensure we have the requisite ZeroInfo here. This means that an + -- accum-mode variable reference will (if its incoming cotangent is + -- sparse enough) need to store some ZeroInfo fragments computed from + -- the primal (not necessarily the entire primal). Doing this properly, + -- i.e. not just storing a full D1 but only the required ZeroInfo + -- fragments, is possible and not too inefficient but a bit of + -- engineering again. + -- 2. When creating an accumulator, don't initialise it with a generic + -- EZero based on a ZeroInfo, but instead a special "deep zero" based on + -- probably a full D1. This deep zero also initialises Left/Right/Just + -- modelled after the primal. With this, an accumulation needs no zero + -- info whatsoever (!) under the assumption that it receives a cotangent + -- that is compatible with the primal it is propagated back to. + (_, SMTLEither t1 t2, SpLEither s1 s2) -> + elcase arg + (ENil ext) + (accumulateSparse SAI_S t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj))) + (accumulateSparse SAI_S t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj))) + (_, SMTMaybe t, SpMaybe s) -> + emaybe arg + (ENil ext) + (accumulateSparse SAI_S t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj))) + (SAI_D, SMTArr n t, SpArr s) -> + let tn = tTup (sreplicate n tIx) in + elet arg $ + elet (EBuild ext n (EShape ext (evar IZ)) $ + accumulateSparse dense t s + (EIdx ext (evar (IS IZ)) (EVar ext tn IZ)) + (\w prj val idx -> accum (WPop (WPop w)) (SAPArrIdx prj) val (EPair ext (EVar ext tn (w @> IZ)) idx))) $ + ENil ext + (SAI_S, SMTArr{}, SpArr{}) -> + error "TODO: accumulating into array inside coproduct unimplemented" + -- See the pair case above, same reasoning |