diff options
Diffstat (limited to 'check.hs')
-rw-r--r-- | check.hs | 380 |
1 files changed, 380 insertions, 0 deletions
diff --git a/check.hs b/check.hs new file mode 100644 index 0000000..5f7dd12 --- /dev/null +++ b/check.hs @@ -0,0 +1,380 @@ +module Check(checkProgram) where + +import Control.Monad +import Data.Maybe +import qualified Data.Map.Strict as Map +--import Debug.Trace + +import AST +import PShow + + +type Error a = Either String a + + +checkProgram :: Program -> Error Program +checkProgram prog = do + let processed = replaceTypes prog + checkUndefinedTypes processed + typeCheck processed >>= bundleVarDecls + + +replaceTypes :: Program -> Program +replaceTypes prog@(Program decls) = mapProgram' filtered mapper + where + filtered = Program $ filter notTypedef decls + mapper = defaultPM' {typeHandler' = typeReplacer (findTypeRenames prog)} + + notTypedef :: Declaration -> Bool + notTypedef (DecTypedef _ _) = False + notTypedef _ = True + + typeReplacer :: Map.Map Name Type -> Type -> Type + typeReplacer m t@(TypeName n) = maybe t id $ Map.lookup n m + typeReplacer _ t = t + + findTypeRenames :: Program -> Map.Map Name Type + findTypeRenames (Program d) = foldl go Map.empty d + where + go :: Map.Map Name Type -> Declaration -> Map.Map Name Type + go m (DecTypedef t n) = Map.insert n t m + go m _ = m + + +checkUndefinedTypes :: Program -> Error () +checkUndefinedTypes prog = fmap (const ()) $ mapProgram prog $ defaultPM {typeHandler = check} + where + check :: MapperHandler Type + check (TypeName n) = Left $ "Undefined type name '" ++ n ++ "'" + check t = Right t + + +typeCheck :: Program -> Error Program +typeCheck (Program decls) = Program <$> mapM (goD topLevelNames) decls + where + topLevelNames :: Map.Map Name Type + topLevelNames = foldr (uncurry Map.insert) Map.empty pairs + where pairs = map ((,) <$> nameOf <*> typeOf) $ filter isVarDecl decls + + functionTypes :: Map.Map Name (Type,[Type]) + functionTypes = foldr (uncurry Map.insert) Map.empty pairs + where pairs = map ((,) <$> nameOf <*> getTypes) $ filter isFunctionDecl decls + getTypes (DecFunction rt _ args _) = (rt, map fst args) + getTypes _ = undefined + + isVarDecl (DecVariable {}) = True + isVarDecl _ = False + + isFunctionDecl (DecFunction {}) = True + isFunctionDecl _ = False + + goD :: Map.Map Name Type -> Declaration -> Error Declaration + goD names (DecFunction frt name args body) = do + newbody <- goB frt (foldr (\(t,n) m -> Map.insert n t m) names args) body + return $ DecFunction frt name args newbody + goD _ dec = return dec + + goB :: Type -- function return type + -> Map.Map Name Type -> Block -> Error Block + goB frt names (Block stmts) = Block . snd <$> foldl foldfunc (return (names, [])) stmts + where + foldfunc :: Error (Map.Map Name Type, [Statement]) -> Statement -> Error (Map.Map Name Type, [Statement]) + foldfunc ep st = do + (names', lst) <- ep + (newnames', newst) <- goS frt names' st + return (newnames', lst ++ [newst]) -- TODO: fix slow tail-append + + goS :: Type -- function return type + -> Map.Map Name Type -> Statement -> Error (Map.Map Name Type, Statement) + goS _ names st@(StVarDeclaration t n Nothing) = return (Map.insert n t names, st) + goS frt names (StVarDeclaration t n (Just e)) = do + (newnames, _) <- goS frt names (StVarDeclaration t n Nothing) + (_, StAssignment _ newe) <- goS frt newnames (StAssignment n e) + return (newnames, StVarDeclaration t n (Just newe)) + goS _ names (StAssignment n e) = maybe (Left $ "Undefined variable '" ++ n ++ "'") go (Map.lookup n names) + where go dsttype = do + re <- goE names e + let (Just extype) = exTypeOf re + if canConvert extype dsttype + then return (names, StAssignment n re) + else Left $ "Cannot convert type '" ++ pshow extype ++ "' to '" + ++ pshow dsttype ++ "' in assignment to variable '" ++ n ++ "'" + goS _ names st@StEmpty = return (names, st) + goS frt names (StBlock bl) = do + newbl <- goB frt names bl + return (names, StBlock newbl) + goS _ names (StExpr e) = do + re <- goE names e + return (names, StExpr re) + goS frt names (StIf e s1 s2) = do + re <- goE names e + (_, rs1) <- goS frt names s1 + (_, rs2) <- goS frt names s2 + return (names, StIf re rs1 rs2) + goS frt names (StWhile e s) = do + re <- goE names e + (_, rs) <- goS frt names s + return (names, StWhile re rs) + goS frt names (StReturn e) = do + re <- goE names e + let (Just extype) = exTypeOf re + if canConvert extype frt + then return (names, StReturn re) + else Left $ "Cannot convert type '" ++ pshow extype ++ "' to '" + ++ pshow frt ++ "' in return statement" + + -- Postcondition: the expression (if any) has a type annotation. + goE :: Map.Map Name Type -> Expression -> Error Expression + goE _ (ExLit l@(LitInt i) _) = return $ ExLit l $ Just (smallestIntType i) + goE _ (ExLit l@(LitString _) _) = return $ ExLit l $ Just (TypePtr (TypeInt 8)) + goE names (ExLit l@(LitVar n) _) = maybe (Left $ "Undefined variable '" ++ n ++ "'") (return . ExLit l . Just) + (Map.lookup n names) + goE names (ExLit l@(LitCall n args) _) = do + ft <- maybe (Left $ "Unknown function '" ++ n ++ "'") return $ Map.lookup n functionTypes + rargs <- mapM (goE names) args + when (length rargs /= length (snd ft)) + $ Left ("Expected " ++ show (length (snd ft)) ++ "arguments to " + ++ "function '" ++ n ++ "', but got " ++ show (length rargs)) + >> return () + flip mapM_ rargs $ + \a -> let argtype = fromJust (exTypeOf a) + in if canConvert argtype (fst ft) + then return a + else Left $ "Cannot convert type '" ++ pshow argtype ++ "' to '" ++ pshow (fst ft) + ++ "' in call of function '" ++ pshow n ++ "'" + return $ ExLit l (Just (fst ft)) + goE names (ExBinOp bo e1 e2 _) = do + re1 <- goE names e1 + re2 <- goE names e2 + maybe (Left $ "Cannot use operator '" ++ pshow bo ++ "' with argument types '" + ++ pshow (fromJust $ exTypeOf re1) ++ "' and '" ++ pshow (fromJust $ exTypeOf re2) ++ "'") + (return . ExBinOp bo re1 re2 . Just) + $ typeCompatibleBO bo (fromJust $ exTypeOf re1) (fromJust $ exTypeOf re2) + goE names (ExUnOp uo e _) = do + re <- goE names e + maybe (Left $ "Cannot use operator '" ++ pshow uo ++ "' with argument type '" ++ pshow (fromJust $ exTypeOf re)) + (return . ExUnOp uo re . Just) + $ typeCompatibleUO uo (fromJust $ exTypeOf re) + + +bundleVarDecls :: Program -> Error Program +bundleVarDecls prog = mapProgram prog $ defaultPM {blockHandler = goBlock} + where + goBlock :: MapperHandler Block + goBlock (Block stmts) = + let isVarDecl (StVarDeclaration {}) = True + isVarDecl _ = False + + removeDecls [] = [] + removeDecls ((StVarDeclaration _ n (Just ex)):rest) = StAssignment n ex : removeDecls rest + removeDecls ((StVarDeclaration _ _ Nothing):rest) = removeDecls rest + removeDecls (st:rest) = st : removeDecls rest + + onlyDecl (StVarDeclaration t n _) = StVarDeclaration t n Nothing + onlyDecl _ = undefined + + vdecls = map onlyDecl $ filter isVarDecl stmts + in return $ Block $ vdecls ++ removeDecls stmts + + +canConvert :: Type -> Type -> Bool +canConvert x y | x == y = True +canConvert (TypeInt f) (TypeInt t) = f <= t +canConvert (TypeUInt f) (TypeUInt t) = f <= t +canConvert TypeFloat TypeDouble = True +canConvert _ _ = False + +arithBO, compareBO, logicBO, complogBO :: [BinaryOperator] +arithBO = [Plus, Minus, Times, Divide, Modulo] +compareBO = [Equal, Unequal, Greater, Less, GEqual, LEqual] +logicBO = [BoolAnd, BoolOr] +complogBO = compareBO ++ logicBO + +typeCompatibleBO :: BinaryOperator -> Type -> Type -> Maybe Type +typeCompatibleBO Minus (TypePtr t1) (TypePtr t2) | t1 == t2 = Just $ TypeInt 1 +typeCompatibleBO bo (TypePtr t1) (TypePtr t2) | t1 == t2 && bo `elem` complogBO = Just $ TypeInt 1 +typeCompatibleBO _ (TypePtr _) _ = Nothing +typeCompatibleBO _ _ (TypePtr _) = Nothing + +typeCompatibleBO bo (TypeInt s1) (TypeInt s2) | bo `elem` arithBO = Just $ TypeInt (max s1 s2) +typeCompatibleBO bo (TypeInt _) (TypeInt _) | bo `elem` complogBO = Just $ TypeInt 1 + +typeCompatibleBO bo (TypeUInt s1) (TypeUInt s2) | bo `elem` arithBO = Just $ TypeUInt (max s1 s2) +typeCompatibleBO bo (TypeUInt _) (TypeUInt _) | bo `elem` complogBO = Just $ TypeInt 1 + +typeCompatibleBO bo t1 t2 | bo `elem` complogBO && t1 == t2 = Just $ TypeInt 1 + +typeCompatibleBO bo TypeFloat (TypeInt s) | s <= 24 = Just $ if bo `elem` arithBO then TypeFloat else TypeInt 1 +typeCompatibleBO bo (TypeInt s) TypeFloat | s <= 24 = Just $ if bo `elem` arithBO then TypeFloat else TypeInt 1 +typeCompatibleBO bo TypeDouble (TypeInt s) | s <= 53 = Just $ if bo `elem` arithBO then TypeDouble else TypeInt 1 +typeCompatibleBO bo (TypeInt s) TypeDouble | s <= 53 = Just $ if bo `elem` arithBO then TypeDouble else TypeInt 1 +typeCompatibleBO bo TypeFloat TypeDouble = Just $ if bo `elem` arithBO then TypeFloat else TypeInt 1 +typeCompatibleBO bo TypeDouble TypeFloat = Just $ if bo `elem` arithBO then TypeFloat else TypeInt 1 + +typeCompatibleBO _ _ _ = Nothing + +typeCompatibleUO :: UnaryOperator -> Type -> Maybe Type +typeCompatibleUO Not _ = Just $ TypeInt 1 +typeCompatibleUO Address t = Just $ TypePtr t +typeCompatibleUO uo t@(TypeInt _) | uo `elem` [Negate, Invert] = Just t +typeCompatibleUO uo t@(TypeUInt _) | uo `elem` [Negate, Invert] = Just t +typeCompatibleUO Negate TypeFloat = Just TypeFloat +typeCompatibleUO Negate TypeDouble = Just TypeDouble +typeCompatibleUO Dereference t@(TypePtr _) = Just t +typeCompatibleUO _ _ = Nothing + +smallestIntType :: Integer -> Type +smallestIntType i + | i >= -2^7 && i < 2^7 = TypeInt 8 + | i >= -2^15 && i < 2^15 = TypeInt 16 + | i >= -2^31 && i < 2^31 = TypeInt 32 + | otherwise = TypeInt 64 + +-- smallestUIntType :: Integer -> Type +-- smallestUIntType i +-- | i >= 0 && i < 2^8 = TypeUInt 8 +-- | i >= 0 && i < 2^16 = TypeUInt 16 +-- | i >= 0 && i < 2^32 = TypeUInt 32 +-- | otherwise = TypeUInt 64 + + +type MapperHandler a = a -> Error a + +data ProgramMapper = ProgramMapper + {declarationHandler :: MapperHandler Declaration + ,blockHandler :: MapperHandler Block + ,typeHandler :: MapperHandler Type + ,literalHandler :: MapperHandler Literal + ,binOpHandler :: MapperHandler BinaryOperator + ,unOpHandler :: MapperHandler UnaryOperator + ,expressionHandler :: MapperHandler Expression + ,statementHandler :: MapperHandler Statement + ,nameHandler :: MapperHandler Name} + +type MapperHandler' a = a -> a + +data ProgramMapper' = ProgramMapper' + {declarationHandler' :: MapperHandler' Declaration + ,blockHandler' :: MapperHandler' Block + ,typeHandler' :: MapperHandler' Type + ,literalHandler' :: MapperHandler' Literal + ,binOpHandler' :: MapperHandler' BinaryOperator + ,unOpHandler' :: MapperHandler' UnaryOperator + ,expressionHandler' :: MapperHandler' Expression + ,statementHandler' :: MapperHandler' Statement + ,nameHandler' :: MapperHandler' Name} + +defaultPM :: ProgramMapper +defaultPM = ProgramMapper return return return return return return return return return + +defaultPM' :: ProgramMapper' +defaultPM' = ProgramMapper' id id id id id id id id id + +mapProgram' :: Program -> ProgramMapper' -> Program +mapProgram' prog mapper = (\(Right r) -> r) $ mapProgram prog $ ProgramMapper + {declarationHandler = return . declarationHandler' mapper + ,blockHandler = return . blockHandler' mapper + ,typeHandler = return . typeHandler' mapper + ,literalHandler = return . literalHandler' mapper + ,binOpHandler = return . binOpHandler' mapper + ,unOpHandler = return . unOpHandler' mapper + ,expressionHandler = return . expressionHandler' mapper + ,statementHandler = return . statementHandler' mapper + ,nameHandler = return . nameHandler' mapper} + +mapProgram :: Program -> ProgramMapper -> Error Program +mapProgram prog mapper = goP prog + where + h_d = declarationHandler mapper + h_b = blockHandler mapper + h_t = typeHandler mapper + h_l = literalHandler mapper + h_bo = binOpHandler mapper + h_uo = unOpHandler mapper + h_e = expressionHandler mapper + h_s = statementHandler mapper + h_n = nameHandler mapper + + goP :: MapperHandler Program + goP (Program decls) = Program <$> sequence (map (\d -> goD d >>= h_d) decls) + + goD :: MapperHandler Declaration + goD (DecFunction t n a b) = do + rt <- goT t + rn <- goN n + ra <- sequence $ map (\(at,an) -> (,) <$> goT at <*> goN an) a + rb <- goB b + h_d $ DecFunction rt rn ra rb + goD (DecVariable t n mv) = do + rt <- goT t + rn <- goN n + rmv <- sequence $ fmap goE mv + h_d $ DecVariable rt rn rmv + goD (DecTypedef t n) = do + rt <- goT t + rn <- goN n + h_d $ DecTypedef rt rn + + goT :: MapperHandler Type + goT (TypePtr t) = goT t >>= (h_t . TypePtr) + goT (TypeName n) = goN n >>= (h_t . TypeName) + goT t = h_t t + + goN :: MapperHandler Name + goN = h_n + + goB :: MapperHandler Block + goB (Block sts) = (Block <$> sequence (map goS sts)) >>= h_b + + goE :: MapperHandler Expression + goE (ExLit l mt) = do + rl <- goL l + h_e $ ExLit rl mt + goE (ExBinOp bo e1 e2 mt) = do + rbo <- goBO bo + re1 <- goE e1 + re2 <- goE e2 + h_e $ ExBinOp rbo re1 re2 mt + goE (ExUnOp uo e mt) = do + ruo <- goUO uo + re <- goE e + h_e $ ExUnOp ruo re mt + + goS :: MapperHandler Statement + goS StEmpty = h_s StEmpty + goS (StBlock b) = goB b >>= (h_s . StBlock) + goS (StExpr e) = goE e >>= (h_s . StExpr) + goS (StVarDeclaration t n me) = do + rt <- goT t + rn <- goN n + rme <- sequence $ fmap goE me + h_s $ StVarDeclaration rt rn rme + goS (StAssignment n e) = do + rn <- goN n + re <- goE e + h_s $ StAssignment rn re + goS (StIf e s1 s2) = do + re <- goE e + rs1 <- goS s1 + rs2 <- goS s2 + h_s $ StIf re rs1 rs2 + goS (StWhile e s) = do + re <- goE e + rs <- goS s + h_s $ StWhile re rs + goS (StReturn e) = goE e >>= (h_s . StReturn) + + goL :: MapperHandler Literal + goL l@(LitString _) = h_l l + goL l@(LitInt _) = h_l l + goL (LitVar n) = goN n >>= (h_l . LitVar) + goL (LitCall n a) = do + rn <- goN n + ra <- sequence $ map goE a + h_l $ LitCall rn ra + + goBO :: MapperHandler BinaryOperator + goBO = h_bo + + goUO :: MapperHandler UnaryOperator + goUO = h_uo |