summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2025-04-06 17:07:22 +0200
committerTom Smeding <t.j.smeding@uu.nl>2025-04-06 17:07:22 +0200
commit0a9e6dfc1accf9dc0254f0c720f633dab6e71f42 (patch)
tree754eaeecf01e554d7ad904c27a9b665879441ca0 /src/CHAD.hs
parentb6c1d3a9d0651aa25ea5f03d514a214a3347f7a4 (diff)
Populate accumMapHEADmaster
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs174
1 files changed, 103 insertions, 71 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 6a4d5f5..8db0410 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -36,7 +36,7 @@ import Data.Type.Bool (If)
import Data.Type.Equality (type (==))
import GHC.Stack (HasCallStack)
-import Analysis.Identity (ValId(..))
+import Analysis.Identity (ValId(..), validSplitEither)
import AST
import AST.Bindings
import AST.Count
@@ -294,18 +294,18 @@ data Idx2 env sto t
| Idx2Di (Idx (Select env sto "discr") t)
conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t
-conv2Idx (DPush _ (_, SAccum)) IZ = Idx2Ac IZ
-conv2Idx (DPush _ (_, SMerge)) IZ = Idx2Me IZ
-conv2Idx (DPush _ (_, SDiscr)) IZ = Idx2Di IZ
-conv2Idx (DPush des (_, SAccum)) (IS i) =
+conv2Idx (DPush _ (_, _, SAccum)) IZ = Idx2Ac IZ
+conv2Idx (DPush _ (_, _, SMerge)) IZ = Idx2Me IZ
+conv2Idx (DPush _ (_, _, SDiscr)) IZ = Idx2Di IZ
+conv2Idx (DPush des (_, _, SAccum)) (IS i) =
case conv2Idx des i of Idx2Ac j -> Idx2Ac (IS j)
Idx2Me j -> Idx2Me j
Idx2Di j -> Idx2Di j
-conv2Idx (DPush des (_, SMerge)) (IS i) =
+conv2Idx (DPush des (_, _, SMerge)) (IS i) =
case conv2Idx des i of Idx2Ac j -> Idx2Ac j
Idx2Me j -> Idx2Me (IS j)
Idx2Di j -> Idx2Di j
-conv2Idx (DPush des (_, SDiscr)) (IS i) =
+conv2Idx (DPush des (_, _, SDiscr)) (IS i) =
case conv2Idx des i of Idx2Ac j -> Idx2Ac j
Idx2Me j -> Idx2Me j
Idx2Di j -> Idx2Di (IS j)
@@ -376,6 +376,10 @@ assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty"
--------------------------------- ACCUMULATORS ---------------------------------
+fromArrayValId :: Maybe (ValId t) -> Maybe Int
+fromArrayValId (Just (VIArr i _)) = Just i
+fromArrayValId _ = Nothing
+
accumPromote :: forall dt env sto proxy r.
proxy dt
-> Descr env sto
@@ -384,8 +388,7 @@ accumPromote :: forall dt env sto proxy r.
=> Descr env stoRepl
-- ^ A revised environment description that switches
-- arrays (used in the OccEnv) that are currently on
- -- "merge" storage, to "accum" storage. Any other "merge"
- -- entries are deleted.
+ -- "merge" storage, to "accum" storage.
-> SList STy envPro
-- ^ New entries on top of the original dual environment,
-- that house the accumulators for the promoted arrays in
@@ -393,6 +396,12 @@ accumPromote :: forall dt env sto proxy r.
-> Subenv (Select env sto "merge") envPro
-- ^ The promoted entries were merge entries in the
-- original environment.
+ -> Subenv (D2AcE (Select env stoRepl "accum")) (D2AcE (Select env sto "accum"))
+ -- ^ All entries that were accumulators are still
+ -- accumulators.
+ -> VarMap Int (D2AcE (Select env stoRepl "accum"))
+ -- ^ Accumulator map for _only_ the the newly allocated
+ -- accumulators.
-> (forall shbinds.
SList STy shbinds
-> (D2 dt : Append shbinds (D2AcE (Select env stoRepl "accum")))
@@ -402,57 +411,70 @@ accumPromote :: forall dt env sto proxy r.
-- extended with some accumulators.
-> r)
-> r
-accumPromote _ DTop k = k DTop SNil SETop (\_ -> WId)
-accumPromote pdty (descr `DPush` (t :: STy t, sto)) k =
- accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub wf ->
- case sto of
- -- Accumulators are left as-is
- SAccum ->
- k (storepl `DPush` (t, SAccum))
- envpro
- prosub
- (\shbinds ->
- autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum descr)))
- (#acc :++: (#pro :++: #d :++: #shb :++: #tl))
- (#pro :++: #d :++: #shb :++: #acc :++: #tl)
- .> WCopy (wf shbinds)
- .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum storepl)))
- (#d :++: #shb :++: #acc :++: #tl)
- (#acc :++: (#d :++: #shb :++: #tl)))
-
- SMerge -> case t of
- -- Discrete values are left as-is
- _ | isDiscrete t ->
- k (storepl `DPush` (t, SDiscr))
- envpro
- (SENo prosub)
- wf
-
- -- Values with "merge" storage are promoted to an accumulator in envPro
- _ ->
- k (storepl `DPush` (t, SAccum))
- (t `SCons` envpro)
- (SEYes prosub)
- (\(shbinds :: SList _ shbinds) ->
- let shbindsC = slistMap (\_ -> Const ()) shbinds
- in
- -- wf:
- -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
- -- WCopy wf:
- -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
- -- WPICK: ^ THESE TWO ||
- -- goal: | ARE EQUAL ||
- -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
- WCopy (wf shbinds)
- .> WPick @(TAccum t) @(D2 dt : shbinds) (Const () `SCons` shbindsC)
- (WId @(D2AcE (Select env1 stoRepl "accum"))))
-
- -- Discrete values are left as-is, nothing to do
- SDiscr ->
- k (storepl `DPush` (t, SDiscr))
+accumPromote _ DTop k = k DTop SNil SETop SETop VarMap.empty (\_ -> WId)
+accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of
+ -- Accumulators are left as-is
+ SAccum ->
+ accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf ->
+ k (storepl `DPush` (t, vid, SAccum))
+ envpro
+ prosub
+ (SEYes accrevsub)
+ (VarMap.sink1 accumMap)
+ (\shbinds ->
+ autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum descr)))
+ (#acc :++: (#pro :++: #d :++: #shb :++: #tl))
+ (#pro :++: #d :++: #shb :++: #acc :++: #tl)
+ .> WCopy (wf shbinds)
+ .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum storepl)))
+ (#d :++: #shb :++: #acc :++: #tl)
+ (#acc :++: (#d :++: #shb :++: #tl)))
+
+ SMerge -> case t of
+ -- Discrete values are left as-is
+ _ | isDiscrete t ->
+ accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap' wf ->
+ k (storepl `DPush` (t, vid, SDiscr))
envpro
- prosub
+ (SENo prosub)
+ accrevsub
+ accumMap'
wf
+
+ -- Values with "merge" storage are promoted to an accumulator in envPro
+ _ ->
+ accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf ->
+ k (storepl `DPush` (t, vid, SAccum))
+ (t `SCons` envpro)
+ (SEYes prosub)
+ (SENo accrevsub)
+ (let accumMap' = VarMap.sink1 accumMap
+ in case fromArrayValId vid of
+ Just i -> VarMap.insert i (STAccum t) IZ accumMap'
+ Nothing -> accumMap')
+ (\(shbinds :: SList _ shbinds) ->
+ let shbindsC = slistMap (\_ -> Const ()) shbinds
+ in
+ -- wf:
+ -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
+ -- WCopy wf:
+ -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
+ -- WPICK: ^ THESE TWO ||
+ -- goal: | ARE EQUAL ||
+ -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
+ WCopy (wf shbinds)
+ .> WPick @(TAccum t) @(D2 dt : shbinds) (Const () `SCons` shbindsC)
+ (WId @(D2AcE (Select env1 stoRepl "accum"))))
+
+ -- Discrete values are left as-is, nothing to do
+ SDiscr ->
+ accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf ->
+ k (storepl `DPush` (t, vid, SDiscr))
+ envpro
+ prosub
+ accrevsub
+ accumMap
+ wf
where
isDiscrete :: STy t' -> Bool
isDiscrete = \case
@@ -561,7 +583,7 @@ freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) =
drev :: forall env sto t.
(?config :: CHADConfig)
- => Descr env sto -> VarMap Int env
+ => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
-> Expr ValId env t -> Ret env sto t
drev des accumMap = \case
EVar _ t i ->
@@ -590,7 +612,7 @@ drev des accumMap = \case
ELet _ (rhs :: Expr _ _ a) body
| Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des accumMap rhs
, ChosenStorage storage <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge
- , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des accumMap (typeOf rhs) storage body
+ , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) body
, let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0
, Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env)
, Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) ->
@@ -687,8 +709,9 @@ drev des accumMap = \case
, Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap e
, ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge
, ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge
- , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des accumMap t1 storage1 a
- , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des accumMap t2 storage2 b
+ , let (bindids1, bindids2) = validSplitEither (extOf e)
+ , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des accumMap t1 storage1 bindids1 a
+ , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des accumMap t2 storage2 bindids2 b
, Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
, Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
, let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA)
@@ -819,8 +842,9 @@ drev des accumMap = \case
deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') ->
let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in
subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
- accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro ->
- case drev (prodes `DPush` (shty, SDiscr)) (VarMap.sink1 (VarMap.subMap usedSub accumMap)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 ->
+ accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro ->
+ let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in
+ case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 ->
case assertSubenvEmpty sub of { Refl ->
let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in
let collectexpr = bindingsCollect e0 subtapeE in
@@ -1055,17 +1079,22 @@ deriving instance Show (RetScoped env0 sto a s t)
drevScoped :: forall a s env sto t.
(?config :: CHADConfig)
- => Descr env sto -> VarMap Int env -> STy a -> Storage s
+ => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum"))
+ -> STy a -> Storage s -> Maybe (ValId a)
-> Expr ValId (a : env) t
-> RetScoped env sto a s t
-drevScoped des accumMap argty argsto expr
- | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argsto)) (VarMap.sink1 accumMap) expr
- = case argsto of
- SMerge ->
+drevScoped des accumMap argty argsto argids expr = case argsto of
+ SMerge
+ | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr ->
case sub of
SEYes sub' -> RetScoped e0 subtape e1 sub' e2
SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (EZero ext argty))
- SAccum ->
+
+ SAccum
+ | let accumMap' = case argids of
+ Just (VIArr i _) -> VarMap.insert i (STAccum argty) IZ (VarMap.sink1 accumMap)
+ _ -> VarMap.sink1 accumMap
+ , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' expr ->
RetScoped e0 subtape e1 sub $
EWith ext argty (EZero ext argty) $
weakenExpr (autoWeak (#d (auto1 @(D2 t))
@@ -1075,4 +1104,7 @@ drevScoped des accumMap argty argsto expr
(#d :++: #body :++: #ac :++: #tl)
(#ac :++: #d :++: #body :++: #tl))
e2
- SDiscr -> RetScoped e0 subtape e1 sub e2
+
+ SDiscr
+ | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr ->
+ RetScoped e0 subtape e1 sub e2