From e8663e189c41637d348ce100cdab40e8d19ed62c Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Sat, 23 Nov 2024 09:43:33 +0100
Subject: drevScoped returns a data type, not CPS

---
 src/CHAD.hs | 59 +++++++++++++++++++++++++++++------------------------------
 1 file changed, 29 insertions(+), 30 deletions(-)

(limited to 'src')

diff --git a/src/CHAD.hs b/src/CHAD.hs
index c37d379..35c3dda 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -777,15 +777,11 @@ drev des = \case
 
   ELet _ (rhs :: Ex _ a) body
     | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des rhs
-    , let chooseStorage :: (forall s. ((s == "discr") ~ False) => Storage s -> r) -> r
-          chooseStorage k  -- separate function because it needs a (higher-rank) type signature
-            | chcLetArrayAccum ?config && hasArrays (typeOf rhs) = k SAccum
-            | otherwise = k SMerge ->
-    chooseStorage $ \storage ->
-    drevScoped des (typeOf rhs) storage body $ \(body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 ->
-    let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 in
-    case lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) of { Refl ->
-    case lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) of { Refl ->
+    , ChosenStorage storage <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge
+    , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des (typeOf rhs) storage body
+    , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0
+    , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env)
+    , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) ->
     subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body ->
     let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody))) (d2 (typeOf rhs)) in
     Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0')
@@ -805,7 +801,6 @@ drev des = \case
          plus_RHS_Body
            (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ)
            (EFst ext (EVar ext bodyResType (IS IZ))))
-    }}
 
   EPair _ a b
     | Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil)
@@ -1223,33 +1218,37 @@ drev des = \case
     hasArrays STScal{} = False
     hasArrays STAccum{} = error "Accumulators not allowed in source program"
 
-drevScoped :: forall a s env sto t r.
+data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s)
+
+data RetScoped env0 sto a s t =
+  forall shbinds tapebinds env0Merge.
+    RetScoped
+        (Bindings Ex (D1E (a : env0)) shbinds)  -- shared binds
+        (Subenv shbinds tapebinds)
+        (Ex (Append shbinds (D1E (a : env0))) (D1 t))
+        (Subenv (Select env0 sto "merge") env0Merge)
+           -- ^ merge contributions to the _enclosing_ merge environment
+        (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum")))
+           (If (s == "discr") (Tup (D2E env0Merge))
+                              (TPair (Tup (D2E env0Merge)) (D2 a))))
+          -- ^ the merge contributions, plus the cotangent to the argument
+          -- (if there is any)
+deriving instance Show (RetScoped env0 sto a s t)
+
+drevScoped :: forall a s env sto t.
               (?config :: CHADConfig)
            => Descr env sto -> STy a -> Storage s
            -> Ex (a : env) t
-              -- | This is mostly a 'Ret', but with a little twist.
-           -> (forall shbinds tapebinds env0Merge.
-                 Bindings Ex (D1E (a : env)) shbinds         -- ^ as usual
-              -> Subenv shbinds tapebinds                    -- ^ as usual
-              -> Ex (Append shbinds (D1E (a : env))) (D1 t)  -- ^ as usual
-              -> Subenv (Select env sto "merge") env0Merge
-                   -- ^ merge contributions to the enclosing merge environment
-              -> Ex (D2 t : Append tapebinds (D2AcE (Select env sto "accum")))
-                    (If (s == "discr") (Tup (D2E env0Merge))
-                                       (TPair (Tup (D2E env0Merge)) (D2 a)))
-                   -- ^ the merge contributions, plus the cotangent to the argument
-                   -- (if there is any)
-              -> r)
-           -> r
-drevScoped des argty argsto expr k
+           -> RetScoped env sto a s t
+drevScoped des argty argsto expr
   | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argsto)) expr
   = case argsto of
       SMerge ->
         case sub of
-          SEYes sub' -> k e0 subtape e1 sub' e2
-          SENo sub' -> k e0 subtape e1 sub' (EPair ext e2 (EZero argty))
+          SEYes sub' -> RetScoped e0 subtape e1 sub' e2
+          SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (EZero argty))
       SAccum ->
-        k e0 subtape e1 sub $
+        RetScoped e0 subtape e1 sub $
           EWith (EZero argty) $
             weakenExpr (autoWeak (#d (auto1 @(D2 t))
                                   &. #body (subList (bindingsBinds e0) subtape)
@@ -1258,4 +1257,4 @@ drevScoped des argty argsto expr k
                                  (#d :++: #body :++: #ac :++: #tl)
                                  (#ac :++: #d :++: #body :++: #tl))
                        e2
-      SDiscr -> k e0 subtape e1 sub e2
+      SDiscr -> RetScoped e0 subtape e1 sub e2
-- 
cgit v1.2.3-70-g09d2