diff options
author | tomsmeding <tom.smeding@gmail.com> | 2017-01-29 13:03:44 +0100 |
---|---|---|
committer | tomsmeding <tom.smeding@gmail.com> | 2017-01-29 13:03:44 +0100 |
commit | ce13c3ff2b64e1bfde13f735d871ea0a0e58a145 (patch) | |
tree | 1f56d96fe80c8abe3fc026fa9abcaa3bf14fd5b7 | |
parent | af1523e4b51f432d3df4d2e2ae57de95e3440d12 (diff) |
Call functions
-rw-r--r-- | ast.hs | 9 | ||||
-rw-r--r-- | check.hs | 9 | ||||
-rw-r--r-- | codegen.hs | 78 | ||||
-rw-r--r-- | parser.hs | 36 | ||||
-rw-r--r-- | test_string.nl | 9 |
5 files changed, 107 insertions, 34 deletions
@@ -31,6 +31,8 @@ data Type = TypeInt Int | TypeDouble | TypePtr Type | TypeName Name + | TypeFunc Type [Type] + | TypeVoid deriving (Show, Eq) data Literal = LitInt Integer @@ -63,7 +65,7 @@ data Statement | StAssignment Name Expression | StIf Expression Statement Statement | StWhile Expression Statement - | StReturn Expression + | StReturn (Maybe Expression) deriving (Show) @@ -116,6 +118,8 @@ instance PShow Type where pshow TypeDouble = "double" pshow (TypePtr t) = concat ["ptr(", pshow t, ")"] pshow (TypeName n) = n + pshow TypeVoid = "void" + pshow (TypeFunc ret args) = concat ["func ", pshow ret, "("] ++ intercalate "," (map pshow args) ++ ")" instance PShow Literal where pshow (LitInt i) = pshow i @@ -164,4 +168,5 @@ instance PShow Statement where pshow (StIf c t@(StBlock _) e) = concat ["if (", pshow c, ") ", pshow t, " else ", pshow e] pshow (StIf c t e) = concat ["if (", pshow c, ") ", pshow t, "\nelse ", pshow e] pshow (StWhile c s) = concat ["while (", pshow c, ") ", pshow s] - pshow (StReturn e) = concat ["return ", pshow e, ";"] + pshow (StReturn Nothing) = "return;" + pshow (StReturn (Just e)) = concat ["return ", pshow e, ";"] @@ -123,11 +123,13 @@ typeCheck (Program decls) = Program <$> mapM (goD topLevelNames) decls re <- goE names e (_, rs) <- goS frt names s return (names, StWhile re rs) - goS frt names (StReturn e) = do + goS TypeVoid names (StReturn Nothing) = return (names, StReturn Nothing) + goS _ _ (StReturn Nothing) = Left $ "Non-void function should return a value" + goS frt names (StReturn (Just e)) = do re <- goE names e let (Just extype) = exTypeOf re if canConvert extype frt - then return (names, StReturn re) + then return (names, StReturn (Just re)) else Left $ "Cannot convert type '" ++ pshow extype ++ "' to '" ++ pshow frt ++ "' in return statement" @@ -370,7 +372,8 @@ mapProgram prog mapper = goP prog re <- goE e rs <- goS s h_s $ StWhile re rs - goS (StReturn e) = goE e >>= (h_s . StReturn) + goS (StReturn Nothing) = h_s (StReturn Nothing) + goS (StReturn (Just e)) = goE e >>= (h_s . StReturn . Just) goL :: MapperHandler Literal goL l@(LitString _) = h_l l @@ -8,6 +8,7 @@ import Data.Maybe import qualified Data.Map.Strict as Map import qualified LLVM.General.AST.Type as A import qualified LLVM.General.AST.Global as A.G +import qualified LLVM.General.AST.CallingConvention as A.CC import qualified LLVM.General.AST.Constant as A.C import qualified LLVM.General.AST.Float as A.F -- import qualified LLVM.General.AST.Operand as A @@ -34,6 +35,7 @@ data GenState ,definitions :: [A.Definition] ,variables :: Map.Map Name (Type, LLName) ,globalVariables :: Map.Map Name (Type, LLName) + ,globalFunctions :: Map.Map Name (Type, LLName) ,stringLiterals :: [(LLName, String)]} deriving (Show) @@ -46,6 +48,7 @@ initialGenState ,definitions = [] ,variables = Map.empty ,globalVariables = Map.empty + ,globalFunctions = Map.empty ,stringLiterals = []} newtype CGMonad a = CGMonad {unMon :: ExceptT String (State GenState) a} @@ -109,6 +112,10 @@ setGlobalVar :: Name -> LLName -> Type -> CGMonad () setGlobalVar name label t = do state $ \s -> ((), s {globalVariables = Map.insert name (t, label) $ globalVariables s}) +setGlobalFunction :: Name -> LLName -> Type -> CGMonad () +setGlobalFunction name label t = do + state $ \s -> ((), s {globalFunctions = Map.insert name (t, label) $ globalFunctions s}) + lookupVar :: Name -> CGMonad (Type, LLName) lookupVar name | trace ("Looking up var " ++ name) False = undefined lookupVar name = do @@ -122,6 +129,9 @@ lookupVar name = do lookupGlobalVar :: Name -> CGMonad (Type, LLName) lookupGlobalVar name = liftM (fromJust . Map.lookup name . globalVariables) get +lookupGlobalFunction :: Name -> CGMonad (Type, LLName) +lookupGlobalFunction name = liftM (fromJust . Map.lookup name . globalFunctions) get + addStringLiteral :: String -> CGMonad LLName addStringLiteral str = do name <- getNewName "str" @@ -167,8 +177,8 @@ codegen :: Program -- Program to compile codegen prog name fname = do (defs, st) <- runCGMonad $ do defs <- generateDefs prog - traceShow defs $ return () - liftM stringLiterals get >>= flip traceShow (return ()) + -- traceShow defs $ return () + -- liftM stringLiterals get >>= flip traceShow (return ()) return defs traceShow st $ return () @@ -185,19 +195,22 @@ generateDefs prog = liftM concat $ sequence $ [genGlobalVars prog, genFunctions prog, genStringLiterals] genGlobalVars :: Program -> CGMonad [A.Definition] -genGlobalVars (Program decs) = mapM gen $ filter isDecVariable decs +genGlobalVars (Program decs) = liftM (mapMaybe id) $ mapM gen decs where - gen :: Declaration -> CGMonad A.Definition + gen :: Declaration -> CGMonad (Maybe A.Definition) gen (DecVariable t n Nothing) = do setGlobalVar n n t - return $ A.GlobalDefinition $ + return $ Just $ A.GlobalDefinition $ A.globalVariableDefaults { A.G.name = A.Name n, A.G.type' = toLLVMType t, A.G.initializer = Just $ initializerFor t } gen (DecVariable _ _ (Just _)) = throwError $ "Initialised global variables not supported yet" - gen _ = undefined + gen (DecFunction rt n a _) = do + setGlobalFunction n n (TypeFunc rt (map fst a)) + return Nothing + gen _ = return Nothing genStringLiterals :: CGMonad [A.Definition] genStringLiterals = liftM stringLiterals get >>= return . map gen @@ -211,9 +224,9 @@ genStringLiterals = liftM stringLiterals get >>= return . map gen } genFunctions :: Program -> CGMonad [A.Definition] -genFunctions (Program decs) = mapM gen $ filter isDecFunction decs +genFunctions (Program decs) = liftM (mapMaybe id) $ mapM gen decs where - gen :: Declaration -> CGMonad A.Definition + gen :: Declaration -> CGMonad (Maybe A.Definition) gen dec@(DecFunction rettype name args body) = do setCurrentFunction dec firstbb <- genBlock' body @@ -221,13 +234,14 @@ genFunctions (Program decs) = mapM gen $ filter isDecFunction decs blockmap <- liftM allBlocks get let bbs' = map snd $ filter (\x -> fst x /= firstbb) $ Map.toList blockmap bbs = fromJust (Map.lookup firstbb blockmap) : bbs' - return $ A.GlobalDefinition $ A.functionDefaults { + state $ \s -> ((), s {allBlocks = Map.empty}) + return $ Just $ A.GlobalDefinition $ A.functionDefaults { A.G.returnType = toLLVMType rettype, A.G.name = A.Name name, A.G.parameters = ([A.Parameter (toLLVMType t) (A.Name n) [] | (t,n) <- args], False), A.G.basicBlocks = bbs } - gen _ = undefined + gen _ = return Nothing @@ -275,7 +289,11 @@ genSingle (StAssignment name expr) following = do ref <- variableStoreOperand name void $ addInstr $ A.Store False ref oper' Nothing 0 [] return bb -genSingle (StReturn expr) _ = do +genSingle (StReturn Nothing) _ = do + bb <- newBlock + setTerminator $ A.Ret Nothing [] + return bb +genSingle (StReturn (Just expr)) _ = do bb <- newBlock oper <- genExpression expr rettype <- liftM (typeOf . currentFunction) get @@ -316,6 +334,8 @@ genExpression (ExBinOp bo e1 e2 (Just t)) = do TypeDouble -> addInstr $ A.FAdd A.NoFastMathFlags e1op' e2op' [] (TypePtr _) -> addInstr $ A.Add False False e1op' e2op' [] (TypeName _) -> undefined + (TypeFunc _ _) -> throwError $ "Plus '+' operator not defined on function pointers" + TypeVoid -> undefined return $ A.LocalReference (toLLVMType t) (A.Name label) Minus -> do e1op' <- castOperand e1op t @@ -327,6 +347,8 @@ genExpression (ExBinOp bo e1 e2 (Just t)) = do TypeDouble -> addInstr $ A.FSub A.NoFastMathFlags e1op' e2op' [] (TypePtr _) -> addInstr $ A.Sub False False e1op' e2op' [] (TypeName _) -> undefined + (TypeFunc _ _) -> throwError $ "Minus '-' operator not defined on function pointers" + TypeVoid -> undefined return $ A.LocalReference (toLLVMType t) (A.Name label) Divide -> do e1op' <- castOperand e1op t @@ -338,6 +360,8 @@ genExpression (ExBinOp bo e1 e2 (Just t)) = do TypeDouble -> addInstr $ A.FDiv A.NoFastMathFlags e1op' e2op' [] (TypePtr _) -> throwError $ "Modulo '%' operator not defined on pointers" (TypeName _) -> undefined + (TypeFunc _ _) -> throwError $ "Divide '/' operator not defined on function pointers" + TypeVoid -> undefined return $ A.LocalReference (toLLVMType t) (A.Name label) Modulo -> do e1op' <- castOperand e1op t @@ -349,6 +373,8 @@ genExpression (ExBinOp bo e1 e2 (Just t)) = do TypeDouble -> addInstr $ A.FRem A.NoFastMathFlags e1op' e2op' [] (TypePtr _) -> throwError $ "Modulo '%' operator not defined on pointers" (TypeName _) -> undefined + (TypeFunc _ _) -> throwError $ "Modulo '%' operator not defined on function pointers" + TypeVoid -> undefined return $ A.LocalReference (toLLVMType t) (A.Name label) Equal -> do sharedType <- commonTypeM (fromJust (exTypeOf e1)) (fromJust (exTypeOf e2)) @@ -395,6 +421,8 @@ genExpression (ExUnOp uo e1 (Just t)) = do TypeDouble -> addInstr $ A.FSub A.NoFastMathFlags (A.ConstantOperand (A.C.Float (A.F.Double 0))) e1op [] (TypePtr _) -> throwError $ "Negate '-' operator not defined on a pointer" (TypeName _) -> undefined + (TypeFunc _ _) -> throwError $ "Negate '-' operator not defined on a function pointer" + TypeVoid -> undefined return $ A.LocalReference (toLLVMType t) (A.Name label) _ -> throwError $ "Unary operator " ++ pshow uo ++ " not implemented" genExpression ex = throwError $ "Expression '" ++ pshow ex ++ "' not implemented" @@ -416,6 +444,23 @@ literalToOperand (LitString s) (TypePtr (TypeInt 8)) = do label <- addInstr $ A.Load False loadoper Nothing 0 [] return $ A.LocalReference (A.ptr A.i8) (A.Name label) literalToOperand (LitString _) _ = undefined +literalToOperand (LitCall n args) _ = do + ((TypeFunc rt ats), lname) <- lookupGlobalFunction n + let processArgs :: [Expression] -> [Type] -> CGMonad [A.Operand] + processArgs [] [] = return [] + processArgs [] _ = undefined + processArgs _ [] = undefined + processArgs (ex:exs) (t:ts) = do + first <- genExpression ex >>= flip castOperand t + rest <- processArgs exs ts + return $ first : rest + rargs <- processArgs args ats + let argpairs = map (\a -> (a,[])) rargs + foper = A.ConstantOperand $ + A.C.GlobalReference (A.FunctionType (toLLVMType rt) (map toLLVMType ats) False) + (A.Name lname) + label <- addInstr $ A.Call Nothing A.CC.C [] (Right foper) argpairs [] [] + return $ A.LocalReference (toLLVMType rt) (A.Name label) literalToOperand lit _ = throwError $ "Literal '" ++ pshow lit ++ "' not implemented" castOperand :: A.Operand -> Type -> CGMonad A.Operand @@ -533,17 +578,10 @@ toLLVMType TypeFloat = A.float toLLVMType TypeDouble = A.double toLLVMType (TypePtr t) = A.ptr $ toLLVMType t toLLVMType (TypeName _) = undefined +toLLVMType (TypeFunc r a) = A.FunctionType (toLLVMType r) (map toLLVMType a) False +toLLVMType TypeVoid = A.VoidType initializerFor :: Type -> A.C.Constant initializerFor (TypeInt s) = A.C.Int (fromIntegral s) 0 initializerFor (TypeUInt s) = A.C.Int (fromIntegral s) 0 initializerFor _ = undefined - - -isDecVariable :: Declaration -> Bool -isDecVariable (DecVariable {}) = True -isDecVariable _ = False - -isDecFunction :: Declaration -> Bool -isDecFunction (DecFunction {}) = True -isDecFunction _ = False @@ -29,9 +29,11 @@ pProgram = pWhiteComment >> (Program <$> many1 pDeclaration) pDeclaration :: Parser Declaration pDeclaration = pDecTypedef <|> do - t <- pType + t <- pTypeVoid <|> pType n <- pName - pDecFunction' t n <|> pDecVariable' t n + if t == TypeVoid + then pDecFunction' t n + else pDecFunction' t n <|> pDecVariable' t n pDecTypedef :: Parser Declaration pDecTypedef = do @@ -158,9 +160,10 @@ pStWhile = do pStReturn :: Parser Statement pStReturn = do symbol "return" - e <- pExpression - symbol ";" - return $ StReturn e + (symbol ";" >> return (StReturn Nothing)) <|> do + e <- pExpression + symbol ";" + return $ StReturn (Just e) primitiveTypes :: Map.Map String Type @@ -173,7 +176,7 @@ findPrimType :: String -> Type findPrimType s = fromJust $ Map.lookup s primitiveTypes pType :: Parser Type -pType = pPrimType <|> pPtrType <|> pTypeName +pType = pPrimType <|> pTypePtr <|> pTypeFunc <|> pTypeName pPrimType :: Parser Type pPrimType = findPrimType <$> choice (map typeParser $ Map.keys primitiveTypes) @@ -183,14 +186,26 @@ pPrimType = findPrimType <$> choice (map typeParser $ Map.keys primitiveTypes) pWhiteComment return t -pPtrType :: Parser Type -pPtrType = do +pTypeVoid :: Parser Type +pTypeVoid = symbol "void" >> return TypeVoid + +pTypePtr :: Parser Type +pTypePtr = do symbol "ptr" symbol "(" t <- pType symbol ")" return $ TypePtr t +pTypeFunc :: Parser Type +pTypeFunc = do + symbol "func" + r <- pTypeVoid <|> pType + symbol "(" + a <- sepBy pType (symbol ",") + symbol ")" + return $ TypeFunc r a + pTypeName :: Parser Type pTypeName = TypeName <$> pName @@ -232,7 +247,10 @@ pString = do <|> (liftM (\c -> ord c - ord 'A' + 10) (oneOf "ABCDEF")) symbol :: String -> Parser () -symbol s = try (string s) >> pWhiteComment +symbol s = do + void $ try (string s) + when (isAlphaNum (last s)) $ notFollowedBy alphaNum + pWhiteComment pWhiteComment :: Parser () pWhiteComment = sepBy pWhitespace pComment >> return () diff --git a/test_string.nl b/test_string.nl index ec8cd02..37841a7 100644 --- a/test_string.nl +++ b/test_string.nl @@ -2,6 +2,15 @@ type int = i32; type char = i8; type string = ptr(char); +void func(string s) { + int i = 1; + return; +} + int main(int argc, ptr(string) argv) { string s = "kaas"; + ptr(i8) s2 = "kaas2"; + //func void(string) the_func = func; + func(s); + return 0; } |