summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-23 12:11:45 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-23 12:11:45 +0100
commit883bfad4bf0c5f383c986ebbb9e9dab61e3c2098 (patch)
tree52bc4a336e26ce84fefe7d1a4b109ea6501a0987
parent84f6845803511e24770fbf1dffc6a9a007371edf (diff)
Use accum storage for Case too
-rw-r--r--src/CHAD.hs38
1 files changed, 11 insertions, 27 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 35c3dda..5ee13fa 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -500,20 +500,6 @@ assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl
assertSubenvEmpty SETop = Refl
assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty"
-popFromScope
- :: Descr env0 sto
- -> STy a
- -> Subenv (Select (a : env0) ("merge" : sto) "merge") envSub
- -> Ex env (Tup (D2E envSub))
- -> (forall envSub'.
- Subenv (Select env0 sto "merge") envSub'
- -> Ex env (TPair (Tup (D2E envSub')) (D2 a))
- -> r)
- -> r
-popFromScope _ ty sub e k = case sub of
- SEYes sub' -> k sub' e
- SENo sub' -> k sub' $ EPair ext e (zero ty)
-
--------------------------------- ACCUMULATORS ---------------------------------
@@ -867,13 +853,13 @@ drev des = \case
(EError (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl")
(weakenExpr (WCopy (wSinks' @[_,_])) e2)))
- ECase{} | chcCaseArrayAccum ?config -> error "chcCaseArrayAccum unsupported"
-
ECase _ e (a :: Ex _ t) b
| STEither t1 t2 <- typeOf e
, Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des e
- , Ret (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drev (des `DPush` (t1, SMerge)) a
- , Ret (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drev (des `DPush` (t2, SMerge)) b
+ , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge
+ , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge
+ , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des t1 storage1 a
+ , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des t2 storage2 b
, Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
, Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
, let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA)
@@ -884,9 +870,7 @@ drev des = \case
, let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0
, let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0
->
- popFromScope des t1 subA a2 $ \subA' a2' ->
- popFromScope des t2 subB b2 $ \subB' b2' ->
- subenvPlus (select SMerge des) subA' subB' $ \subAB sAB_A sAB_B _ ->
+ subenvPlus (select SMerge des) subA subB $ \subAB sAB_A sAB_B _ ->
subenvPlus (select SMerge des) subAB subE $ \subOut _ _ plus_AB_E ->
let tCaseRet = STPair (tTup (d2e (subList (select SMerge des) subAB))) (STEither (d2 t1) (d2 t2)) in
Ret (e0 `BPush`
@@ -912,12 +896,12 @@ drev des = \case
&. #tl (d2ace (select SAccum des)))
(#d :++: #ta0 :++: #tl)
(#d :++: (#ta0 :++: #prea0) :++: #recon :++: #binds :++: #tl))
- a2') $
+ a2) $
EPair ext
(expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $
- EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ))
+ EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ))
(EInl ext (d2 t2)
- (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ))))
+ (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ))))
(let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds b0) subtapeB) IZ
in letBinds rebinds $
ELet ext
@@ -931,12 +915,12 @@ drev des = \case
&. #tl (d2ace (select SAccum des)))
(#d :++: #tb0 :++: #tl)
(#d :++: (#tb0 :++: #preb0) :++: #recon :++: #binds :++: #tl))
- b2') $
+ b2) $
EPair ext
(expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $
- EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB'))) (d2 t2)) IZ))
+ EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))
(EInr ext (d2 t1)
- (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB'))) (d2 t2)) IZ))))) $
+ (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))))) $
ELet ext
(ELet ext (EInr ext STNil (ESnd ext (EVar ext tCaseRet IZ))) $
weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $