diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-22 13:59:49 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-22 13:59:49 +0100 |
commit | 11b9ead68b7cdf63c272fd83fa293c8110792904 (patch) | |
tree | 0bc8c60132be450addd0f4d53b54aae54f631820 | |
parent | b8c162ce9cb1faeec621b751fff9aff46e022417 (diff) |
Factor let storage-dependent scoping logic into separate function
-rw-r--r-- | src/CHAD.hs | 105 |
1 files changed, 57 insertions, 48 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index b35836a..c37d379 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -33,6 +33,8 @@ module CHAD ( import Data.Functor.Const import Data.Kind (Type) +import Data.Type.Bool (If) +import Data.Type.Equality (type (==)) import GHC.Stack (HasCallStack) import GHC.TypeLits (Symbol) @@ -774,66 +776,36 @@ drev des = \case (ENil ext) ELet _ (rhs :: Ex _ a) body - | chcLetArrayAccum ?config && hasArrays (typeOf rhs) - , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 - <- drev des rhs - , Ret (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 - <- drev (des `DPush` (typeOf rhs, SAccum)) body - , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 - , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) - , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) -> + | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des rhs + , let chooseStorage :: (forall s. ((s == "discr") ~ False) => Storage s -> r) -> r + chooseStorage k -- separate function because it needs a (higher-rank) type signature + | chcLetArrayAccum ?config && hasArrays (typeOf rhs) = k SAccum + | otherwise = k SMerge -> + chooseStorage $ \storage -> + drevScoped des (typeOf rhs) storage body $ \(body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 -> + let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 in + case lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) of { Refl -> + case lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) of { Refl -> subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody))) (d2 (typeOf rhs)) in Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody) (weakenExpr wbody0' body1) subBoth - (ELet ext - (EWith (EZero (typeOf rhs)) $ - weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #body (subList (bindingsBinds body0) subtapeBody) - &. #ac (auto1 @(TAccum (D2 a))) - &. #rhs (subList (bindingsBinds rhs0) subtapeRHS) - &. #tl (d2ace (select SAccum des))) - (#d :++: #body :++: #ac :++: #tl) - (#ac :++: #d :++: (#body :++: #rhs) :++: #tl)) - body2) $ - ELet ext - (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ - weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $ - plus_RHS_Body - (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ) - (EFst ext (EVar ext bodyResType (IS IZ)))) - - ELet _ rhs body - | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 - <- drev des rhs - , Ret (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 - <- drev (des `DPush` (typeOf rhs, SMerge)) body - , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 - , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) - , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) -> - popFromScope des (typeOf rhs) subBody body2 $ \subBody' body2' -> - subenvPlus (select SMerge des) subRHS subBody' $ \subBoth _ _ plus_RHS_Body -> - let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody'))) (d2 (typeOf rhs)) in - Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') - (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody) - (weakenExpr wbody0' body1) - subBoth - (ELet ext - (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #body (subList (bindingsBinds body0) subtapeBody) - &. #rhs (subList (bindingsBinds rhs0) subtapeRHS) - &. #tl (d2ace (select SAccum des))) - (#d :++: #body :++: #tl) - (#d :++: (#body :++: #rhs) :++: #tl)) - body2') $ + (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t)) + &. #body (subList (bindingsBinds body0) subtapeBody) + &. #rhs (subList (bindingsBinds rhs0) subtapeRHS) + &. #tl (d2ace (select SAccum des))) + (#d :++: #body :++: #tl) + (#d :++: (#body :++: #rhs) :++: #tl)) + body2) $ ELet ext (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $ plus_RHS_Body (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ) (EFst ext (EVar ext bodyResType (IS IZ)))) + }} EPair _ a b | Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) @@ -1250,3 +1222,40 @@ drev des = \case hasArrays STArr{} = True hasArrays STScal{} = False hasArrays STAccum{} = error "Accumulators not allowed in source program" + +drevScoped :: forall a s env sto t r. + (?config :: CHADConfig) + => Descr env sto -> STy a -> Storage s + -> Ex (a : env) t + -- | This is mostly a 'Ret', but with a little twist. + -> (forall shbinds tapebinds env0Merge. + Bindings Ex (D1E (a : env)) shbinds -- ^ as usual + -> Subenv shbinds tapebinds -- ^ as usual + -> Ex (Append shbinds (D1E (a : env))) (D1 t) -- ^ as usual + -> Subenv (Select env sto "merge") env0Merge + -- ^ merge contributions to the enclosing merge environment + -> Ex (D2 t : Append tapebinds (D2AcE (Select env sto "accum"))) + (If (s == "discr") (Tup (D2E env0Merge)) + (TPair (Tup (D2E env0Merge)) (D2 a))) + -- ^ the merge contributions, plus the cotangent to the argument + -- (if there is any) + -> r) + -> r +drevScoped des argty argsto expr k + | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argsto)) expr + = case argsto of + SMerge -> + case sub of + SEYes sub' -> k e0 subtape e1 sub' e2 + SENo sub' -> k e0 subtape e1 sub' (EPair ext e2 (EZero argty)) + SAccum -> + k e0 subtape e1 sub $ + EWith (EZero argty) $ + weakenExpr (autoWeak (#d (auto1 @(D2 t)) + &. #body (subList (bindingsBinds e0) subtape) + &. #ac (auto1 @(TAccum (D2 a))) + &. #tl (d2ace (select SAccum des))) + (#d :++: #body :++: #ac :++: #tl) + (#ac :++: #d :++: #body :++: #tl)) + e2 + SDiscr -> k e0 subtape e1 sub e2 |