diff options
Diffstat (limited to 'src/AST/SplitLets.hs')
| -rw-r--r-- | src/AST/SplitLets.hs | 36 | 
1 files changed, 35 insertions, 1 deletions
| diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index 82ec1d6..f75e795 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -22,7 +22,7 @@ splitLets = splitLets' (\t i w -> EVar ext t (w @> i))  splitLets' :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> Ex env t -> Ex env' t  splitLets' = \sub -> \case    EVar _ t i -> sub t i WId -  ELet _ (rhs :: Ex env t1) body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body) +  ELet _ rhs body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body)    ECase x e a b ->      let STEither t1 t2 = typeOf e      in ECase x (splitLets' sub e) (split1 sub t1 a) (split1 sub t2 b) @@ -35,6 +35,13 @@ splitLets' = \sub -> \case    EFold1Inner x cm a b c ->      let STArr _ t1 = typeOf c      in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) +  EFold1InnerD1 x cm a b c -> +    let STArr _ t1 = typeOf c +    in EFold1InnerD1 x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) +  EFold1InnerD2 x cm t2 a b c d e -> +    let STArr _ t1 = typeOf b +        STArr _ (STPair _ ttape) = typeOf d +    in EFold1InnerD2 x cm t2 (split4 sub ttape t1 t1 (fromSMTy t2) a) (splitLets' sub b) (splitLets' sub c) (splitLets' sub d) (splitLets' sub e)    EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b)    EFst x e -> EFst x (splitLets' sub e) @@ -98,6 +105,33 @@ splitLets' = \sub -> \case                                t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w)))))                        body +    -- TODO: abstract this to splitN lol wtf +    split4 :: forall bind1 bind2 bind3 bind4 env' env t. +              (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) +           -> STy bind1 -> STy bind2 -> STy bind3 -> STy bind4 -> Ex (bind4 : bind3 : bind2 : bind1 : env) t -> Ex (bind4 : bind3 : bind2 : bind1 : env') t +    split4 sub tbind1 tbind2 tbind3 tbind4 body = +      let (ptrs1, bs1') = split @env' tbind1 +          (ptrs2, bs2') = split @(bind1 : env') tbind2 +          (ptrs3, bs3') = split @(bind2 : bind1 : env') tbind3 +          (ptrs4, bs4) = split @(bind3 : bind2 : bind1 : env') tbind4 +          bs1 = fst (weakenBindingsE (WSink .> WSink .> WSink) bs1') +          bs2 = fst (weakenBindingsE (WSink .> WSink) bs2') +          bs3 = fst (weakenBindingsE WSink bs3') +          b1 = bindingsBinds bs1 +          b2 = bindingsBinds bs2 +          b3 = bindingsBinds bs3 +          b4 = bindingsBinds bs4 +      in letBinds bs1 $ +         letBinds (fst (weakenBindingsE (                                                sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs2)) $ +         letBinds (fst (weakenBindingsE (                        sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs3)) $ +         letBinds (fst (weakenBindingsE (sinkWithBindings bs3 .> sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs4)) $ +           splitLets' (\cases _ IZ w ->                subPointers ptrs4 (w .> wCopies b4 (wSinks b3 .> wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1)) +                              _ (IS IZ) w ->           subPointers ptrs3 (w .> wSinks b4 .> wCopies b3 (wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink)) +                              _ (IS (IS IZ)) w ->      subPointers ptrs2 (w .> wSinks b4 .> wSinks b3 .> wCopies b2 (wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink .> WSink)) +                              _ (IS (IS (IS IZ))) w -> subPointers ptrs1 (w .> wSinks b4 .> wSinks b3 .> wSinks b2 .> wCopies b1 (WSink @bind4 .> WSink @bind3 .> WSink @bind2 @(bind1 : env'))) +                              t (IS (IS (IS (IS i)))) w -> sub t i (WPop @bind1 (WPop @bind2 (WPop @bind3 (WPop @bind4 (wPops b1 (wPops b2 (wPops b3 (wPops b4 w))))))))) +                      body +  type family Split t where    Split (TPair a b) = SplitRec (TPair a b)    Split _ = '[] | 
