From 883bfad4bf0c5f383c986ebbb9e9dab61e3c2098 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Sat, 23 Nov 2024 12:11:45 +0100
Subject: Use accum storage for Case too

---
 src/CHAD.hs | 38 +++++++++++---------------------------
 1 file changed, 11 insertions(+), 27 deletions(-)

(limited to 'src')

diff --git a/src/CHAD.hs b/src/CHAD.hs
index 35c3dda..5ee13fa 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -500,20 +500,6 @@ assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl
 assertSubenvEmpty SETop = Refl
 assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty"
 
-popFromScope
-  :: Descr env0 sto
-  -> STy a
-  -> Subenv (Select (a : env0) ("merge" : sto) "merge") envSub
-  -> Ex env (Tup (D2E envSub))
-  -> (forall envSub'.
-         Subenv (Select env0 sto "merge") envSub'
-      -> Ex env (TPair (Tup (D2E envSub')) (D2 a))
-      -> r)
-  -> r
-popFromScope _ ty sub e k = case sub of
-  SEYes sub' -> k sub' e
-  SENo sub' -> k sub' $ EPair ext e (zero ty)
-
 
 --------------------------------- ACCUMULATORS ---------------------------------
 
@@ -867,13 +853,13 @@ drev des = \case
               (EError (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl")
               (weakenExpr (WCopy (wSinks' @[_,_])) e2)))
 
-  ECase{} | chcCaseArrayAccum ?config -> error "chcCaseArrayAccum unsupported"
-
   ECase _ e (a :: Ex _ t) b
     | STEither t1 t2 <- typeOf e
     , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des e
-    , Ret (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drev (des `DPush` (t1, SMerge)) a
-    , Ret (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drev (des `DPush` (t2, SMerge)) b
+    , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge
+    , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge
+    , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des t1 storage1 a
+    , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des t2 storage2 b
     , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
     , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
     , let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA)
@@ -884,9 +870,7 @@ drev des = \case
     , let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0
     , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0
     ->
-    popFromScope des t1 subA a2 $ \subA' a2' ->
-    popFromScope des t2 subB b2 $ \subB' b2' ->
-    subenvPlus (select SMerge des) subA' subB' $ \subAB sAB_A sAB_B _ ->
+    subenvPlus (select SMerge des) subA subB $ \subAB sAB_A sAB_B _ ->
     subenvPlus (select SMerge des) subAB subE $ \subOut _ _ plus_AB_E ->
     let tCaseRet = STPair (tTup (d2e (subList (select SMerge des) subAB))) (STEither (d2 t1) (d2 t2)) in
     Ret (e0 `BPush`
@@ -912,12 +896,12 @@ drev des = \case
                                              &. #tl (d2ace (select SAccum des)))
                                             (#d :++: #ta0 :++: #tl)
                                             (#d :++: (#ta0 :++: #prea0) :++: #recon :++: #binds :++: #tl))
-                                  a2') $
+                                  a2) $
                     EPair ext
                      (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $
-                        EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ))
+                        EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ))
                      (EInl ext (d2 t2)
-                       (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ))))
+                       (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ))))
               (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds b0) subtapeB) IZ
                in letBinds rebinds $
                     ELet ext
@@ -931,12 +915,12 @@ drev des = \case
                                              &. #tl (d2ace (select SAccum des)))
                                             (#d :++: #tb0 :++: #tl)
                                             (#d :++: (#tb0 :++: #preb0) :++: #recon :++: #binds :++: #tl))
-                                  b2') $
+                                  b2) $
                     EPair ext
                       (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $
-                         EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB'))) (d2 t2)) IZ))
+                         EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))
                       (EInr ext (d2 t1)
-                        (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB'))) (d2 t2)) IZ))))) $
+                        (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))))) $
          ELet ext
            (ELet ext (EInr ext STNil (ESnd ext (EVar ext tCaseRet IZ))) $
               weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $
-- 
cgit v1.2.3-70-g09d2