1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
|
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
module AST where
import Data.GADT.Compare
import Data.Type.Equality
import qualified Data.Vector as V
import Data.Vector (Vector)
data Exp env a where
App :: Exp env (a -> b) -> Exp env a -> Exp env b
Lam :: Type t -> Exp (t ': env) a -> Exp env (t -> a)
Var :: Type a -> Idx env a -> Exp env a
Let :: Exp env t -> Exp (t ': env) a -> Exp env a
Lit :: Literal a -> Exp env a
Cond :: Exp env Bool -> Exp env a -> Exp env a -> Exp env a
Const :: Constant a -> Exp env a
Pair :: Exp env a -> Exp env b -> Exp env (a, b)
Fst :: Exp env (a, b) -> Exp env a
Snd :: Exp env (a, b) -> Exp env b
Build :: ShapeType sh -> Exp env sh -> Exp env (sh -> a) -> Exp env (Array sh a)
Ifold :: ShapeType sh -> Exp env ((s, sh) -> s) -> Exp env s -> Exp env sh -> Exp env s
Index :: Exp env (Array sh a) -> Exp env sh -> Exp env a
Shape :: Exp env (Array sh a) -> Exp env sh
Undef :: Type a -> Exp env a
data Constant a where
CAddI :: Constant ((Int, Int) -> Int)
CSubI :: Constant ((Int, Int) -> Int)
CMulI :: Constant ((Int, Int) -> Int)
CDivI :: Constant ((Int, Int) -> Int)
CAddF :: Constant ((Double, Double) -> Double)
CSubF :: Constant ((Double, Double) -> Double)
CMulF :: Constant ((Double, Double) -> Double)
CDivF :: Constant ((Double, Double) -> Double)
CLog :: Constant (Double -> Double)
CExp :: Constant (Double -> Double)
CtoF :: Constant (Int -> Double)
CRound :: Constant (Double -> Int)
CLtI :: Constant ((Int, Int) -> Bool)
CLeI :: Constant ((Int, Int) -> Bool)
CLtF :: Constant ((Double, Double) -> Bool)
CEq :: Type a -> Constant ((a, a) -> Bool)
CAnd :: Constant ((Bool, Bool) -> Bool)
COr :: Constant ((Bool, Bool) -> Bool)
CNot :: Constant (Bool -> Bool)
data Type a where
TInt :: Type Int
TBool :: Type Bool
TDouble :: Type Double
TArray :: ShapeType sh -> Type a -> Type (Array sh a)
TNil :: Type ()
TPair :: Type a -> Type b -> Type (a, b)
TFun :: Type a -> Type b -> Type (a -> b)
data Idx env a where
Zero :: Idx (a ': env) a
Succ :: Idx env a -> Idx (t ': env) a
data Literal a where
LInt :: Int -> Literal Int
LBool :: Bool -> Literal Bool
LDouble :: Double -> Literal Double
LArray :: Array sh a -> Literal (Array sh a)
LShape :: Shape sh -> Literal sh
LNil :: Literal ()
LPair :: Literal a -> Literal b -> Literal (a, b)
data Shape sh where
Z :: Shape ()
(:.) :: Shape sh -> Int -> Shape (sh, Int)
data ShapeType sh where
STZ :: ShapeType ()
STC :: ShapeType sh -> ShapeType (sh, Int)
data Array sh a where
Array :: Shape sh -> Type a -> Vector a -> Array sh a
deriving instance Show (Exp env a)
deriving instance Show (Constant a)
deriving instance Show (Type a)
deriving instance Show (Idx env a)
deriving instance Show (Literal a)
deriving instance Show (ShapeType a)
instance Show (Shape sh) where
showsPrec _ Z = showString "Z"
showsPrec p (sh :. n) = showParen (p > 0) $
showsPrec 10 sh . showString " :. " . shows n
instance Show (Array sh a) where
showsPrec p (Array sh t v) =
showParen (p > 10) $
showString "Array "
. showsPrec 11 sh . showString " "
. showsPrec 11 t . showString " "
. (case typeHasShow t of
Just Has -> showsPrec 11 v
Nothing -> showString ("[_ * " ++ show (V.length v) ++ "]"))
deriving instance Eq (Type a)
deriving instance Eq (Shape sh)
deriving instance Eq (ShapeType sh)
deriving instance Eq a => Eq (Array sh a)
instance GEq (Exp env) where
geq (App a b) (App a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl
geq App{} _ = Nothing
geq (Lam t e) (Lam t' e') | Just Refl <- geq t t', Just Refl <- geq e e' = Just Refl
geq Lam{} _ = Nothing
geq (Var t i) (Var t' i') | Just Refl <- geq t t', Just Refl <- geq i i' = Just Refl
geq Var{} _ = Nothing
geq (Let a b) (Let a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl
geq Let{} _ = Nothing
geq (Lit l) (Lit l') | Just Refl <- geq l l' = Just Refl
geq Lit{} _ = Nothing
geq (Cond a b c) (Cond a' b' c') | Just Refl <- geq a a', Just Refl <- geq b b', Just Refl <- geq c c' = Just Refl
geq Cond{} _ = Nothing
geq (Const c) (Const c') | Just Refl <- geq c c' = Just Refl
geq Const{} _ = Nothing
geq (Pair a b) (Pair a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl
geq Pair{} _ = Nothing
geq (Fst a) (Fst a') | Just Refl <- geq a a' = Just Refl
geq Fst{} _ = Nothing
geq (Snd a) (Snd a') | Just Refl <- geq a a' = Just Refl
geq Snd{} _ = Nothing
geq (Build t a b) (Build t' a' b') | Just Refl <- geq t t', Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl
geq Build{} _ = Nothing
geq (Ifold t a b c) (Ifold t' a' b' c') | Just Refl <- geq t t', Just Refl <- geq a a', Just Refl <- geq b b' , Just Refl <- geq c c'
= Just Refl
geq Ifold{} _ = Nothing
geq (Index a b) (Index a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl
geq Index{} _ = Nothing
geq (Shape a) (Shape a') | Just Refl <- geq a a' = Just Refl
geq Shape{} _ = Nothing
geq (Undef a) (Undef a') | Just Refl <- geq a a' = Just Refl
geq Undef{} _ = Nothing
instance GEq Constant where
geq CAddI CAddI = Just Refl ; geq CAddI _ = Nothing
geq CSubI CSubI = Just Refl ; geq CSubI _ = Nothing
geq CMulI CMulI = Just Refl ; geq CMulI _ = Nothing
geq CDivI CDivI = Just Refl ; geq CDivI _ = Nothing
geq CAddF CAddF = Just Refl ; geq CAddF _ = Nothing
geq CSubF CSubF = Just Refl ; geq CSubF _ = Nothing
geq CMulF CMulF = Just Refl ; geq CMulF _ = Nothing
geq CDivF CDivF = Just Refl ; geq CDivF _ = Nothing
geq CLog CLog = Just Refl ; geq CLog _ = Nothing
geq CExp CExp = Just Refl ; geq CExp _ = Nothing
geq CtoF CtoF = Just Refl ; geq CtoF _ = Nothing
geq CRound CRound = Just Refl ; geq CRound _ = Nothing
geq CLtI CLtI = Just Refl ; geq CLtI _ = Nothing
geq CLeI CLeI = Just Refl ; geq CLeI _ = Nothing
geq CLtF CLtF = Just Refl ; geq CLtF _ = Nothing
geq (CEq t) (CEq t') | Just Refl <- geq t t' = Just Refl ; geq CEq{} _ = Nothing
geq CAnd CAnd = Just Refl ; geq CAnd _ = Nothing
geq COr COr = Just Refl ; geq COr _ = Nothing
geq CNot CNot = Just Refl ; geq CNot _ = Nothing
instance GEq Type where
geq TInt TInt = Just Refl ; geq TInt _ = Nothing
geq TBool TBool = Just Refl ; geq TBool _ = Nothing
geq TDouble TDouble = Just Refl ; geq TDouble _ = Nothing
geq (TArray sht t) (TArray sht' t') | Just Refl <- geq sht sht', Just Refl <- geq t t' = Just Refl ; geq TArray{} _ = Nothing
geq TNil TNil = Just Refl ; geq TNil _ = Nothing
geq (TPair a b) (TPair a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl ; geq TPair{} _ = Nothing
geq (TFun a b) (TFun a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl ; geq TFun{} _ = Nothing
instance GEq (Idx env) where
geq Zero Zero = Just Refl
geq (Succ i) (Succ i') | Just Refl <- geq i i' = Just Refl
geq _ _ = Nothing
instance GEq Literal where
geq (LInt a) (LInt a') | a == a' = Just Refl ; geq LInt{} _ = Nothing
geq (LBool a) (LBool a') | a == a' = Just Refl ; geq LBool{} _ = Nothing
geq (LDouble a) (LDouble a') | a == a' = Just Refl ; geq LDouble{} _ = Nothing
geq (LArray (Array sht t v)) (LArray (Array sht' t' v'))
| Just Refl <- geq sht sht'
, Just Refl <- geq t t'
= case typeHasEq t of
Just Has | v == v' -> Just Refl
| otherwise -> Nothing
Nothing -> error "GEq Literal: Literal array of incomparable values"
geq LArray{} _ = Nothing
geq (LShape a) (LShape a') | Just Refl <- geq a a' = Just Refl ; geq LShape{} _ = Nothing
geq LNil LNil = Just Refl ; geq LNil _ = Nothing
geq (LPair a b) (LPair a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl ; geq LPair{} _ = Nothing
instance GEq Shape where
geq Z Z = Just Refl
geq (sh :. n) (sh' :. n') | n == n', Just Refl <- geq sh sh' = Just Refl
geq _ _ = Nothing
instance GEq ShapeType where
geq STZ STZ = Just Refl
geq (STC sht) (STC sht') | Just Refl <- geq sht sht' = Just Refl
geq _ _ = Nothing
-- Requires that the given term is neither 'Lam' nor 'Let'.
recurseMon :: Monoid s => (forall t. Exp env t -> s) -> Exp env a -> s
recurseMon f = \case
App a b -> f a <> f b
Lam _ _ -> error "recurseMon: Given Lam"
Var _ _ -> mempty
Let _ _ -> error "recurseMon: Given Let"
Lit _ -> mempty
Cond a b c -> f a <> f b <> f c
Const _ -> mempty
Pair a b -> f a <> f b
Fst a -> f a
Snd a -> f a
Build _ a b -> f a <> f b
Ifold _ a b c -> f a <> f b <> f c
Index a b -> f a <> f b
Shape a -> f a
Undef _ -> mempty
shapeType :: Shape sh -> ShapeType sh
shapeType Z = STZ
shapeType (sh :. _) = STC (shapeType sh)
shapeType' :: Shape sh -> Type sh
shapeType' Z = TNil
shapeType' (sh :. _) = TPair (shapeType' sh) TInt
shapeTypeType :: ShapeType sh -> Type sh
shapeTypeType STZ = TNil
shapeTypeType (STC sht) = TPair (shapeTypeType sht) TInt
literalType :: Literal a -> Type a
literalType LInt{} = TInt
literalType LBool{} = TBool
literalType LDouble{} = TDouble
literalType (LArray (Array sh t _)) = TArray (shapeType sh) t
literalType (LShape sh) = shapeType' sh
literalType LNil{} = TNil
literalType (LPair a b) = TPair (literalType a) (literalType b)
constType :: Constant a -> Type a
constType CAddI = TFun (TPair TInt TInt) TInt
constType CSubI = TFun (TPair TInt TInt) TInt
constType CMulI = TFun (TPair TInt TInt) TInt
constType CDivI = TFun (TPair TInt TInt) TInt
constType CAddF = TFun (TPair TDouble TDouble) TDouble
constType CSubF = TFun (TPair TDouble TDouble) TDouble
constType CMulF = TFun (TPair TDouble TDouble) TDouble
constType CDivF = TFun (TPair TDouble TDouble) TDouble
constType CLog = TFun TDouble TDouble
constType CExp = TFun TDouble TDouble
constType CtoF = TFun TInt TDouble
constType CRound = TFun TDouble TInt
constType CLtI = TFun (TPair TInt TInt) TBool
constType CLeI = TFun (TPair TInt TInt) TBool
constType CLtF = TFun (TPair TDouble TDouble) TBool
constType (CEq t) = TFun (TPair t t) TBool
constType CAnd = TFun (TPair TBool TBool) TBool
constType COr = TFun (TPair TBool TBool) TBool
constType CNot = TFun TBool TBool
typeof :: Exp env a -> Type a
typeof (App e _) = let TFun _ t = typeof e in t
typeof (Lam t e) = TFun t (typeof e)
typeof (Var t _) = t
typeof (Let _ e) = typeof e
typeof (Lit l) = literalType l
typeof (Cond _ e _) = typeof e
typeof (Const c) = constType c
typeof (Pair e1 e2) = TPair (typeof e1) (typeof e2)
typeof (Fst e) = let TPair t _ = typeof e in t
typeof (Snd e) = let TPair _ t = typeof e in t
typeof (Build sht _ e) = let TFun _ t = typeof e in TArray sht t
typeof (Ifold _ _ e _) = typeof e
typeof (Index e _) = let TArray _ t = typeof e in t
typeof (Shape e) = let TArray sht _ = typeof e in shapeTypeType sht
typeof (Undef t) = t
idxToInt :: Idx env a -> Int
idxToInt Zero = 0
idxToInt (Succ i) = idxToInt i + 1
shtToInt :: ShapeType sh -> Int
shtToInt STZ = 0
shtToInt (STC sht) = shtToInt sht + 1
data Has c a where
Has :: c a => Has c a
typeHasShow :: Type a -> Maybe (Has Show a)
typeHasShow TInt = Just Has
typeHasShow TBool = Just Has
typeHasShow TDouble = Just Has
typeHasShow TArray{} = Just Has
typeHasShow TNil = Just Has
typeHasShow (TPair a b)
| Just Has <- typeHasShow a
, Just Has <- typeHasShow b
= Just Has
| otherwise
= Nothing
typeHasShow TFun{} = Nothing
typeHasEq :: Type a -> Maybe (Has Eq a)
typeHasEq TInt = Just Has
typeHasEq TBool = Just Has
typeHasEq TDouble = Just Has
typeHasEq (TArray _ t)
| Just Has <- typeHasEq t
= Just Has
| otherwise
= Nothing
typeHasEq TNil = Just Has
typeHasEq (TPair a b)
| Just Has <- typeHasEq a
, Just Has <- typeHasEq b
= Just Has
| otherwise
= Nothing
typeHasEq TFun{} = Nothing
|