From afd2214b2039390e440d9ab82dfa97077b76d827 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 21 Apr 2025 21:58:17 +0200 Subject: splitLets: Don't split if unnecessary --- src/AST/SplitLets.hs | 47 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 14 deletions(-) (limited to 'src') diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index 1de417c..dcba1ad 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -70,7 +70,7 @@ splitLets' = \sub -> \case split1 :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> STy bind -> Ex (bind : env) t -> Ex (bind : env') t split1 sub (tbind :: STy bind) body = - let (ptrs, bs) = split (EVar ext tbind IZ) tbind + let (ptrs, bs) = split tbind in letBinds bs $ splitLets' (\cases _ IZ w -> subPointers ptrs w t (IS i) w -> sub t i (WPop @bind (wPops (bindingsBinds bs) w))) @@ -80,19 +80,24 @@ splitLets' = \sub -> \case (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> STy bind1 -> STy bind2 -> Ex (bind2 : bind1 : env) t -> Ex (bind2 : bind1 : env') t split2 sub tbind1 tbind2 body = - let (ptrs1, bs1) = split (EVar ext tbind1 (IS IZ)) tbind1 - (ptrs2, bs2) = split (EVar ext tbind2 IZ) tbind2 + let (ptrs1', bs1') = split @env' tbind1 + bs1 = fst (weakenBindings weakenExpr WSink bs1') + (ptrs2, bs2) = split @(bind1 : env') tbind2 in letBinds bs1 $ letBinds (fst (weakenBindings weakenExpr (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $ splitLets' (\cases _ IZ w -> subPointers ptrs2 (w .> wCopies (bindingsBinds bs2) (wSinks @(bind2 : bind1 : env') (bindingsBinds bs1))) - _ (IS IZ) w -> subPointers ptrs1 (w .> wSinks (bindingsBinds bs2)) + _ (IS IZ) w -> subPointers ptrs1' (w .> wSinks (bindingsBinds bs2) .> wCopies (bindingsBinds bs1) (WSink @bind2 @(bind1 : env'))) t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w))))) body type family Split t where - Split TNil = '[] - Split (TPair a b) = Append (Split b) (Split a) - Split t = '[t] + Split (TPair a b) = SplitRec (TPair a b) + Split _ = '[] + +type family SplitRec t where + SplitRec TNil = '[] + SplitRec (TPair a b) = Append (SplitRec b) (SplitRec a) + SplitRec t = '[t] data Pointers env t where Point :: STy t -> Idx env t -> Pointers env t @@ -106,14 +111,28 @@ subPointers PNil _ = ENil ext subPointers (PPair a b) w = EPair ext (subPointers a w) (subPointers b w) subPointers (PWeak w' p) w = subPointers p (w .> w') -split :: forall env t. Ex env t -> STy t - -> (Pointers (Append (Split t) env) t, Bindings Ex env (Split t)) -split i = \case +split :: forall env t. STy t + -> (Pointers (Append (Split t) (t : env)) t, Bindings Ex (t : env) (Split t)) +split typ = case typ of + STPair{} -> splitRec (EVar ext typ IZ) typ + STNil -> other + STEither{} -> other + STMaybe{} -> other + STArr{} -> other + STScal{} -> other + STAccum{} -> other + where + other :: (Pointers (t : env) t, Bindings Ex (t : env) '[]) + other = (Point typ IZ, BTop) + +splitRec :: forall env t. Ex env t -> STy t + -> (Pointers (Append (SplitRec t) env) t, Bindings Ex env (SplitRec t)) +splitRec rhs = \case STNil -> (PNil, BTop) STPair (a :: STy a) (b :: STy b) - | Refl <- lemAppendAssoc @(Split b) @(Split a) @env -> - let (p1, bs1) = split (EFst ext i) a - (p2, bs2) = split (ESnd ext (sinkWithBindings bs1 `weakenExpr` i)) b + | Refl <- lemAppendAssoc @(SplitRec b) @(SplitRec a) @env -> + let (p1, bs1) = splitRec (EFst ext rhs) a + (p2, bs2) = splitRec (ESnd ext (sinkWithBindings bs1 `weakenExpr` rhs)) b in (PPair (PWeak (sinkWithBindings bs2) p1) p2, bconcat bs1 bs2) t@STEither{} -> other t t@STMaybe{} -> other t @@ -122,4 +141,4 @@ split i = \case t@STAccum{} -> other t where other :: STy t -> (Pointers (t : env) t, Bindings Ex env '[t]) - other t = (Point t IZ, BPush BTop (t, i)) + other t = (Point t IZ, BPush BTop (t, rhs)) -- cgit v1.2.3-70-g09d2