summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2025-04-05 18:26:31 +0200
committerTom Smeding <t.j.smeding@uu.nl>2025-04-05 18:26:31 +0200
commitb6c1d3a9d0651aa25ea5f03d514a214a3347f7a4 (patch)
tree49764a3f3b78bb2848cdc871a1217f7ae1a04120
parentebe8d8219e12fc9ac7ca58b367bc91e640ed0556 (diff)
Split product lets before chad
-rw-r--r--chad-fast.cabal1
-rw-r--r--src/AST/Bindings.hs2
-rw-r--r--src/AST/SplitLets.hs125
-rw-r--r--src/AST/Weaken.hs4
-rw-r--r--src/CHAD/Top.hs7
5 files changed, 136 insertions, 3 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index ebd5a48..1aadc6b 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -18,6 +18,7 @@ library
AST.Count
AST.Env
AST.Pretty
+ AST.SplitLets
AST.Types
AST.UnMonoid
AST.Weaken
diff --git a/src/AST/Bindings.hs b/src/AST/Bindings.hs
index 2e63b42..3d99afe 100644
--- a/src/AST/Bindings.hs
+++ b/src/AST/Bindings.hs
@@ -45,7 +45,7 @@ weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env'
weakenOver SNil w = w
weakenOver (SCons _ ts) w = WCopy (weakenOver ts w)
-sinkWithBindings :: Bindings f env binds -> env' :> Append binds env'
+sinkWithBindings :: forall env' env binds f. Bindings f env binds -> env' :> Append binds env'
sinkWithBindings BTop = WId
sinkWithBindings (BPush b _) = WSink .> sinkWithBindings b
diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs
new file mode 100644
index 0000000..1de417c
--- /dev/null
+++ b/src/AST/SplitLets.hs
@@ -0,0 +1,125 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+module AST.SplitLets (splitLets) where
+
+import Data.Type.Equality
+
+import AST
+import AST.Bindings
+import Lemmas
+
+
+splitLets :: Ex env t -> Ex env t
+splitLets = splitLets' (\t i w -> EVar ext t (w @> i))
+
+splitLets' :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> Ex env t -> Ex env' t
+splitLets' = \sub -> \case
+ EVar _ t i -> sub t i WId
+ ELet _ (rhs :: Ex env t1) body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body)
+ ECase x e a b ->
+ let STEither t1 t2 = typeOf e
+ in ECase x (splitLets' sub e) (split1 sub t1 a) (split1 sub t2 b)
+ EMaybe x a b e ->
+ let STMaybe t1 = typeOf e
+ in EMaybe x (splitLets' sub a) (split1 sub t1 b) (splitLets' sub e)
+ EFold1Inner x cm a b c ->
+ let STArr _ t1 = typeOf c
+ in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c)
+
+ EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b)
+ EFst x e -> EFst x (splitLets' sub e)
+ ESnd x e -> ESnd x (splitLets' sub e)
+ ENil x -> ENil x
+ EInl x t e -> EInl x t (splitLets' sub e)
+ EInr x t e -> EInr x t (splitLets' sub e)
+ ENothing x t -> ENothing x t
+ EJust x e -> EJust x (splitLets' sub e)
+ EConstArr x n t a -> EConstArr x n t a
+ EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b)
+ ESum1Inner x e -> ESum1Inner x (splitLets' sub e)
+ EUnit x e -> EUnit x (splitLets' sub e)
+ EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b)
+ EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e)
+ EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e)
+ EConst x t v -> EConst x t v
+ EIdx0 x e -> EIdx0 x (splitLets' sub e)
+ EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b)
+ EIdx x e es -> EIdx x (splitLets' sub e) (splitLets' sub es)
+ EShape x e -> EShape x (splitLets' sub e)
+ EOp x op e -> EOp x op (splitLets' sub e)
+ ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2)
+ EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2)
+ EAccum x t p e1 e2 e3 -> EAccum x t p (splitLets' sub e1) (splitLets' sub e2) (splitLets' sub e3)
+ EZero x t -> EZero x t
+ EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b)
+ EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b)
+ EError x t s -> EError x t s
+ where
+ sinkF :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a)
+ -> STy t -> Idx (b : env) t -> (b : env') :> env3 -> Ex env3 t
+ sinkF _ t IZ w = EVar ext t (w @> IZ)
+ sinkF f t (IS i) w = f t i (w .> WSink)
+
+ split1 :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a)
+ -> STy bind -> Ex (bind : env) t -> Ex (bind : env') t
+ split1 sub (tbind :: STy bind) body =
+ let (ptrs, bs) = split (EVar ext tbind IZ) tbind
+ in letBinds bs $
+ splitLets' (\cases _ IZ w -> subPointers ptrs w
+ t (IS i) w -> sub t i (WPop @bind (wPops (bindingsBinds bs) w)))
+ body
+
+ split2 :: forall bind1 bind2 env' env t.
+ (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a)
+ -> STy bind1 -> STy bind2 -> Ex (bind2 : bind1 : env) t -> Ex (bind2 : bind1 : env') t
+ split2 sub tbind1 tbind2 body =
+ let (ptrs1, bs1) = split (EVar ext tbind1 (IS IZ)) tbind1
+ (ptrs2, bs2) = split (EVar ext tbind2 IZ) tbind2
+ in letBinds bs1 $
+ letBinds (fst (weakenBindings weakenExpr (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $
+ splitLets' (\cases _ IZ w -> subPointers ptrs2 (w .> wCopies (bindingsBinds bs2) (wSinks @(bind2 : bind1 : env') (bindingsBinds bs1)))
+ _ (IS IZ) w -> subPointers ptrs1 (w .> wSinks (bindingsBinds bs2))
+ t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w)))))
+ body
+
+type family Split t where
+ Split TNil = '[]
+ Split (TPair a b) = Append (Split b) (Split a)
+ Split t = '[t]
+
+data Pointers env t where
+ Point :: STy t -> Idx env t -> Pointers env t
+ PNil :: Pointers env TNil
+ PPair :: Pointers env a -> Pointers env b -> Pointers env (TPair a b)
+ PWeak :: env' :> env -> Pointers env' t -> Pointers env t
+
+subPointers :: Pointers env t -> env :> env' -> Ex env' t
+subPointers (Point t i) w = EVar ext t (w @> i)
+subPointers PNil _ = ENil ext
+subPointers (PPair a b) w = EPair ext (subPointers a w) (subPointers b w)
+subPointers (PWeak w' p) w = subPointers p (w .> w')
+
+split :: forall env t. Ex env t -> STy t
+ -> (Pointers (Append (Split t) env) t, Bindings Ex env (Split t))
+split i = \case
+ STNil -> (PNil, BTop)
+ STPair (a :: STy a) (b :: STy b)
+ | Refl <- lemAppendAssoc @(Split b) @(Split a) @env ->
+ let (p1, bs1) = split (EFst ext i) a
+ (p2, bs2) = split (ESnd ext (sinkWithBindings bs1 `weakenExpr` i)) b
+ in (PPair (PWeak (sinkWithBindings bs2) p1) p2, bconcat bs1 bs2)
+ t@STEither{} -> other t
+ t@STMaybe{} -> other t
+ t@STArr{} -> other t
+ t@STScal{} -> other t
+ t@STAccum{} -> other t
+ where
+ other :: STy t -> (Pointers (t : env) t, Bindings Ex env '[t])
+ other t = (Point t IZ, BPush BTop (t, i))
diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs
index bd2c244..d882e28 100644
--- a/src/AST/Weaken.hs
+++ b/src/AST/Weaken.hs
@@ -126,3 +126,7 @@ wCopies bs w =
wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env
wRaiseAbove SNil _ = WClosed
wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env)
+
+wPops :: SList f bs -> Append bs env1 :> env2 -> env1 :> env2
+wPops SNil w = w
+wPops (_ `SCons` bs) w = wPops bs (WPop w)
diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs
index ced7550..ea7449d 100644
--- a/src/CHAD/Top.hs
+++ b/src/CHAD/Top.hs
@@ -12,6 +12,7 @@ module CHAD.Top where
import Analysis.Identity
import AST
+import AST.SplitLets
import AST.Weaken.Auto
import CHAD
import CHAD.Accum
@@ -87,7 +88,7 @@ chad config env (term :: Ex env t)
&. #tl (d1e env))
(#d :++: #acenv :++: #tl)
(#acenv :++: #d :++: #tl)) $
- freezeRet descr (drev descr VarMap.empty (identityAnalysis env term))) $
+ freezeRet descr (drev descr VarMap.empty term')) $
EPair ext (EFst ext (EFst ext (EVar ext tvar IZ)))
(reassembleD2E descr (EPair ext (ESnd ext (EVar ext tvar IZ))
(ESnd ext (EFst ext (EVar ext tvar IZ)))))
@@ -95,7 +96,9 @@ chad config env (term :: Ex env t)
| False <- chcArgArrayAccum config
, Refl <- mergeEnvNoAccum env
, Refl <- mergeEnvOnlyMerge env
- = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty (identityAnalysis env term))
+ = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty term')
+ where
+ term' = identityAnalysis env (splitLets term)
chad' :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))
chad' config env term