diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:49:45 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:50:25 +0100 |
| commit | 174af2ba568de66e0d890825b8bda930b8e7bb96 (patch) | |
| tree | 5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/CHAD/AST/SplitLets.hs | |
| parent | 92bca235e3aaa287286b6af082d3fce585825a35 (diff) | |
Move module hierarchy under CHAD.
Diffstat (limited to 'src/CHAD/AST/SplitLets.hs')
| -rw-r--r-- | src/CHAD/AST/SplitLets.hs | 191 |
1 files changed, 191 insertions, 0 deletions
diff --git a/src/CHAD/AST/SplitLets.hs b/src/CHAD/AST/SplitLets.hs new file mode 100644 index 0000000..34267e4 --- /dev/null +++ b/src/CHAD/AST/SplitLets.hs @@ -0,0 +1,191 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +module CHAD.AST.SplitLets (splitLets) where + +import Data.Type.Equality + +import CHAD.AST +import CHAD.AST.Bindings +import CHAD.Lemmas + + +splitLets :: Ex env t -> Ex env t +splitLets = splitLets' (\t i w -> EVar ext t (w @> i)) + +splitLets' :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> Ex env t -> Ex env' t +splitLets' = \sub -> \case + EVar _ t i -> sub t i WId + ELet _ rhs body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body) + ECase x e a b -> + let STEither t1 t2 = typeOf e + in ECase x (splitLets' sub e) (split1 sub t1 a) (split1 sub t2 b) + EMaybe x a b e -> + let STMaybe t1 = typeOf e + in EMaybe x (splitLets' sub a) (split1 sub t1 b) (splitLets' sub e) + ELCase x e a b c -> + let STLEither t1 t2 = typeOf e + in ELCase x (splitLets' sub e) (splitLets' sub a) (split1 sub t1 b) (split1 sub t2 c) + EFold1Inner x cm a b c -> + let STArr _ t1 = typeOf c + in EFold1Inner x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c) + EFold1InnerD1 x cm a b c -> + let STArr _ t1 = typeOf c + in EFold1InnerD1 x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c) + EFold1InnerD2 x cm a b c -> + let STArr _ tB = typeOf b + STArr _ t2 = typeOf c + in EFold1InnerD2 x cm (split2 sub tB t2 a) (splitLets' sub b) (splitLets' sub c) + + EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b) + EFst x e -> EFst x (splitLets' sub e) + ESnd x e -> ESnd x (splitLets' sub e) + ENil x -> ENil x + EInl x t e -> EInl x t (splitLets' sub e) + EInr x t e -> EInr x t (splitLets' sub e) + ENothing x t -> ENothing x t + EJust x e -> EJust x (splitLets' sub e) + ELNil x t1 t2 -> ELNil x t1 t2 + ELInl x t e -> ELInl x t (splitLets' sub e) + ELInr x t e -> ELInr x t (splitLets' sub e) + EConstArr x n t a -> EConstArr x n t a + EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b) + EMap x a b -> EMap x (splitLets' (sinkF sub) a) (splitLets' sub b) + ESum1Inner x e -> ESum1Inner x (splitLets' sub e) + EUnit x e -> EUnit x (splitLets' sub e) + EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b) + EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e) + EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e) + EReshape x n a b -> EReshape x n (splitLets' sub a) (splitLets' sub b) + EZip x a b -> EZip x (splitLets' sub a) (splitLets' sub b) + EConst x t v -> EConst x t v + EIdx0 x e -> EIdx0 x (splitLets' sub e) + EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b) + EIdx x e es -> EIdx x (splitLets' sub e) (splitLets' sub es) + EShape x e -> EShape x (splitLets' sub e) + EOp x op e -> EOp x op (splitLets' sub e) + ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2) + ERecompute x e -> ERecompute x (splitLets' sub e) + EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2) + EAccum x t p e1 sp e2 e3 -> EAccum x t p (splitLets' sub e1) sp (splitLets' sub e2) (splitLets' sub e3) + EZero x t ezi -> EZero x t (splitLets' sub ezi) + EDeepZero x t ezi -> EDeepZero x t (splitLets' sub ezi) + EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b) + EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b) + EError x t s -> EError x t s + where + sinkF :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) + -> STy t -> Idx (b : env) t -> (b : env') :> env3 -> Ex env3 t + sinkF _ t IZ w = EVar ext t (w @> IZ) + sinkF f t (IS i) w = f t i (w .> WSink) + + split1 :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) + -> STy bind -> Ex (bind : env) t -> Ex (bind : env') t + split1 sub (tbind :: STy bind) body = + let (ptrs, bs) = split tbind + in letBinds bs $ + splitLets' (\cases _ IZ w -> subPointers ptrs w + t (IS i) w -> sub t i (WPop @bind (wPops (bindingsBinds bs) w))) + body + + split2 :: forall bind1 bind2 env' env t. + (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) + -> STy bind1 -> STy bind2 -> Ex (bind2 : bind1 : env) t -> Ex (bind2 : bind1 : env') t + split2 sub tbind1 tbind2 body = + let (ptrs1', bs1') = split @env' tbind1 + bs1 = fst (weakenBindingsE WSink bs1') + (ptrs2, bs2) = split @(bind1 : env') tbind2 + in letBinds bs1 $ + letBinds (fst (weakenBindingsE (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $ + splitLets' (\cases _ IZ w -> subPointers ptrs2 (w .> wCopies (bindingsBinds bs2) (wSinks @(bind2 : bind1 : env') (bindingsBinds bs1))) + _ (IS IZ) w -> subPointers ptrs1' (w .> wSinks (bindingsBinds bs2) .> wCopies (bindingsBinds bs1) (WSink @bind2 @(bind1 : env'))) + t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w))))) + body + + -- TODO: abstract this to splitN lol wtf + _split4 :: forall bind1 bind2 bind3 bind4 env' env t. + (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) + -> STy bind1 -> STy bind2 -> STy bind3 -> STy bind4 -> Ex (bind4 : bind3 : bind2 : bind1 : env) t -> Ex (bind4 : bind3 : bind2 : bind1 : env') t + _split4 sub tbind1 tbind2 tbind3 tbind4 body = + let (ptrs1, bs1') = split @env' tbind1 + (ptrs2, bs2') = split @(bind1 : env') tbind2 + (ptrs3, bs3') = split @(bind2 : bind1 : env') tbind3 + (ptrs4, bs4) = split @(bind3 : bind2 : bind1 : env') tbind4 + bs1 = fst (weakenBindingsE (WSink .> WSink .> WSink) bs1') + bs2 = fst (weakenBindingsE (WSink .> WSink) bs2') + bs3 = fst (weakenBindingsE WSink bs3') + b1 = bindingsBinds bs1 + b2 = bindingsBinds bs2 + b3 = bindingsBinds bs3 + b4 = bindingsBinds bs4 + in letBinds bs1 $ + letBinds (fst (weakenBindingsE ( sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs2)) $ + letBinds (fst (weakenBindingsE ( sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs3)) $ + letBinds (fst (weakenBindingsE (sinkWithBindings bs3 .> sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs4)) $ + splitLets' (\cases _ IZ w -> subPointers ptrs4 (w .> wCopies b4 (wSinks b3 .> wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1)) + _ (IS IZ) w -> subPointers ptrs3 (w .> wSinks b4 .> wCopies b3 (wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink)) + _ (IS (IS IZ)) w -> subPointers ptrs2 (w .> wSinks b4 .> wSinks b3 .> wCopies b2 (wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink .> WSink)) + _ (IS (IS (IS IZ))) w -> subPointers ptrs1 (w .> wSinks b4 .> wSinks b3 .> wSinks b2 .> wCopies b1 (WSink @bind4 .> WSink @bind3 .> WSink @bind2 @(bind1 : env'))) + t (IS (IS (IS (IS i)))) w -> sub t i (WPop @bind1 (WPop @bind2 (WPop @bind3 (WPop @bind4 (wPops b1 (wPops b2 (wPops b3 (wPops b4 w))))))))) + body + +type family Split t where + Split (TPair a b) = SplitRec (TPair a b) + Split _ = '[] + +type family SplitRec t where + SplitRec TNil = '[] + SplitRec (TPair a b) = Append (SplitRec b) (SplitRec a) + SplitRec t = '[t] + +data Pointers env t where + Point :: STy t -> Idx env t -> Pointers env t + PNil :: Pointers env TNil + PPair :: Pointers env a -> Pointers env b -> Pointers env (TPair a b) + PWeak :: env' :> env -> Pointers env' t -> Pointers env t + +subPointers :: Pointers env t -> env :> env' -> Ex env' t +subPointers (Point t i) w = EVar ext t (w @> i) +subPointers PNil _ = ENil ext +subPointers (PPair a b) w = EPair ext (subPointers a w) (subPointers b w) +subPointers (PWeak w' p) w = subPointers p (w .> w') + +split :: forall env t. STy t + -> (Pointers (Append (Split t) (t : env)) t, Bindings Ex (t : env) (Split t)) +split typ = case typ of + STPair{} -> splitRec (EVar ext typ IZ) typ + STNil -> other + STEither{} -> other + STLEither{} -> other + STMaybe{} -> other + STArr{} -> other + STScal{} -> other + STAccum{} -> other + where + other :: (Pointers (t : env) t, Bindings Ex (t : env) '[]) + other = (Point typ IZ, BTop) + +splitRec :: forall env t. Ex env t -> STy t + -> (Pointers (Append (SplitRec t) env) t, Bindings Ex env (SplitRec t)) +splitRec rhs typ = case typ of + STNil -> (PNil, BTop) + STPair (a :: STy a) (b :: STy b) + | Refl <- lemAppendAssoc @(SplitRec b) @(SplitRec a) @env -> + let (p1, bs1) = splitRec (EFst ext rhs) a + (p2, bs2) = splitRec (ESnd ext (sinkWithBindings bs1 `weakenExpr` rhs)) b + in (PPair (PWeak (sinkWithBindings bs2) p1) p2, bconcat bs1 bs2) + STEither{} -> other + STLEither{} -> other + STMaybe{} -> other + STArr{} -> other + STScal{} -> other + STAccum{} -> other + where + other :: (Pointers (t : env) t, Bindings Ex env '[t]) + other = (Point typ IZ, BPush BTop (typ, rhs)) |
