{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module AST.SplitLets (splitLets) where import Data.Type.Equality import AST import AST.Bindings import Lemmas splitLets :: Ex env t -> Ex env t splitLets = splitLets' (\t i w -> EVar ext t (w @> i)) splitLets' :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> Ex env t -> Ex env' t splitLets' = \sub -> \case EVar _ t i -> sub t i WId ELet _ (rhs :: Ex env t1) body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body) ECase x e a b -> let STEither t1 t2 = typeOf e in ECase x (splitLets' sub e) (split1 sub t1 a) (split1 sub t2 b) EMaybe x a b e -> let STMaybe t1 = typeOf e in EMaybe x (splitLets' sub a) (split1 sub t1 b) (splitLets' sub e) EFold1Inner x cm a b c -> let STArr _ t1 = typeOf c in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b) EFst x e -> EFst x (splitLets' sub e) ESnd x e -> ESnd x (splitLets' sub e) ENil x -> ENil x EInl x t e -> EInl x t (splitLets' sub e) EInr x t e -> EInr x t (splitLets' sub e) ENothing x t -> ENothing x t EJust x e -> EJust x (splitLets' sub e) EConstArr x n t a -> EConstArr x n t a EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b) ESum1Inner x e -> ESum1Inner x (splitLets' sub e) EUnit x e -> EUnit x (splitLets' sub e) EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b) EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e) EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (splitLets' sub e) EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b) EIdx x e es -> EIdx x (splitLets' sub e) (splitLets' sub es) EShape x e -> EShape x (splitLets' sub e) EOp x op e -> EOp x op (splitLets' sub e) ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2) EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2) EAccum x t p e1 e2 e3 -> EAccum x t p (splitLets' sub e1) (splitLets' sub e2) (splitLets' sub e3) EZero x t -> EZero x t EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b) EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b) EError x t s -> EError x t s where sinkF :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> STy t -> Idx (b : env) t -> (b : env') :> env3 -> Ex env3 t sinkF _ t IZ w = EVar ext t (w @> IZ) sinkF f t (IS i) w = f t i (w .> WSink) split1 :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> STy bind -> Ex (bind : env) t -> Ex (bind : env') t split1 sub (tbind :: STy bind) body = let (ptrs, bs) = split (EVar ext tbind IZ) tbind in letBinds bs $ splitLets' (\cases _ IZ w -> subPointers ptrs w t (IS i) w -> sub t i (WPop @bind (wPops (bindingsBinds bs) w))) body split2 :: forall bind1 bind2 env' env t. (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> STy bind1 -> STy bind2 -> Ex (bind2 : bind1 : env) t -> Ex (bind2 : bind1 : env') t split2 sub tbind1 tbind2 body = let (ptrs1, bs1) = split (EVar ext tbind1 (IS IZ)) tbind1 (ptrs2, bs2) = split (EVar ext tbind2 IZ) tbind2 in letBinds bs1 $ letBinds (fst (weakenBindings weakenExpr (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $ splitLets' (\cases _ IZ w -> subPointers ptrs2 (w .> wCopies (bindingsBinds bs2) (wSinks @(bind2 : bind1 : env') (bindingsBinds bs1))) _ (IS IZ) w -> subPointers ptrs1 (w .> wSinks (bindingsBinds bs2)) t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w))))) body type family Split t where Split TNil = '[] Split (TPair a b) = Append (Split b) (Split a) Split t = '[t] data Pointers env t where Point :: STy t -> Idx env t -> Pointers env t PNil :: Pointers env TNil PPair :: Pointers env a -> Pointers env b -> Pointers env (TPair a b) PWeak :: env' :> env -> Pointers env' t -> Pointers env t subPointers :: Pointers env t -> env :> env' -> Ex env' t subPointers (Point t i) w = EVar ext t (w @> i) subPointers PNil _ = ENil ext subPointers (PPair a b) w = EPair ext (subPointers a w) (subPointers b w) subPointers (PWeak w' p) w = subPointers p (w .> w') split :: forall env t. Ex env t -> STy t -> (Pointers (Append (Split t) env) t, Bindings Ex env (Split t)) split i = \case STNil -> (PNil, BTop) STPair (a :: STy a) (b :: STy b) | Refl <- lemAppendAssoc @(Split b) @(Split a) @env -> let (p1, bs1) = split (EFst ext i) a (p2, bs2) = split (ESnd ext (sinkWithBindings bs1 `weakenExpr` i)) b in (PPair (PWeak (sinkWithBindings bs2) p1) p2, bconcat bs1 bs2) t@STEither{} -> other t t@STMaybe{} -> other t t@STArr{} -> other t t@STScal{} -> other t t@STAccum{} -> other t where other :: STy t -> (Pointers (t : env) t, Bindings Ex env '[t]) other t = (Point t IZ, BPush BTop (t, i))