summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-28 22:40:41 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-28 22:40:41 +0100
commitc06b4bd71a94601d467b509a26c08020d1fbd794 (patch)
treeb16981c769231ef4af2c3ec5f002a01f857d95c6 /src/CHAD.hs
parenta3ba3bdc5c2f9606a0b98cdf53183841cca07eac (diff)
Pass around an accumMap (but it's empty still)
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs73
1 files changed, 38 insertions, 35 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index be308cd..6a4d5f5 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -36,6 +36,7 @@ import Data.Type.Bool (If)
import Data.Type.Equality (type (==))
import GHC.Stack (HasCallStack)
+import Analysis.Identity (ValId(..))
import AST
import AST.Bindings
import AST.Count
@@ -45,6 +46,8 @@ import CHAD.Accum
import CHAD.EnvDescr
import CHAD.Types
import Data
+import qualified Data.VarMap as VarMap
+import Data.VarMap (VarMap)
import Lemmas
@@ -558,9 +561,9 @@ freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) =
drev :: forall env sto t.
(?config :: CHADConfig)
- => Descr env sto
- -> Ex env t -> Ret env sto t
-drev des = \case
+ => Descr env sto -> VarMap Int env
+ -> Expr ValId env t -> Ret env sto t
+drev des accumMap = \case
EVar _ t i ->
case conv2Idx des i of
Idx2Ac accI ->
@@ -584,10 +587,10 @@ drev des = \case
(subenvNone (select SMerge des))
(ENil ext)
- ELet _ (rhs :: Ex _ a) body
- | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des rhs
+ ELet _ (rhs :: Expr _ _ a) body
+ | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des accumMap rhs
, ChosenStorage storage <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge
- , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des (typeOf rhs) storage body
+ , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des accumMap (typeOf rhs) storage body
, let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0
, Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env)
, Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) ->
@@ -613,7 +616,7 @@ drev des = \case
EPair _ a b
| Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil)
- <- retConcat des $ drev des a `SCons` drev des b `SCons` SNil
+ <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil
, let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) ->
subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B ->
Ret binds
@@ -632,7 +635,7 @@ drev des = \case
(EVar ext (STMaybe (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ))
EFst _ e
- | Ret e0 subtape e1 sub e2 <- drev des e
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap e
, STPair t1 t2 <- typeOf e ->
Ret e0
subtape
@@ -642,7 +645,7 @@ drev des = \case
weakenExpr (WCopy WSink) e2)
ESnd _ e
- | Ret e0 subtape e1 sub e2 <- drev des e
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap e
, STPair t1 t2 <- typeOf e ->
Ret e0
subtape
@@ -654,7 +657,7 @@ drev des = \case
ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext)
EInl _ t2 e
- | Ret e0 subtape e1 sub e2 <- drev des e ->
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap e ->
Ret e0
subtape
(EInl ext (d1 t2) e1)
@@ -667,7 +670,7 @@ drev des = \case
(EVar ext (STMaybe (STEither (d2 (typeOf e)) (d2 t2))) IZ))
EInr _ t1 e
- | Ret e0 subtape e1 sub e2 <- drev des e ->
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap e ->
Ret e0
subtape
(EInr ext (d1 t1) e1)
@@ -679,13 +682,13 @@ drev des = \case
(weakenExpr (WCopy (wSinks' @[_,_])) e2))
(EVar ext (STMaybe (STEither (d2 t1) (d2 (typeOf e)))) IZ))
- ECase _ e (a :: Ex _ t) b
+ ECase _ e (a :: Expr _ _ t) b
| STEither t1 t2 <- typeOf e
- , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des e
+ , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap e
, ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge
, ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge
- , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des t1 storage1 a
- , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des t2 storage2 b
+ , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des accumMap t1 storage1 a
+ , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des accumMap t2 storage2 b
, Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
, Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
, let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA)
@@ -762,7 +765,7 @@ drev des = \case
(ENil ext)
EOp _ op e
- | Ret e0 subtape e1 sub e2 <- drev des e ->
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap e ->
case d2op op of
Linear d2opfun ->
Ret e0
@@ -783,15 +786,15 @@ drev des = \case
ECustom _ _ _ storety _ pr du a b
-- allowed to ignore a2 because 'a' is the part of the input that is inactive
| Rets binds subtape (RetPair a1 _ _ `SCons` RetPair b1 bsub b2 `SCons` SNil)
- <- retConcat des $ drev des a `SCons` drev des b `SCons` SNil ->
+ <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil ->
Ret (binds `BPush` (typeOf a1, a1)
`BPush` (typeOf b1, weakenExpr WSink b1)
- `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) pr)
+ `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr))
`BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ)))
(SEYes (SENo (SENo (SENo subtape))))
(EFst ext (EVar ext (typeOf pr) (IS IZ)))
bsub
- (ELet ext (weakenExpr (WCopy (WCopy WClosed)) du) $
+ (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $
weakenExpr (WCopy (WSink .> WSink)) b2)
EError _ t s ->
@@ -808,8 +811,8 @@ drev des = \case
(subenvNone (select SMerge des))
(ENil ext)
- EBuild _ (ndim :: SNat ndim) she (orige :: Ex _ eltty)
- | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des she -- allowed to ignore she2 here because she has a discrete result
+ EBuild _ (ndim :: SNat ndim) she (orige :: Expr _ _ eltty)
+ | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des accumMap she -- allowed to ignore she2 here because she has a discrete result
, let eltty = typeOf orige
, shty :: STy shty <- tTup (sreplicate ndim tIx)
, Refl <- indexTupD1Id ndim ->
@@ -817,7 +820,7 @@ drev des = \case
let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in
subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro ->
- case drev (prodes `DPush` (shty, SDiscr)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 ->
+ case drev (prodes `DPush` (shty, SDiscr)) (VarMap.sink1 (VarMap.subMap usedSub accumMap)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 ->
case assertSubenvEmpty sub of { Refl ->
let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in
let collectexpr = bindingsCollect e0 subtapeE in
@@ -881,7 +884,7 @@ drev des = \case
}}
EUnit _ e
- | Ret e0 subtape e1 sub e2 <- drev des e ->
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap e ->
Ret e0
subtape
(EUnit ext e1)
@@ -895,7 +898,7 @@ drev des = \case
EReplicate1Inner _ en e
-- We're allowed to ignore en2 here because the output of 'ei' is discrete.
| Rets binds subtape (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil)
- <- retConcat des $ drev des en `SCons` drev des e `SCons` SNil
+ <- retConcat des $ drev des accumMap en `SCons` drev des accumMap e `SCons` SNil
, let STArr ndim eltty = typeOf e ->
Ret binds
subtape
@@ -911,7 +914,7 @@ drev des = \case
(EVar ext (d2 (STArr (SS ndim) eltty)) IZ))
EIdx0 _ e
- | Ret e0 subtape e1 sub e2 <- drev des e
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap e
, STArr _ t <- typeOf e ->
Ret e0
subtape
@@ -925,7 +928,7 @@ drev des = \case
EIdx1 _ e ei
-- We're allowed to ignore ei2 here because the output of 'ei' is discrete.
| Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)
- <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil
+ <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil
, STArr (SS n) eltty <- typeOf e ->
Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1)
`BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ)))
@@ -942,7 +945,7 @@ drev des = \case
EIdx _ e ei
-- We're allowed to ignore ei2 here because the output of 'ei' is discrete.
| Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)
- <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil
+ <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil
, STArr n eltty <- typeOf e
, Refl <- indexTupD1Id n
, let tIxN = tTup (sreplicate n tIx) ->
@@ -962,7 +965,7 @@ drev des = \case
EShape _ e
-- Allowed to ignore e2 here because the output of EShape is discrete,
-- hence we'd be passing a zero cotangent to e2 anyway.
- | Ret e0 subtape e1 _ _ <- drev des e
+ | Ret e0 subtape e1 _ _ <- drev des accumMap e
, STArr n _ <- typeOf e
, Refl <- indexTupD1Id n ->
Ret e0
@@ -972,7 +975,7 @@ drev des = \case
(ENil ext)
ESum1Inner _ e
- | Ret e0 subtape e1 sub e2 <- drev des e
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap e
, STArr (SS n) t <- typeOf e ->
Ret (e0 `BPush` (STArr (SS n) t, e1)
`BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ)))
@@ -1010,9 +1013,9 @@ drev des = \case
deriv_extremum :: ScalIsNumeric t' ~ True
=> (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t')))
- -> Ex env (TArr (S n) (TScal t')) -> Ret env sto (TArr n (TScal t'))
+ -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto (TArr n (TScal t'))
deriv_extremum extremum e
- | Ret e0 subtape e1 sub e2 <- drev des e
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap e
, at@(STArr (SS n) t@(STScal st)) <- typeOf e
, let at' = STArr n t
, let tIxN = tTup (sreplicate (SS n) tIx) =
@@ -1052,11 +1055,11 @@ deriving instance Show (RetScoped env0 sto a s t)
drevScoped :: forall a s env sto t.
(?config :: CHADConfig)
- => Descr env sto -> STy a -> Storage s
- -> Ex (a : env) t
+ => Descr env sto -> VarMap Int env -> STy a -> Storage s
+ -> Expr ValId (a : env) t
-> RetScoped env sto a s t
-drevScoped des argty argsto expr
- | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argsto)) expr
+drevScoped des accumMap argty argsto expr
+ | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argsto)) (VarMap.sink1 accumMap) expr
= case argsto of
SMerge ->
case sub of