1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
|
module Check(checkProgram) where
import Control.Monad
import Data.Maybe
import qualified Data.Map.Strict as Map
--import Debug.Trace
import AST
import PShow
type Error a = Either String a
checkProgram :: Program -> Error Program
checkProgram prog = do
let processed = replaceTypes prog
checkUndefinedTypes processed
typeCheck processed >>= bundleVarDecls
replaceTypes :: Program -> Program
replaceTypes prog@(Program decls) = mapProgram' filtered mapper
where
filtered = Program $ filter notTypedef decls
mapper = defaultPM' {typeHandler' = typeReplacer (findTypeRenames prog)}
notTypedef :: Declaration -> Bool
notTypedef (DecTypedef _ _) = False
notTypedef _ = True
typeReplacer :: Map.Map Name Type -> Type -> Type
typeReplacer m t@(TypeName n) = maybe t (recurseAfter m) $ Map.lookup n m
typeReplacer _ t = t
recurseAfter :: Map.Map Name Type -> Type -> Type
recurseAfter m t@(TypeName _) = typeReplacer m t
recurseAfter m (TypePtr t) = TypePtr $ recurseAfter m t
recurseAfter _ t = t
findTypeRenames :: Program -> Map.Map Name Type
findTypeRenames (Program d) = foldl go Map.empty d
where
go :: Map.Map Name Type -> Declaration -> Map.Map Name Type
go m (DecTypedef t n) = Map.insert n t m
go m _ = m
checkUndefinedTypes :: Program -> Error ()
checkUndefinedTypes prog = fmap (const ()) $ mapProgram prog $ defaultPM {typeHandler = check}
where
check :: MapperHandler Type
check (TypeName n) = Left $ "Undefined type name '" ++ n ++ "'"
check t = Right t
typeCheck :: Program -> Error Program
typeCheck (Program decls) = Program <$> mapM (goD topLevelNames) decls
where
topLevelNames :: Map.Map Name Type
topLevelNames = foldr (uncurry Map.insert) Map.empty pairs
where pairs = map ((,) <$> nameOf <*> typeOf) $ filter isVarDecl decls
functionTypes :: Map.Map Name (Type,[Type])
functionTypes = foldr (uncurry Map.insert) Map.empty pairs
where pairs = map ((,) <$> nameOf <*> getTypes) $ filter isFunctionDecl decls
getTypes (DecFunction rt _ args _) = (rt, map fst args)
getTypes (DecExtern (TypeFunc rt ats) _) = (rt, ats)
getTypes _ = undefined
isVarDecl (DecVariable {}) = True
isVarDecl _ = False
isFunctionDecl (DecFunction {}) = True
isFunctionDecl (DecExtern (TypeFunc {}) _) = True
isFunctionDecl _ = False
goD :: Map.Map Name Type -> Declaration -> Error Declaration
goD names (DecFunction frt name args body) = do
newbody <- goB frt (foldr (\(t,n) m -> Map.insert n t m) names args) body
return $ DecFunction frt name args newbody
goD _ (DecVariable (TypeFunc _ _) _ _) = Left $ "Cannot declare global variable with function type"
goD _ dec = return dec
goB :: Type -- function return type
-> Map.Map Name Type -> Block -> Error Block
goB frt names (Block stmts) = Block . snd <$> foldl foldfunc (return (names, [])) stmts
where
foldfunc :: Error (Map.Map Name Type, [Statement]) -> Statement -> Error (Map.Map Name Type, [Statement])
foldfunc ep st = do
(names', lst) <- ep
(newnames', newst) <- goS frt names' st
return (newnames', lst ++ [newst]) -- TODO: fix slow tail-append
goS :: Type -- function return type
-> Map.Map Name Type -> Statement -> Error (Map.Map Name Type, Statement)
goS _ _ (StVarDeclaration (TypeFunc _ _) _ _) = Left $ "Cannot declare variable with function type"
goS _ names st@(StVarDeclaration t n Nothing) = do
maybe (return (Map.insert n t names, st))
(const $ Left $ "Duplicate variable '" ++ n ++ "'")
(Map.lookup n names)
goS frt names (StVarDeclaration t n (Just e)) = do
(newnames, _) <- goS frt names (StVarDeclaration t n Nothing)
(_, StAssignment _ newe) <- goS frt newnames (StAssignment n e)
return (newnames, StVarDeclaration t n (Just newe))
goS _ names (StAssignment n e) = maybe (Left $ "Undefined variable '" ++ n ++ "'") go (Map.lookup n names)
where go dsttype = do
re <- goE names e
let (Just extype) = exTypeOf re
if canConvert extype dsttype
then return (names, StAssignment n re)
else Left $ "Cannot convert type '" ++ pshow extype ++ "' to '"
++ pshow dsttype ++ "' in assignment to variable '" ++ n ++ "'"
goS _ names st@StEmpty = return (names, st)
goS frt names (StBlock bl) = do
newbl <- goB frt names bl
return (names, StBlock newbl)
goS _ names (StExpr e) = do
re <- goE names e
return (names, StExpr re)
goS frt names (StIf e s1 s2) = do
re <- goE names e
(_, rs1) <- goS frt names s1
(_, rs2) <- goS frt names s2
return (names, StIf re rs1 rs2)
goS frt names (StWhile e s) = do
re <- goE names e
(_, rs) <- goS frt names s
return (names, StWhile re rs)
goS TypeVoid names (StReturn Nothing) = return (names, StReturn Nothing)
goS _ _ (StReturn Nothing) = Left $ "Non-void function should return a value"
goS frt names (StReturn (Just e)) = do
re <- goE names e
let (Just extype) = exTypeOf re
if canConvert extype frt
then return (names, StReturn (Just re))
else Left $ "Cannot convert type '" ++ pshow extype ++ "' to '"
++ pshow frt ++ "' in return statement"
-- Postcondition: the expression (if any) has a type annotation.
goE :: Map.Map Name Type -> Expression -> Error Expression
goE _ (ExLit l@(LitInt i) _) = smallestIntType i >>= \t -> return $ ExLit l $ Just t
goE _ (ExLit l@(LitUInt i) _) = smallestUIntType i >>= \t -> return $ ExLit l $ Just t
goE _ (ExLit l@(LitFloat f) _) = return $ ExLit l $ Just (smallestFloatType f)
goE _ (ExLit l@(LitString _) _) = return $ ExLit l $ Just (TypePtr (TypeInt 8))
goE names (ExLit l@(LitVar n) _) = maybe (Left $ "Undefined variable '" ++ n ++ "'") (return . ExLit l . Just)
(Map.lookup n names)
goE names (ExLit (LitCall n args) _) = do
ft <- maybe (Left $ "Unknown function '" ++ n ++ "'") return $ Map.lookup n functionTypes
rargs <- mapM (goE names) args
when (length rargs /= length (snd ft))
$ Left ("Expected " ++ show (length (snd ft)) ++ "arguments to "
++ "function '" ++ n ++ "', but got " ++ show (length rargs))
>> return ()
flip mapM_ (zip rargs [0..]) $
\(a,i) -> let argtype = fromJust (exTypeOf a)
in if canConvert argtype (snd ft !! i)
then return a
else Left $ "Cannot convert type '" ++ pshow argtype ++ "' to '" ++ pshow (snd ft !! i)
++ "' in call of function '" ++ n ++ "'"
return $ ExLit (LitCall n rargs) (Just (fst ft))
goE names (ExCast totype ex) = do
rex <- goE names ex
let fromtype = fromJust (exTypeOf rex)
if canCast fromtype totype
then return $ ExCast totype rex
else Left $ "Cannot cast type '" ++ pshow fromtype ++ "' to '" ++ pshow totype ++ "'"
goE names (ExBinOp bo e1 e2 _) = do
re1 <- goE names e1
re2 <- goE names e2
maybe (Left $ "Cannot use operator '" ++ pshow bo ++ "' with argument types '"
++ pshow (fromJust $ exTypeOf re1) ++ "' and '" ++ pshow (fromJust $ exTypeOf re2) ++ "'")
(return . ExBinOp bo re1 re2 . Just)
$ resultTypeBO bo (fromJust $ exTypeOf re1) (fromJust $ exTypeOf re2)
goE names (ExUnOp uo e _) = do
re <- goE names e
maybe (Left $ "Cannot use operator '" ++ pshow uo ++ "' with argument type '" ++ pshow (fromJust $ exTypeOf re))
(return . ExUnOp uo re . Just)
$ resultTypeUO uo (fromJust $ exTypeOf re)
bundleVarDecls :: Program -> Error Program
bundleVarDecls prog = mapProgram prog $ defaultPM {blockHandler = goBlock}
where
goBlock :: MapperHandler Block
goBlock (Block stmts) =
let isVarDecl (StVarDeclaration {}) = True
isVarDecl _ = False
removeDecls [] = []
removeDecls ((StVarDeclaration _ n (Just ex)):rest) = StAssignment n ex : removeDecls rest
removeDecls ((StVarDeclaration _ _ Nothing):rest) = removeDecls rest
removeDecls (st:rest) = st : removeDecls rest
onlyDecl (StVarDeclaration t n _) = StVarDeclaration t n Nothing
onlyDecl _ = undefined
vdecls = map onlyDecl $ filter isVarDecl stmts
in return $ Block $ vdecls ++ removeDecls stmts
canConvert :: Type -> Type -> Bool
canConvert x y | x == y = True
canConvert (TypeInt f) (TypeInt t) = f <= t
canConvert (TypeUInt f) (TypeUInt t) = f <= t
canConvert (TypeUInt 1) (TypeInt _) = True
canConvert TypeFloat TypeDouble = True
canConvert (TypeInt _) TypeFloat = True
canConvert (TypeInt _) TypeDouble = True
canConvert (TypeUInt _) TypeFloat = True
canConvert (TypeUInt _) TypeDouble = True
canConvert _ _ = False
canCast :: Type -> Type -> Bool
canCast t1 t2 = any (\f -> f t1 && f t2) [numberGroup, intptrGroup]
where
numberGroup (TypeInt _) = True
numberGroup (TypeUInt _) = True
numberGroup TypeFloat = True
numberGroup TypeDouble = True
numberGroup _ = False
intptrGroup (TypeInt _) = True
intptrGroup (TypeUInt _) = True
intptrGroup (TypePtr _) = True
intptrGroup (TypeFunc _ _) = True
intptrGroup _ = False
arithBO, compareBO, logicBO, complogBO :: [BinaryOperator]
arithBO = [Plus, Minus, Times, Divide, Modulo]
compareBO = [Equal, Unequal, Greater, Less, GEqual, LEqual]
logicBO = [BoolAnd, BoolOr]
complogBO = compareBO ++ logicBO
resultTypeBO :: BinaryOperator -> Type -> Type -> Maybe Type
resultTypeBO Minus (TypePtr t1) (TypePtr t2) | t1 == t2 = Just $ TypeUInt 64
resultTypeBO bo (TypePtr t1) (TypePtr t2) | t1 == t2 && bo `elem` complogBO = Just $ TypeUInt 1
resultTypeBO bo t@(TypePtr _) (TypeInt _) | bo `elem` [Plus, Minus] = Just t
resultTypeBO bo t@(TypePtr _) (TypeUInt _) | bo `elem` [Plus, Minus] = Just t
resultTypeBO bo (TypeInt _) t@(TypePtr _) | bo `elem` [Plus, Minus] = Just t
resultTypeBO bo (TypeUInt _) t@(TypePtr _) | bo `elem` [Plus, Minus] = Just t
resultTypeBO Index (TypePtr t) (TypeInt _) = Just t
resultTypeBO Index (TypePtr t) (TypeUInt _) = Just t
resultTypeBO _ (TypePtr _) _ = Nothing
resultTypeBO _ _ (TypePtr _) = Nothing
resultTypeBO bo (TypeInt s1) (TypeInt s2) | bo `elem` arithBO = Just $ TypeInt (max s1 s2)
resultTypeBO bo (TypeInt _) (TypeInt _) | bo `elem` complogBO = Just $ TypeUInt 1
resultTypeBO bo (TypeUInt s1) (TypeUInt s2) | bo `elem` arithBO = Just $ TypeUInt (max s1 s2)
resultTypeBO bo (TypeUInt _) (TypeUInt _) | bo `elem` complogBO = Just $ TypeUInt 1
resultTypeBO bo t1 t2 | bo `elem` complogBO && t1 == t2 = Just $ TypeUInt 1
resultTypeBO bo TypeFloat (TypeInt s) | s <= 24 = Just $ if bo `elem` arithBO then TypeFloat else TypeUInt 1
resultTypeBO bo (TypeInt s) TypeFloat | s <= 24 = Just $ if bo `elem` arithBO then TypeFloat else TypeUInt 1
resultTypeBO bo TypeDouble (TypeInt s) | s <= 53 = Just $ if bo `elem` arithBO then TypeDouble else TypeUInt 1
resultTypeBO bo (TypeInt s) TypeDouble | s <= 53 = Just $ if bo `elem` arithBO then TypeDouble else TypeUInt 1
resultTypeBO bo TypeFloat TypeFloat = Just $ if bo `elem` arithBO then TypeFloat else TypeUInt 1
resultTypeBO bo TypeDouble TypeDouble = Just $ if bo `elem` arithBO then TypeDouble else TypeUInt 1
resultTypeBO bo TypeFloat TypeDouble = Just $ if bo `elem` arithBO then TypeDouble else TypeUInt 1
resultTypeBO bo TypeDouble TypeFloat = Just $ if bo `elem` arithBO then TypeDouble else TypeUInt 1
resultTypeBO _ _ _ = Nothing
resultTypeUO :: UnaryOperator -> Type -> Maybe Type
resultTypeUO Not _ = Just $ TypeUInt 1
resultTypeUO Address t = Just $ TypePtr t
resultTypeUO uo t@(TypeInt _) | uo `elem` [Negate, Invert] = Just t
resultTypeUO uo t@(TypeUInt _) | uo `elem` [Negate, Invert] = Just t
resultTypeUO Negate TypeFloat = Just TypeFloat
resultTypeUO Negate TypeDouble = Just TypeDouble
resultTypeUO Dereference (TypePtr t) = Just t
resultTypeUO _ _ = Nothing
smallestFloatType :: Double -> Type
smallestFloatType d =
let truncfloat = realToFrac (realToFrac d :: Float) :: Double
in if d == truncfloat then TypeFloat else TypeDouble
smallestIntType :: Integer -> Error Type
smallestIntType i
| i >= -2^7 && i < 2^7 = return $ TypeInt 8
| i >= -2^15 && i < 2^15 = return $ TypeInt 16
| i >= -2^31 && i < 2^31 = return $ TypeInt 32
| i >= -2^63 && i < 2^63 = return $ TypeInt 64
| otherwise = Left $ "Integer literal '" ++ pshow i ++ "' too wide for i64"
smallestUIntType :: Integer -> Error Type
smallestUIntType i
| i > -2^8 && i < 2^8 = return $ TypeUInt 8
| i > -2^16 && i < 2^16 = return $ TypeUInt 16
| i > -2^32 && i < 2^32 = return $ TypeUInt 32
| i > -2^64 && i < 2^64 = return $ TypeUInt 64
| otherwise = Left $ "Unsigned integer literal '" ++ pshow i ++ "U' too wide for u64"
type MapperHandler a = a -> Error a
data ProgramMapper = ProgramMapper
{declarationHandler :: MapperHandler Declaration
,blockHandler :: MapperHandler Block
,typeHandler :: MapperHandler Type
,literalHandler :: MapperHandler Literal
,binOpHandler :: MapperHandler BinaryOperator
,unOpHandler :: MapperHandler UnaryOperator
,expressionHandler :: MapperHandler Expression
,statementHandler :: MapperHandler Statement
,nameHandler :: MapperHandler Name}
type MapperHandler' a = a -> a
data ProgramMapper' = ProgramMapper'
{declarationHandler' :: MapperHandler' Declaration
,blockHandler' :: MapperHandler' Block
,typeHandler' :: MapperHandler' Type
,literalHandler' :: MapperHandler' Literal
,binOpHandler' :: MapperHandler' BinaryOperator
,unOpHandler' :: MapperHandler' UnaryOperator
,expressionHandler' :: MapperHandler' Expression
,statementHandler' :: MapperHandler' Statement
,nameHandler' :: MapperHandler' Name}
defaultPM :: ProgramMapper
defaultPM = ProgramMapper return return return return return return return return return
defaultPM' :: ProgramMapper'
defaultPM' = ProgramMapper' id id id id id id id id id
mapProgram' :: Program -> ProgramMapper' -> Program
mapProgram' prog mapper = (\(Right r) -> r) $ mapProgram prog $ ProgramMapper
{declarationHandler = return . declarationHandler' mapper
,blockHandler = return . blockHandler' mapper
,typeHandler = return . typeHandler' mapper
,literalHandler = return . literalHandler' mapper
,binOpHandler = return . binOpHandler' mapper
,unOpHandler = return . unOpHandler' mapper
,expressionHandler = return . expressionHandler' mapper
,statementHandler = return . statementHandler' mapper
,nameHandler = return . nameHandler' mapper}
mapProgram :: Program -> ProgramMapper -> Error Program
mapProgram prog mapper = goP prog
where
h_d = declarationHandler mapper
h_b = blockHandler mapper
h_t = typeHandler mapper
h_l = literalHandler mapper
h_bo = binOpHandler mapper
h_uo = unOpHandler mapper
h_e = expressionHandler mapper
h_s = statementHandler mapper
h_n = nameHandler mapper
goP :: MapperHandler Program
goP (Program decls) = Program <$> sequence (map (\d -> goD d >>= h_d) decls)
goD :: MapperHandler Declaration
goD (DecFunction t n a b) = do
rt <- goT t
rn <- goN n
ra <- sequence $ map (\(at,an) -> (,) <$> goT at <*> goN an) a
rb <- goB b
h_d $ DecFunction rt rn ra rb
goD (DecVariable t n mv) = do
rt <- goT t
rn <- goN n
rmv <- sequence $ fmap goE mv
h_d $ DecVariable rt rn rmv
goD (DecTypedef t n) = do
rt <- goT t
rn <- goN n
h_d $ DecTypedef rt rn
goD (DecExtern t n) = do
rt <- goT t
rn <- goN n
h_d $ DecExtern rt rn
goT :: MapperHandler Type
goT (TypePtr t) = goT t >>= (h_t . TypePtr)
goT (TypeName n) = goN n >>= (h_t . TypeName)
goT (TypeFunc t as) = do
rt <- goT t
ras <- mapM goT as
h_t $ TypeFunc rt ras
goT t = h_t t
goN :: MapperHandler Name
goN = h_n
goB :: MapperHandler Block
goB (Block sts) = (Block <$> sequence (map goS sts)) >>= h_b
goE :: MapperHandler Expression
goE (ExLit l mt) = do
rl <- goL l
h_e $ ExLit rl mt
goE (ExCast t e) = do
rt <- goT t
re <- goE e
h_e $ ExCast rt re
goE (ExBinOp bo e1 e2 mt) = do
rbo <- goBO bo
re1 <- goE e1
re2 <- goE e2
h_e $ ExBinOp rbo re1 re2 mt
goE (ExUnOp uo e mt) = do
ruo <- goUO uo
re <- goE e
h_e $ ExUnOp ruo re mt
goS :: MapperHandler Statement
goS StEmpty = h_s StEmpty
goS (StBlock b) = goB b >>= (h_s . StBlock)
goS (StExpr e) = goE e >>= (h_s . StExpr)
goS (StVarDeclaration t n me) = do
rt <- goT t
rn <- goN n
rme <- sequence $ fmap goE me
h_s $ StVarDeclaration rt rn rme
goS (StAssignment n e) = do
rn <- goN n
re <- goE e
h_s $ StAssignment rn re
goS (StIf e s1 s2) = do
re <- goE e
rs1 <- goS s1
rs2 <- goS s2
h_s $ StIf re rs1 rs2
goS (StWhile e s) = do
re <- goE e
rs <- goS s
h_s $ StWhile re rs
goS (StReturn Nothing) = h_s (StReturn Nothing)
goS (StReturn (Just e)) = goE e >>= (h_s . StReturn . Just)
goL :: MapperHandler Literal
goL l@(LitString _) = h_l l
goL l@(LitInt _) = h_l l
goL l@(LitUInt _) = h_l l
goL l@(LitFloat _) = h_l l
goL (LitVar n) = goN n >>= (h_l . LitVar)
goL (LitCall n a) = do
rn <- goN n
ra <- sequence $ map goE a
h_l $ LitCall rn ra
goBO :: MapperHandler BinaryOperator
goBO = h_bo
goUO :: MapperHandler UnaryOperator
goUO = h_uo
|