module Check(checkProgram) where import Control.Monad import Data.Maybe import qualified Data.Map.Strict as Map --import Debug.Trace import AST import PShow type Error a = Either String a checkProgram :: Program -> Error Program checkProgram prog = do let processed = replaceTypes prog checkUndefinedTypes processed typeCheck processed >>= bundleVarDecls replaceTypes :: Program -> Program replaceTypes prog@(Program decls) = mapProgram' filtered mapper where filtered = Program $ filter notTypedef decls mapper = defaultPM' {typeHandler' = typeReplacer (findTypeRenames prog)} notTypedef :: Declaration -> Bool notTypedef (DecTypedef _ _) = False notTypedef _ = True typeReplacer :: Map.Map Name Type -> Type -> Type typeReplacer m t@(TypeName n) = maybe t (recurseAfter m) $ Map.lookup n m typeReplacer _ t = t recurseAfter :: Map.Map Name Type -> Type -> Type recurseAfter m t@(TypeName _) = typeReplacer m t recurseAfter m (TypePtr t) = TypePtr $ recurseAfter m t recurseAfter _ t = t findTypeRenames :: Program -> Map.Map Name Type findTypeRenames (Program d) = foldl go Map.empty d where go :: Map.Map Name Type -> Declaration -> Map.Map Name Type go m (DecTypedef t n) = Map.insert n t m go m _ = m checkUndefinedTypes :: Program -> Error () checkUndefinedTypes prog = fmap (const ()) $ mapProgram prog $ defaultPM {typeHandler = check} where check :: MapperHandler Type check (TypeName n) = Left $ "Undefined type name '" ++ n ++ "'" check t = Right t typeCheck :: Program -> Error Program typeCheck (Program decls) = Program <$> mapM (goD topLevelNames) decls where topLevelNames :: Map.Map Name Type topLevelNames = foldr (uncurry Map.insert) Map.empty pairs where pairs = map ((,) <$> nameOf <*> typeOf) $ filter isVarDecl decls functionTypes :: Map.Map Name (Type,[Type]) functionTypes = foldr (uncurry Map.insert) Map.empty pairs where pairs = map ((,) <$> nameOf <*> getTypes) $ filter isFunctionDecl decls getTypes (DecFunction rt _ args _) = (rt, map fst args) getTypes (DecExtern (TypeFunc rt ats) _) = (rt, ats) getTypes _ = undefined isVarDecl (DecVariable {}) = True isVarDecl _ = False isFunctionDecl (DecFunction {}) = True isFunctionDecl (DecExtern (TypeFunc {}) _) = True isFunctionDecl _ = False goD :: Map.Map Name Type -> Declaration -> Error Declaration goD names (DecFunction frt name args body) = do newbody <- goB frt (foldr (\(t,n) m -> Map.insert n t m) names args) body return $ DecFunction frt name args newbody goD _ (DecVariable (TypeFunc _ _) _ _) = Left $ "Cannot declare global variable with function type" goD _ dec = return dec goB :: Type -- function return type -> Map.Map Name Type -> Block -> Error Block goB frt names (Block stmts) = Block . snd <$> foldl foldfunc (return (names, [])) stmts where foldfunc :: Error (Map.Map Name Type, [Statement]) -> Statement -> Error (Map.Map Name Type, [Statement]) foldfunc ep st = do (names', lst) <- ep (newnames', newst) <- goS frt names' st return (newnames', lst ++ [newst]) -- TODO: fix slow tail-append goS :: Type -- function return type -> Map.Map Name Type -> Statement -> Error (Map.Map Name Type, Statement) goS _ _ (StVarDeclaration (TypeFunc _ _) _ _) = Left $ "Cannot declare variable with function type" goS _ names st@(StVarDeclaration t n Nothing) = do maybe (return (Map.insert n t names, st)) (const $ Left $ "Duplicate variable '" ++ n ++ "'") (Map.lookup n names) goS frt names (StVarDeclaration t n (Just e)) = do (newnames, _) <- goS frt names (StVarDeclaration t n Nothing) (_, StAssignment _ newe) <- goS frt newnames (StAssignment n e) return (newnames, StVarDeclaration t n (Just newe)) goS _ names (StAssignment n e) = maybe (Left $ "Undefined variable '" ++ n ++ "'") go (Map.lookup n names) where go dsttype = do re <- goE names e let (Just extype) = exTypeOf re if canConvert extype dsttype then return (names, StAssignment n re) else Left $ "Cannot convert type '" ++ pshow extype ++ "' to '" ++ pshow dsttype ++ "' in assignment to variable '" ++ n ++ "'" goS _ names st@StEmpty = return (names, st) goS frt names (StBlock bl) = do newbl <- goB frt names bl return (names, StBlock newbl) goS _ names (StExpr e) = do re <- goE names e return (names, StExpr re) goS frt names (StIf e s1 s2) = do re <- goE names e (_, rs1) <- goS frt names s1 (_, rs2) <- goS frt names s2 return (names, StIf re rs1 rs2) goS frt names (StWhile e s) = do re <- goE names e (_, rs) <- goS frt names s return (names, StWhile re rs) goS TypeVoid names (StReturn Nothing) = return (names, StReturn Nothing) goS _ _ (StReturn Nothing) = Left $ "Non-void function should return a value" goS frt names (StReturn (Just e)) = do re <- goE names e let (Just extype) = exTypeOf re if canConvert extype frt then return (names, StReturn (Just re)) else Left $ "Cannot convert type '" ++ pshow extype ++ "' to '" ++ pshow frt ++ "' in return statement" -- Postcondition: the expression (if any) has a type annotation. goE :: Map.Map Name Type -> Expression -> Error Expression goE _ (ExLit l@(LitInt i) _) = smallestIntType i >>= \t -> return $ ExLit l $ Just t goE _ (ExLit l@(LitUInt i) _) = smallestUIntType i >>= \t -> return $ ExLit l $ Just t goE _ (ExLit l@(LitFloat f) _) = return $ ExLit l $ Just (smallestFloatType f) goE _ (ExLit l@(LitString _) _) = return $ ExLit l $ Just (TypePtr (TypeInt 8)) goE names (ExLit l@(LitVar n) _) = maybe (Left $ "Undefined variable '" ++ n ++ "'") (return . ExLit l . Just) (Map.lookup n names) goE names (ExLit (LitCall n args) _) = do ft <- maybe (Left $ "Unknown function '" ++ n ++ "'") return $ Map.lookup n functionTypes rargs <- mapM (goE names) args when (length rargs /= length (snd ft)) $ Left ("Expected " ++ show (length (snd ft)) ++ "arguments to " ++ "function '" ++ n ++ "', but got " ++ show (length rargs)) >> return () flip mapM_ (zip rargs [0..]) $ \(a,i) -> let argtype = fromJust (exTypeOf a) in if canConvert argtype (snd ft !! i) then return a else Left $ "Cannot convert type '" ++ pshow argtype ++ "' to '" ++ pshow (snd ft !! i) ++ "' in call of function '" ++ n ++ "'" return $ ExLit (LitCall n rargs) (Just (fst ft)) goE names (ExCast totype ex) = do rex <- goE names ex let fromtype = fromJust (exTypeOf rex) if canCast fromtype totype then return $ ExCast totype rex else Left $ "Cannot cast type '" ++ pshow fromtype ++ "' to '" ++ pshow totype ++ "'" goE names (ExBinOp bo e1 e2 _) = do re1 <- goE names e1 re2 <- goE names e2 maybe (Left $ "Cannot use operator '" ++ pshow bo ++ "' with argument types '" ++ pshow (fromJust $ exTypeOf re1) ++ "' and '" ++ pshow (fromJust $ exTypeOf re2) ++ "'") (return . ExBinOp bo re1 re2 . Just) $ resultTypeBO bo (fromJust $ exTypeOf re1) (fromJust $ exTypeOf re2) goE names (ExUnOp uo e _) = do re <- goE names e maybe (Left $ "Cannot use operator '" ++ pshow uo ++ "' with argument type '" ++ pshow (fromJust $ exTypeOf re)) (return . ExUnOp uo re . Just) $ resultTypeUO uo (fromJust $ exTypeOf re) bundleVarDecls :: Program -> Error Program bundleVarDecls prog = mapProgram prog $ defaultPM {blockHandler = goBlock} where goBlock :: MapperHandler Block goBlock (Block stmts) = let isVarDecl (StVarDeclaration {}) = True isVarDecl _ = False removeDecls [] = [] removeDecls ((StVarDeclaration _ n (Just ex)):rest) = StAssignment n ex : removeDecls rest removeDecls ((StVarDeclaration _ _ Nothing):rest) = removeDecls rest removeDecls (st:rest) = st : removeDecls rest onlyDecl (StVarDeclaration t n _) = StVarDeclaration t n Nothing onlyDecl _ = undefined vdecls = map onlyDecl $ filter isVarDecl stmts in return $ Block $ vdecls ++ removeDecls stmts canConvert :: Type -> Type -> Bool canConvert x y | x == y = True canConvert (TypeInt f) (TypeInt t) = f <= t canConvert (TypeUInt f) (TypeUInt t) = f <= t canConvert (TypeUInt 1) (TypeInt _) = True canConvert TypeFloat TypeDouble = True canConvert (TypeInt _) TypeFloat = True canConvert (TypeInt _) TypeDouble = True canConvert (TypeUInt _) TypeFloat = True canConvert (TypeUInt _) TypeDouble = True canConvert _ _ = False canCast :: Type -> Type -> Bool canCast t1 t2 = any (\f -> f t1 && f t2) [numberGroup, intptrGroup] where numberGroup (TypeInt _) = True numberGroup (TypeUInt _) = True numberGroup TypeFloat = True numberGroup TypeDouble = True numberGroup _ = False intptrGroup (TypeInt _) = True intptrGroup (TypeUInt _) = True intptrGroup (TypePtr _) = True intptrGroup (TypeFunc _ _) = True intptrGroup _ = False arithBO, compareBO, logicBO, complogBO :: [BinaryOperator] arithBO = [Plus, Minus, Times, Divide, Modulo] compareBO = [Equal, Unequal, Greater, Less, GEqual, LEqual] logicBO = [BoolAnd, BoolOr] complogBO = compareBO ++ logicBO resultTypeBO :: BinaryOperator -> Type -> Type -> Maybe Type resultTypeBO Minus (TypePtr t1) (TypePtr t2) | t1 == t2 = Just $ TypeUInt 64 resultTypeBO bo (TypePtr t1) (TypePtr t2) | t1 == t2 && bo `elem` complogBO = Just $ TypeUInt 1 resultTypeBO bo t@(TypePtr _) (TypeInt _) | bo `elem` [Plus, Minus] = Just t resultTypeBO bo t@(TypePtr _) (TypeUInt _) | bo `elem` [Plus, Minus] = Just t resultTypeBO bo (TypeInt _) t@(TypePtr _) | bo `elem` [Plus, Minus] = Just t resultTypeBO bo (TypeUInt _) t@(TypePtr _) | bo `elem` [Plus, Minus] = Just t resultTypeBO Index (TypePtr t) (TypeInt _) = Just t resultTypeBO Index (TypePtr t) (TypeUInt _) = Just t resultTypeBO _ (TypePtr _) _ = Nothing resultTypeBO _ _ (TypePtr _) = Nothing resultTypeBO bo (TypeInt s1) (TypeInt s2) | bo `elem` arithBO = Just $ TypeInt (max s1 s2) resultTypeBO bo (TypeInt _) (TypeInt _) | bo `elem` complogBO = Just $ TypeUInt 1 resultTypeBO bo (TypeUInt s1) (TypeUInt s2) | bo `elem` arithBO = Just $ TypeUInt (max s1 s2) resultTypeBO bo (TypeUInt _) (TypeUInt _) | bo `elem` complogBO = Just $ TypeUInt 1 resultTypeBO bo t1 t2 | bo `elem` complogBO && t1 == t2 = Just $ TypeUInt 1 resultTypeBO bo TypeFloat (TypeInt s) | s <= 24 = Just $ if bo `elem` arithBO then TypeFloat else TypeUInt 1 resultTypeBO bo (TypeInt s) TypeFloat | s <= 24 = Just $ if bo `elem` arithBO then TypeFloat else TypeUInt 1 resultTypeBO bo TypeDouble (TypeInt s) | s <= 53 = Just $ if bo `elem` arithBO then TypeDouble else TypeUInt 1 resultTypeBO bo (TypeInt s) TypeDouble | s <= 53 = Just $ if bo `elem` arithBO then TypeDouble else TypeUInt 1 resultTypeBO bo TypeFloat TypeFloat = Just $ if bo `elem` arithBO then TypeFloat else TypeUInt 1 resultTypeBO bo TypeDouble TypeDouble = Just $ if bo `elem` arithBO then TypeDouble else TypeUInt 1 resultTypeBO bo TypeFloat TypeDouble = Just $ if bo `elem` arithBO then TypeDouble else TypeUInt 1 resultTypeBO bo TypeDouble TypeFloat = Just $ if bo `elem` arithBO then TypeDouble else TypeUInt 1 resultTypeBO _ _ _ = Nothing resultTypeUO :: UnaryOperator -> Type -> Maybe Type resultTypeUO Not _ = Just $ TypeUInt 1 resultTypeUO Address t = Just $ TypePtr t resultTypeUO uo t@(TypeInt _) | uo `elem` [Negate, Invert] = Just t resultTypeUO uo t@(TypeUInt _) | uo `elem` [Negate, Invert] = Just t resultTypeUO Negate TypeFloat = Just TypeFloat resultTypeUO Negate TypeDouble = Just TypeDouble resultTypeUO Dereference (TypePtr t) = Just t resultTypeUO _ _ = Nothing smallestFloatType :: Double -> Type smallestFloatType d = let truncfloat = realToFrac (realToFrac d :: Float) :: Double in if d == truncfloat then TypeFloat else TypeDouble smallestIntType :: Integer -> Error Type smallestIntType i | i >= -2^7 && i < 2^7 = return $ TypeInt 8 | i >= -2^15 && i < 2^15 = return $ TypeInt 16 | i >= -2^31 && i < 2^31 = return $ TypeInt 32 | i >= -2^63 && i < 2^63 = return $ TypeInt 64 | otherwise = Left $ "Integer literal '" ++ pshow i ++ "' too wide for i64" smallestUIntType :: Integer -> Error Type smallestUIntType i | i > -2^8 && i < 2^8 = return $ TypeUInt 8 | i > -2^16 && i < 2^16 = return $ TypeUInt 16 | i > -2^32 && i < 2^32 = return $ TypeUInt 32 | i > -2^64 && i < 2^64 = return $ TypeUInt 64 | otherwise = Left $ "Unsigned integer literal '" ++ pshow i ++ "U' too wide for u64" type MapperHandler a = a -> Error a data ProgramMapper = ProgramMapper {declarationHandler :: MapperHandler Declaration ,blockHandler :: MapperHandler Block ,typeHandler :: MapperHandler Type ,literalHandler :: MapperHandler Literal ,binOpHandler :: MapperHandler BinaryOperator ,unOpHandler :: MapperHandler UnaryOperator ,expressionHandler :: MapperHandler Expression ,statementHandler :: MapperHandler Statement ,nameHandler :: MapperHandler Name} type MapperHandler' a = a -> a data ProgramMapper' = ProgramMapper' {declarationHandler' :: MapperHandler' Declaration ,blockHandler' :: MapperHandler' Block ,typeHandler' :: MapperHandler' Type ,literalHandler' :: MapperHandler' Literal ,binOpHandler' :: MapperHandler' BinaryOperator ,unOpHandler' :: MapperHandler' UnaryOperator ,expressionHandler' :: MapperHandler' Expression ,statementHandler' :: MapperHandler' Statement ,nameHandler' :: MapperHandler' Name} defaultPM :: ProgramMapper defaultPM = ProgramMapper return return return return return return return return return defaultPM' :: ProgramMapper' defaultPM' = ProgramMapper' id id id id id id id id id mapProgram' :: Program -> ProgramMapper' -> Program mapProgram' prog mapper = (\(Right r) -> r) $ mapProgram prog $ ProgramMapper {declarationHandler = return . declarationHandler' mapper ,blockHandler = return . blockHandler' mapper ,typeHandler = return . typeHandler' mapper ,literalHandler = return . literalHandler' mapper ,binOpHandler = return . binOpHandler' mapper ,unOpHandler = return . unOpHandler' mapper ,expressionHandler = return . expressionHandler' mapper ,statementHandler = return . statementHandler' mapper ,nameHandler = return . nameHandler' mapper} mapProgram :: Program -> ProgramMapper -> Error Program mapProgram prog mapper = goP prog where h_d = declarationHandler mapper h_b = blockHandler mapper h_t = typeHandler mapper h_l = literalHandler mapper h_bo = binOpHandler mapper h_uo = unOpHandler mapper h_e = expressionHandler mapper h_s = statementHandler mapper h_n = nameHandler mapper goP :: MapperHandler Program goP (Program decls) = Program <$> sequence (map (\d -> goD d >>= h_d) decls) goD :: MapperHandler Declaration goD (DecFunction t n a b) = do rt <- goT t rn <- goN n ra <- sequence $ map (\(at,an) -> (,) <$> goT at <*> goN an) a rb <- goB b h_d $ DecFunction rt rn ra rb goD (DecVariable t n mv) = do rt <- goT t rn <- goN n rmv <- sequence $ fmap goE mv h_d $ DecVariable rt rn rmv goD (DecTypedef t n) = do rt <- goT t rn <- goN n h_d $ DecTypedef rt rn goD (DecExtern t n) = do rt <- goT t rn <- goN n h_d $ DecExtern rt rn goT :: MapperHandler Type goT (TypePtr t) = goT t >>= (h_t . TypePtr) goT (TypeName n) = goN n >>= (h_t . TypeName) goT (TypeFunc t as) = do rt <- goT t ras <- mapM goT as h_t $ TypeFunc rt ras goT t = h_t t goN :: MapperHandler Name goN = h_n goB :: MapperHandler Block goB (Block sts) = (Block <$> sequence (map goS sts)) >>= h_b goE :: MapperHandler Expression goE (ExLit l mt) = do rl <- goL l h_e $ ExLit rl mt goE (ExCast t e) = do rt <- goT t re <- goE e h_e $ ExCast rt re goE (ExBinOp bo e1 e2 mt) = do rbo <- goBO bo re1 <- goE e1 re2 <- goE e2 h_e $ ExBinOp rbo re1 re2 mt goE (ExUnOp uo e mt) = do ruo <- goUO uo re <- goE e h_e $ ExUnOp ruo re mt goS :: MapperHandler Statement goS StEmpty = h_s StEmpty goS (StBlock b) = goB b >>= (h_s . StBlock) goS (StExpr e) = goE e >>= (h_s . StExpr) goS (StVarDeclaration t n me) = do rt <- goT t rn <- goN n rme <- sequence $ fmap goE me h_s $ StVarDeclaration rt rn rme goS (StAssignment n e) = do rn <- goN n re <- goE e h_s $ StAssignment rn re goS (StIf e s1 s2) = do re <- goE e rs1 <- goS s1 rs2 <- goS s2 h_s $ StIf re rs1 rs2 goS (StWhile e s) = do re <- goE e rs <- goS s h_s $ StWhile re rs goS (StReturn Nothing) = h_s (StReturn Nothing) goS (StReturn (Just e)) = goE e >>= (h_s . StReturn . Just) goL :: MapperHandler Literal goL l@(LitString _) = h_l l goL l@(LitInt _) = h_l l goL l@(LitUInt _) = h_l l goL l@(LitFloat _) = h_l l goL (LitVar n) = goN n >>= (h_l . LitVar) goL (LitCall n a) = do rn <- goN n ra <- sequence $ map goE a h_l $ LitCall rn ra goBO :: MapperHandler BinaryOperator goBO = h_bo goUO :: MapperHandler UnaryOperator goUO = h_uo