summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortomsmeding <tom.smeding@gmail.com>2017-01-29 13:03:44 +0100
committertomsmeding <tom.smeding@gmail.com>2017-01-29 13:03:44 +0100
commitce13c3ff2b64e1bfde13f735d871ea0a0e58a145 (patch)
tree1f56d96fe80c8abe3fc026fa9abcaa3bf14fd5b7
parentaf1523e4b51f432d3df4d2e2ae57de95e3440d12 (diff)
Call functions
-rw-r--r--ast.hs9
-rw-r--r--check.hs9
-rw-r--r--codegen.hs78
-rw-r--r--parser.hs36
-rw-r--r--test_string.nl9
5 files changed, 107 insertions, 34 deletions
diff --git a/ast.hs b/ast.hs
index e3db600..f566335 100644
--- a/ast.hs
+++ b/ast.hs
@@ -31,6 +31,8 @@ data Type = TypeInt Int
| TypeDouble
| TypePtr Type
| TypeName Name
+ | TypeFunc Type [Type]
+ | TypeVoid
deriving (Show, Eq)
data Literal = LitInt Integer
@@ -63,7 +65,7 @@ data Statement
| StAssignment Name Expression
| StIf Expression Statement Statement
| StWhile Expression Statement
- | StReturn Expression
+ | StReturn (Maybe Expression)
deriving (Show)
@@ -116,6 +118,8 @@ instance PShow Type where
pshow TypeDouble = "double"
pshow (TypePtr t) = concat ["ptr(", pshow t, ")"]
pshow (TypeName n) = n
+ pshow TypeVoid = "void"
+ pshow (TypeFunc ret args) = concat ["func ", pshow ret, "("] ++ intercalate "," (map pshow args) ++ ")"
instance PShow Literal where
pshow (LitInt i) = pshow i
@@ -164,4 +168,5 @@ instance PShow Statement where
pshow (StIf c t@(StBlock _) e) = concat ["if (", pshow c, ") ", pshow t, " else ", pshow e]
pshow (StIf c t e) = concat ["if (", pshow c, ") ", pshow t, "\nelse ", pshow e]
pshow (StWhile c s) = concat ["while (", pshow c, ") ", pshow s]
- pshow (StReturn e) = concat ["return ", pshow e, ";"]
+ pshow (StReturn Nothing) = "return;"
+ pshow (StReturn (Just e)) = concat ["return ", pshow e, ";"]
diff --git a/check.hs b/check.hs
index 88b20a9..a29f18b 100644
--- a/check.hs
+++ b/check.hs
@@ -123,11 +123,13 @@ typeCheck (Program decls) = Program <$> mapM (goD topLevelNames) decls
re <- goE names e
(_, rs) <- goS frt names s
return (names, StWhile re rs)
- goS frt names (StReturn e) = do
+ goS TypeVoid names (StReturn Nothing) = return (names, StReturn Nothing)
+ goS _ _ (StReturn Nothing) = Left $ "Non-void function should return a value"
+ goS frt names (StReturn (Just e)) = do
re <- goE names e
let (Just extype) = exTypeOf re
if canConvert extype frt
- then return (names, StReturn re)
+ then return (names, StReturn (Just re))
else Left $ "Cannot convert type '" ++ pshow extype ++ "' to '"
++ pshow frt ++ "' in return statement"
@@ -370,7 +372,8 @@ mapProgram prog mapper = goP prog
re <- goE e
rs <- goS s
h_s $ StWhile re rs
- goS (StReturn e) = goE e >>= (h_s . StReturn)
+ goS (StReturn Nothing) = h_s (StReturn Nothing)
+ goS (StReturn (Just e)) = goE e >>= (h_s . StReturn . Just)
goL :: MapperHandler Literal
goL l@(LitString _) = h_l l
diff --git a/codegen.hs b/codegen.hs
index 0deb959..1df87b4 100644
--- a/codegen.hs
+++ b/codegen.hs
@@ -8,6 +8,7 @@ import Data.Maybe
import qualified Data.Map.Strict as Map
import qualified LLVM.General.AST.Type as A
import qualified LLVM.General.AST.Global as A.G
+import qualified LLVM.General.AST.CallingConvention as A.CC
import qualified LLVM.General.AST.Constant as A.C
import qualified LLVM.General.AST.Float as A.F
-- import qualified LLVM.General.AST.Operand as A
@@ -34,6 +35,7 @@ data GenState
,definitions :: [A.Definition]
,variables :: Map.Map Name (Type, LLName)
,globalVariables :: Map.Map Name (Type, LLName)
+ ,globalFunctions :: Map.Map Name (Type, LLName)
,stringLiterals :: [(LLName, String)]}
deriving (Show)
@@ -46,6 +48,7 @@ initialGenState
,definitions = []
,variables = Map.empty
,globalVariables = Map.empty
+ ,globalFunctions = Map.empty
,stringLiterals = []}
newtype CGMonad a = CGMonad {unMon :: ExceptT String (State GenState) a}
@@ -109,6 +112,10 @@ setGlobalVar :: Name -> LLName -> Type -> CGMonad ()
setGlobalVar name label t = do
state $ \s -> ((), s {globalVariables = Map.insert name (t, label) $ globalVariables s})
+setGlobalFunction :: Name -> LLName -> Type -> CGMonad ()
+setGlobalFunction name label t = do
+ state $ \s -> ((), s {globalFunctions = Map.insert name (t, label) $ globalFunctions s})
+
lookupVar :: Name -> CGMonad (Type, LLName)
lookupVar name | trace ("Looking up var " ++ name) False = undefined
lookupVar name = do
@@ -122,6 +129,9 @@ lookupVar name = do
lookupGlobalVar :: Name -> CGMonad (Type, LLName)
lookupGlobalVar name = liftM (fromJust . Map.lookup name . globalVariables) get
+lookupGlobalFunction :: Name -> CGMonad (Type, LLName)
+lookupGlobalFunction name = liftM (fromJust . Map.lookup name . globalFunctions) get
+
addStringLiteral :: String -> CGMonad LLName
addStringLiteral str = do
name <- getNewName "str"
@@ -167,8 +177,8 @@ codegen :: Program -- Program to compile
codegen prog name fname = do
(defs, st) <- runCGMonad $ do
defs <- generateDefs prog
- traceShow defs $ return ()
- liftM stringLiterals get >>= flip traceShow (return ())
+ -- traceShow defs $ return ()
+ -- liftM stringLiterals get >>= flip traceShow (return ())
return defs
traceShow st $ return ()
@@ -185,19 +195,22 @@ generateDefs prog
= liftM concat $ sequence $ [genGlobalVars prog, genFunctions prog, genStringLiterals]
genGlobalVars :: Program -> CGMonad [A.Definition]
-genGlobalVars (Program decs) = mapM gen $ filter isDecVariable decs
+genGlobalVars (Program decs) = liftM (mapMaybe id) $ mapM gen decs
where
- gen :: Declaration -> CGMonad A.Definition
+ gen :: Declaration -> CGMonad (Maybe A.Definition)
gen (DecVariable t n Nothing) = do
setGlobalVar n n t
- return $ A.GlobalDefinition $
+ return $ Just $ A.GlobalDefinition $
A.globalVariableDefaults {
A.G.name = A.Name n,
A.G.type' = toLLVMType t,
A.G.initializer = Just $ initializerFor t
}
gen (DecVariable _ _ (Just _)) = throwError $ "Initialised global variables not supported yet"
- gen _ = undefined
+ gen (DecFunction rt n a _) = do
+ setGlobalFunction n n (TypeFunc rt (map fst a))
+ return Nothing
+ gen _ = return Nothing
genStringLiterals :: CGMonad [A.Definition]
genStringLiterals = liftM stringLiterals get >>= return . map gen
@@ -211,9 +224,9 @@ genStringLiterals = liftM stringLiterals get >>= return . map gen
}
genFunctions :: Program -> CGMonad [A.Definition]
-genFunctions (Program decs) = mapM gen $ filter isDecFunction decs
+genFunctions (Program decs) = liftM (mapMaybe id) $ mapM gen decs
where
- gen :: Declaration -> CGMonad A.Definition
+ gen :: Declaration -> CGMonad (Maybe A.Definition)
gen dec@(DecFunction rettype name args body) = do
setCurrentFunction dec
firstbb <- genBlock' body
@@ -221,13 +234,14 @@ genFunctions (Program decs) = mapM gen $ filter isDecFunction decs
blockmap <- liftM allBlocks get
let bbs' = map snd $ filter (\x -> fst x /= firstbb) $ Map.toList blockmap
bbs = fromJust (Map.lookup firstbb blockmap) : bbs'
- return $ A.GlobalDefinition $ A.functionDefaults {
+ state $ \s -> ((), s {allBlocks = Map.empty})
+ return $ Just $ A.GlobalDefinition $ A.functionDefaults {
A.G.returnType = toLLVMType rettype,
A.G.name = A.Name name,
A.G.parameters = ([A.Parameter (toLLVMType t) (A.Name n) [] | (t,n) <- args], False),
A.G.basicBlocks = bbs
}
- gen _ = undefined
+ gen _ = return Nothing
@@ -275,7 +289,11 @@ genSingle (StAssignment name expr) following = do
ref <- variableStoreOperand name
void $ addInstr $ A.Store False ref oper' Nothing 0 []
return bb
-genSingle (StReturn expr) _ = do
+genSingle (StReturn Nothing) _ = do
+ bb <- newBlock
+ setTerminator $ A.Ret Nothing []
+ return bb
+genSingle (StReturn (Just expr)) _ = do
bb <- newBlock
oper <- genExpression expr
rettype <- liftM (typeOf . currentFunction) get
@@ -316,6 +334,8 @@ genExpression (ExBinOp bo e1 e2 (Just t)) = do
TypeDouble -> addInstr $ A.FAdd A.NoFastMathFlags e1op' e2op' []
(TypePtr _) -> addInstr $ A.Add False False e1op' e2op' []
(TypeName _) -> undefined
+ (TypeFunc _ _) -> throwError $ "Plus '+' operator not defined on function pointers"
+ TypeVoid -> undefined
return $ A.LocalReference (toLLVMType t) (A.Name label)
Minus -> do
e1op' <- castOperand e1op t
@@ -327,6 +347,8 @@ genExpression (ExBinOp bo e1 e2 (Just t)) = do
TypeDouble -> addInstr $ A.FSub A.NoFastMathFlags e1op' e2op' []
(TypePtr _) -> addInstr $ A.Sub False False e1op' e2op' []
(TypeName _) -> undefined
+ (TypeFunc _ _) -> throwError $ "Minus '-' operator not defined on function pointers"
+ TypeVoid -> undefined
return $ A.LocalReference (toLLVMType t) (A.Name label)
Divide -> do
e1op' <- castOperand e1op t
@@ -338,6 +360,8 @@ genExpression (ExBinOp bo e1 e2 (Just t)) = do
TypeDouble -> addInstr $ A.FDiv A.NoFastMathFlags e1op' e2op' []
(TypePtr _) -> throwError $ "Modulo '%' operator not defined on pointers"
(TypeName _) -> undefined
+ (TypeFunc _ _) -> throwError $ "Divide '/' operator not defined on function pointers"
+ TypeVoid -> undefined
return $ A.LocalReference (toLLVMType t) (A.Name label)
Modulo -> do
e1op' <- castOperand e1op t
@@ -349,6 +373,8 @@ genExpression (ExBinOp bo e1 e2 (Just t)) = do
TypeDouble -> addInstr $ A.FRem A.NoFastMathFlags e1op' e2op' []
(TypePtr _) -> throwError $ "Modulo '%' operator not defined on pointers"
(TypeName _) -> undefined
+ (TypeFunc _ _) -> throwError $ "Modulo '%' operator not defined on function pointers"
+ TypeVoid -> undefined
return $ A.LocalReference (toLLVMType t) (A.Name label)
Equal -> do
sharedType <- commonTypeM (fromJust (exTypeOf e1)) (fromJust (exTypeOf e2))
@@ -395,6 +421,8 @@ genExpression (ExUnOp uo e1 (Just t)) = do
TypeDouble -> addInstr $ A.FSub A.NoFastMathFlags (A.ConstantOperand (A.C.Float (A.F.Double 0))) e1op []
(TypePtr _) -> throwError $ "Negate '-' operator not defined on a pointer"
(TypeName _) -> undefined
+ (TypeFunc _ _) -> throwError $ "Negate '-' operator not defined on a function pointer"
+ TypeVoid -> undefined
return $ A.LocalReference (toLLVMType t) (A.Name label)
_ -> throwError $ "Unary operator " ++ pshow uo ++ " not implemented"
genExpression ex = throwError $ "Expression '" ++ pshow ex ++ "' not implemented"
@@ -416,6 +444,23 @@ literalToOperand (LitString s) (TypePtr (TypeInt 8)) = do
label <- addInstr $ A.Load False loadoper Nothing 0 []
return $ A.LocalReference (A.ptr A.i8) (A.Name label)
literalToOperand (LitString _) _ = undefined
+literalToOperand (LitCall n args) _ = do
+ ((TypeFunc rt ats), lname) <- lookupGlobalFunction n
+ let processArgs :: [Expression] -> [Type] -> CGMonad [A.Operand]
+ processArgs [] [] = return []
+ processArgs [] _ = undefined
+ processArgs _ [] = undefined
+ processArgs (ex:exs) (t:ts) = do
+ first <- genExpression ex >>= flip castOperand t
+ rest <- processArgs exs ts
+ return $ first : rest
+ rargs <- processArgs args ats
+ let argpairs = map (\a -> (a,[])) rargs
+ foper = A.ConstantOperand $
+ A.C.GlobalReference (A.FunctionType (toLLVMType rt) (map toLLVMType ats) False)
+ (A.Name lname)
+ label <- addInstr $ A.Call Nothing A.CC.C [] (Right foper) argpairs [] []
+ return $ A.LocalReference (toLLVMType rt) (A.Name label)
literalToOperand lit _ = throwError $ "Literal '" ++ pshow lit ++ "' not implemented"
castOperand :: A.Operand -> Type -> CGMonad A.Operand
@@ -533,17 +578,10 @@ toLLVMType TypeFloat = A.float
toLLVMType TypeDouble = A.double
toLLVMType (TypePtr t) = A.ptr $ toLLVMType t
toLLVMType (TypeName _) = undefined
+toLLVMType (TypeFunc r a) = A.FunctionType (toLLVMType r) (map toLLVMType a) False
+toLLVMType TypeVoid = A.VoidType
initializerFor :: Type -> A.C.Constant
initializerFor (TypeInt s) = A.C.Int (fromIntegral s) 0
initializerFor (TypeUInt s) = A.C.Int (fromIntegral s) 0
initializerFor _ = undefined
-
-
-isDecVariable :: Declaration -> Bool
-isDecVariable (DecVariable {}) = True
-isDecVariable _ = False
-
-isDecFunction :: Declaration -> Bool
-isDecFunction (DecFunction {}) = True
-isDecFunction _ = False
diff --git a/parser.hs b/parser.hs
index 46bd3d0..615f6e7 100644
--- a/parser.hs
+++ b/parser.hs
@@ -29,9 +29,11 @@ pProgram = pWhiteComment >> (Program <$> many1 pDeclaration)
pDeclaration :: Parser Declaration
pDeclaration = pDecTypedef <|> do
- t <- pType
+ t <- pTypeVoid <|> pType
n <- pName
- pDecFunction' t n <|> pDecVariable' t n
+ if t == TypeVoid
+ then pDecFunction' t n
+ else pDecFunction' t n <|> pDecVariable' t n
pDecTypedef :: Parser Declaration
pDecTypedef = do
@@ -158,9 +160,10 @@ pStWhile = do
pStReturn :: Parser Statement
pStReturn = do
symbol "return"
- e <- pExpression
- symbol ";"
- return $ StReturn e
+ (symbol ";" >> return (StReturn Nothing)) <|> do
+ e <- pExpression
+ symbol ";"
+ return $ StReturn (Just e)
primitiveTypes :: Map.Map String Type
@@ -173,7 +176,7 @@ findPrimType :: String -> Type
findPrimType s = fromJust $ Map.lookup s primitiveTypes
pType :: Parser Type
-pType = pPrimType <|> pPtrType <|> pTypeName
+pType = pPrimType <|> pTypePtr <|> pTypeFunc <|> pTypeName
pPrimType :: Parser Type
pPrimType = findPrimType <$> choice (map typeParser $ Map.keys primitiveTypes)
@@ -183,14 +186,26 @@ pPrimType = findPrimType <$> choice (map typeParser $ Map.keys primitiveTypes)
pWhiteComment
return t
-pPtrType :: Parser Type
-pPtrType = do
+pTypeVoid :: Parser Type
+pTypeVoid = symbol "void" >> return TypeVoid
+
+pTypePtr :: Parser Type
+pTypePtr = do
symbol "ptr"
symbol "("
t <- pType
symbol ")"
return $ TypePtr t
+pTypeFunc :: Parser Type
+pTypeFunc = do
+ symbol "func"
+ r <- pTypeVoid <|> pType
+ symbol "("
+ a <- sepBy pType (symbol ",")
+ symbol ")"
+ return $ TypeFunc r a
+
pTypeName :: Parser Type
pTypeName = TypeName <$> pName
@@ -232,7 +247,10 @@ pString = do
<|> (liftM (\c -> ord c - ord 'A' + 10) (oneOf "ABCDEF"))
symbol :: String -> Parser ()
-symbol s = try (string s) >> pWhiteComment
+symbol s = do
+ void $ try (string s)
+ when (isAlphaNum (last s)) $ notFollowedBy alphaNum
+ pWhiteComment
pWhiteComment :: Parser ()
pWhiteComment = sepBy pWhitespace pComment >> return ()
diff --git a/test_string.nl b/test_string.nl
index ec8cd02..37841a7 100644
--- a/test_string.nl
+++ b/test_string.nl
@@ -2,6 +2,15 @@ type int = i32;
type char = i8;
type string = ptr(char);
+void func(string s) {
+ int i = 1;
+ return;
+}
+
int main(int argc, ptr(string) argv) {
string s = "kaas";
+ ptr(i8) s2 = "kaas2";
+ //func void(string) the_func = func;
+ func(s);
+ return 0;
}