summaryrefslogtreecommitdiff
path: root/check.hs
blob: 18842f6713e9ff464c5364b48fa770dfbdb28eaf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
module Check(checkProgram) where

import Control.Monad
import Data.Maybe
import qualified Data.Map.Strict as Map
--import Debug.Trace

import AST
import PShow


type Error a = Either String a


checkProgram :: Program -> Error Program
checkProgram prog = do
    let processed = replaceTypes prog
    checkUndefinedTypes processed
    typeCheck processed >>= bundleVarDecls


replaceTypes :: Program -> Program
replaceTypes prog@(Program decls) = mapProgram' filtered mapper
    where
        filtered = Program $ filter notTypedef decls
        mapper = defaultPM' {typeHandler' = typeReplacer (findTypeRenames prog)}

        notTypedef :: Declaration -> Bool
        notTypedef (DecTypedef _ _) = False
        notTypedef _ = True

        typeReplacer :: Map.Map Name Type -> Type -> Type
        typeReplacer m t@(TypeName n) = maybe t (recurseAfter m) $ Map.lookup n m
        typeReplacer _ t = t

        recurseAfter :: Map.Map Name Type -> Type -> Type
        recurseAfter m t@(TypeName _) = typeReplacer m t
        recurseAfter m (TypePtr t) = TypePtr $ recurseAfter m t
        recurseAfter _ t = t

        findTypeRenames :: Program -> Map.Map Name Type
        findTypeRenames (Program d) = foldl go Map.empty d
            where
                go :: Map.Map Name Type -> Declaration -> Map.Map Name Type
                go m (DecTypedef t n) = Map.insert n t m
                go m _ = m


checkUndefinedTypes :: Program -> Error ()
checkUndefinedTypes prog = fmap (const ()) $ mapProgram prog $ defaultPM {typeHandler = check}
  where
    check :: MapperHandler Type
    check (TypeName n) = Left $ "Undefined type name '" ++ n ++ "'"
    check t = Right t


typeCheck :: Program -> Error Program
typeCheck (Program decls) = Program <$> mapM (goD topLevelNames) decls
  where
    topLevelNames :: Map.Map Name Type
    topLevelNames = foldr (uncurry Map.insert) Map.empty pairs
        where pairs = map ((,) <$> nameOf <*> typeOf) $ filter isVarDecl decls

    functionTypes :: Map.Map Name (Type,[Type])
    functionTypes = foldr (uncurry Map.insert) Map.empty pairs
        where pairs = map ((,) <$> nameOf <*> getTypes) $ filter isFunctionDecl decls

              getTypes (DecFunction rt _ args _) = (rt, map fst args)
              getTypes (DecExtern (TypeFunc rt ats) _) = (rt, ats)
              getTypes _ = undefined

    isVarDecl (DecVariable {}) = True
    isVarDecl _ = False

    isFunctionDecl (DecFunction {}) = True
    isFunctionDecl (DecExtern (TypeFunc {}) _) = True
    isFunctionDecl _ = False

    goD :: Map.Map Name Type -> Declaration -> Error Declaration
    goD names (DecFunction frt name args body) = do
        newbody <- goB frt (foldr (\(t,n) m -> Map.insert n t m) names args) body
        return $ DecFunction frt name args newbody
    goD _ (DecVariable (TypeFunc _ _) _ _) = Left $ "Cannot declare global variable with function type"
    goD _ dec = return dec

    goB :: Type  -- function return type
        -> Map.Map Name Type -> Block -> Error Block
    goB frt names (Block stmts) = Block . snd <$> foldl foldfunc (return (names, [])) stmts
      where
        foldfunc :: Error (Map.Map Name Type, [Statement]) -> Statement -> Error (Map.Map Name Type, [Statement])
        foldfunc ep st = do
            (names', lst) <- ep
            (newnames', newst) <- goS frt names' st
            return (newnames', lst ++ [newst])  -- TODO: fix slow tail-append

    goS :: Type  -- function return type
        -> Map.Map Name Type -> Statement -> Error (Map.Map Name Type, Statement)
    goS _ _ (StVarDeclaration (TypeFunc _ _) _ _) = Left $ "Cannot declare variable with function type"
    goS _ names st@(StVarDeclaration t n Nothing) = do
        maybe (return (Map.insert n t names, st))
              (const $ Left $ "Duplicate variable '" ++ n ++ "'")
              (Map.lookup n names)
    goS frt names (StVarDeclaration t n (Just e)) = do
        (newnames, _) <- goS frt names (StVarDeclaration t n Nothing)
        (_, StAssignment _ newe) <- goS frt newnames (StAssignment n e)
        return (newnames, StVarDeclaration t n (Just newe))
    goS _ names (StAssignment n e) = maybe (Left $ "Undefined variable '" ++ n ++ "'") go (Map.lookup n names)
      where go dsttype = do
              re <- goE names e
              let (Just extype) = exTypeOf re
              if canConvert extype dsttype
                  then return (names, StAssignment n re)
                  else Left $ "Cannot convert type '" ++ pshow extype ++ "' to '"
                              ++ pshow dsttype ++ "' in assignment to variable '" ++ n ++ "'"
    goS _ names st@StEmpty = return (names, st)
    goS frt names (StBlock bl) = do
        newbl <- goB frt names bl
        return (names, StBlock newbl)
    goS _ names (StExpr e) = do
        re <- goE names e
        return (names, StExpr re)
    goS frt names (StIf e s1 s2) = do
        re <- goE names e
        (_, rs1) <- goS frt names s1
        (_, rs2) <- goS frt names s2
        return (names, StIf re rs1 rs2)
    goS frt names (StWhile e s) = do
        re <- goE names e
        (_, rs) <- goS frt names s
        return (names, StWhile re rs)
    goS TypeVoid names (StReturn Nothing) = return (names, StReturn Nothing)
    goS _ _ (StReturn Nothing) = Left $ "Non-void function should return a value"
    goS frt names (StReturn (Just e)) = do
        re <- goE names e
        let (Just extype) = exTypeOf re
        if canConvert extype frt
            then return (names, StReturn (Just re))
            else Left $ "Cannot convert type '" ++ pshow extype ++ "' to '"
                        ++ pshow frt ++ "' in return statement"

    -- Postcondition: the expression (if any) has a type annotation.
    goE :: Map.Map Name Type -> Expression -> Error Expression
    goE _ (ExLit l@(LitInt i) _) = smallestIntType i >>= \t -> return $ ExLit l $ Just t
    goE _ (ExLit l@(LitUInt i) _) = smallestUIntType i >>= \t -> return $ ExLit l $ Just t
    goE _ (ExLit l@(LitFloat f) _) = return $ ExLit l $ Just (smallestFloatType f)
    goE _ (ExLit l@(LitString _) _) = return $ ExLit l $ Just (TypePtr (TypeInt 8))
    goE names (ExLit l@(LitVar n) _) = maybe (Left $ "Undefined variable '" ++ n ++ "'") (return . ExLit l . Just)
                                             (Map.lookup n names)
    goE names (ExLit (LitCall n args) _) = do
        ft <- maybe (Left $ "Unknown function '" ++ n ++ "'") return $ Map.lookup n functionTypes
        rargs <- mapM (goE names) args
        when (length rargs /= length (snd ft))
            $ Left ("Expected " ++ show (length (snd ft)) ++ "arguments to "
                    ++ "function '" ++ n ++ "', but got " ++ show (length rargs))
              >> return ()
        flip mapM_ (zip rargs [0..]) $
            \(a,i) -> let argtype = fromJust (exTypeOf a)
                      in if canConvert argtype (snd ft !! i)
                          then return a
                          else Left $ "Cannot convert type '" ++ pshow argtype ++ "' to '" ++ pshow (snd ft !! i)
                                      ++ "' in call of function '" ++ n ++ "'"
        return $ ExLit (LitCall n rargs) (Just (fst ft))
    goE names (ExCast totype ex) = do
        rex <- goE names ex
        let fromtype = fromJust (exTypeOf rex)
        if canCast fromtype totype
            then return $ ExCast totype rex
            else Left $ "Cannot cast type '" ++ pshow fromtype ++ "' to '" ++ pshow totype ++ "'"
    goE names (ExBinOp bo e1 e2 _) = do
        re1 <- goE names e1
        re2 <- goE names e2
        maybe (Left $ "Cannot use operator '" ++ pshow bo ++ "' with argument types '"
                      ++ pshow (fromJust $ exTypeOf re1) ++ "' and '" ++ pshow (fromJust $ exTypeOf re2) ++ "'")
              (return . ExBinOp bo re1 re2 . Just)
              $ resultTypeBO bo (fromJust $ exTypeOf re1) (fromJust $ exTypeOf re2)
    goE names (ExUnOp uo e _) = do
        re <- goE names e
        maybe (Left $ "Cannot use operator '" ++ pshow uo ++ "' with argument type '" ++ pshow (fromJust $ exTypeOf re))
              (return . ExUnOp uo re . Just)
              $ resultTypeUO uo (fromJust $ exTypeOf re)


bundleVarDecls :: Program -> Error Program
bundleVarDecls prog = mapProgram prog $ defaultPM {blockHandler = goBlock}
  where
    goBlock :: MapperHandler Block
    goBlock (Block stmts) =
        let isVarDecl (StVarDeclaration {}) = True
            isVarDecl _ = False

            removeDecls [] = []
            removeDecls ((StVarDeclaration _ n (Just ex)):rest) = StAssignment n ex : removeDecls rest
            removeDecls ((StVarDeclaration _ _ Nothing):rest) = removeDecls rest
            removeDecls (st:rest) = st : removeDecls rest

            onlyDecl (StVarDeclaration t n _) = StVarDeclaration t n Nothing
            onlyDecl _ = undefined

            vdecls = map onlyDecl $ filter isVarDecl stmts
        in return $ Block $ vdecls ++ removeDecls stmts


canConvert :: Type -> Type -> Bool
canConvert x y | x == y = True
canConvert (TypeInt f) (TypeInt t) = f <= t
canConvert (TypeUInt f) (TypeUInt t) = f <= t
canConvert (TypeUInt 1) (TypeInt _) = True
canConvert TypeFloat TypeDouble = True
canConvert (TypeInt _) TypeFloat = True
canConvert (TypeInt _) TypeDouble = True
canConvert (TypeUInt _) TypeFloat = True
canConvert (TypeUInt _) TypeDouble = True
canConvert _ _ = False

canCast :: Type -> Type -> Bool
canCast t1 t2 = any (\f -> f t1 && f t2) [numberGroup, intptrGroup]
  where
    numberGroup (TypeInt _) = True
    numberGroup (TypeUInt _) = True
    numberGroup TypeFloat = True
    numberGroup TypeDouble = True
    numberGroup _ = False

    intptrGroup (TypeInt _) = True
    intptrGroup (TypeUInt _) = True
    intptrGroup (TypePtr _) = True
    intptrGroup (TypeFunc _ _) = True
    intptrGroup _ = False

arithBO, compareBO, logicBO, complogBO :: [BinaryOperator]
arithBO = [Plus, Minus, Times, Divide, Modulo]
compareBO = [Equal, Unequal, Greater, Less, GEqual, LEqual]
logicBO = [BoolAnd, BoolOr]
complogBO = compareBO ++ logicBO

resultTypeBO :: BinaryOperator -> Type -> Type -> Maybe Type
resultTypeBO Minus (TypePtr t1) (TypePtr t2) | t1 == t2 = Just $ TypeUInt 64
resultTypeBO bo (TypePtr t1) (TypePtr t2) | t1 == t2 && bo `elem` complogBO = Just $ TypeUInt 1
resultTypeBO bo t@(TypePtr _) (TypeInt _) | bo `elem` [Plus, Minus] = Just t
resultTypeBO bo t@(TypePtr _) (TypeUInt _) | bo `elem` [Plus, Minus] = Just t
resultTypeBO bo (TypeInt _) t@(TypePtr _) | bo `elem` [Plus, Minus] = Just t
resultTypeBO bo (TypeUInt _) t@(TypePtr _) | bo `elem` [Plus, Minus] = Just t
resultTypeBO Index (TypePtr t) (TypeInt _) = Just t
resultTypeBO Index (TypePtr t) (TypeUInt _) = Just t
resultTypeBO _ (TypePtr _) _ = Nothing
resultTypeBO _ _ (TypePtr _) = Nothing

resultTypeBO bo (TypeInt s1) (TypeInt s2) | bo `elem` arithBO = Just $ TypeInt (max s1 s2)
resultTypeBO bo (TypeInt _) (TypeInt _) | bo `elem` complogBO = Just $ TypeUInt 1

resultTypeBO bo (TypeUInt s1) (TypeUInt s2) | bo `elem` arithBO = Just $ TypeUInt (max s1 s2)
resultTypeBO bo (TypeUInt _) (TypeUInt _) | bo `elem` complogBO = Just $ TypeUInt 1

resultTypeBO bo t1 t2 | bo `elem` complogBO && t1 == t2 = Just $ TypeUInt 1

resultTypeBO bo TypeFloat (TypeInt s) | s <= 24  = Just $ if bo `elem` arithBO then TypeFloat else TypeUInt 1
resultTypeBO bo (TypeInt s) TypeFloat | s <= 24  = Just $ if bo `elem` arithBO then TypeFloat else TypeUInt 1
resultTypeBO bo TypeDouble (TypeInt s) | s <= 53 = Just $ if bo `elem` arithBO then TypeDouble else TypeUInt 1
resultTypeBO bo (TypeInt s) TypeDouble | s <= 53 = Just $ if bo `elem` arithBO then TypeDouble else TypeUInt 1
resultTypeBO bo TypeFloat TypeFloat = Just $ if bo `elem` arithBO then TypeFloat else TypeUInt 1
resultTypeBO bo TypeDouble TypeDouble = Just $ if bo `elem` arithBO then TypeDouble else TypeUInt 1
resultTypeBO bo TypeFloat TypeDouble = Just $ if bo `elem` arithBO then TypeDouble else TypeUInt 1
resultTypeBO bo TypeDouble TypeFloat = Just $ if bo `elem` arithBO then TypeDouble else TypeUInt 1

resultTypeBO _ _ _ = Nothing

resultTypeUO :: UnaryOperator -> Type -> Maybe Type
resultTypeUO Not _ = Just $ TypeUInt 1
resultTypeUO Address t = Just $ TypePtr t
resultTypeUO uo t@(TypeInt _) | uo `elem` [Negate, Invert] = Just t
resultTypeUO uo t@(TypeUInt _) | uo `elem` [Negate, Invert] = Just t
resultTypeUO Negate TypeFloat = Just TypeFloat
resultTypeUO Negate TypeDouble = Just TypeDouble
resultTypeUO Dereference (TypePtr t) = Just t
resultTypeUO _ _ = Nothing

smallestFloatType :: Double -> Type
smallestFloatType d =
    let truncfloat = realToFrac (realToFrac d :: Float) :: Double
    in if d == truncfloat then TypeFloat else TypeDouble

smallestIntType :: Integer -> Error Type
smallestIntType i
    | i >= -2^7 && i < 2^7 = return $ TypeInt 8
    | i >= -2^15 && i < 2^15 = return $ TypeInt 16
    | i >= -2^31 && i < 2^31 = return $ TypeInt 32
    | i >= -2^63 && i < 2^63 = return $ TypeInt 64
    | otherwise = Left $ "Integer literal '" ++ pshow i ++ "' too wide for i64"

smallestUIntType :: Integer -> Error Type
smallestUIntType i
    | i > -2^8 && i < 2^8 = return $ TypeUInt 8
    | i > -2^16 && i < 2^16 = return $ TypeUInt 16
    | i > -2^32 && i < 2^32 = return $ TypeUInt 32
    | i > -2^64 && i < 2^64 = return $ TypeUInt 64
    | otherwise = Left $ "Unsigned integer literal '" ++ pshow i ++ "U' too wide for u64"


type MapperHandler a = a -> Error a

data ProgramMapper = ProgramMapper
    {declarationHandler :: MapperHandler Declaration
    ,blockHandler :: MapperHandler Block
    ,typeHandler :: MapperHandler Type
    ,literalHandler :: MapperHandler Literal
    ,binOpHandler :: MapperHandler BinaryOperator
    ,unOpHandler :: MapperHandler UnaryOperator
    ,expressionHandler :: MapperHandler Expression
    ,statementHandler :: MapperHandler Statement
    ,nameHandler :: MapperHandler Name}

type MapperHandler' a = a -> a

data ProgramMapper' = ProgramMapper'
    {declarationHandler' :: MapperHandler' Declaration
    ,blockHandler' :: MapperHandler' Block
    ,typeHandler' :: MapperHandler' Type
    ,literalHandler' :: MapperHandler' Literal
    ,binOpHandler' :: MapperHandler' BinaryOperator
    ,unOpHandler' :: MapperHandler' UnaryOperator
    ,expressionHandler' :: MapperHandler' Expression
    ,statementHandler' :: MapperHandler' Statement
    ,nameHandler' :: MapperHandler' Name}

defaultPM :: ProgramMapper
defaultPM = ProgramMapper return return return return return return return return return

defaultPM' :: ProgramMapper'
defaultPM' = ProgramMapper' id id id id id id id id id

mapProgram' :: Program -> ProgramMapper' -> Program
mapProgram' prog mapper = (\(Right r) -> r) $ mapProgram prog $ ProgramMapper
    {declarationHandler = return . declarationHandler' mapper
    ,blockHandler       = return . blockHandler' mapper
    ,typeHandler        = return . typeHandler' mapper
    ,literalHandler     = return . literalHandler' mapper
    ,binOpHandler       = return . binOpHandler' mapper
    ,unOpHandler        = return . unOpHandler' mapper
    ,expressionHandler  = return . expressionHandler' mapper
    ,statementHandler   = return . statementHandler' mapper
    ,nameHandler        = return . nameHandler' mapper}

mapProgram :: Program -> ProgramMapper -> Error Program
mapProgram prog mapper = goP prog
    where
        h_d = declarationHandler mapper
        h_b = blockHandler mapper
        h_t = typeHandler mapper
        h_l = literalHandler mapper
        h_bo = binOpHandler mapper
        h_uo = unOpHandler mapper
        h_e = expressionHandler mapper
        h_s = statementHandler mapper
        h_n = nameHandler mapper

        goP :: MapperHandler Program
        goP (Program decls) = Program <$> sequence (map (\d -> goD d >>= h_d) decls)

        goD :: MapperHandler Declaration
        goD (DecFunction t n a b) = do
            rt <- goT t
            rn <- goN n
            ra <- sequence $ map (\(at,an) -> (,) <$> goT at <*> goN an) a
            rb <- goB b
            h_d $ DecFunction rt rn ra rb
        goD (DecVariable t n mv) = do
            rt <- goT t
            rn <- goN n
            rmv <- sequence $ fmap goE mv
            h_d $ DecVariable rt rn rmv
        goD (DecTypedef t n) = do
            rt <- goT t
            rn <- goN n
            h_d $ DecTypedef rt rn
        goD (DecExtern t n) = do
            rt <- goT t
            rn <- goN n
            h_d $ DecExtern rt rn

        goT :: MapperHandler Type
        goT (TypePtr t) = goT t >>= (h_t . TypePtr)
        goT (TypeName n) = goN n >>= (h_t . TypeName)
        goT (TypeFunc t as) = do
            rt <- goT t
            ras <- mapM goT as
            h_t $ TypeFunc rt ras
        goT t = h_t t

        goN :: MapperHandler Name
        goN = h_n

        goB :: MapperHandler Block
        goB (Block sts) = (Block <$> sequence (map goS sts)) >>= h_b

        goE :: MapperHandler Expression
        goE (ExLit l mt) = do
            rl <- goL l
            h_e $ ExLit rl mt
        goE (ExCast t e) = do
            rt <- goT t
            re <- goE e
            h_e $ ExCast rt re
        goE (ExBinOp bo e1 e2 mt) = do
            rbo <- goBO bo
            re1 <- goE e1
            re2 <- goE e2
            h_e $ ExBinOp rbo re1 re2 mt
        goE (ExUnOp uo e mt) = do
            ruo <- goUO uo
            re <- goE e
            h_e $ ExUnOp ruo re mt

        goS :: MapperHandler Statement
        goS StEmpty = h_s StEmpty
        goS (StBlock b) = goB b >>= (h_s . StBlock)
        goS (StExpr e) = goE e >>= (h_s . StExpr)
        goS (StVarDeclaration t n me) = do
            rt <- goT t
            rn <- goN n
            rme <- sequence $ fmap goE me
            h_s $ StVarDeclaration rt rn rme
        goS (StAssignment n e) = do
            rn <- goN n
            re <- goE e
            h_s $ StAssignment rn re
        goS (StIf e s1 s2) = do
            re <- goE e
            rs1 <- goS s1
            rs2 <- goS s2
            h_s $ StIf re rs1 rs2
        goS (StWhile e s) = do
            re <- goE e
            rs <- goS s
            h_s $ StWhile re rs
        goS (StReturn Nothing) = h_s (StReturn Nothing)
        goS (StReturn (Just e)) = goE e >>= (h_s . StReturn . Just)

        goL :: MapperHandler Literal
        goL l@(LitString _) = h_l l
        goL l@(LitInt _) = h_l l
        goL l@(LitUInt _) = h_l l
        goL l@(LitFloat _) = h_l l
        goL (LitVar n) = goN n >>= (h_l . LitVar)
        goL (LitCall n a) = do
            rn <- goN n
            ra <- sequence $ map goE a
            h_l $ LitCall rn ra

        goBO :: MapperHandler BinaryOperator
        goBO = h_bo

        goUO :: MapperHandler UnaryOperator
        goUO = h_uo