diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-28 22:40:41 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-28 22:40:41 +0100 |
commit | c06b4bd71a94601d467b509a26c08020d1fbd794 (patch) | |
tree | b16981c769231ef4af2c3ec5f002a01f857d95c6 /src/CHAD.hs | |
parent | a3ba3bdc5c2f9606a0b98cdf53183841cca07eac (diff) |
Pass around an accumMap (but it's empty still)
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 73 |
1 files changed, 38 insertions, 35 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index be308cd..6a4d5f5 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -36,6 +36,7 @@ import Data.Type.Bool (If) import Data.Type.Equality (type (==)) import GHC.Stack (HasCallStack) +import Analysis.Identity (ValId(..)) import AST import AST.Bindings import AST.Count @@ -45,6 +46,8 @@ import CHAD.Accum import CHAD.EnvDescr import CHAD.Types import Data +import qualified Data.VarMap as VarMap +import Data.VarMap (VarMap) import Lemmas @@ -558,9 +561,9 @@ freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) = drev :: forall env sto t. (?config :: CHADConfig) - => Descr env sto - -> Ex env t -> Ret env sto t -drev des = \case + => Descr env sto -> VarMap Int env + -> Expr ValId env t -> Ret env sto t +drev des accumMap = \case EVar _ t i -> case conv2Idx des i of Idx2Ac accI -> @@ -584,10 +587,10 @@ drev des = \case (subenvNone (select SMerge des)) (ENil ext) - ELet _ (rhs :: Ex _ a) body - | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des rhs + ELet _ (rhs :: Expr _ _ a) body + | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des accumMap rhs , ChosenStorage storage <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge - , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des (typeOf rhs) storage body + , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des accumMap (typeOf rhs) storage body , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) -> @@ -613,7 +616,7 @@ drev des = \case EPair _ a b | Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) - <- retConcat des $ drev des a `SCons` drev des b `SCons` SNil + <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) -> subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B -> Ret binds @@ -632,7 +635,7 @@ drev des = \case (EVar ext (STMaybe (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ)) EFst _ e - | Ret e0 subtape e1 sub e2 <- drev des e + | Ret e0 subtape e1 sub e2 <- drev des accumMap e , STPair t1 t2 <- typeOf e -> Ret e0 subtape @@ -642,7 +645,7 @@ drev des = \case weakenExpr (WCopy WSink) e2) ESnd _ e - | Ret e0 subtape e1 sub e2 <- drev des e + | Ret e0 subtape e1 sub e2 <- drev des accumMap e , STPair t1 t2 <- typeOf e -> Ret e0 subtape @@ -654,7 +657,7 @@ drev des = \case ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext) EInl _ t2 e - | Ret e0 subtape e1 sub e2 <- drev des e -> + | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> Ret e0 subtape (EInl ext (d1 t2) e1) @@ -667,7 +670,7 @@ drev des = \case (EVar ext (STMaybe (STEither (d2 (typeOf e)) (d2 t2))) IZ)) EInr _ t1 e - | Ret e0 subtape e1 sub e2 <- drev des e -> + | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> Ret e0 subtape (EInr ext (d1 t1) e1) @@ -679,13 +682,13 @@ drev des = \case (weakenExpr (WCopy (wSinks' @[_,_])) e2)) (EVar ext (STMaybe (STEither (d2 t1) (d2 (typeOf e)))) IZ)) - ECase _ e (a :: Ex _ t) b + ECase _ e (a :: Expr _ _ t) b | STEither t1 t2 <- typeOf e - , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des e + , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap e , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge - , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des t1 storage1 a - , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des t2 storage2 b + , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des accumMap t1 storage1 a + , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des accumMap t2 storage2 b , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) , let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA) @@ -762,7 +765,7 @@ drev des = \case (ENil ext) EOp _ op e - | Ret e0 subtape e1 sub e2 <- drev des e -> + | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> case d2op op of Linear d2opfun -> Ret e0 @@ -783,15 +786,15 @@ drev des = \case ECustom _ _ _ storety _ pr du a b -- allowed to ignore a2 because 'a' is the part of the input that is inactive | Rets binds subtape (RetPair a1 _ _ `SCons` RetPair b1 bsub b2 `SCons` SNil) - <- retConcat des $ drev des a `SCons` drev des b `SCons` SNil -> + <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil -> Ret (binds `BPush` (typeOf a1, a1) `BPush` (typeOf b1, weakenExpr WSink b1) - `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) pr) + `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)) `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ))) (SEYes (SENo (SENo (SENo subtape)))) (EFst ext (EVar ext (typeOf pr) (IS IZ))) bsub - (ELet ext (weakenExpr (WCopy (WCopy WClosed)) du) $ + (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ weakenExpr (WCopy (WSink .> WSink)) b2) EError _ t s -> @@ -808,8 +811,8 @@ drev des = \case (subenvNone (select SMerge des)) (ENil ext) - EBuild _ (ndim :: SNat ndim) she (orige :: Ex _ eltty) - | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des she -- allowed to ignore she2 here because she has a discrete result + EBuild _ (ndim :: SNat ndim) she (orige :: Expr _ _ eltty) + | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des accumMap she -- allowed to ignore she2 here because she has a discrete result , let eltty = typeOf orige , shty :: STy shty <- tTup (sreplicate ndim tIx) , Refl <- indexTupD1Id ndim -> @@ -817,7 +820,7 @@ drev des = \case let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro -> - case drev (prodes `DPush` (shty, SDiscr)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> + case drev (prodes `DPush` (shty, SDiscr)) (VarMap.sink1 (VarMap.subMap usedSub accumMap)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> case assertSubenvEmpty sub of { Refl -> let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in let collectexpr = bindingsCollect e0 subtapeE in @@ -881,7 +884,7 @@ drev des = \case }} EUnit _ e - | Ret e0 subtape e1 sub e2 <- drev des e -> + | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> Ret e0 subtape (EUnit ext e1) @@ -895,7 +898,7 @@ drev des = \case EReplicate1Inner _ en e -- We're allowed to ignore en2 here because the output of 'ei' is discrete. | Rets binds subtape (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil) - <- retConcat des $ drev des en `SCons` drev des e `SCons` SNil + <- retConcat des $ drev des accumMap en `SCons` drev des accumMap e `SCons` SNil , let STArr ndim eltty = typeOf e -> Ret binds subtape @@ -911,7 +914,7 @@ drev des = \case (EVar ext (d2 (STArr (SS ndim) eltty)) IZ)) EIdx0 _ e - | Ret e0 subtape e1 sub e2 <- drev des e + | Ret e0 subtape e1 sub e2 <- drev des accumMap e , STArr _ t <- typeOf e -> Ret e0 subtape @@ -925,7 +928,7 @@ drev des = \case EIdx1 _ e ei -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) - <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil + <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil , STArr (SS n) eltty <- typeOf e -> Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1) `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ))) @@ -942,7 +945,7 @@ drev des = \case EIdx _ e ei -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) - <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil + <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil , STArr n eltty <- typeOf e , Refl <- indexTupD1Id n , let tIxN = tTup (sreplicate n tIx) -> @@ -962,7 +965,7 @@ drev des = \case EShape _ e -- Allowed to ignore e2 here because the output of EShape is discrete, -- hence we'd be passing a zero cotangent to e2 anyway. - | Ret e0 subtape e1 _ _ <- drev des e + | Ret e0 subtape e1 _ _ <- drev des accumMap e , STArr n _ <- typeOf e , Refl <- indexTupD1Id n -> Ret e0 @@ -972,7 +975,7 @@ drev des = \case (ENil ext) ESum1Inner _ e - | Ret e0 subtape e1 sub e2 <- drev des e + | Ret e0 subtape e1 sub e2 <- drev des accumMap e , STArr (SS n) t <- typeOf e -> Ret (e0 `BPush` (STArr (SS n) t, e1) `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ))) @@ -1010,9 +1013,9 @@ drev des = \case deriv_extremum :: ScalIsNumeric t' ~ True => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t'))) - -> Ex env (TArr (S n) (TScal t')) -> Ret env sto (TArr n (TScal t')) + -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto (TArr n (TScal t')) deriv_extremum extremum e - | Ret e0 subtape e1 sub e2 <- drev des e + | Ret e0 subtape e1 sub e2 <- drev des accumMap e , at@(STArr (SS n) t@(STScal st)) <- typeOf e , let at' = STArr n t , let tIxN = tTup (sreplicate (SS n) tIx) = @@ -1052,11 +1055,11 @@ deriving instance Show (RetScoped env0 sto a s t) drevScoped :: forall a s env sto t. (?config :: CHADConfig) - => Descr env sto -> STy a -> Storage s - -> Ex (a : env) t + => Descr env sto -> VarMap Int env -> STy a -> Storage s + -> Expr ValId (a : env) t -> RetScoped env sto a s t -drevScoped des argty argsto expr - | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argsto)) expr +drevScoped des accumMap argty argsto expr + | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argsto)) (VarMap.sink1 accumMap) expr = case argsto of SMerge -> case sub of |