summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-23 09:43:33 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-23 09:43:33 +0100
commite8663e189c41637d348ce100cdab40e8d19ed62c (patch)
tree44933274659e4f14a72d9bd2adb84f3b0dcf9576
parent11b9ead68b7cdf63c272fd83fa293c8110792904 (diff)
drevScoped returns a data type, not CPS
-rw-r--r--src/CHAD.hs59
1 files changed, 29 insertions, 30 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index c37d379..35c3dda 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -777,15 +777,11 @@ drev des = \case
ELet _ (rhs :: Ex _ a) body
| Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des rhs
- , let chooseStorage :: (forall s. ((s == "discr") ~ False) => Storage s -> r) -> r
- chooseStorage k -- separate function because it needs a (higher-rank) type signature
- | chcLetArrayAccum ?config && hasArrays (typeOf rhs) = k SAccum
- | otherwise = k SMerge ->
- chooseStorage $ \storage ->
- drevScoped des (typeOf rhs) storage body $ \(body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 ->
- let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 in
- case lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) of { Refl ->
- case lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) of { Refl ->
+ , ChosenStorage storage <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge
+ , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des (typeOf rhs) storage body
+ , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0
+ , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env)
+ , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) ->
subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body ->
let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody))) (d2 (typeOf rhs)) in
Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0')
@@ -805,7 +801,6 @@ drev des = \case
plus_RHS_Body
(EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ)
(EFst ext (EVar ext bodyResType (IS IZ))))
- }}
EPair _ a b
| Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil)
@@ -1223,33 +1218,37 @@ drev des = \case
hasArrays STScal{} = False
hasArrays STAccum{} = error "Accumulators not allowed in source program"
-drevScoped :: forall a s env sto t r.
+data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s)
+
+data RetScoped env0 sto a s t =
+ forall shbinds tapebinds env0Merge.
+ RetScoped
+ (Bindings Ex (D1E (a : env0)) shbinds) -- shared binds
+ (Subenv shbinds tapebinds)
+ (Ex (Append shbinds (D1E (a : env0))) (D1 t))
+ (Subenv (Select env0 sto "merge") env0Merge)
+ -- ^ merge contributions to the _enclosing_ merge environment
+ (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum")))
+ (If (s == "discr") (Tup (D2E env0Merge))
+ (TPair (Tup (D2E env0Merge)) (D2 a))))
+ -- ^ the merge contributions, plus the cotangent to the argument
+ -- (if there is any)
+deriving instance Show (RetScoped env0 sto a s t)
+
+drevScoped :: forall a s env sto t.
(?config :: CHADConfig)
=> Descr env sto -> STy a -> Storage s
-> Ex (a : env) t
- -- | This is mostly a 'Ret', but with a little twist.
- -> (forall shbinds tapebinds env0Merge.
- Bindings Ex (D1E (a : env)) shbinds -- ^ as usual
- -> Subenv shbinds tapebinds -- ^ as usual
- -> Ex (Append shbinds (D1E (a : env))) (D1 t) -- ^ as usual
- -> Subenv (Select env sto "merge") env0Merge
- -- ^ merge contributions to the enclosing merge environment
- -> Ex (D2 t : Append tapebinds (D2AcE (Select env sto "accum")))
- (If (s == "discr") (Tup (D2E env0Merge))
- (TPair (Tup (D2E env0Merge)) (D2 a)))
- -- ^ the merge contributions, plus the cotangent to the argument
- -- (if there is any)
- -> r)
- -> r
-drevScoped des argty argsto expr k
+ -> RetScoped env sto a s t
+drevScoped des argty argsto expr
| Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argsto)) expr
= case argsto of
SMerge ->
case sub of
- SEYes sub' -> k e0 subtape e1 sub' e2
- SENo sub' -> k e0 subtape e1 sub' (EPair ext e2 (EZero argty))
+ SEYes sub' -> RetScoped e0 subtape e1 sub' e2
+ SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (EZero argty))
SAccum ->
- k e0 subtape e1 sub $
+ RetScoped e0 subtape e1 sub $
EWith (EZero argty) $
weakenExpr (autoWeak (#d (auto1 @(D2 t))
&. #body (subList (bindingsBinds e0) subtape)
@@ -1258,4 +1257,4 @@ drevScoped des argty argsto expr k
(#d :++: #body :++: #ac :++: #tl)
(#ac :++: #d :++: #body :++: #tl))
e2
- SDiscr -> k e0 subtape e1 sub e2
+ SDiscr -> RetScoped e0 subtape e1 sub e2