aboutsummaryrefslogtreecommitdiff
path: root/Sink.hs
blob: c258dc5de1b20b6c927de3e399a2a88cc9e246b6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeOperators #-}
module Sink where

import AST


newtype env :> env' = Weaken { (>:>) :: forall t'. Idx env t' -> Idx env' t' }

wId :: env :> env
wId = Weaken id

wSucc :: env :> env' -> env :> (a ': env')
wSucc (Weaken f) = Weaken (Succ . f)

wSink :: env :> env' -> (a ': env) :> (a ': env')
wSink w = Weaken (\case Zero -> Zero
                        Succ i -> Succ (w >:> i))

(.>) :: env2 :> env3 -> env1 :> env2 -> env1 :> env3
Weaken f .> Weaken g = Weaken (f . g)

sinkExp :: env :> env' -> Exp env a -> Exp env' a
sinkExp w = \case
    App e1 e2 -> App (sinkExp w e1) (sinkExp w e2)
    Lam t e -> Lam t (sinkExp (wSink w) e)
    Var t i -> Var t (w >:> i)
    Let e1 e2 -> Let (sinkExp w e1) (sinkExp (wSink w) e2)
    Lit l -> Lit l
    Cond e1 e2 e3 -> Cond (sinkExp w e1) (sinkExp w e2) (sinkExp w e3)
    Const c -> Const c
    Pair e1 e2 -> Pair (sinkExp w e1) (sinkExp w e2)
    Fst e -> Fst (sinkExp w e)
    Snd e -> Snd (sinkExp w e)
    Build sht e1 e2 -> Build sht (sinkExp w e1) (sinkExp w e2)
    Ifold sht e1 e2 e3 -> Ifold sht (sinkExp w e1) (sinkExp w e2) (sinkExp w e3)
    Index e1 e2 -> Index (sinkExp w e1) (sinkExp w e2)
    Shape e -> Shape (sinkExp w e)
    Undef t -> Undef t

sinkExp1 :: Exp env a -> Exp (t ': env) a
sinkExp1 = sinkExp (wSucc wId)

sinkExp2 :: Exp env a -> Exp (t1 ': t2 ': env) a
sinkExp2 = sinkExp (wSucc (wSucc wId))