diff options
Diffstat (limited to 'src/AST/SplitLets.hs')
| -rw-r--r-- | src/AST/SplitLets.hs | 154 |
1 files changed, 0 insertions, 154 deletions
diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs deleted file mode 100644 index dcaf82f..0000000 --- a/src/AST/SplitLets.hs +++ /dev/null @@ -1,154 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -module AST.SplitLets (splitLets) where - -import Data.Type.Equality - -import AST -import AST.Bindings -import Lemmas - - -splitLets :: Ex env t -> Ex env t -splitLets = splitLets' (\t i w -> EVar ext t (w @> i)) - -splitLets' :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> Ex env t -> Ex env' t -splitLets' = \sub -> \case - EVar _ t i -> sub t i WId - ELet _ (rhs :: Ex env t1) body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body) - ECase x e a b -> - let STEither t1 t2 = typeOf e - in ECase x (splitLets' sub e) (split1 sub t1 a) (split1 sub t2 b) - EMaybe x a b e -> - let STMaybe t1 = typeOf e - in EMaybe x (splitLets' sub a) (split1 sub t1 b) (splitLets' sub e) - ELCase x e a b c -> - let STLEither t1 t2 = typeOf e - in ELCase x (splitLets' sub e) (splitLets' sub a) (split1 sub t1 b) (split1 sub t2 c) - EFold1Inner x cm a b c -> - let STArr _ t1 = typeOf c - in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) - - EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b) - EFst x e -> EFst x (splitLets' sub e) - ESnd x e -> ESnd x (splitLets' sub e) - ENil x -> ENil x - EInl x t e -> EInl x t (splitLets' sub e) - EInr x t e -> EInr x t (splitLets' sub e) - ENothing x t -> ENothing x t - EJust x e -> EJust x (splitLets' sub e) - ELNil x t1 t2 -> ELNil x t1 t2 - ELInl x t e -> ELInl x t (splitLets' sub e) - ELInr x t e -> ELInr x t (splitLets' sub e) - EConstArr x n t a -> EConstArr x n t a - EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b) - ESum1Inner x e -> ESum1Inner x (splitLets' sub e) - EUnit x e -> EUnit x (splitLets' sub e) - EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b) - EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e) - EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e) - EConst x t v -> EConst x t v - EIdx0 x e -> EIdx0 x (splitLets' sub e) - EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b) - EIdx x e es -> EIdx x (splitLets' sub e) (splitLets' sub es) - EShape x e -> EShape x (splitLets' sub e) - EOp x op e -> EOp x op (splitLets' sub e) - ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2) - ERecompute x e -> ERecompute x (splitLets' sub e) - EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2) - EAccum x t p e1 sp e2 e3 -> EAccum x t p (splitLets' sub e1) sp (splitLets' sub e2) (splitLets' sub e3) - EZero x t ezi -> EZero x t (splitLets' sub ezi) - EDeepZero x t ezi -> EDeepZero x t (splitLets' sub ezi) - EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b) - EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b) - EError x t s -> EError x t s - where - sinkF :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) - -> STy t -> Idx (b : env) t -> (b : env') :> env3 -> Ex env3 t - sinkF _ t IZ w = EVar ext t (w @> IZ) - sinkF f t (IS i) w = f t i (w .> WSink) - - split1 :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) - -> STy bind -> Ex (bind : env) t -> Ex (bind : env') t - split1 sub (tbind :: STy bind) body = - let (ptrs, bs) = split tbind - in letBinds bs $ - splitLets' (\cases _ IZ w -> subPointers ptrs w - t (IS i) w -> sub t i (WPop @bind (wPops (bindingsBinds bs) w))) - body - - split2 :: forall bind1 bind2 env' env t. - (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) - -> STy bind1 -> STy bind2 -> Ex (bind2 : bind1 : env) t -> Ex (bind2 : bind1 : env') t - split2 sub tbind1 tbind2 body = - let (ptrs1', bs1') = split @env' tbind1 - bs1 = fst (weakenBindings weakenExpr WSink bs1') - (ptrs2, bs2) = split @(bind1 : env') tbind2 - in letBinds bs1 $ - letBinds (fst (weakenBindings weakenExpr (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $ - splitLets' (\cases _ IZ w -> subPointers ptrs2 (w .> wCopies (bindingsBinds bs2) (wSinks @(bind2 : bind1 : env') (bindingsBinds bs1))) - _ (IS IZ) w -> subPointers ptrs1' (w .> wSinks (bindingsBinds bs2) .> wCopies (bindingsBinds bs1) (WSink @bind2 @(bind1 : env'))) - t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w))))) - body - -type family Split t where - Split (TPair a b) = SplitRec (TPair a b) - Split _ = '[] - -type family SplitRec t where - SplitRec TNil = '[] - SplitRec (TPair a b) = Append (SplitRec b) (SplitRec a) - SplitRec t = '[t] - -data Pointers env t where - Point :: STy t -> Idx env t -> Pointers env t - PNil :: Pointers env TNil - PPair :: Pointers env a -> Pointers env b -> Pointers env (TPair a b) - PWeak :: env' :> env -> Pointers env' t -> Pointers env t - -subPointers :: Pointers env t -> env :> env' -> Ex env' t -subPointers (Point t i) w = EVar ext t (w @> i) -subPointers PNil _ = ENil ext -subPointers (PPair a b) w = EPair ext (subPointers a w) (subPointers b w) -subPointers (PWeak w' p) w = subPointers p (w .> w') - -split :: forall env t. STy t - -> (Pointers (Append (Split t) (t : env)) t, Bindings Ex (t : env) (Split t)) -split typ = case typ of - STPair{} -> splitRec (EVar ext typ IZ) typ - STNil -> other - STEither{} -> other - STLEither{} -> other - STMaybe{} -> other - STArr{} -> other - STScal{} -> other - STAccum{} -> other - where - other :: (Pointers (t : env) t, Bindings Ex (t : env) '[]) - other = (Point typ IZ, BTop) - -splitRec :: forall env t. Ex env t -> STy t - -> (Pointers (Append (SplitRec t) env) t, Bindings Ex env (SplitRec t)) -splitRec rhs typ = case typ of - STNil -> (PNil, BTop) - STPair (a :: STy a) (b :: STy b) - | Refl <- lemAppendAssoc @(SplitRec b) @(SplitRec a) @env -> - let (p1, bs1) = splitRec (EFst ext rhs) a - (p2, bs2) = splitRec (ESnd ext (sinkWithBindings bs1 `weakenExpr` rhs)) b - in (PPair (PWeak (sinkWithBindings bs2) p1) p2, bconcat bs1 bs2) - STEither{} -> other - STLEither{} -> other - STMaybe{} -> other - STArr{} -> other - STScal{} -> other - STAccum{} -> other - where - other :: (Pointers (t : env) t, Bindings Ex env '[t]) - other = (Point typ IZ, BPush BTop (typ, rhs)) |
