summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-22 13:59:49 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-22 13:59:49 +0100
commit11b9ead68b7cdf63c272fd83fa293c8110792904 (patch)
tree0bc8c60132be450addd0f4d53b54aae54f631820
parentb8c162ce9cb1faeec621b751fff9aff46e022417 (diff)
Factor let storage-dependent scoping logic into separate function
-rw-r--r--src/CHAD.hs105
1 files changed, 57 insertions, 48 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index b35836a..c37d379 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -33,6 +33,8 @@ module CHAD (
import Data.Functor.Const
import Data.Kind (Type)
+import Data.Type.Bool (If)
+import Data.Type.Equality (type (==))
import GHC.Stack (HasCallStack)
import GHC.TypeLits (Symbol)
@@ -774,66 +776,36 @@ drev des = \case
(ENil ext)
ELet _ (rhs :: Ex _ a) body
- | chcLetArrayAccum ?config && hasArrays (typeOf rhs)
- , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2
- <- drev des rhs
- , Ret (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2
- <- drev (des `DPush` (typeOf rhs, SAccum)) body
- , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0
- , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env)
- , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) ->
+ | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des rhs
+ , let chooseStorage :: (forall s. ((s == "discr") ~ False) => Storage s -> r) -> r
+ chooseStorage k -- separate function because it needs a (higher-rank) type signature
+ | chcLetArrayAccum ?config && hasArrays (typeOf rhs) = k SAccum
+ | otherwise = k SMerge ->
+ chooseStorage $ \storage ->
+ drevScoped des (typeOf rhs) storage body $ \(body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 ->
+ let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 in
+ case lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) of { Refl ->
+ case lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) of { Refl ->
subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body ->
let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody))) (d2 (typeOf rhs)) in
Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0')
(subenvConcat (SENo @d1_a subtapeRHS) subtapeBody)
(weakenExpr wbody0' body1)
subBoth
- (ELet ext
- (EWith (EZero (typeOf rhs)) $
- weakenExpr (autoWeak (#d (auto1 @(D2 t))
- &. #body (subList (bindingsBinds body0) subtapeBody)
- &. #ac (auto1 @(TAccum (D2 a)))
- &. #rhs (subList (bindingsBinds rhs0) subtapeRHS)
- &. #tl (d2ace (select SAccum des)))
- (#d :++: #body :++: #ac :++: #tl)
- (#ac :++: #d :++: (#body :++: #rhs) :++: #tl))
- body2) $
- ELet ext
- (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $
- weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $
- plus_RHS_Body
- (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ)
- (EFst ext (EVar ext bodyResType (IS IZ))))
-
- ELet _ rhs body
- | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2
- <- drev des rhs
- , Ret (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2
- <- drev (des `DPush` (typeOf rhs, SMerge)) body
- , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0
- , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env)
- , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) ->
- popFromScope des (typeOf rhs) subBody body2 $ \subBody' body2' ->
- subenvPlus (select SMerge des) subRHS subBody' $ \subBoth _ _ plus_RHS_Body ->
- let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody'))) (d2 (typeOf rhs)) in
- Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0')
- (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody)
- (weakenExpr wbody0' body1)
- subBoth
- (ELet ext
- (weakenExpr (autoWeak (#d (auto1 @(D2 t))
- &. #body (subList (bindingsBinds body0) subtapeBody)
- &. #rhs (subList (bindingsBinds rhs0) subtapeRHS)
- &. #tl (d2ace (select SAccum des)))
- (#d :++: #body :++: #tl)
- (#d :++: (#body :++: #rhs) :++: #tl))
- body2') $
+ (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t))
+ &. #body (subList (bindingsBinds body0) subtapeBody)
+ &. #rhs (subList (bindingsBinds rhs0) subtapeRHS)
+ &. #tl (d2ace (select SAccum des)))
+ (#d :++: #body :++: #tl)
+ (#d :++: (#body :++: #rhs) :++: #tl))
+ body2) $
ELet ext
(ELet ext (ESnd ext (EVar ext bodyResType IZ)) $
weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $
plus_RHS_Body
(EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ)
(EFst ext (EVar ext bodyResType (IS IZ))))
+ }}
EPair _ a b
| Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil)
@@ -1250,3 +1222,40 @@ drev des = \case
hasArrays STArr{} = True
hasArrays STScal{} = False
hasArrays STAccum{} = error "Accumulators not allowed in source program"
+
+drevScoped :: forall a s env sto t r.
+ (?config :: CHADConfig)
+ => Descr env sto -> STy a -> Storage s
+ -> Ex (a : env) t
+ -- | This is mostly a 'Ret', but with a little twist.
+ -> (forall shbinds tapebinds env0Merge.
+ Bindings Ex (D1E (a : env)) shbinds -- ^ as usual
+ -> Subenv shbinds tapebinds -- ^ as usual
+ -> Ex (Append shbinds (D1E (a : env))) (D1 t) -- ^ as usual
+ -> Subenv (Select env sto "merge") env0Merge
+ -- ^ merge contributions to the enclosing merge environment
+ -> Ex (D2 t : Append tapebinds (D2AcE (Select env sto "accum")))
+ (If (s == "discr") (Tup (D2E env0Merge))
+ (TPair (Tup (D2E env0Merge)) (D2 a)))
+ -- ^ the merge contributions, plus the cotangent to the argument
+ -- (if there is any)
+ -> r)
+ -> r
+drevScoped des argty argsto expr k
+ | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argsto)) expr
+ = case argsto of
+ SMerge ->
+ case sub of
+ SEYes sub' -> k e0 subtape e1 sub' e2
+ SENo sub' -> k e0 subtape e1 sub' (EPair ext e2 (EZero argty))
+ SAccum ->
+ k e0 subtape e1 sub $
+ EWith (EZero argty) $
+ weakenExpr (autoWeak (#d (auto1 @(D2 t))
+ &. #body (subList (bindingsBinds e0) subtape)
+ &. #ac (auto1 @(TAccum (D2 a)))
+ &. #tl (d2ace (select SAccum des)))
+ (#d :++: #body :++: #ac :++: #tl)
+ (#ac :++: #d :++: #body :++: #tl))
+ e2
+ SDiscr -> k e0 subtape e1 sub e2