diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-23 09:43:33 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-23 09:43:33 +0100 |
commit | e8663e189c41637d348ce100cdab40e8d19ed62c (patch) | |
tree | 44933274659e4f14a72d9bd2adb84f3b0dcf9576 | |
parent | 11b9ead68b7cdf63c272fd83fa293c8110792904 (diff) |
drevScoped returns a data type, not CPS
-rw-r--r-- | src/CHAD.hs | 59 |
1 files changed, 29 insertions, 30 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index c37d379..35c3dda 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -777,15 +777,11 @@ drev des = \case ELet _ (rhs :: Ex _ a) body | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des rhs - , let chooseStorage :: (forall s. ((s == "discr") ~ False) => Storage s -> r) -> r - chooseStorage k -- separate function because it needs a (higher-rank) type signature - | chcLetArrayAccum ?config && hasArrays (typeOf rhs) = k SAccum - | otherwise = k SMerge -> - chooseStorage $ \storage -> - drevScoped des (typeOf rhs) storage body $ \(body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 -> - let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 in - case lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) of { Refl -> - case lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) of { Refl -> + , ChosenStorage storage <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge + , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des (typeOf rhs) storage body + , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 + , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) + , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) -> subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody))) (d2 (typeOf rhs)) in Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') @@ -805,7 +801,6 @@ drev des = \case plus_RHS_Body (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ) (EFst ext (EVar ext bodyResType (IS IZ)))) - }} EPair _ a b | Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) @@ -1223,33 +1218,37 @@ drev des = \case hasArrays STScal{} = False hasArrays STAccum{} = error "Accumulators not allowed in source program" -drevScoped :: forall a s env sto t r. +data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s) + +data RetScoped env0 sto a s t = + forall shbinds tapebinds env0Merge. + RetScoped + (Bindings Ex (D1E (a : env0)) shbinds) -- shared binds + (Subenv shbinds tapebinds) + (Ex (Append shbinds (D1E (a : env0))) (D1 t)) + (Subenv (Select env0 sto "merge") env0Merge) + -- ^ merge contributions to the _enclosing_ merge environment + (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) + (If (s == "discr") (Tup (D2E env0Merge)) + (TPair (Tup (D2E env0Merge)) (D2 a)))) + -- ^ the merge contributions, plus the cotangent to the argument + -- (if there is any) +deriving instance Show (RetScoped env0 sto a s t) + +drevScoped :: forall a s env sto t. (?config :: CHADConfig) => Descr env sto -> STy a -> Storage s -> Ex (a : env) t - -- | This is mostly a 'Ret', but with a little twist. - -> (forall shbinds tapebinds env0Merge. - Bindings Ex (D1E (a : env)) shbinds -- ^ as usual - -> Subenv shbinds tapebinds -- ^ as usual - -> Ex (Append shbinds (D1E (a : env))) (D1 t) -- ^ as usual - -> Subenv (Select env sto "merge") env0Merge - -- ^ merge contributions to the enclosing merge environment - -> Ex (D2 t : Append tapebinds (D2AcE (Select env sto "accum"))) - (If (s == "discr") (Tup (D2E env0Merge)) - (TPair (Tup (D2E env0Merge)) (D2 a))) - -- ^ the merge contributions, plus the cotangent to the argument - -- (if there is any) - -> r) - -> r -drevScoped des argty argsto expr k + -> RetScoped env sto a s t +drevScoped des argty argsto expr | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argsto)) expr = case argsto of SMerge -> case sub of - SEYes sub' -> k e0 subtape e1 sub' e2 - SENo sub' -> k e0 subtape e1 sub' (EPair ext e2 (EZero argty)) + SEYes sub' -> RetScoped e0 subtape e1 sub' e2 + SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (EZero argty)) SAccum -> - k e0 subtape e1 sub $ + RetScoped e0 subtape e1 sub $ EWith (EZero argty) $ weakenExpr (autoWeak (#d (auto1 @(D2 t)) &. #body (subList (bindingsBinds e0) subtape) @@ -1258,4 +1257,4 @@ drevScoped des argty argsto expr k (#d :++: #body :++: #ac :++: #tl) (#ac :++: #d :++: #body :++: #tl)) e2 - SDiscr -> k e0 subtape e1 sub e2 + SDiscr -> RetScoped e0 subtape e1 sub e2 |