aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/Drev.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2026-02-22 12:59:40 +0100
committerTom Smeding <tom@tomsmeding.com>2026-02-22 12:59:40 +0100
commiteb25e315ac8bd1e75498cc92c7f3326fa582171a (patch)
tree71df3e761c127ad3ca301a0645ae7760099088ae /src/CHAD/Drev.hs
parentf5b1b405fa4ba63bdffe0f2998f655f0b06534bf (diff)
WIP: Store subset of D1 Gamma for recompute at binding sitesrecompute-primalstores
Diffstat (limited to 'src/CHAD/Drev.hs')
-rw-r--r--src/CHAD/Drev.hs94
1 files changed, 60 insertions, 34 deletions
diff --git a/src/CHAD/Drev.hs b/src/CHAD/Drev.hs
index 9f2921c..4e39f79 100644
--- a/src/CHAD/Drev.hs
+++ b/src/CHAD/Drev.hs
@@ -532,13 +532,13 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
(SEYesR accrevsub)
(VarMap.sink1 accumMap)
(\shbinds ->
- autoWeak (#pro (d2ace envpro) &. #d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr)))
- (#acc :++: (#pro :++: #d :++: #shb :++: #tl))
- (#pro :++: #d :++: #shb :++: #acc :++: #tl)
+ autoWeak (#pro (d2ace envpro) &. #d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #ace (d2ace (select SAccum descr)))
+ (#acc :++: (#pro :++: #d :++: #shb :++: #ace))
+ (#pro :++: #d :++: #shb :++: #acc :++: #ace)
.> WCopy (wf shbinds)
- .> autoWeak (#d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl)))
- (#d :++: #shb :++: #acc :++: #tl)
- (#acc :++: (#d :++: #shb :++: #tl)))
+ .> autoWeak (#d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #ace (d2ace (select SAccum storepl)))
+ (#d :++: #shb :++: #acc :++: #ace)
+ (#acc :++: (#d :++: #shb :++: #ace)))
SMerge -> case t of
-- Discrete values are left as-is
@@ -606,12 +606,13 @@ accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
---------------------------- RETURN TRIPLE FROM CHAD ---------------------------
data Ret env0 sto sd t =
- forall shbinds tapebinds contribs.
+ forall shbinds tapebinds contribs reprim.
Ret (Bindings Ex (D1E env0) shbinds) -- shared binds
(Subenv shbinds tapebinds)
(Ex (Append shbinds (D1E env0)) (D1 t))
(SubenvS (D2E (Select env0 sto "merge")) contribs)
- (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs))
+ (Ex (sd : Append tapebinds (Append reprim (D2AcE (Select env0 sto "accum")))) (Tup contribs))
+ (Subenv (D1E env0) reprim)
deriving instance Show (Ret env0 sto sd t)
type data TyTyPair = MkTyTyPair Ty Ty
@@ -633,11 +634,12 @@ data SingleRet env0 sto (pair :: TyTyPair) =
-- {-# COMPLETE Ret1 #-}
data RetPair env0 sto env shbinds tapebinds (pair :: TyTyPair) where
- RetPair :: forall sd t contribs -- existentials
+ RetPair :: forall sd t contribs reprim -- existentials
env0 sto env shbinds tapebinds. -- universals
Ex (Append shbinds env) (D1 t)
-> SubenvS (D2E (Select env0 sto "merge")) contribs
- -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)
+ -> Ex (sd : Append tapebinds (Append reprim (D2AcE (Select env0 sto "accum")))) (Tup contribs)
+ -> Subenv (D1E env0) reprim
-> RetPair env0 sto env shbinds tapebinds (MkTyTyPair sd t)
deriving instance Show (RetPair env0 sto env shbinds tapebinds pair)
@@ -649,11 +651,11 @@ data Rets env0 sto env list =
deriving instance Show (Rets env0 sto env list)
toSingleRet :: Ret env0 sto sd t -> SingleRet env0 sto (MkTyTyPair sd t)
-toSingleRet (Ret e0 subtape e1 sub e2) = SingleRet e0 subtape (RetPair e1 sub e2)
+toSingleRet (Ret e0 subtape e1 sub e2 rp) = SingleRet e0 subtape (RetPair e1 sub e2 rp)
weakenRetPair :: SList STy shbinds -> env :> env'
-> RetPair env0 sto env shbinds tapebinds pair -> RetPair env0 sto env' shbinds tapebinds pair
-weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2
+weakenRetPair bindslist w (RetPair e1 sub e2 rp) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2 rp
weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list
weakenRets w (Rets binds tapesub list) =
@@ -666,30 +668,33 @@ rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto pair f.
-> Subenv b1 tapebinds1 -> Subenv b2 tapebinds2
-> RetPair env0 sto (Append b1 env) b2 tapebinds2 pair
-> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) pair
-rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair @sd e1 sub e2)
+rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair @sd e1 sub e2 rp)
| Refl <- lemAppendAssoc @b2 @b1 @env =
RetPair e1 sub
(weakenExpr (autoWeak
(#d (auto1 @sd)
&. #t2 (subList b2 subtape2)
&. #t1 (subList b1 subtape1)
- &. #tl (d2ace (select SAccum descr)))
- (#d :++: (#t2 :++: #tl))
- (#d :++: ((#t2 :++: #t1) :++: #tl)))
+ &. #rp (subList (desD1E descr) rp)
+ &. #ace (d2ace (select SAccum descr)))
+ (#d :++: (#t2 :++: #rp :++: #ace))
+ (#d :++: ((#t2 :++: #t1) :++: #rp :++: #ace)))
e2)
+ rp
retConcat :: forall env0 sto list. Descr env0 sto -> SList (SingleRet env0 sto) list -> Rets env0 sto (D1E env0) list
retConcat _ SNil = Rets BTop SETop SNil
-retConcat descr (SCons (SingleRet (e0 :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) (RetPair e1 sub e2)) list)
+retConcat descr (SCons (SingleRet (e0 :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) (RetPair e1 sub e2 (rp :: Subenv _ reprim))) list)
| Rets (binds :: Bindings _ _ shbinds2) (subtape2 :: Subenv _ tapebinds2) pairs
<- weakenRets (sinkWithBindings e0) (retConcat descr list)
, Refl <- lemAppendAssoc @shbinds2 @shbinds1 @(D1E env0)
- , Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(D2AcE (Select env0 sto "accum"))
+ , Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(Append reprim (D2AcE (Select env0 sto "accum")))
= Rets (bconcat e0 binds)
(subenvConcat subtape subtape2)
(SCons (RetPair (weakenExpr (sinkWithBindings binds) e1)
sub
- (weakenExpr (WCopy (sinkWithSubenv subtape2)) e2))
+ (weakenExpr (WCopy (sinkWithSubenv subtape2)) e2)
+ rp)
(slistMap (rebaseRetPair descr (bindingsBinds e0) (bindingsBinds binds)
subtape subtape2)
pairs))
@@ -697,25 +702,25 @@ retConcat descr (SCons (SingleRet (e0 :: Bindings _ _ shbinds1) (subtape :: Sube
freezeRet :: Descr env sto
-> Ret env sto (D2 t) t
-> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge"))))
-freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ _ t) =
+freezeRet descr (Ret e0 subtape e1 sub e2 rp :: Ret _ _ _ t) =
let (e0', wInsertD2Ac) = weakenBindingsE (WSink .> wSinks (d2ace (select SAccum descr))) e0
- e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2
tContribs = tTup (slistMap fromSMTy (subList (d2eM (select SMerge descr)) sub))
library = #d (auto1 @(D2 t))
&. #tape (subList (bindingsBinds e0) subtape)
&. #shbinds (bindingsBinds e0)
&. #d2ace (d2ace (select SAccum descr))
- &. #tl (desD1E descr)
+ &. #rp (subList (desD1E descr) rp)
+ &. #d1e (desD1E descr)
&. #contribs (SCons tContribs SNil)
in letBinds e0' $
EPair ext
(weakenExpr wInsertD2Ac e1)
(ELet ext (weakenExpr (autoWeak library
- (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl)
- (#shbinds :++: #d :++: #d2ace :++: #tl))
- e2') $
+ (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: (LPreW #rp #d1e (wUndoSubenv rp) :++: #d2ace))
+ (#shbinds :++: #d :++: #d2ace :++: #d1e))
+ e2) $
expandSubenvZeros
- (autoWeak library #tl (#contribs :++: #shbinds :++: #d :++: #d2ace :++: #tl)
+ (autoWeak library #d1e (#contribs :++: #shbinds :++: #d :++: #d2ace :++: #d1e)
.> wUndoSubenv (subenvD1E (selectSub SMerge descr)))
(select SMerge descr) sub (EVar ext tContribs IZ))
@@ -734,11 +739,12 @@ drev des _ sd | isAbsent sd =
(drevPrimal des e)
(subenvNone (d2e (select SMerge des)))
(ENil ext)
+ (subenvNone (desD1E des))
drev _ _ SpAbsent = error "Absent should be isAbsent"
drev des accumMap (SpSparse sd) =
\e ->
- case drev des accumMap sd e of { Ret e0 subtape e1 sub e2 ->
+ case drev des accumMap sd e of { Ret e0 subtape e1 sub e2 rp ->
subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ ->
Ret e0
subtape
@@ -747,6 +753,7 @@ drev des accumMap (SpSparse sd) =
(emaybe (EVar ext (STMaybe (applySparse sd (d2 (typeOf e)))) IZ)
(inj2 (ENil ext))
(inj1 (weakenExpr (WCopy WSink) e2)))
+ rp
}
drev des accumMap sd = \case
@@ -759,6 +766,7 @@ drev des accumMap sd = \case
(subenvNone (d2e (select SMerge des)))
(let ty = applySparse sd (d2M t)
in EAccum ext (d2M t) SAPHere (ENil ext) sd (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI)))
+ (subenvNone (desD1E des))
Idx2Me tupI ->
Ret BTop
@@ -766,6 +774,7 @@ drev des accumMap sd = \case
(EVar ext (d1 t) (conv1Idx i))
(subenvOnehot (d2e (select SMerge des)) tupI sd)
(EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t)) IZ))
+ (subenvNone (desD1E des))
Idx2Di _ ->
Ret BTop
@@ -773,7 +782,9 @@ drev des accumMap sd = \case
(EVar ext (d1 t) (conv1Idx i))
(subenvNone (d2e (select SMerge des)))
(ENil ext)
+ (subenvNone (desD1E des))
+{-
ELet _ (rhs :: Expr _ _ a) body
| ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && typeHasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge
, RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody sdBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) sd body
@@ -802,25 +813,36 @@ drev des accumMap sd = \case
plus_RHS_Body
(EVar ext (contribTupTy des subRHS) IZ)
(EFst ext (EVar ext bodyResType (IS IZ))))
+-}
EPair _ a b
| SpPair sd1 sd2 <- sd
- , Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil)
+ , Rets binds subtape (RetPair a1 subA a2 rpA `SCons` RetPair b1 subB b2 rpB `SCons` SNil)
<- retConcat des $ toSingleRet (drev des accumMap sd1 a) `SCons` toSingleRet (drev des accumMap sd2 b) `SCons` SNil
, let dt = STPair (applySparse sd1 (d2 (typeOf a))) (applySparse sd2 (d2 (typeOf b))) ->
subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B ->
+ subenvUnion rpA rpB $ \rp rpA' rpB' ->
Ret binds
subtape
(EPair ext a1 b1)
subBoth
(ELet ext (ELet ext (EFst ext (EVar ext dt IZ))
- (weakenExpr (WCopy WSink) a2)) $
+ (weakenExpr (WCopy (WSink .> _ rpA')) a2)) $
ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ)))
(weakenExpr (WCopy (WSink .> WSink)) b2)) $
plus_A_B
(EVar ext (contribTupTy des subA) (IS IZ))
(EVar ext (contribTupTy des subB) IZ))
+ rp
+ where
+ rpWeak :: SList f tapebinds
+ -> SList g reprim1
+ -> SList h reprim2
+ -> Append tapebinds (Append reprim1 (D2AcE (Select env sto "accum"))) :>
+ Append tapebinds (Append reprim2 (D2AcE (Select env sto "accum")))
+ rpWeak tb rp1 rp2 = wCopies tb (weakenOver)
+{-
EFst _ e
| Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair sd SpAbsent) e
, STPair t1 _ <- typeOf e ->
@@ -1378,6 +1400,8 @@ drev des accumMap sd = \case
EFold1InnerD1{} -> err_targetlang "EFold1InnerD1"
EFold1InnerD2{} -> err_targetlang "EFold1InnerD2"
+-}
+ _ -> undefined
where
err_accum = error "Accumulator operations unsupported in the source program"
@@ -1388,6 +1412,7 @@ drev des accumMap sd = \case
contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs)
contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub))
+{-
deriv_extremum :: (?config :: CHADConfig, ScalIsNumeric t ~ True)
=> (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t)))
-> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
@@ -1413,6 +1438,7 @@ deriv_extremum extremum des accumMap sd e
(inj2 (ENil ext))) $
weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
}
+-}
data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s)
@@ -1442,7 +1468,7 @@ drevScoped :: forall a s env sto sd t.
-> RetScoped env sto a s sd t
drevScoped des accumMap argty argsto argids sd expr = case argsto of
SMerge
- | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr
+ | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 rp <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr
, Refl <- lemAppendNil @tapebinds ->
case sub of
SEYes sp sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' sp e2
@@ -1453,7 +1479,7 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of
, Just (VIArr i _) <- argids
, Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap
, Just Refl <- testEquality foundTy (STAccum (d2M argty))
- , Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) sd expr
+ , Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 rp <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) sd expr
, Refl <- lemAppendNil @tapebinds ->
-- Our contribution to the binding's cotangent _here_ is zero (absent),
-- because we're contributing to an earlier binding of the same value
@@ -1472,7 +1498,7 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of
| let accumMap' = case argids of
Just (VIArr i _) -> VarMap.insert i (STAccum (d2M argty)) IZ (VarMap.sink1 accumMap)
_ -> VarMap.sink1 accumMap
- , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' sd expr ->
+ , Ret e0 subtape e1 sub e2 rp <- drev (des `DPush` (argty, argids, argsto)) accumMap' sd expr ->
let library = #d (auto1 @sd)
&. #p (auto1 @(D1 a))
&. #body (subList (bindingsBinds e0) subtape)
@@ -1488,7 +1514,7 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of
e2
SDiscr
- | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr
+ | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 rp <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr
, Refl <- lemAppendNil @tapebinds ->
RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2
@@ -1523,7 +1549,7 @@ drevLambda des accumMap (argty, argsto) sd origef k =
let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in
let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in
case prf1 prodes argty argsto of { Refl ->
- case drev (prodes `DPush` (argty, Nothing, argsto)) accumMapPro sd ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 ->
+ case drev (prodes `DPush` (argty, Nothing, argsto)) accumMapPro sd ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 rp ->
let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in
extractContrib prodes argty argsto subEf $ \argSp getSparseArg ->
let library = #fbinds (bindingsBinds ef0)