summaryrefslogtreecommitdiff
path: root/check.hs
diff options
context:
space:
mode:
Diffstat (limited to 'check.hs')
-rw-r--r--check.hs380
1 files changed, 380 insertions, 0 deletions
diff --git a/check.hs b/check.hs
new file mode 100644
index 0000000..5f7dd12
--- /dev/null
+++ b/check.hs
@@ -0,0 +1,380 @@
+module Check(checkProgram) where
+
+import Control.Monad
+import Data.Maybe
+import qualified Data.Map.Strict as Map
+--import Debug.Trace
+
+import AST
+import PShow
+
+
+type Error a = Either String a
+
+
+checkProgram :: Program -> Error Program
+checkProgram prog = do
+ let processed = replaceTypes prog
+ checkUndefinedTypes processed
+ typeCheck processed >>= bundleVarDecls
+
+
+replaceTypes :: Program -> Program
+replaceTypes prog@(Program decls) = mapProgram' filtered mapper
+ where
+ filtered = Program $ filter notTypedef decls
+ mapper = defaultPM' {typeHandler' = typeReplacer (findTypeRenames prog)}
+
+ notTypedef :: Declaration -> Bool
+ notTypedef (DecTypedef _ _) = False
+ notTypedef _ = True
+
+ typeReplacer :: Map.Map Name Type -> Type -> Type
+ typeReplacer m t@(TypeName n) = maybe t id $ Map.lookup n m
+ typeReplacer _ t = t
+
+ findTypeRenames :: Program -> Map.Map Name Type
+ findTypeRenames (Program d) = foldl go Map.empty d
+ where
+ go :: Map.Map Name Type -> Declaration -> Map.Map Name Type
+ go m (DecTypedef t n) = Map.insert n t m
+ go m _ = m
+
+
+checkUndefinedTypes :: Program -> Error ()
+checkUndefinedTypes prog = fmap (const ()) $ mapProgram prog $ defaultPM {typeHandler = check}
+ where
+ check :: MapperHandler Type
+ check (TypeName n) = Left $ "Undefined type name '" ++ n ++ "'"
+ check t = Right t
+
+
+typeCheck :: Program -> Error Program
+typeCheck (Program decls) = Program <$> mapM (goD topLevelNames) decls
+ where
+ topLevelNames :: Map.Map Name Type
+ topLevelNames = foldr (uncurry Map.insert) Map.empty pairs
+ where pairs = map ((,) <$> nameOf <*> typeOf) $ filter isVarDecl decls
+
+ functionTypes :: Map.Map Name (Type,[Type])
+ functionTypes = foldr (uncurry Map.insert) Map.empty pairs
+ where pairs = map ((,) <$> nameOf <*> getTypes) $ filter isFunctionDecl decls
+ getTypes (DecFunction rt _ args _) = (rt, map fst args)
+ getTypes _ = undefined
+
+ isVarDecl (DecVariable {}) = True
+ isVarDecl _ = False
+
+ isFunctionDecl (DecFunction {}) = True
+ isFunctionDecl _ = False
+
+ goD :: Map.Map Name Type -> Declaration -> Error Declaration
+ goD names (DecFunction frt name args body) = do
+ newbody <- goB frt (foldr (\(t,n) m -> Map.insert n t m) names args) body
+ return $ DecFunction frt name args newbody
+ goD _ dec = return dec
+
+ goB :: Type -- function return type
+ -> Map.Map Name Type -> Block -> Error Block
+ goB frt names (Block stmts) = Block . snd <$> foldl foldfunc (return (names, [])) stmts
+ where
+ foldfunc :: Error (Map.Map Name Type, [Statement]) -> Statement -> Error (Map.Map Name Type, [Statement])
+ foldfunc ep st = do
+ (names', lst) <- ep
+ (newnames', newst) <- goS frt names' st
+ return (newnames', lst ++ [newst]) -- TODO: fix slow tail-append
+
+ goS :: Type -- function return type
+ -> Map.Map Name Type -> Statement -> Error (Map.Map Name Type, Statement)
+ goS _ names st@(StVarDeclaration t n Nothing) = return (Map.insert n t names, st)
+ goS frt names (StVarDeclaration t n (Just e)) = do
+ (newnames, _) <- goS frt names (StVarDeclaration t n Nothing)
+ (_, StAssignment _ newe) <- goS frt newnames (StAssignment n e)
+ return (newnames, StVarDeclaration t n (Just newe))
+ goS _ names (StAssignment n e) = maybe (Left $ "Undefined variable '" ++ n ++ "'") go (Map.lookup n names)
+ where go dsttype = do
+ re <- goE names e
+ let (Just extype) = exTypeOf re
+ if canConvert extype dsttype
+ then return (names, StAssignment n re)
+ else Left $ "Cannot convert type '" ++ pshow extype ++ "' to '"
+ ++ pshow dsttype ++ "' in assignment to variable '" ++ n ++ "'"
+ goS _ names st@StEmpty = return (names, st)
+ goS frt names (StBlock bl) = do
+ newbl <- goB frt names bl
+ return (names, StBlock newbl)
+ goS _ names (StExpr e) = do
+ re <- goE names e
+ return (names, StExpr re)
+ goS frt names (StIf e s1 s2) = do
+ re <- goE names e
+ (_, rs1) <- goS frt names s1
+ (_, rs2) <- goS frt names s2
+ return (names, StIf re rs1 rs2)
+ goS frt names (StWhile e s) = do
+ re <- goE names e
+ (_, rs) <- goS frt names s
+ return (names, StWhile re rs)
+ goS frt names (StReturn e) = do
+ re <- goE names e
+ let (Just extype) = exTypeOf re
+ if canConvert extype frt
+ then return (names, StReturn re)
+ else Left $ "Cannot convert type '" ++ pshow extype ++ "' to '"
+ ++ pshow frt ++ "' in return statement"
+
+ -- Postcondition: the expression (if any) has a type annotation.
+ goE :: Map.Map Name Type -> Expression -> Error Expression
+ goE _ (ExLit l@(LitInt i) _) = return $ ExLit l $ Just (smallestIntType i)
+ goE _ (ExLit l@(LitString _) _) = return $ ExLit l $ Just (TypePtr (TypeInt 8))
+ goE names (ExLit l@(LitVar n) _) = maybe (Left $ "Undefined variable '" ++ n ++ "'") (return . ExLit l . Just)
+ (Map.lookup n names)
+ goE names (ExLit l@(LitCall n args) _) = do
+ ft <- maybe (Left $ "Unknown function '" ++ n ++ "'") return $ Map.lookup n functionTypes
+ rargs <- mapM (goE names) args
+ when (length rargs /= length (snd ft))
+ $ Left ("Expected " ++ show (length (snd ft)) ++ "arguments to "
+ ++ "function '" ++ n ++ "', but got " ++ show (length rargs))
+ >> return ()
+ flip mapM_ rargs $
+ \a -> let argtype = fromJust (exTypeOf a)
+ in if canConvert argtype (fst ft)
+ then return a
+ else Left $ "Cannot convert type '" ++ pshow argtype ++ "' to '" ++ pshow (fst ft)
+ ++ "' in call of function '" ++ pshow n ++ "'"
+ return $ ExLit l (Just (fst ft))
+ goE names (ExBinOp bo e1 e2 _) = do
+ re1 <- goE names e1
+ re2 <- goE names e2
+ maybe (Left $ "Cannot use operator '" ++ pshow bo ++ "' with argument types '"
+ ++ pshow (fromJust $ exTypeOf re1) ++ "' and '" ++ pshow (fromJust $ exTypeOf re2) ++ "'")
+ (return . ExBinOp bo re1 re2 . Just)
+ $ typeCompatibleBO bo (fromJust $ exTypeOf re1) (fromJust $ exTypeOf re2)
+ goE names (ExUnOp uo e _) = do
+ re <- goE names e
+ maybe (Left $ "Cannot use operator '" ++ pshow uo ++ "' with argument type '" ++ pshow (fromJust $ exTypeOf re))
+ (return . ExUnOp uo re . Just)
+ $ typeCompatibleUO uo (fromJust $ exTypeOf re)
+
+
+bundleVarDecls :: Program -> Error Program
+bundleVarDecls prog = mapProgram prog $ defaultPM {blockHandler = goBlock}
+ where
+ goBlock :: MapperHandler Block
+ goBlock (Block stmts) =
+ let isVarDecl (StVarDeclaration {}) = True
+ isVarDecl _ = False
+
+ removeDecls [] = []
+ removeDecls ((StVarDeclaration _ n (Just ex)):rest) = StAssignment n ex : removeDecls rest
+ removeDecls ((StVarDeclaration _ _ Nothing):rest) = removeDecls rest
+ removeDecls (st:rest) = st : removeDecls rest
+
+ onlyDecl (StVarDeclaration t n _) = StVarDeclaration t n Nothing
+ onlyDecl _ = undefined
+
+ vdecls = map onlyDecl $ filter isVarDecl stmts
+ in return $ Block $ vdecls ++ removeDecls stmts
+
+
+canConvert :: Type -> Type -> Bool
+canConvert x y | x == y = True
+canConvert (TypeInt f) (TypeInt t) = f <= t
+canConvert (TypeUInt f) (TypeUInt t) = f <= t
+canConvert TypeFloat TypeDouble = True
+canConvert _ _ = False
+
+arithBO, compareBO, logicBO, complogBO :: [BinaryOperator]
+arithBO = [Plus, Minus, Times, Divide, Modulo]
+compareBO = [Equal, Unequal, Greater, Less, GEqual, LEqual]
+logicBO = [BoolAnd, BoolOr]
+complogBO = compareBO ++ logicBO
+
+typeCompatibleBO :: BinaryOperator -> Type -> Type -> Maybe Type
+typeCompatibleBO Minus (TypePtr t1) (TypePtr t2) | t1 == t2 = Just $ TypeInt 1
+typeCompatibleBO bo (TypePtr t1) (TypePtr t2) | t1 == t2 && bo `elem` complogBO = Just $ TypeInt 1
+typeCompatibleBO _ (TypePtr _) _ = Nothing
+typeCompatibleBO _ _ (TypePtr _) = Nothing
+
+typeCompatibleBO bo (TypeInt s1) (TypeInt s2) | bo `elem` arithBO = Just $ TypeInt (max s1 s2)
+typeCompatibleBO bo (TypeInt _) (TypeInt _) | bo `elem` complogBO = Just $ TypeInt 1
+
+typeCompatibleBO bo (TypeUInt s1) (TypeUInt s2) | bo `elem` arithBO = Just $ TypeUInt (max s1 s2)
+typeCompatibleBO bo (TypeUInt _) (TypeUInt _) | bo `elem` complogBO = Just $ TypeInt 1
+
+typeCompatibleBO bo t1 t2 | bo `elem` complogBO && t1 == t2 = Just $ TypeInt 1
+
+typeCompatibleBO bo TypeFloat (TypeInt s) | s <= 24 = Just $ if bo `elem` arithBO then TypeFloat else TypeInt 1
+typeCompatibleBO bo (TypeInt s) TypeFloat | s <= 24 = Just $ if bo `elem` arithBO then TypeFloat else TypeInt 1
+typeCompatibleBO bo TypeDouble (TypeInt s) | s <= 53 = Just $ if bo `elem` arithBO then TypeDouble else TypeInt 1
+typeCompatibleBO bo (TypeInt s) TypeDouble | s <= 53 = Just $ if bo `elem` arithBO then TypeDouble else TypeInt 1
+typeCompatibleBO bo TypeFloat TypeDouble = Just $ if bo `elem` arithBO then TypeFloat else TypeInt 1
+typeCompatibleBO bo TypeDouble TypeFloat = Just $ if bo `elem` arithBO then TypeFloat else TypeInt 1
+
+typeCompatibleBO _ _ _ = Nothing
+
+typeCompatibleUO :: UnaryOperator -> Type -> Maybe Type
+typeCompatibleUO Not _ = Just $ TypeInt 1
+typeCompatibleUO Address t = Just $ TypePtr t
+typeCompatibleUO uo t@(TypeInt _) | uo `elem` [Negate, Invert] = Just t
+typeCompatibleUO uo t@(TypeUInt _) | uo `elem` [Negate, Invert] = Just t
+typeCompatibleUO Negate TypeFloat = Just TypeFloat
+typeCompatibleUO Negate TypeDouble = Just TypeDouble
+typeCompatibleUO Dereference t@(TypePtr _) = Just t
+typeCompatibleUO _ _ = Nothing
+
+smallestIntType :: Integer -> Type
+smallestIntType i
+ | i >= -2^7 && i < 2^7 = TypeInt 8
+ | i >= -2^15 && i < 2^15 = TypeInt 16
+ | i >= -2^31 && i < 2^31 = TypeInt 32
+ | otherwise = TypeInt 64
+
+-- smallestUIntType :: Integer -> Type
+-- smallestUIntType i
+-- | i >= 0 && i < 2^8 = TypeUInt 8
+-- | i >= 0 && i < 2^16 = TypeUInt 16
+-- | i >= 0 && i < 2^32 = TypeUInt 32
+-- | otherwise = TypeUInt 64
+
+
+type MapperHandler a = a -> Error a
+
+data ProgramMapper = ProgramMapper
+ {declarationHandler :: MapperHandler Declaration
+ ,blockHandler :: MapperHandler Block
+ ,typeHandler :: MapperHandler Type
+ ,literalHandler :: MapperHandler Literal
+ ,binOpHandler :: MapperHandler BinaryOperator
+ ,unOpHandler :: MapperHandler UnaryOperator
+ ,expressionHandler :: MapperHandler Expression
+ ,statementHandler :: MapperHandler Statement
+ ,nameHandler :: MapperHandler Name}
+
+type MapperHandler' a = a -> a
+
+data ProgramMapper' = ProgramMapper'
+ {declarationHandler' :: MapperHandler' Declaration
+ ,blockHandler' :: MapperHandler' Block
+ ,typeHandler' :: MapperHandler' Type
+ ,literalHandler' :: MapperHandler' Literal
+ ,binOpHandler' :: MapperHandler' BinaryOperator
+ ,unOpHandler' :: MapperHandler' UnaryOperator
+ ,expressionHandler' :: MapperHandler' Expression
+ ,statementHandler' :: MapperHandler' Statement
+ ,nameHandler' :: MapperHandler' Name}
+
+defaultPM :: ProgramMapper
+defaultPM = ProgramMapper return return return return return return return return return
+
+defaultPM' :: ProgramMapper'
+defaultPM' = ProgramMapper' id id id id id id id id id
+
+mapProgram' :: Program -> ProgramMapper' -> Program
+mapProgram' prog mapper = (\(Right r) -> r) $ mapProgram prog $ ProgramMapper
+ {declarationHandler = return . declarationHandler' mapper
+ ,blockHandler = return . blockHandler' mapper
+ ,typeHandler = return . typeHandler' mapper
+ ,literalHandler = return . literalHandler' mapper
+ ,binOpHandler = return . binOpHandler' mapper
+ ,unOpHandler = return . unOpHandler' mapper
+ ,expressionHandler = return . expressionHandler' mapper
+ ,statementHandler = return . statementHandler' mapper
+ ,nameHandler = return . nameHandler' mapper}
+
+mapProgram :: Program -> ProgramMapper -> Error Program
+mapProgram prog mapper = goP prog
+ where
+ h_d = declarationHandler mapper
+ h_b = blockHandler mapper
+ h_t = typeHandler mapper
+ h_l = literalHandler mapper
+ h_bo = binOpHandler mapper
+ h_uo = unOpHandler mapper
+ h_e = expressionHandler mapper
+ h_s = statementHandler mapper
+ h_n = nameHandler mapper
+
+ goP :: MapperHandler Program
+ goP (Program decls) = Program <$> sequence (map (\d -> goD d >>= h_d) decls)
+
+ goD :: MapperHandler Declaration
+ goD (DecFunction t n a b) = do
+ rt <- goT t
+ rn <- goN n
+ ra <- sequence $ map (\(at,an) -> (,) <$> goT at <*> goN an) a
+ rb <- goB b
+ h_d $ DecFunction rt rn ra rb
+ goD (DecVariable t n mv) = do
+ rt <- goT t
+ rn <- goN n
+ rmv <- sequence $ fmap goE mv
+ h_d $ DecVariable rt rn rmv
+ goD (DecTypedef t n) = do
+ rt <- goT t
+ rn <- goN n
+ h_d $ DecTypedef rt rn
+
+ goT :: MapperHandler Type
+ goT (TypePtr t) = goT t >>= (h_t . TypePtr)
+ goT (TypeName n) = goN n >>= (h_t . TypeName)
+ goT t = h_t t
+
+ goN :: MapperHandler Name
+ goN = h_n
+
+ goB :: MapperHandler Block
+ goB (Block sts) = (Block <$> sequence (map goS sts)) >>= h_b
+
+ goE :: MapperHandler Expression
+ goE (ExLit l mt) = do
+ rl <- goL l
+ h_e $ ExLit rl mt
+ goE (ExBinOp bo e1 e2 mt) = do
+ rbo <- goBO bo
+ re1 <- goE e1
+ re2 <- goE e2
+ h_e $ ExBinOp rbo re1 re2 mt
+ goE (ExUnOp uo e mt) = do
+ ruo <- goUO uo
+ re <- goE e
+ h_e $ ExUnOp ruo re mt
+
+ goS :: MapperHandler Statement
+ goS StEmpty = h_s StEmpty
+ goS (StBlock b) = goB b >>= (h_s . StBlock)
+ goS (StExpr e) = goE e >>= (h_s . StExpr)
+ goS (StVarDeclaration t n me) = do
+ rt <- goT t
+ rn <- goN n
+ rme <- sequence $ fmap goE me
+ h_s $ StVarDeclaration rt rn rme
+ goS (StAssignment n e) = do
+ rn <- goN n
+ re <- goE e
+ h_s $ StAssignment rn re
+ goS (StIf e s1 s2) = do
+ re <- goE e
+ rs1 <- goS s1
+ rs2 <- goS s2
+ h_s $ StIf re rs1 rs2
+ goS (StWhile e s) = do
+ re <- goE e
+ rs <- goS s
+ h_s $ StWhile re rs
+ goS (StReturn e) = goE e >>= (h_s . StReturn)
+
+ goL :: MapperHandler Literal
+ goL l@(LitString _) = h_l l
+ goL l@(LitInt _) = h_l l
+ goL (LitVar n) = goN n >>= (h_l . LitVar)
+ goL (LitCall n a) = do
+ rn <- goN n
+ ra <- sequence $ map goE a
+ h_l $ LitCall rn ra
+
+ goBO :: MapperHandler BinaryOperator
+ goBO = h_bo
+
+ goUO :: MapperHandler UnaryOperator
+ goUO = h_uo