1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
|
module Check(checkProgram) where
import Control.Monad
import Data.Maybe
import qualified Data.Map.Strict as Map
--import Debug.Trace
import AST
import PShow
type Error a = Either String a
checkProgram :: Program -> Error Program
checkProgram prog = do
let processed = replaceTypes prog
checkUndefinedTypes processed
typeCheck processed >>= bundleVarDecls
replaceTypes :: Program -> Program
replaceTypes prog@(Program decls) = mapProgram' filtered mapper
where
filtered = Program $ filter notTypedef decls
mapper = defaultPM' {typeHandler' = typeReplacer (findTypeRenames prog)}
notTypedef :: Declaration -> Bool
notTypedef (DecTypedef _ _) = False
notTypedef _ = True
typeReplacer :: Map.Map Name Type -> Type -> Type
typeReplacer m t@(TypeName n) = maybe t id $ Map.lookup n m
typeReplacer _ t = t
findTypeRenames :: Program -> Map.Map Name Type
findTypeRenames (Program d) = foldl go Map.empty d
where
go :: Map.Map Name Type -> Declaration -> Map.Map Name Type
go m (DecTypedef t n) = Map.insert n t m
go m _ = m
checkUndefinedTypes :: Program -> Error ()
checkUndefinedTypes prog = fmap (const ()) $ mapProgram prog $ defaultPM {typeHandler = check}
where
check :: MapperHandler Type
check (TypeName n) = Left $ "Undefined type name '" ++ n ++ "'"
check t = Right t
typeCheck :: Program -> Error Program
typeCheck (Program decls) = Program <$> mapM (goD topLevelNames) decls
where
topLevelNames :: Map.Map Name Type
topLevelNames = foldr (uncurry Map.insert) Map.empty pairs
where pairs = map ((,) <$> nameOf <*> typeOf) $ filter isVarDecl decls
functionTypes :: Map.Map Name (Type,[Type])
functionTypes = foldr (uncurry Map.insert) Map.empty pairs
where pairs = map ((,) <$> nameOf <*> getTypes) $ filter isFunctionDecl decls
getTypes (DecFunction rt _ args _) = (rt, map fst args)
getTypes _ = undefined
isVarDecl (DecVariable {}) = True
isVarDecl _ = False
isFunctionDecl (DecFunction {}) = True
isFunctionDecl _ = False
goD :: Map.Map Name Type -> Declaration -> Error Declaration
goD names (DecFunction frt name args body) = do
newbody <- goB frt (foldr (\(t,n) m -> Map.insert n t m) names args) body
return $ DecFunction frt name args newbody
goD _ dec = return dec
goB :: Type -- function return type
-> Map.Map Name Type -> Block -> Error Block
goB frt names (Block stmts) = Block . snd <$> foldl foldfunc (return (names, [])) stmts
where
foldfunc :: Error (Map.Map Name Type, [Statement]) -> Statement -> Error (Map.Map Name Type, [Statement])
foldfunc ep st = do
(names', lst) <- ep
(newnames', newst) <- goS frt names' st
return (newnames', lst ++ [newst]) -- TODO: fix slow tail-append
goS :: Type -- function return type
-> Map.Map Name Type -> Statement -> Error (Map.Map Name Type, Statement)
goS _ names st@(StVarDeclaration t n Nothing) = return (Map.insert n t names, st)
goS frt names (StVarDeclaration t n (Just e)) = do
(newnames, _) <- goS frt names (StVarDeclaration t n Nothing)
(_, StAssignment _ newe) <- goS frt newnames (StAssignment n e)
return (newnames, StVarDeclaration t n (Just newe))
goS _ names (StAssignment n e) = maybe (Left $ "Undefined variable '" ++ n ++ "'") go (Map.lookup n names)
where go dsttype = do
re <- goE names e
let (Just extype) = exTypeOf re
if canConvert extype dsttype
then return (names, StAssignment n re)
else Left $ "Cannot convert type '" ++ pshow extype ++ "' to '"
++ pshow dsttype ++ "' in assignment to variable '" ++ n ++ "'"
goS _ names st@StEmpty = return (names, st)
goS frt names (StBlock bl) = do
newbl <- goB frt names bl
return (names, StBlock newbl)
goS _ names (StExpr e) = do
re <- goE names e
return (names, StExpr re)
goS frt names (StIf e s1 s2) = do
re <- goE names e
(_, rs1) <- goS frt names s1
(_, rs2) <- goS frt names s2
return (names, StIf re rs1 rs2)
goS frt names (StWhile e s) = do
re <- goE names e
(_, rs) <- goS frt names s
return (names, StWhile re rs)
goS frt names (StReturn e) = do
re <- goE names e
let (Just extype) = exTypeOf re
if canConvert extype frt
then return (names, StReturn re)
else Left $ "Cannot convert type '" ++ pshow extype ++ "' to '"
++ pshow frt ++ "' in return statement"
-- Postcondition: the expression (if any) has a type annotation.
goE :: Map.Map Name Type -> Expression -> Error Expression
goE _ (ExLit l@(LitInt i) _) = return $ ExLit l $ Just (smallestIntType i)
goE _ (ExLit l@(LitString _) _) = return $ ExLit l $ Just (TypePtr (TypeInt 8))
goE names (ExLit l@(LitVar n) _) = maybe (Left $ "Undefined variable '" ++ n ++ "'") (return . ExLit l . Just)
(Map.lookup n names)
goE names (ExLit l@(LitCall n args) _) = do
ft <- maybe (Left $ "Unknown function '" ++ n ++ "'") return $ Map.lookup n functionTypes
rargs <- mapM (goE names) args
when (length rargs /= length (snd ft))
$ Left ("Expected " ++ show (length (snd ft)) ++ "arguments to "
++ "function '" ++ n ++ "', but got " ++ show (length rargs))
>> return ()
flip mapM_ rargs $
\a -> let argtype = fromJust (exTypeOf a)
in if canConvert argtype (fst ft)
then return a
else Left $ "Cannot convert type '" ++ pshow argtype ++ "' to '" ++ pshow (fst ft)
++ "' in call of function '" ++ pshow n ++ "'"
return $ ExLit l (Just (fst ft))
goE names (ExBinOp bo e1 e2 _) = do
re1 <- goE names e1
re2 <- goE names e2
maybe (Left $ "Cannot use operator '" ++ pshow bo ++ "' with argument types '"
++ pshow (fromJust $ exTypeOf re1) ++ "' and '" ++ pshow (fromJust $ exTypeOf re2) ++ "'")
(return . ExBinOp bo re1 re2 . Just)
$ typeCompatibleBO bo (fromJust $ exTypeOf re1) (fromJust $ exTypeOf re2)
goE names (ExUnOp uo e _) = do
re <- goE names e
maybe (Left $ "Cannot use operator '" ++ pshow uo ++ "' with argument type '" ++ pshow (fromJust $ exTypeOf re))
(return . ExUnOp uo re . Just)
$ typeCompatibleUO uo (fromJust $ exTypeOf re)
bundleVarDecls :: Program -> Error Program
bundleVarDecls prog = mapProgram prog $ defaultPM {blockHandler = goBlock}
where
goBlock :: MapperHandler Block
goBlock (Block stmts) =
let isVarDecl (StVarDeclaration {}) = True
isVarDecl _ = False
removeDecls [] = []
removeDecls ((StVarDeclaration _ n (Just ex)):rest) = StAssignment n ex : removeDecls rest
removeDecls ((StVarDeclaration _ _ Nothing):rest) = removeDecls rest
removeDecls (st:rest) = st : removeDecls rest
onlyDecl (StVarDeclaration t n _) = StVarDeclaration t n Nothing
onlyDecl _ = undefined
vdecls = map onlyDecl $ filter isVarDecl stmts
in return $ Block $ vdecls ++ removeDecls stmts
canConvert :: Type -> Type -> Bool
canConvert x y | x == y = True
canConvert (TypeInt f) (TypeInt t) = f <= t
canConvert (TypeUInt f) (TypeUInt t) = f <= t
canConvert TypeFloat TypeDouble = True
canConvert _ _ = False
arithBO, compareBO, logicBO, complogBO :: [BinaryOperator]
arithBO = [Plus, Minus, Times, Divide, Modulo]
compareBO = [Equal, Unequal, Greater, Less, GEqual, LEqual]
logicBO = [BoolAnd, BoolOr]
complogBO = compareBO ++ logicBO
typeCompatibleBO :: BinaryOperator -> Type -> Type -> Maybe Type
typeCompatibleBO Minus (TypePtr t1) (TypePtr t2) | t1 == t2 = Just $ TypeInt 1
typeCompatibleBO bo (TypePtr t1) (TypePtr t2) | t1 == t2 && bo `elem` complogBO = Just $ TypeInt 1
typeCompatibleBO _ (TypePtr _) _ = Nothing
typeCompatibleBO _ _ (TypePtr _) = Nothing
typeCompatibleBO bo (TypeInt s1) (TypeInt s2) | bo `elem` arithBO = Just $ TypeInt (max s1 s2)
typeCompatibleBO bo (TypeInt _) (TypeInt _) | bo `elem` complogBO = Just $ TypeInt 1
typeCompatibleBO bo (TypeUInt s1) (TypeUInt s2) | bo `elem` arithBO = Just $ TypeUInt (max s1 s2)
typeCompatibleBO bo (TypeUInt _) (TypeUInt _) | bo `elem` complogBO = Just $ TypeInt 1
typeCompatibleBO bo t1 t2 | bo `elem` complogBO && t1 == t2 = Just $ TypeInt 1
typeCompatibleBO bo TypeFloat (TypeInt s) | s <= 24 = Just $ if bo `elem` arithBO then TypeFloat else TypeInt 1
typeCompatibleBO bo (TypeInt s) TypeFloat | s <= 24 = Just $ if bo `elem` arithBO then TypeFloat else TypeInt 1
typeCompatibleBO bo TypeDouble (TypeInt s) | s <= 53 = Just $ if bo `elem` arithBO then TypeDouble else TypeInt 1
typeCompatibleBO bo (TypeInt s) TypeDouble | s <= 53 = Just $ if bo `elem` arithBO then TypeDouble else TypeInt 1
typeCompatibleBO bo TypeFloat TypeDouble = Just $ if bo `elem` arithBO then TypeFloat else TypeInt 1
typeCompatibleBO bo TypeDouble TypeFloat = Just $ if bo `elem` arithBO then TypeFloat else TypeInt 1
typeCompatibleBO _ _ _ = Nothing
typeCompatibleUO :: UnaryOperator -> Type -> Maybe Type
typeCompatibleUO Not _ = Just $ TypeInt 1
typeCompatibleUO Address t = Just $ TypePtr t
typeCompatibleUO uo t@(TypeInt _) | uo `elem` [Negate, Invert] = Just t
typeCompatibleUO uo t@(TypeUInt _) | uo `elem` [Negate, Invert] = Just t
typeCompatibleUO Negate TypeFloat = Just TypeFloat
typeCompatibleUO Negate TypeDouble = Just TypeDouble
typeCompatibleUO Dereference t@(TypePtr _) = Just t
typeCompatibleUO _ _ = Nothing
smallestIntType :: Integer -> Type
smallestIntType i
| i >= -2^7 && i < 2^7 = TypeInt 8
| i >= -2^15 && i < 2^15 = TypeInt 16
| i >= -2^31 && i < 2^31 = TypeInt 32
| otherwise = TypeInt 64
-- smallestUIntType :: Integer -> Type
-- smallestUIntType i
-- | i >= 0 && i < 2^8 = TypeUInt 8
-- | i >= 0 && i < 2^16 = TypeUInt 16
-- | i >= 0 && i < 2^32 = TypeUInt 32
-- | otherwise = TypeUInt 64
type MapperHandler a = a -> Error a
data ProgramMapper = ProgramMapper
{declarationHandler :: MapperHandler Declaration
,blockHandler :: MapperHandler Block
,typeHandler :: MapperHandler Type
,literalHandler :: MapperHandler Literal
,binOpHandler :: MapperHandler BinaryOperator
,unOpHandler :: MapperHandler UnaryOperator
,expressionHandler :: MapperHandler Expression
,statementHandler :: MapperHandler Statement
,nameHandler :: MapperHandler Name}
type MapperHandler' a = a -> a
data ProgramMapper' = ProgramMapper'
{declarationHandler' :: MapperHandler' Declaration
,blockHandler' :: MapperHandler' Block
,typeHandler' :: MapperHandler' Type
,literalHandler' :: MapperHandler' Literal
,binOpHandler' :: MapperHandler' BinaryOperator
,unOpHandler' :: MapperHandler' UnaryOperator
,expressionHandler' :: MapperHandler' Expression
,statementHandler' :: MapperHandler' Statement
,nameHandler' :: MapperHandler' Name}
defaultPM :: ProgramMapper
defaultPM = ProgramMapper return return return return return return return return return
defaultPM' :: ProgramMapper'
defaultPM' = ProgramMapper' id id id id id id id id id
mapProgram' :: Program -> ProgramMapper' -> Program
mapProgram' prog mapper = (\(Right r) -> r) $ mapProgram prog $ ProgramMapper
{declarationHandler = return . declarationHandler' mapper
,blockHandler = return . blockHandler' mapper
,typeHandler = return . typeHandler' mapper
,literalHandler = return . literalHandler' mapper
,binOpHandler = return . binOpHandler' mapper
,unOpHandler = return . unOpHandler' mapper
,expressionHandler = return . expressionHandler' mapper
,statementHandler = return . statementHandler' mapper
,nameHandler = return . nameHandler' mapper}
mapProgram :: Program -> ProgramMapper -> Error Program
mapProgram prog mapper = goP prog
where
h_d = declarationHandler mapper
h_b = blockHandler mapper
h_t = typeHandler mapper
h_l = literalHandler mapper
h_bo = binOpHandler mapper
h_uo = unOpHandler mapper
h_e = expressionHandler mapper
h_s = statementHandler mapper
h_n = nameHandler mapper
goP :: MapperHandler Program
goP (Program decls) = Program <$> sequence (map (\d -> goD d >>= h_d) decls)
goD :: MapperHandler Declaration
goD (DecFunction t n a b) = do
rt <- goT t
rn <- goN n
ra <- sequence $ map (\(at,an) -> (,) <$> goT at <*> goN an) a
rb <- goB b
h_d $ DecFunction rt rn ra rb
goD (DecVariable t n mv) = do
rt <- goT t
rn <- goN n
rmv <- sequence $ fmap goE mv
h_d $ DecVariable rt rn rmv
goD (DecTypedef t n) = do
rt <- goT t
rn <- goN n
h_d $ DecTypedef rt rn
goT :: MapperHandler Type
goT (TypePtr t) = goT t >>= (h_t . TypePtr)
goT (TypeName n) = goN n >>= (h_t . TypeName)
goT t = h_t t
goN :: MapperHandler Name
goN = h_n
goB :: MapperHandler Block
goB (Block sts) = (Block <$> sequence (map goS sts)) >>= h_b
goE :: MapperHandler Expression
goE (ExLit l mt) = do
rl <- goL l
h_e $ ExLit rl mt
goE (ExBinOp bo e1 e2 mt) = do
rbo <- goBO bo
re1 <- goE e1
re2 <- goE e2
h_e $ ExBinOp rbo re1 re2 mt
goE (ExUnOp uo e mt) = do
ruo <- goUO uo
re <- goE e
h_e $ ExUnOp ruo re mt
goS :: MapperHandler Statement
goS StEmpty = h_s StEmpty
goS (StBlock b) = goB b >>= (h_s . StBlock)
goS (StExpr e) = goE e >>= (h_s . StExpr)
goS (StVarDeclaration t n me) = do
rt <- goT t
rn <- goN n
rme <- sequence $ fmap goE me
h_s $ StVarDeclaration rt rn rme
goS (StAssignment n e) = do
rn <- goN n
re <- goE e
h_s $ StAssignment rn re
goS (StIf e s1 s2) = do
re <- goE e
rs1 <- goS s1
rs2 <- goS s2
h_s $ StIf re rs1 rs2
goS (StWhile e s) = do
re <- goE e
rs <- goS s
h_s $ StWhile re rs
goS (StReturn e) = goE e >>= (h_s . StReturn)
goL :: MapperHandler Literal
goL l@(LitString _) = h_l l
goL l@(LitInt _) = h_l l
goL (LitVar n) = goN n >>= (h_l . LitVar)
goL (LitCall n a) = do
rn <- goN n
ra <- sequence $ map goE a
h_l $ LitCall rn ra
goBO :: MapperHandler BinaryOperator
goBO = h_bo
goUO :: MapperHandler UnaryOperator
goUO = h_uo
|