aboutsummaryrefslogtreecommitdiff
path: root/AST.hs
blob: 3e1d2f6e6cccc7db8e4ad9b53d825b81e76195ca (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
module AST where

import Data.GADT.Compare
import Data.Type.Equality
import qualified Data.Vector as V
import Data.Vector (Vector)


data Exp env a where
    App :: Exp env (a -> b) -> Exp env a -> Exp env b
    Lam :: Type t -> Exp (t ': env) a -> Exp env (t -> a)
    Var :: Type a -> Idx env a -> Exp env a
    Let :: Exp env t -> Exp (t ': env) a -> Exp env a
    Lit :: Literal a -> Exp env a
    Cond :: Exp env Bool -> Exp env a -> Exp env a -> Exp env a
    Const :: Constant a -> Exp env a
    Pair :: Exp env a -> Exp env b -> Exp env (a, b)
    Fst :: Exp env (a, b) -> Exp env a
    Snd :: Exp env (a, b) -> Exp env b
    Build :: ShapeType sh -> Exp env sh -> Exp env (sh -> a) -> Exp env (Array sh a)
    Ifold :: ShapeType sh -> Exp env ((s, sh) -> s) -> Exp env s -> Exp env sh -> Exp env s
    Index :: Exp env (Array sh a) -> Exp env sh -> Exp env a
    Shape :: Exp env (Array sh a) -> Exp env sh
    Undef :: Type a -> Exp env a

data Constant a where
    CAddI :: Constant ((Int, Int) -> Int)
    CSubI :: Constant ((Int, Int) -> Int)
    CMulI :: Constant ((Int, Int) -> Int)
    CDivI :: Constant ((Int, Int) -> Int)
    CAddF :: Constant ((Double, Double) -> Double)
    CSubF :: Constant ((Double, Double) -> Double)
    CMulF :: Constant ((Double, Double) -> Double)
    CDivF :: Constant ((Double, Double) -> Double)
    CLog :: Constant (Double -> Double)
    CExp :: Constant (Double -> Double)
    CtoF :: Constant (Int -> Double)
    CRound :: Constant (Double -> Int)

    CLtI :: Constant ((Int, Int) -> Bool)
    CLeI :: Constant ((Int, Int) -> Bool)
    CLtF :: Constant ((Double, Double) -> Bool)
    CEq  :: Type a -> Constant ((a, a) -> Bool)
    CAnd :: Constant ((Bool, Bool) -> Bool)
    COr  :: Constant ((Bool, Bool) -> Bool)
    CNot :: Constant (Bool -> Bool)

data Type a where
    TInt :: Type Int
    TBool :: Type Bool
    TDouble :: Type Double
    TArray :: ShapeType sh -> Type a -> Type (Array sh a)
    TNil :: Type ()
    TPair :: Type a -> Type b -> Type (a, b)
    TFun :: Type a -> Type b -> Type (a -> b)

data Idx env a where
    Zero :: Idx (a ': env) a
    Succ :: Idx env a -> Idx (t ': env) a

data Literal a where
    LInt :: Int -> Literal Int
    LBool :: Bool -> Literal Bool
    LDouble :: Double -> Literal Double
    LArray :: Array sh a -> Literal (Array sh a)
    LShape :: Shape sh -> Literal sh
    LNil :: Literal ()
    LPair :: Literal a -> Literal b -> Literal (a, b)

data Shape sh where
    Z :: Shape ()
    (:.) :: Shape sh -> Int -> Shape (sh, Int)

data ShapeType sh where
    STZ :: ShapeType ()
    STC :: ShapeType sh -> ShapeType (sh, Int)

data Array sh a where
    Array :: Shape sh -> Type a -> Vector a -> Array sh a

deriving instance Show (Exp env a)
deriving instance Show (Constant a)
deriving instance Show (Type a)
deriving instance Show (Idx env a)
deriving instance Show (Literal a)
deriving instance Show (ShapeType a)

instance Show (Shape sh) where
    showsPrec _ Z = showString "Z"
    showsPrec p (sh :. n) = showParen (p > 0) $
        showsPrec 10 sh . showString " :. " . shows n

instance Show (Array sh a) where
    showsPrec p (Array sh t v) =
        showParen (p > 10) $
            showString "Array "
            . showsPrec 11 sh . showString " "
            . showsPrec 11 t . showString " "
            . (case typeHasShow t of
                 Just Has -> showsPrec 11 v
                 Nothing -> showString ("[_ * " ++ show (V.length v) ++ "]"))

deriving instance Eq (Type a)
deriving instance Eq (Shape sh)
deriving instance Eq (ShapeType sh)
deriving instance Eq a => Eq (Array sh a)

instance GEq (Exp env) where
    geq (App a b) (App a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl
    geq App{} _ = Nothing
    geq (Lam t e) (Lam t' e') | Just Refl <- geq t t', Just Refl <- geq e e' = Just Refl
    geq Lam{} _ = Nothing
    geq (Var t i) (Var t' i') | Just Refl <- geq t t', Just Refl <- geq i i' = Just Refl
    geq Var{} _ = Nothing
    geq (Let a b) (Let a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl
    geq Let{} _ = Nothing
    geq (Lit l) (Lit l') | Just Refl <- geq l l' = Just Refl
    geq Lit{} _ = Nothing
    geq (Cond a b c) (Cond a' b' c') | Just Refl <- geq a a', Just Refl <- geq b b', Just Refl <- geq c c' = Just Refl
    geq Cond{} _ = Nothing
    geq (Const c) (Const c') | Just Refl <- geq c c' = Just Refl
    geq Const{} _ = Nothing
    geq (Pair a b) (Pair a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl
    geq Pair{} _ = Nothing
    geq (Fst a) (Fst a') | Just Refl <- geq a a' = Just Refl
    geq Fst{} _ = Nothing
    geq (Snd a) (Snd a') | Just Refl <- geq a a' = Just Refl
    geq Snd{} _ = Nothing
    geq (Build t a b) (Build t' a' b') | Just Refl <- geq t t', Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl
    geq Build{} _ = Nothing
    geq (Ifold t a b c) (Ifold t' a' b' c') | Just Refl <- geq t t', Just Refl <- geq a a', Just Refl <- geq b b' , Just Refl <- geq c c'
      = Just Refl
    geq Ifold{} _ = Nothing
    geq (Index a b) (Index a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl
    geq Index{} _ = Nothing
    geq (Shape a) (Shape a') | Just Refl <- geq a a' = Just Refl
    geq Shape{} _ = Nothing
    geq (Undef a) (Undef a') | Just Refl <- geq a a' = Just Refl
    geq Undef{} _ = Nothing

instance GEq Constant where
    geq CAddI CAddI = Just Refl ; geq CAddI _ = Nothing
    geq CSubI CSubI = Just Refl ; geq CSubI _ = Nothing
    geq CMulI CMulI = Just Refl ; geq CMulI _ = Nothing
    geq CDivI CDivI = Just Refl ; geq CDivI _ = Nothing
    geq CAddF CAddF = Just Refl ; geq CAddF _ = Nothing
    geq CSubF CSubF = Just Refl ; geq CSubF _ = Nothing
    geq CMulF CMulF = Just Refl ; geq CMulF _ = Nothing
    geq CDivF CDivF = Just Refl ; geq CDivF _ = Nothing
    geq CLog CLog = Just Refl ; geq CLog _ = Nothing
    geq CExp CExp = Just Refl ; geq CExp _ = Nothing
    geq CtoF CtoF = Just Refl ; geq CtoF _ = Nothing
    geq CRound CRound = Just Refl ; geq CRound _ = Nothing
    geq CLtI CLtI = Just Refl ; geq CLtI _ = Nothing
    geq CLeI CLeI = Just Refl ; geq CLeI _ = Nothing
    geq CLtF CLtF = Just Refl ; geq CLtF _ = Nothing
    geq (CEq t) (CEq t') | Just Refl <- geq t t' = Just Refl ; geq CEq{} _ = Nothing
    geq CAnd CAnd = Just Refl ; geq CAnd _ = Nothing
    geq COr COr = Just Refl ; geq COr _ = Nothing
    geq CNot CNot = Just Refl ; geq CNot _ = Nothing

instance GEq Type where
    geq TInt TInt = Just Refl ; geq TInt _ = Nothing
    geq TBool TBool = Just Refl ; geq TBool _ = Nothing
    geq TDouble TDouble = Just Refl ; geq TDouble _ = Nothing
    geq (TArray sht t) (TArray sht' t') | Just Refl <- geq sht sht', Just Refl <- geq t t' = Just Refl ; geq TArray{} _ = Nothing
    geq TNil TNil = Just Refl ; geq TNil _ = Nothing
    geq (TPair a b) (TPair a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl ; geq TPair{} _ = Nothing
    geq (TFun a b) (TFun a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl ; geq TFun{} _ = Nothing

instance GEq (Idx env) where
    geq Zero Zero = Just Refl
    geq (Succ i) (Succ i') | Just Refl <- geq i i' = Just Refl
    geq _ _ = Nothing

instance GEq Literal where
    geq (LInt a) (LInt a') | a == a' = Just Refl ; geq LInt{} _ = Nothing
    geq (LBool a) (LBool a') | a == a' = Just Refl ; geq LBool{} _ = Nothing
    geq (LDouble a) (LDouble a') | a == a' = Just Refl ; geq LDouble{} _ = Nothing
    geq (LArray (Array sht t v)) (LArray (Array sht' t' v'))
      | Just Refl <- geq sht sht'
      , Just Refl <- geq t t'
      = case typeHasEq t of
          Just Has | v == v' -> Just Refl
                   | otherwise -> Nothing
          Nothing -> error "GEq Literal: Literal array of incomparable values"
    geq LArray{} _ = Nothing
    geq (LShape a) (LShape a') | Just Refl <- geq a a' = Just Refl ; geq LShape{} _ = Nothing
    geq LNil LNil = Just Refl ; geq LNil _ = Nothing
    geq (LPair a b) (LPair a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl ; geq LPair{} _ = Nothing

instance GEq Shape where
    geq Z Z = Just Refl
    geq (sh :. n) (sh' :. n') | n == n', Just Refl <- geq sh sh' = Just Refl
    geq _ _ = Nothing

instance GEq ShapeType where
    geq STZ STZ = Just Refl
    geq (STC sht) (STC sht') | Just Refl <- geq sht sht' = Just Refl
    geq _ _ = Nothing

-- Requires that the given term is neither 'Lam' nor 'Let'.
recurseMon :: Monoid s => (forall t. Exp env t -> s) -> Exp env a -> s
recurseMon f = \case
    App a b -> f a <> f b
    Lam _ _ -> error "recurseMon: Given Lam"
    Var _ _ -> mempty
    Let _ _ -> error "recurseMon: Given Let"
    Lit _ -> mempty
    Cond a b c -> f a <> f b <> f c
    Const _ -> mempty
    Pair a b -> f a <> f b
    Fst a -> f a
    Snd a -> f a
    Build _ a b -> f a <> f b
    Ifold _ a b c -> f a <> f b <> f c
    Index a b -> f a <> f b
    Shape a -> f a
    Undef _ -> mempty

shapeType :: Shape sh -> ShapeType sh
shapeType Z = STZ
shapeType (sh :. _) = STC (shapeType sh)

shapeType' :: Shape sh -> Type sh
shapeType' Z = TNil
shapeType' (sh :. _) = TPair (shapeType' sh) TInt

shapeTypeType :: ShapeType sh -> Type sh
shapeTypeType STZ = TNil
shapeTypeType (STC sht) = TPair (shapeTypeType sht) TInt

literalType :: Literal a -> Type a
literalType LInt{} = TInt
literalType LBool{} = TBool
literalType LDouble{} = TDouble
literalType (LArray (Array sh t _)) = TArray (shapeType sh) t
literalType (LShape sh) = shapeType' sh
literalType LNil{} = TNil
literalType (LPair a b) = TPair (literalType a) (literalType b)

constType :: Constant a -> Type a
constType CAddI = TFun (TPair TInt TInt) TInt
constType CSubI = TFun (TPair TInt TInt) TInt
constType CMulI = TFun (TPair TInt TInt) TInt
constType CDivI = TFun (TPair TInt TInt) TInt
constType CAddF = TFun (TPair TDouble TDouble) TDouble
constType CSubF = TFun (TPair TDouble TDouble) TDouble
constType CMulF = TFun (TPair TDouble TDouble) TDouble
constType CDivF = TFun (TPair TDouble TDouble) TDouble
constType CLog = TFun TDouble TDouble
constType CExp = TFun TDouble TDouble
constType CtoF = TFun TInt TDouble
constType CRound = TFun TDouble TInt
constType CLtI = TFun (TPair TInt TInt) TBool
constType CLeI = TFun (TPair TInt TInt) TBool
constType CLtF = TFun (TPair TDouble TDouble) TBool
constType (CEq t) = TFun (TPair t t) TBool
constType CAnd = TFun (TPair TBool TBool) TBool
constType COr  = TFun (TPair TBool TBool) TBool
constType CNot = TFun TBool TBool

typeof :: Exp env a -> Type a
typeof (App e _) = let TFun _ t = typeof e in t
typeof (Lam t e) = TFun t (typeof e)
typeof (Var t _) = t
typeof (Let _ e) = typeof e
typeof (Lit l) = literalType l
typeof (Cond _ e _) = typeof e
typeof (Const c) = constType c
typeof (Pair e1 e2) = TPair (typeof e1) (typeof e2)
typeof (Fst e) = let TPair t _ = typeof e in t
typeof (Snd e) = let TPair _ t = typeof e in t
typeof (Build sht _ e) = let TFun _ t = typeof e in TArray sht t
typeof (Ifold _ _ e _) = typeof e
typeof (Index e _) = let TArray _ t = typeof e in t
typeof (Shape e) = let TArray sht _ = typeof e in shapeTypeType sht
typeof (Undef t) = t

idxToInt :: Idx env a -> Int
idxToInt Zero = 0
idxToInt (Succ i) = idxToInt i + 1

shtToInt :: ShapeType sh -> Int
shtToInt STZ = 0
shtToInt (STC sht) = shtToInt sht + 1

data Has c a where
    Has :: c a => Has c a

typeHasShow :: Type a -> Maybe (Has Show a)
typeHasShow TInt = Just Has
typeHasShow TBool = Just Has
typeHasShow TDouble = Just Has
typeHasShow TArray{} = Just Has
typeHasShow TNil = Just Has
typeHasShow (TPair a b)
  | Just Has <- typeHasShow a
  , Just Has <- typeHasShow b
  = Just Has
  | otherwise
  = Nothing
typeHasShow TFun{} = Nothing

typeHasEq :: Type a -> Maybe (Has Eq a)
typeHasEq TInt = Just Has
typeHasEq TBool = Just Has
typeHasEq TDouble = Just Has
typeHasEq (TArray _ t)
  | Just Has <- typeHasEq t
  = Just Has
  | otherwise
  = Nothing
typeHasEq TNil = Just Has
typeHasEq (TPair a b)
  | Just Has <- typeHasEq a
  , Just Has <- typeHasEq b
  = Just Has
  | otherwise
  = Nothing
typeHasEq TFun{} = Nothing