diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2025-04-06 17:07:22 +0200 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-04-06 17:07:22 +0200 |
commit | 0a9e6dfc1accf9dc0254f0c720f633dab6e71f42 (patch) | |
tree | 754eaeecf01e554d7ad904c27a9b665879441ca0 /src/CHAD.hs | |
parent | b6c1d3a9d0651aa25ea5f03d514a214a3347f7a4 (diff) |
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 174 |
1 files changed, 103 insertions, 71 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index 6a4d5f5..8db0410 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -36,7 +36,7 @@ import Data.Type.Bool (If) import Data.Type.Equality (type (==)) import GHC.Stack (HasCallStack) -import Analysis.Identity (ValId(..)) +import Analysis.Identity (ValId(..), validSplitEither) import AST import AST.Bindings import AST.Count @@ -294,18 +294,18 @@ data Idx2 env sto t | Idx2Di (Idx (Select env sto "discr") t) conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t -conv2Idx (DPush _ (_, SAccum)) IZ = Idx2Ac IZ -conv2Idx (DPush _ (_, SMerge)) IZ = Idx2Me IZ -conv2Idx (DPush _ (_, SDiscr)) IZ = Idx2Di IZ -conv2Idx (DPush des (_, SAccum)) (IS i) = +conv2Idx (DPush _ (_, _, SAccum)) IZ = Idx2Ac IZ +conv2Idx (DPush _ (_, _, SMerge)) IZ = Idx2Me IZ +conv2Idx (DPush _ (_, _, SDiscr)) IZ = Idx2Di IZ +conv2Idx (DPush des (_, _, SAccum)) (IS i) = case conv2Idx des i of Idx2Ac j -> Idx2Ac (IS j) Idx2Me j -> Idx2Me j Idx2Di j -> Idx2Di j -conv2Idx (DPush des (_, SMerge)) (IS i) = +conv2Idx (DPush des (_, _, SMerge)) (IS i) = case conv2Idx des i of Idx2Ac j -> Idx2Ac j Idx2Me j -> Idx2Me (IS j) Idx2Di j -> Idx2Di j -conv2Idx (DPush des (_, SDiscr)) (IS i) = +conv2Idx (DPush des (_, _, SDiscr)) (IS i) = case conv2Idx des i of Idx2Ac j -> Idx2Ac j Idx2Me j -> Idx2Me j Idx2Di j -> Idx2Di (IS j) @@ -376,6 +376,10 @@ assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty" --------------------------------- ACCUMULATORS --------------------------------- +fromArrayValId :: Maybe (ValId t) -> Maybe Int +fromArrayValId (Just (VIArr i _)) = Just i +fromArrayValId _ = Nothing + accumPromote :: forall dt env sto proxy r. proxy dt -> Descr env sto @@ -384,8 +388,7 @@ accumPromote :: forall dt env sto proxy r. => Descr env stoRepl -- ^ A revised environment description that switches -- arrays (used in the OccEnv) that are currently on - -- "merge" storage, to "accum" storage. Any other "merge" - -- entries are deleted. + -- "merge" storage, to "accum" storage. -> SList STy envPro -- ^ New entries on top of the original dual environment, -- that house the accumulators for the promoted arrays in @@ -393,6 +396,12 @@ accumPromote :: forall dt env sto proxy r. -> Subenv (Select env sto "merge") envPro -- ^ The promoted entries were merge entries in the -- original environment. + -> Subenv (D2AcE (Select env stoRepl "accum")) (D2AcE (Select env sto "accum")) + -- ^ All entries that were accumulators are still + -- accumulators. + -> VarMap Int (D2AcE (Select env stoRepl "accum")) + -- ^ Accumulator map for _only_ the the newly allocated + -- accumulators. -> (forall shbinds. SList STy shbinds -> (D2 dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) @@ -402,57 +411,70 @@ accumPromote :: forall dt env sto proxy r. -- extended with some accumulators. -> r) -> r -accumPromote _ DTop k = k DTop SNil SETop (\_ -> WId) -accumPromote pdty (descr `DPush` (t :: STy t, sto)) k = - accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub wf -> - case sto of - -- Accumulators are left as-is - SAccum -> - k (storepl `DPush` (t, SAccum)) - envpro - prosub - (\shbinds -> - autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum descr))) - (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) - (#pro :++: #d :++: #shb :++: #acc :++: #tl) - .> WCopy (wf shbinds) - .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum storepl))) - (#d :++: #shb :++: #acc :++: #tl) - (#acc :++: (#d :++: #shb :++: #tl))) - - SMerge -> case t of - -- Discrete values are left as-is - _ | isDiscrete t -> - k (storepl `DPush` (t, SDiscr)) - envpro - (SENo prosub) - wf - - -- Values with "merge" storage are promoted to an accumulator in envPro - _ -> - k (storepl `DPush` (t, SAccum)) - (t `SCons` envpro) - (SEYes prosub) - (\(shbinds :: SList _ shbinds) -> - let shbindsC = slistMap (\_ -> Const ()) shbinds - in - -- wf: - -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) - -- WCopy wf: - -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) - -- WPICK: ^ THESE TWO || - -- goal: | ARE EQUAL || - -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) - WCopy (wf shbinds) - .> WPick @(TAccum t) @(D2 dt : shbinds) (Const () `SCons` shbindsC) - (WId @(D2AcE (Select env1 stoRepl "accum")))) - - -- Discrete values are left as-is, nothing to do - SDiscr -> - k (storepl `DPush` (t, SDiscr)) +accumPromote _ DTop k = k DTop SNil SETop SETop VarMap.empty (\_ -> WId) +accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of + -- Accumulators are left as-is + SAccum -> + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> + k (storepl `DPush` (t, vid, SAccum)) + envpro + prosub + (SEYes accrevsub) + (VarMap.sink1 accumMap) + (\shbinds -> + autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum descr))) + (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) + (#pro :++: #d :++: #shb :++: #acc :++: #tl) + .> WCopy (wf shbinds) + .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum storepl))) + (#d :++: #shb :++: #acc :++: #tl) + (#acc :++: (#d :++: #shb :++: #tl))) + + SMerge -> case t of + -- Discrete values are left as-is + _ | isDiscrete t -> + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap' wf -> + k (storepl `DPush` (t, vid, SDiscr)) envpro - prosub + (SENo prosub) + accrevsub + accumMap' wf + + -- Values with "merge" storage are promoted to an accumulator in envPro + _ -> + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> + k (storepl `DPush` (t, vid, SAccum)) + (t `SCons` envpro) + (SEYes prosub) + (SENo accrevsub) + (let accumMap' = VarMap.sink1 accumMap + in case fromArrayValId vid of + Just i -> VarMap.insert i (STAccum t) IZ accumMap' + Nothing -> accumMap') + (\(shbinds :: SList _ shbinds) -> + let shbindsC = slistMap (\_ -> Const ()) shbinds + in + -- wf: + -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + -- WCopy wf: + -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + -- WPICK: ^ THESE TWO || + -- goal: | ARE EQUAL || + -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + WCopy (wf shbinds) + .> WPick @(TAccum t) @(D2 dt : shbinds) (Const () `SCons` shbindsC) + (WId @(D2AcE (Select env1 stoRepl "accum")))) + + -- Discrete values are left as-is, nothing to do + SDiscr -> + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> + k (storepl `DPush` (t, vid, SDiscr)) + envpro + prosub + accrevsub + accumMap + wf where isDiscrete :: STy t' -> Bool isDiscrete = \case @@ -561,7 +583,7 @@ freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) = drev :: forall env sto t. (?config :: CHADConfig) - => Descr env sto -> VarMap Int env + => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) -> Expr ValId env t -> Ret env sto t drev des accumMap = \case EVar _ t i -> @@ -590,7 +612,7 @@ drev des accumMap = \case ELet _ (rhs :: Expr _ _ a) body | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des accumMap rhs , ChosenStorage storage <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge - , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des accumMap (typeOf rhs) storage body + , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) body , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) -> @@ -687,8 +709,9 @@ drev des accumMap = \case , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap e , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge - , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des accumMap t1 storage1 a - , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des accumMap t2 storage2 b + , let (bindids1, bindids2) = validSplitEither (extOf e) + , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des accumMap t1 storage1 bindids1 a + , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des accumMap t2 storage2 bindids2 b , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) , let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA) @@ -819,8 +842,9 @@ drev des accumMap = \case deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') -> let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> - accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro -> - case drev (prodes `DPush` (shty, SDiscr)) (VarMap.sink1 (VarMap.subMap usedSub accumMap)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> + accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> + let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in + case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> case assertSubenvEmpty sub of { Refl -> let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in let collectexpr = bindingsCollect e0 subtapeE in @@ -1055,17 +1079,22 @@ deriving instance Show (RetScoped env0 sto a s t) drevScoped :: forall a s env sto t. (?config :: CHADConfig) - => Descr env sto -> VarMap Int env -> STy a -> Storage s + => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) + -> STy a -> Storage s -> Maybe (ValId a) -> Expr ValId (a : env) t -> RetScoped env sto a s t -drevScoped des accumMap argty argsto expr - | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argsto)) (VarMap.sink1 accumMap) expr - = case argsto of - SMerge -> +drevScoped des accumMap argty argsto argids expr = case argsto of + SMerge + | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr -> case sub of SEYes sub' -> RetScoped e0 subtape e1 sub' e2 SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (EZero ext argty)) - SAccum -> + + SAccum + | let accumMap' = case argids of + Just (VIArr i _) -> VarMap.insert i (STAccum argty) IZ (VarMap.sink1 accumMap) + _ -> VarMap.sink1 accumMap + , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' expr -> RetScoped e0 subtape e1 sub $ EWith ext argty (EZero ext argty) $ weakenExpr (autoWeak (#d (auto1 @(D2 t)) @@ -1075,4 +1104,7 @@ drevScoped des accumMap argty argsto expr (#d :++: #body :++: #ac :++: #tl) (#ac :++: #d :++: #body :++: #tl)) e2 - SDiscr -> RetScoped e0 subtape e1 sub e2 + + SDiscr + | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr -> + RetScoped e0 subtape e1 sub e2 |