From 285847e5e404bea2941f1ce4b15fb3c8c27993c2 Mon Sep 17 00:00:00 2001 From: tomsmeding Date: Mon, 23 Jan 2017 20:38:30 +0100 Subject: Code now typechecks --- Makefile | 2 +- ast.hs | 47 +++++++--- codegen.hs | 295 +++++++++++++++++++++++++++++++++++++++++++++++-------------- main.hs | 3 +- parser.hs | 10 +-- 5 files changed, 270 insertions(+), 87 deletions(-) diff --git a/Makefile b/Makefile index cb0faab..499a974 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ GHC = ghc -GHCFLAGS = -Wall -O3 +GHCFLAGS = -Wall -Wno-type-defaults -O3 TARGET = main diff --git a/ast.hs b/ast.hs index bd845df..e3db600 100644 --- a/ast.hs +++ b/ast.hs @@ -1,7 +1,4 @@ -module AST( - Name, - Program(..), Declaration(..), Block(..), Type(..), Literal(..), - BinaryOperator(..), UnaryOperator(..), Expression(..), Statement(..)) where +module AST where import Data.List @@ -52,10 +49,10 @@ data UnaryOperator = Negate | Not | Invert | Dereference | Address deriving (Show, Eq) -data Expression - = ExLit Literal - | ExBinOp BinaryOperator Expression Expression - | ExUnOp UnaryOperator Expression +data Expression -- (Maybe Type)'s are type annotations by the type checker + = ExLit Literal (Maybe Type) + | ExBinOp BinaryOperator Expression Expression (Maybe Type) + | ExUnOp UnaryOperator Expression (Maybe Type) deriving (Show) data Statement @@ -70,9 +67,24 @@ data Statement deriving (Show) -indent :: Int -> String -> String -indent sz str = intercalate "\n" $ map (prefix++) $ lines str - where prefix = replicate sz ' ' +exLit_ :: Literal -> Expression +exLit_ l = ExLit l Nothing + +exBinOp_ :: BinaryOperator -> Expression -> Expression -> Expression +exBinOp_ bo a b = ExBinOp bo a b Nothing + +exUnOp_ :: UnaryOperator -> Expression -> Expression +exUnOp_ uo e = ExUnOp uo e Nothing + +exTypeOf :: Expression -> Maybe Type +exTypeOf (ExLit _ mt) = mt +exTypeOf (ExBinOp _ _ _ mt) = mt +exTypeOf (ExUnOp _ _ mt) = mt + +-- exSetType :: Type -> Expression -> Expression +-- exSetType t (ExLit l _) = ExLit l t +-- exSetType t (ExBinOp bo e1 e2 _) = ExBinOp bo e1 e2 t +-- exSetType t (ExUnOp bo e _) = ExUnOp bo e t instance PShow Program where @@ -92,6 +104,10 @@ instance PShow Declaration where instance PShow Block where pshow (Block []) = "{}" pshow (Block stmts) = concat ["{\n", indent 4 $ intercalate "\n" (map pshow stmts), "\n}"] + where + indent :: Int -> String -> String + indent sz str = intercalate "\n" $ map (prefix++) $ lines str + where prefix = replicate sz ' ' instance PShow Type where pshow (TypeInt sz) = 'i' : pshow sz @@ -130,9 +146,12 @@ instance PShow UnaryOperator where pshow Address = "&" instance PShow Expression where - pshow (ExLit lit) = pshow lit - pshow (ExBinOp op a b) = concat [pshow a, " ", pshow op, " ", pshow b] - pshow (ExUnOp op a) = concat [pshow op, pshow a] + pshow (ExLit lit Nothing) = pshow lit + pshow (ExLit lit (Just t)) = concat ["(", pshow lit, " :: ", pshow t, ")"] + pshow (ExBinOp op a b Nothing) = concat ["(", pshow a, " ", pshow op, " ", pshow b, ")"] + pshow (ExBinOp op a b (Just t)) = concat ["(", pshow a, " ", pshow op, " ", pshow b, " :: ", pshow t, ")"] + pshow (ExUnOp op a Nothing) = concat [pshow op, pshow a] + pshow (ExUnOp op a (Just t)) = concat ["(", pshow op, pshow a, " :: ", pshow t, ")"] instance PShow Statement where pshow StEmpty = ";" diff --git a/codegen.hs b/codegen.hs index f2c35b4..f314c63 100644 --- a/codegen.hs +++ b/codegen.hs @@ -1,5 +1,7 @@ -module Codegen(module Codegen, A.Module) where +module Codegen(codegen, A.Module) where +import Control.Monad +import Data.Maybe import qualified Data.Map.Strict as Map -- import qualified LLVM.General.AST.Type as A -- import qualified LLVM.General.AST.Global as A @@ -8,8 +10,10 @@ import qualified Data.Map.Strict as Map -- import qualified LLVM.General.AST.Name as A -- import qualified LLVM.General.AST.Instruction as A import qualified LLVM.General.AST as A +import Debug.Trace import AST +import PShow type Error a = Either String a @@ -52,52 +56,207 @@ preprocess prog@(Program decls) = mapProgram' filtered mapper generateDefs :: Program -> Error [A.Definition] generateDefs prog = do checkUndefinedTypes prog - checkUndefinedVars prog - fail "TODO" + checked <- typeCheck prog + collected <- collectVarDecls checked + void $ trace "Collected:" $ return [] + void $ trace (pshow collected) $ return [] + void $ fail "TODO" return [] checkUndefinedTypes :: Program -> Error () checkUndefinedTypes prog = fmap (const ()) $ mapProgram prog $ defaultPM {typeHandler = check} - where - check :: Type -> Error Type - check (TypeName n) = Left $ "Undefined type name '" ++ n ++ "'" - check t = Right t - --- checkUndefinedVars :: Program -> Error () --- checkUndefinedVars prog = do - - --- mapTypes' :: Program -> (Type -> Type) -> Program --- mapTypes' prog f = (\(Right res) -> res) $ mapTypes prog (return . f) - --- mapTypes :: Program -> (Type -> Error Type) -> Error Program --- mapTypes (Program decls) f = Program <$> sequence (map goD decls) --- where --- handler :: Type -> Error Type --- handler (TypePtr t) = f t >>= f . TypePtr --- handler t = f t - --- goD :: Declaration -> Error Declaration --- goD (DecFunction t n a b) = do --- rt <- handler t --- ra <- sequence $ map (\(at,an) -> (\art -> (art,an)) <$> handler at) a --- rb <- goB b --- return $ DecFunction rt n ra rb --- goD (DecVariable t n v) = (\rt -> DecVariable rt n v) <$> handler t --- goD (DecTypedef t n) = (\rt -> DecTypedef rt n) <$> handler t - --- goB :: Block -> Error Block --- goB (Block stmts) = Block <$> sequence (map goS stmts) - --- goS :: Statement -> Error Statement --- goS (StBlock bl) = StBlock <$> goB bl --- goS (StVarDeclaration t n e) = (\rt -> StVarDeclaration rt n e) <$> handler t --- goS (StIf c t e) = do --- rt <- goS t --- re <- goS e --- return $ StIf c rt re --- goS (StWhile c b) = StWhile c <$> goS b --- goS s = return s + where + check :: MapperHandler Type + check (TypeName n) = Left $ "Undefined type name '" ++ n ++ "'" + check t = Right t + + +typeCheck :: Program -> Error Program +typeCheck (Program decls) = Program <$> mapM (goD topLevelNames) decls + where + topLevelNames :: Map.Map Name Type + topLevelNames = foldr (uncurry Map.insert) Map.empty pairs + where pairs = map ((,) <$> nameOf <*> typeOf) $ filter isVarDecl decls + + functionTypes :: Map.Map Name (Type,[Type]) + functionTypes = foldr (uncurry Map.insert) Map.empty pairs + where pairs = map ((,) <$> nameOf <*> getTypes) $ filter isFunctionDecl decls + getTypes (DecFunction rt _ args _) = (rt, map fst args) + getTypes _ = undefined + + isVarDecl (DecVariable {}) = True + isVarDecl _ = False + + isFunctionDecl (DecFunction {}) = True + isFunctionDecl _ = False + + goD :: Map.Map Name Type -> Declaration -> Error Declaration + goD names (DecFunction frt name args body) = do + newbody <- goB frt (foldr (\(t,n) m -> Map.insert n t m) names args) body + return $ DecFunction frt name args newbody + goD _ dec = return dec + + goB :: Type -- function return type + -> Map.Map Name Type -> Block -> Error Block + goB frt names (Block stmts) = Block . snd <$> foldl foldfunc (return (names, [])) stmts + where + foldfunc :: Error (Map.Map Name Type, [Statement]) -> Statement -> Error (Map.Map Name Type, [Statement]) + foldfunc ep st = do + (names', lst) <- ep + (newnames', newst) <- goS frt names' st + return (newnames', lst ++ [newst]) -- TODO: fix slow tail-append + + goS :: Type -- function return type + -> Map.Map Name Type -> Statement -> Error (Map.Map Name Type, Statement) + goS _ names st@(StVarDeclaration t n Nothing) = return (Map.insert n t names, st) + goS frt names (StVarDeclaration t n (Just e)) = do + (newnames, _) <- goS frt names (StVarDeclaration t n Nothing) + goS frt newnames (StAssignment n e) + goS _ names (StAssignment n e) = maybe (Left $ "Undefined variable '" ++ n ++ "'") go (Map.lookup n names) + where go dsttype = do + re <- goE names e + let (Just extype) = exTypeOf re + if canConvert extype dsttype + then return (names, StAssignment n re) + else Left $ "Cannot convert type '" ++ pshow extype ++ "' to '" + ++ pshow dsttype ++ "' in assignment to variable '" ++ n ++ "'" + goS _ names st@StEmpty = return (names, st) + goS frt names (StBlock bl) = do + newbl <- goB frt names bl + return (names, StBlock newbl) + goS _ names (StExpr e) = do + re <- goE names e + return (names, StExpr re) + goS frt names (StIf e s1 s2) = do + re <- goE names e + (_, rs1) <- goS frt names s1 + (_, rs2) <- goS frt names s2 + return (names, StIf re rs1 rs2) + goS frt names (StWhile e s) = do + re <- goE names e + (_, rs) <- goS frt names s + return (names, StWhile re rs) + goS frt names (StReturn e) = do + re <- goE names e + let (Just extype) = exTypeOf re + if canConvert extype frt + then return (names, StReturn re) + else Left $ "Cannot convert type '" ++ pshow extype ++ "' to '" + ++ pshow frt ++ "' in return statement" + + -- Postcondition: the expression (if any) has a type annotation. + goE :: Map.Map Name Type -> Expression -> Error Expression + goE _ (ExLit l@(LitInt i) _) = return $ ExLit l $ Just (smallestIntType i) + goE _ (ExLit l@(LitString _) _) = return $ ExLit l $ Just (TypePtr (TypeInt 8)) + goE names (ExLit l@(LitVar n) _) = maybe (Left $ "Undefined variable '" ++ n ++ "'") (return . ExLit l . Just) + (Map.lookup n names) + goE names (ExLit l@(LitCall n args) _) = do + ft <- maybe (Left $ "Unknown function '" ++ n ++ "'") return $ Map.lookup n functionTypes + rargs <- mapM (goE names) args + when (length rargs /= length (snd ft)) + $ Left ("Expected " ++ show (length (snd ft)) ++ "arguments to " + ++ "function '" ++ n ++ "', but got " ++ show (length rargs)) + >> return () + flip mapM_ rargs $ + \a -> let argtype = fromJust (exTypeOf a) + in if canConvert argtype (fst ft) + then return a + else Left $ "Cannot convert type '" ++ pshow argtype ++ "' to '" ++ pshow (fst ft) + ++ "' in call of function '" ++ pshow n ++ "'" + return $ ExLit l (Just (fst ft)) + goE names (ExBinOp bo e1 e2 _) = do + re1 <- goE names e1 + re2 <- goE names e2 + maybe (Left $ "Cannot use operator '" ++ pshow bo ++ "' with argument types '" + ++ pshow (fromJust $ exTypeOf re1) ++ "' and '" ++ pshow (fromJust $ exTypeOf re2) ++ "'") + (return . ExBinOp bo re1 re2 . Just) + $ typeCompatibleBO bo (fromJust $ exTypeOf re1) (fromJust $ exTypeOf re2) + goE names (ExUnOp uo e _) = do + re <- goE names e + maybe (Left $ "Cannot use operator '" ++ pshow uo ++ "' with argument type '" ++ pshow (fromJust $ exTypeOf re)) + (return . ExUnOp uo re . Just) + $ typeCompatibleUO uo (fromJust $ exTypeOf re) + + +collectVarDecls :: Program -> Error Program +collectVarDecls prog = mapProgram prog $ defaultPM {blockHandler = goBlock} + where + goBlock :: MapperHandler Block + goBlock (Block stmts) = + let isVarDecl (StVarDeclaration {}) = True + isVarDecl _ = False + + removeDecls [] = [] + removeDecls ((StVarDeclaration _ n (Just ex)):rest) = StAssignment n ex : removeDecls rest + removeDecls ((StVarDeclaration _ _ Nothing):rest) = removeDecls rest + removeDecls (st:rest) = st : removeDecls rest + + onlyDecl (StVarDeclaration t n _) = StVarDeclaration t n Nothing + onlyDecl _ = undefined + + vdecls = map onlyDecl $ filter isVarDecl stmts + in return $ Block $ vdecls ++ removeDecls stmts + + +canConvert :: Type -> Type -> Bool +canConvert x y | x == y = True +canConvert (TypeInt f) (TypeInt t) = f <= t +canConvert (TypeUInt f) (TypeUInt t) = f <= t +canConvert TypeFloat TypeDouble = True +canConvert _ _ = False + +arithBO, compareBO, logicBO, complogBO :: [BinaryOperator] +arithBO = [Plus, Minus, Times, Divide, Modulo] +compareBO = [Equal, Unequal, Greater, Less, GEqual, LEqual] +logicBO = [BoolAnd, BoolOr] +complogBO = compareBO ++ logicBO + +typeCompatibleBO :: BinaryOperator -> Type -> Type -> Maybe Type +typeCompatibleBO Minus (TypePtr t1) (TypePtr t2) | t1 == t2 = Just $ TypeInt 1 +typeCompatibleBO bo (TypePtr t1) (TypePtr t2) | t1 == t2 && bo `elem` complogBO = Just $ TypeInt 1 +typeCompatibleBO _ (TypePtr _) _ = Nothing +typeCompatibleBO _ _ (TypePtr _) = Nothing + +typeCompatibleBO bo (TypeInt s1) (TypeInt s2) | bo `elem` arithBO = Just $ TypeInt (max s1 s2) +typeCompatibleBO bo (TypeInt _) (TypeInt _) | bo `elem` complogBO = Just $ TypeInt 1 + +typeCompatibleBO bo (TypeUInt s1) (TypeUInt s2) | bo `elem` arithBO = Just $ TypeUInt (max s1 s2) +typeCompatibleBO bo (TypeUInt _) (TypeUInt _) | bo `elem` complogBO = Just $ TypeInt 1 + +typeCompatibleBO bo t1 t2 | bo `elem` complogBO && t1 == t2 = Just $ TypeInt 1 + +typeCompatibleBO bo TypeFloat (TypeInt s) | s <= 24 = Just $ if bo `elem` arithBO then TypeFloat else TypeInt 1 +typeCompatibleBO bo (TypeInt s) TypeFloat | s <= 24 = Just $ if bo `elem` arithBO then TypeFloat else TypeInt 1 +typeCompatibleBO bo TypeDouble (TypeInt s) | s <= 53 = Just $ if bo `elem` arithBO then TypeDouble else TypeInt 1 +typeCompatibleBO bo (TypeInt s) TypeDouble | s <= 53 = Just $ if bo `elem` arithBO then TypeDouble else TypeInt 1 +typeCompatibleBO bo TypeFloat TypeDouble = Just $ if bo `elem` arithBO then TypeFloat else TypeInt 1 +typeCompatibleBO bo TypeDouble TypeFloat = Just $ if bo `elem` arithBO then TypeFloat else TypeInt 1 + +typeCompatibleBO _ _ _ = Nothing + +typeCompatibleUO :: UnaryOperator -> Type -> Maybe Type +typeCompatibleUO Not _ = Just $ TypeInt 1 +typeCompatibleUO Address t = Just $ TypePtr t +typeCompatibleUO uo t@(TypeInt _) | uo `elem` [Negate, Invert] = Just t +typeCompatibleUO uo t@(TypeUInt _) | uo `elem` [Negate, Invert] = Just t +typeCompatibleUO Negate TypeFloat = Just TypeFloat +typeCompatibleUO Negate TypeDouble = Just TypeDouble +typeCompatibleUO Dereference t@(TypePtr _) = Just t +typeCompatibleUO _ _ = Nothing + +smallestIntType :: Integer -> Type +smallestIntType i + | i >= -2^7 && i < 2^7 = TypeInt 8 + | i >= -2^15 && i < 2^15 = TypeInt 16 + | i >= -2^31 && i < 2^31 = TypeInt 32 + | otherwise = TypeInt 64 + +-- smallestUIntType :: Integer -> Type +-- smallestUIntType i +-- | i >= 0 && i < 2^8 = TypeUInt 8 +-- | i >= 0 && i < 2^16 = TypeUInt 16 +-- | i >= 0 && i < 2^32 = TypeUInt 32 +-- | otherwise = TypeUInt 64 type MapperHandler a = a -> Error a @@ -135,14 +294,14 @@ defaultPM' = ProgramMapper' id id id id id id id id id mapProgram' :: Program -> ProgramMapper' -> Program mapProgram' prog mapper = (\(Right r) -> r) $ mapProgram prog $ ProgramMapper {declarationHandler = return . declarationHandler' mapper - ,blockHandler = return . blockHandler' mapper - ,typeHandler = return . typeHandler' mapper - ,literalHandler = return . literalHandler' mapper - ,binOpHandler = return . binOpHandler' mapper - ,unOpHandler = return . unOpHandler' mapper - ,expressionHandler = return . expressionHandler' mapper - ,statementHandler = return . statementHandler' mapper - ,nameHandler = return . nameHandler' mapper} + ,blockHandler = return . blockHandler' mapper + ,typeHandler = return . typeHandler' mapper + ,literalHandler = return . literalHandler' mapper + ,binOpHandler = return . binOpHandler' mapper + ,unOpHandler = return . unOpHandler' mapper + ,expressionHandler = return . expressionHandler' mapper + ,statementHandler = return . statementHandler' mapper + ,nameHandler = return . nameHandler' mapper} mapProgram :: Program -> ProgramMapper -> Error Program mapProgram prog mapper = goP prog @@ -157,10 +316,10 @@ mapProgram prog mapper = goP prog h_s = statementHandler mapper h_n = nameHandler mapper - goP :: Program -> Error Program + goP :: MapperHandler Program goP (Program decls) = Program <$> sequence (map (\d -> goD d >>= h_d) decls) - goD :: Declaration -> Error Declaration + goD :: MapperHandler Declaration goD (DecFunction t n a b) = do rt <- goT t rn <- goN n @@ -177,30 +336,32 @@ mapProgram prog mapper = goP prog rn <- goN n h_d $ DecTypedef rt rn - goT :: Type -> Error Type + goT :: MapperHandler Type goT (TypePtr t) = goT t >>= (h_t . TypePtr) goT (TypeName n) = goN n >>= (h_t . TypeName) goT t = h_t t - goN :: Name -> Error Name + goN :: MapperHandler Name goN = h_n - goB :: Block -> Error Block + goB :: MapperHandler Block goB (Block sts) = (Block <$> sequence (map goS sts)) >>= h_b - goE :: Expression -> Error Expression - goE (ExLit l) = goL l >>= (h_e . ExLit) - goE (ExBinOp bo e1 e2) = do + goE :: MapperHandler Expression + goE (ExLit l mt) = do + rl <- goL l + h_e $ ExLit rl mt + goE (ExBinOp bo e1 e2 mt) = do rbo <- goBO bo re1 <- goE e1 re2 <- goE e2 - h_e $ ExBinOp rbo re1 re2 - goE (ExUnOp uo e) = do + h_e $ ExBinOp rbo re1 re2 mt + goE (ExUnOp uo e mt) = do ruo <- goUO uo re <- goE e - h_e $ ExUnOp ruo re + h_e $ ExUnOp ruo re mt - goS :: Statement -> Error Statement + goS :: MapperHandler Statement goS StEmpty = h_s StEmpty goS (StBlock b) = goB b >>= (h_s . StBlock) goS (StExpr e) = goE e >>= (h_s . StExpr) @@ -224,15 +385,17 @@ mapProgram prog mapper = goP prog h_s $ StWhile re rs goS (StReturn e) = goE e >>= (h_s . StReturn) - goL :: Literal -> Error Literal + goL :: MapperHandler Literal + goL l@(LitString _) = h_l l + goL l@(LitInt _) = h_l l goL (LitVar n) = goN n >>= (h_l . LitVar) goL (LitCall n a) = do rn <- goN n ra <- sequence $ map goE a h_l $ LitCall rn ra - goBO :: BinaryOperator -> Error BinaryOperator + goBO :: MapperHandler BinaryOperator goBO = h_bo - goUO :: UnaryOperator -> Error UnaryOperator + goUO :: MapperHandler UnaryOperator goUO = h_uo diff --git a/main.hs b/main.hs index e1b975a..962a01d 100644 --- a/main.hs +++ b/main.hs @@ -33,6 +33,7 @@ main = do when (isLeft parseResult) $ dieShow $ fromLeft parseResult let ast = fromRight parseResult - pprint ast + -- print ast + putStrLn $ pshow ast either die print $ codegen ast "Module" fname diff --git a/parser.hs b/parser.hs index 7359ccf..6520294 100644 --- a/parser.hs +++ b/parser.hs @@ -85,14 +85,14 @@ exprTable = [binary "&&" BoolAnd E.AssocLeft, binary "||" BoolOr E.AssocLeft]] where - binary name op assoc = E.Infix (ExBinOp op <$ symbol name) assoc - prefix name op = E.Prefix (ExUnOp op <$ symbol name) + binary name op assoc = E.Infix (exBinOp_ op <$ symbol name) assoc + prefix name op = E.Prefix (exUnOp_ op <$ symbol name) pExpression :: Parser Expression pExpression = E.buildExpressionParser exprTable pExLit pExLit :: Parser Expression -pExLit = ExLit <$> pLiteral +pExLit = exLit_ <$> pLiteral pLiteral :: Parser Literal pLiteral = (LitInt <$> pInteger) <|> (LitString <$> pString) @@ -109,7 +109,7 @@ pLitCall = do pStatement :: Parser Statement pStatement = pStEmpty <|> pStIf <|> pStWhile <|> pStReturn <|> pStBlock - <|> try pStAssignment <|> pStVarDeclaration <|> pStExpr + <|> try pStAssignment <|> try pStVarDeclaration <|> pStExpr pStEmpty :: Parser Statement pStEmpty = symbol ";" >> return StEmpty @@ -165,7 +165,7 @@ pStReturn = do primitiveTypes :: Map.Map String Type primitiveTypes = Map.fromList - [("i8", TypeInt 8), ("i16", TypeInt 16), ("i32", TypeInt 32), ("i64", TypeInt 64), + [("i1", TypeInt 1), ("i8", TypeInt 8), ("i16", TypeInt 16), ("i32", TypeInt 32), ("i64", TypeInt 64), ("u8", TypeUInt 8), ("u16", TypeUInt 16), ("u32", TypeUInt 32), ("u64", TypeUInt 64), ("float", TypeFloat), ("double", TypeDouble)] -- cgit v1.2.3-54-g00ecf