summaryrefslogtreecommitdiff
path: root/src/AST/SplitLets.hs
blob: 1de417cd0514ab6547914e75e92260507b7ea4e0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module AST.SplitLets (splitLets) where

import Data.Type.Equality

import AST
import AST.Bindings
import Lemmas


splitLets :: Ex env t -> Ex env t
splitLets = splitLets' (\t i w -> EVar ext t (w @> i))

splitLets' :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> Ex env t -> Ex env' t
splitLets' = \sub -> \case
  EVar _ t i -> sub t i WId
  ELet _ (rhs :: Ex env t1) body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body)
  ECase x e a b ->
    let STEither t1 t2 = typeOf e
    in ECase x (splitLets' sub e) (split1 sub t1 a) (split1 sub t2 b)
  EMaybe x a b e ->
    let STMaybe t1 = typeOf e
    in EMaybe x (splitLets' sub a) (split1 sub t1 b) (splitLets' sub e)
  EFold1Inner x cm a b c ->
    let STArr _ t1 = typeOf c
    in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c)

  EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b)
  EFst x e -> EFst x (splitLets' sub e)
  ESnd x e -> ESnd x (splitLets' sub e)
  ENil x -> ENil x
  EInl x t e -> EInl x t (splitLets' sub e)
  EInr x t e -> EInr x t (splitLets' sub e)
  ENothing x t -> ENothing x t
  EJust x e -> EJust x (splitLets' sub e)
  EConstArr x n t a -> EConstArr x n t a
  EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b)
  ESum1Inner x e -> ESum1Inner x (splitLets' sub e)
  EUnit x e -> EUnit x (splitLets' sub e)
  EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b)
  EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e)
  EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e)
  EConst x t v -> EConst x t v
  EIdx0 x e -> EIdx0 x (splitLets' sub e)
  EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b)
  EIdx x e es -> EIdx x (splitLets' sub e) (splitLets' sub es)
  EShape x e -> EShape x (splitLets' sub e)
  EOp x op e -> EOp x op (splitLets' sub e)
  ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2)
  EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2)
  EAccum x t p e1 e2 e3 -> EAccum x t p (splitLets' sub e1) (splitLets' sub e2) (splitLets' sub e3)
  EZero x t -> EZero x t
  EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b)
  EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b)
  EError x t s -> EError x t s
  where
    sinkF :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a)
          -> STy t -> Idx (b : env) t -> (b : env') :> env3 -> Ex env3 t
    sinkF _ t IZ w = EVar ext t (w @> IZ)
    sinkF f t (IS i) w = f t i (w .> WSink)

    split1 :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a)
           -> STy bind -> Ex (bind : env) t -> Ex (bind : env') t
    split1 sub (tbind :: STy bind) body =
      let (ptrs, bs) = split (EVar ext tbind IZ) tbind
      in letBinds bs $
           splitLets' (\cases _ IZ w -> subPointers ptrs w
                              t (IS i) w -> sub t i (WPop @bind (wPops (bindingsBinds bs) w)))
                      body

    split2 :: forall bind1 bind2 env' env t.
              (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a)
           -> STy bind1 -> STy bind2 -> Ex (bind2 : bind1 : env) t -> Ex (bind2 : bind1 : env') t
    split2 sub tbind1 tbind2 body =
      let (ptrs1, bs1) = split (EVar ext tbind1 (IS IZ)) tbind1
          (ptrs2, bs2) = split (EVar ext tbind2 IZ) tbind2
      in letBinds bs1 $
         letBinds (fst (weakenBindings weakenExpr (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $
           splitLets' (\cases _ IZ w -> subPointers ptrs2 (w .> wCopies (bindingsBinds bs2) (wSinks @(bind2 : bind1 : env') (bindingsBinds bs1)))
                              _ (IS IZ) w -> subPointers ptrs1 (w .> wSinks (bindingsBinds bs2))
                              t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w)))))
                      body

type family Split t where
  Split TNil = '[]
  Split (TPair a b) = Append (Split b) (Split a)
  Split t = '[t]

data Pointers env t where
  Point :: STy t -> Idx env t -> Pointers env t
  PNil :: Pointers env TNil
  PPair :: Pointers env a -> Pointers env b -> Pointers env (TPair a b)
  PWeak :: env' :> env -> Pointers env' t -> Pointers env t

subPointers :: Pointers env t -> env :> env' -> Ex env' t
subPointers (Point t i) w = EVar ext t (w @> i)
subPointers PNil _ = ENil ext
subPointers (PPair a b) w = EPair ext (subPointers a w) (subPointers b w)
subPointers (PWeak w' p) w = subPointers p (w .> w')

split :: forall env t. Ex env t -> STy t
      -> (Pointers (Append (Split t) env) t, Bindings Ex env (Split t))
split i = \case
  STNil -> (PNil, BTop)
  STPair (a :: STy a) (b :: STy b)
    | Refl <- lemAppendAssoc @(Split b) @(Split a) @env ->
        let (p1, bs1) = split (EFst ext i) a
            (p2, bs2) = split (ESnd ext (sinkWithBindings bs1 `weakenExpr` i)) b
        in (PPair (PWeak (sinkWithBindings bs2) p1) p2, bconcat bs1 bs2)
  t@STEither{} -> other t
  t@STMaybe{} -> other t
  t@STArr{} -> other t
  t@STScal{} -> other t
  t@STAccum{} -> other t
  where
    other :: STy t -> (Pointers (t : env) t, Bindings Ex env '[t])
    other t = (Point t IZ, BPush BTop (t, i))