{-# LANGUAGE GeneralizedNewtypeDeriving, TupleSections #-} module Codegen(codegen) where import Control.Monad.State.Strict import Control.Monad.Except import Data.Char import Data.Maybe import qualified Data.Map.Strict as Map import qualified LLVM.General.AST.Type as A import qualified LLVM.General.AST.Global as A.G import qualified LLVM.General.AST.CallingConvention as A.CC import qualified LLVM.General.AST.Constant as A.C import qualified LLVM.General.AST.Float as A.F -- import qualified LLVM.General.AST.Operand as A -- import qualified LLVM.General.AST.Name as A -- import qualified LLVM.General.AST.Instruction as A import qualified LLVM.General.AST.IntegerPredicate as A.IP import qualified LLVM.General.AST.FloatingPointPredicate as A.FPP import qualified LLVM.General.AST.Linkage as A.L -- import qualified LLVM.General.AST.Visibility as A.V import qualified LLVM.General.AST as A import Debug.Trace import AST import PShow type Error a = Either String a type LLName = String data GenState = GenState {currentBlock :: Maybe LLName ,allBlocks :: Map.Map LLName A.BasicBlock ,currentFunction :: Declaration ,nextId :: Integer ,definitions :: [A.Definition] ,variables :: Map.Map Name (Type, LLName) ,globalVariables :: Map.Map Name (Type, LLName) ,globalFunctions :: Map.Map Name (Type, LLName) ,stringLiterals :: [(LLName, String)]} deriving (Show) initialGenState :: GenState initialGenState = GenState {currentBlock = Nothing ,allBlocks = Map.empty ,currentFunction = undefined ,nextId = 1 ,definitions = [] ,variables = Map.empty ,globalVariables = Map.empty ,globalFunctions = Map.empty ,stringLiterals = []} newtype CGMonad a = CGMonad {unMon :: ExceptT String (State GenState) a} deriving (Functor, Applicative, Monad, MonadState GenState, MonadError String) runCGMonad :: CGMonad a -> Error (a, GenState) runCGMonad m = let (e, s) = runState (runExceptT (unMon m)) initialGenState in either Left (\x -> Right (x, s)) e getUniqueId :: CGMonad Integer getUniqueId = state $ \s -> (nextId s, s {nextId = nextId s + 1}) getNewName :: String -> CGMonad LLName getNewName base = fmap ((base++) . show) getUniqueId newBlock :: CGMonad LLName newBlock = do name <- getNewName ".bb" state $ \s -> (name, s { currentBlock = Just name, allBlocks = Map.insert name (A.BasicBlock (A.Name name) [] (A.Do $ A.Unreachable [])) $ allBlocks s }) newBlockJump :: LLName -> CGMonad LLName newBlockJump next = do bb <- newBlock setTerminator $ A.Br (A.Name next) [] return bb changeBlock :: LLName -> CGMonad () changeBlock name = state $ \s -> ((), s {currentBlock = Just name}) instrReturnsVoid :: A.Instruction -> Bool instrReturnsVoid (A.Store {}) = True instrReturnsVoid (A.Call _ _ _ (Right oper) _ _ _) = case oper of (A.LocalReference (A.FunctionType A.VoidType _ _) _) -> True (A.ConstantOperand (A.C.GlobalReference (A.FunctionType A.VoidType _ _) _)) -> True _ -> False instrReturnsVoid _ = False addInstr :: A.Instruction -> CGMonad LLName addInstr instr | instrReturnsVoid instr = addNamedInstr $ A.Do instr | otherwise = do name <- getNewName ".t" addNamedInstr $ A.Name name A.:= instr addNamedInstr :: A.Named A.Instruction -> CGMonad LLName addNamedInstr instr@(A.Name name A.:= _) = do let append (A.BasicBlock n il t) = A.BasicBlock n (il ++ [instr]) t state $ \s -> (name, s {allBlocks = Map.adjust append (fromJust (currentBlock s)) (allBlocks s)}) addNamedInstr instr@(A.Do _) = do let append (A.BasicBlock n il t) = A.BasicBlock n (il ++ [instr]) t state $ \s -> ("", s {allBlocks = Map.adjust append (fromJust (currentBlock s)) (allBlocks s)}) addNamedInstr _ = undefined -- addNamedInstrList :: [A.Named A.Instruction] -> CGMonad LLName -- addNamedInstrList l = mapM addNamedInstr l >>= return . last setTerminator :: A.Terminator -> CGMonad () setTerminator term = do let replace (A.BasicBlock n il _) = A.BasicBlock n il (A.Do term) state $ \s -> ((), s {allBlocks = Map.adjust replace (fromJust (currentBlock s)) (allBlocks s)}) getTerminator :: CGMonad (A.Named A.Terminator) getTerminator = do s <- get let (A.BasicBlock _ _ t) = fromJust $ Map.lookup (fromJust $ currentBlock s) (allBlocks s) return t setCurrentFunction :: Declaration -> CGMonad () setCurrentFunction dec = do state $ \s -> ((), s {currentFunction = dec}) setVar :: Name -> LLName -> Type -> CGMonad () setVar name label t = do state $ \s -> ((), s {variables = Map.insert name (t, label) $ variables s}) setGlobalVar :: Name -> LLName -> Type -> CGMonad () setGlobalVar name label t = do state $ \s -> ((), s {globalVariables = Map.insert name (t, label) $ globalVariables s}) setGlobalFunction :: Name -> LLName -> Type -> CGMonad () setGlobalFunction name label t = do state $ \s -> ((), s {globalFunctions = Map.insert name (t, label) $ globalFunctions s}) lookupVar :: Name -> CGMonad (Type, LLName) lookupVar name | trace ("Looking up var " ++ name) False = undefined lookupVar name = do obj <- get let locfound = Map.lookup name $ variables obj glofound = Map.lookup name $ globalVariables obj if isJust locfound then return $ fromJust locfound else return $ fromJust glofound lookupGlobalVar :: Name -> CGMonad (Type, LLName) lookupGlobalVar name = liftM (fromJust . Map.lookup name . globalVariables) get lookupGlobalFunction :: Name -> CGMonad (Type, LLName) lookupGlobalFunction name = liftM (fromJust . Map.lookup name . globalFunctions) get addStringLiteral :: String -> CGMonad (A.Type, LLName) addStringLiteral str = do name <- getNewName ".str" state $ \s -> ((A.ptr $ A.ArrayType (fromIntegral (length str + 1)) A.i8, name), s {stringLiterals = (name, str) : stringLiterals s}) variableStoreOperand :: Name -> CGMonad A.Operand variableStoreOperand name = get >>= (maybe getGlobal getLocal . Map.lookup name . variables) where getLocal :: (Type, LLName) -> CGMonad A.Operand getLocal (t, nm) = return $ A.LocalReference (A.ptr (toLLVMType t)) (A.Name nm) getGlobal :: CGMonad A.Operand getGlobal = do (t, nm) <- lookupGlobalVar name return $ A.ConstantOperand $ A.C.GlobalReference (A.ptr (toLLVMType t)) (A.Name nm) variableOperand :: Name -> CGMonad A.Operand variableOperand name = get >>= (maybe getGlobal getLocal . Map.lookup name . variables) where getLocal :: (Type, LLName) -> CGMonad A.Operand getLocal (t, nm) = do let loadoper = A.LocalReference (A.ptr (toLLVMType t)) (A.Name nm) label <- addInstr $ A.Load False loadoper Nothing 0 [] return $ A.LocalReference (toLLVMType t) (A.Name label) getGlobal :: CGMonad A.Operand getGlobal = do (t, nm) <- lookupGlobalVar name let loadoper = A.ConstantOperand $ A.C.GlobalReference (A.ptr (toLLVMType t)) (A.Name nm) label <- addInstr $ A.Load False loadoper Nothing 0 [] return $ A.LocalReference (toLLVMType t) (A.Name label) -- namedName :: A.Named a -> LLName -- namedName (A.Name name A.:= _) = name -- namedName _ = undefined codegen :: Program -- Program to compile -> String -- Module name -> String -- File name of source -> Error A.Module codegen prog name fname = do (defs, st) <- runCGMonad $ do defs <- generateDefs prog -- traceShow defs $ return () -- liftM stringLiterals get >>= flip traceShow (return ()) return defs traceShow st $ return () return $ A.defaultModule { A.moduleName = name, A.moduleSourceFileName = fname, A.moduleDefinitions = defs } generateDefs :: Program -> CGMonad [A.Definition] generateDefs prog = liftM concat $ sequence $ [genGlobalVars prog, genFunctions prog, genStringLiterals] genGlobalVars :: Program -> CGMonad [A.Definition] genGlobalVars (Program decs) = liftM (mapMaybe id) $ mapM gen decs where gen :: Declaration -> CGMonad (Maybe A.Definition) gen (DecVariable t n Nothing) = do setGlobalVar n n t return $ Just $ A.GlobalDefinition $ A.globalVariableDefaults { A.G.name = A.Name n, A.G.type' = toLLVMType t, A.G.initializer = Just $ initializerFor t } gen (DecVariable _ _ (Just _)) = throwError $ "Initialised global variables not supported yet" gen (DecFunction rt n a _) = do setGlobalFunction n n (TypeFunc rt (map fst a)) return Nothing gen (DecExtern t@(TypeFunc rt ats) n) = do setGlobalFunction n n t argnames <- sequence $ replicate (length ats) (getNewName ".arg") return $ Just $ A.GlobalDefinition $ A.functionDefaults { A.G.returnType = toLLVMType rt, A.G.name = A.Name n, A.G.parameters = ([A.Parameter (toLLVMType at) (A.Name an) [] | (at,an) <- zip ats argnames], False), A.G.basicBlocks = [] } gen (DecExtern t n) = do setGlobalVar n n t return $ Just $ A.GlobalDefinition $ A.globalVariableDefaults { A.G.name = A.Name n, A.G.type' = toLLVMType t, A.G.initializer = Nothing } gen (DecTypedef _ _) = return Nothing genStringLiterals :: CGMonad [A.Definition] genStringLiterals = liftM stringLiterals get >>= return . map gen where gen :: (LLName, String) -> A.Definition gen (name, str) = A.GlobalDefinition $ A.globalVariableDefaults { A.G.name = A.Name name, A.G.linkage = A.L.Private, A.G.isConstant = True, A.G.type' = A.ArrayType (fromIntegral (length str + 1)) A.i8, A.G.initializer = Just $ A.C.Array A.i8 $ [A.C.Int 8 (fromIntegral (ord c)) | c <- str] ++ [A.C.Int 8 0] } genFunctions :: Program -> CGMonad [A.Definition] genFunctions (Program decs) = liftM (mapMaybe id) $ mapM gen decs where gen :: Declaration -> CGMonad (Maybe A.Definition) gen dec@(DecFunction rettype name args body) = do setCurrentFunction dec state $ \s -> ((), s { allBlocks = Map.empty, variables = Map.empty }) firstbb <- genFunctionBlock body (rettype, name) args cleanupTrampolines firstbb blockmap <- liftM allBlocks get let bbs' = map snd $ filter (\x -> fst x /= firstbb) $ Map.toList blockmap bbs = fromJust (Map.lookup firstbb blockmap) : bbs' return $ Just $ A.GlobalDefinition $ A.functionDefaults { A.G.returnType = toLLVMType rettype, A.G.name = A.Name name, A.G.parameters = ([A.Parameter (toLLVMType t) (A.Name (".farg_"++n)) [] | (t,n) <- args], False), A.G.basicBlocks = bbs } gen _ = return Nothing genFunctionBlock :: Block -> (Type, Name) -> [(Type, Name)] -> CGMonad LLName genFunctionBlock bl (rettype, fname) args = do firstbb <- newBlock let prepArg :: (Type,Name) -> CGMonad () prepArg (t,n) = do label <- addInstr $ A.Alloca (toLLVMType t) Nothing 0 [] void $ addInstr $ A.Store False (A.LocalReference (A.ptr (toLLVMType t)) (A.Name label)) (A.LocalReference (toLLVMType t) (A.Name (".farg_"++n))) Nothing 0 [] setVar n label t sequence_ $ map prepArg args termbb <- newBlock setTerminator $ if rettype == TypeVoid then A.Ret Nothing [] else A.Unreachable [] bodybb <- genBlock bl termbb changeBlock firstbb setTerminator $ A.Br (A.Name bodybb) [] if rettype /= TypeVoid then whenM (bbIsReferenced termbb) $ throwError $ "Control reaches end of non-void function '" ++ fname ++ "'" else return () if length args > 0 then return firstbb else return bodybb whenM :: (Monad m) => m Bool -> m a -> m () whenM cond value = cond >>= \c -> if c then void value else return () genBlock :: Block -> LLName -- name of BasicBlock following this Block -> CGMonad LLName -- name of first BasicBlock genBlock (Block []) following = genBlock (Block [StEmpty]) following genBlock (Block [stmt]) following = do firstbb <- genSingle stmt following return firstbb genBlock (Block (stmt:rest)) following = do interbb <- newBlock firstbb <- genSingle stmt interbb restbb <- genBlock (Block rest) following changeBlock interbb setTerminator $ A.Br (A.Name restbb) [] return firstbb genSingle :: Statement -> LLName -- name of BasicBlock following this statement -> CGMonad LLName -- name of first BasicBlock genSingle StEmpty following = return following genSingle (StBlock block) following = genBlock block following genSingle (StExpr expr) following = do bb <- newBlockJump following void $ genExpression expr return bb genSingle (StVarDeclaration t n Nothing) following = do bb <- newBlockJump following label <- addInstr $ A.Alloca (toLLVMType t) Nothing 0 [] setVar n label t return bb genSingle (StVarDeclaration _ _ (Just _)) _ = undefined genSingle (StAssignment name expr) following = do bb <- newBlockJump following oper <- genExpression expr (dsttype, _) <- lookupVar name oper' <- castOperand oper dsttype ref <- variableStoreOperand name void $ addInstr $ A.Store False ref oper' Nothing 0 [] return bb genSingle (StReturn Nothing) _ = do bb <- newBlock setTerminator $ A.Ret Nothing [] return bb genSingle (StReturn (Just expr)) _ = do bb <- newBlock oper <- genExpression expr rettype <- liftM (typeOf . currentFunction) get oper' <- castOperand oper rettype setTerminator $ A.Ret (Just oper') [] return bb genSingle (StIf cexpr st1 st2) following = do stbb1 <- genSingle st1 following stbb2 <- genSingle st2 following cbb <- newBlock coper <- genExpression cexpr coper' <- castToBool coper setTerminator $ A.CondBr coper' (A.Name stbb1) (A.Name stbb2) [] return cbb genSingle (StWhile cexpr st) following = do cbb <- newBlock loopbb <- newBlockJump cbb stbb <- genSingle st loopbb changeBlock cbb coper <- genExpression cexpr coper' <- castToBool coper setTerminator $ A.CondBr coper' (A.Name stbb) (A.Name following) [] return cbb genExpression :: Expression -> CGMonad A.Operand genExpression (ExLit lit (Just t)) = literalToOperand lit t genExpression (ExBinOp bo e1 e2 (Just t)) = do case bo of Plus -> do e1op <- genExprArgument e1 >>= flip castOperand t e2op <- genExprArgument e2 >>= flip castOperand t label <- case t of (TypeInt _) -> addInstr $ A.Add False False e1op e2op [] (TypeUInt _) -> addInstr $ A.Add False False e1op e2op [] TypeFloat -> addInstr $ A.FAdd A.NoFastMathFlags e1op e2op [] TypeDouble -> addInstr $ A.FAdd A.NoFastMathFlags e1op e2op [] (TypePtr _) -> throwError $ "Plus '+' operator not defined on pointers" (TypeFunc _ _) -> throwError $ "Plus '+' operator not defined on function pointers" (TypeName _) -> undefined TypeVoid -> undefined return $ A.LocalReference (toLLVMType t) (A.Name label) Minus -> do e1op <- genExprArgument e1 >>= flip castOperand t e2op <- genExprArgument e2 >>= flip castOperand t label <- case t of (TypeInt _) -> addInstr $ A.Sub False False e1op e2op [] (TypeUInt _) -> addInstr $ A.Sub False False e1op e2op [] TypeFloat -> addInstr $ A.FSub A.NoFastMathFlags e1op e2op [] TypeDouble -> addInstr $ A.FSub A.NoFastMathFlags e1op e2op [] (TypePtr _) -> throwError $ "Minus '-' operator not defined on pointers" (TypeFunc _ _) -> throwError $ "Minus '-' operator not defined on function pointers" (TypeName _) -> undefined TypeVoid -> undefined return $ A.LocalReference (toLLVMType t) (A.Name label) Times -> do e1op <- genExprArgument e1 >>= flip castOperand t e2op <- genExprArgument e2 >>= flip castOperand t label <- case t of (TypeInt _) -> addInstr $ A.Mul False False e1op e2op [] (TypeUInt _) -> addInstr $ A.Mul False False e1op e2op [] TypeFloat -> addInstr $ A.FMul A.NoFastMathFlags e1op e2op [] TypeDouble -> addInstr $ A.FMul A.NoFastMathFlags e1op e2op [] (TypePtr _) -> throwError $ "Multiply '*' operator not defined on pointers" (TypeFunc _ _) -> throwError $ "Multiply '*' operator not defined on function pointers" (TypeName _) -> undefined TypeVoid -> undefined return $ A.LocalReference (toLLVMType t) (A.Name label) Divide -> do e1op <- genExprArgument e1 >>= flip castOperand t e2op <- genExprArgument e2 >>= flip castOperand t label <- case t of (TypeInt _) -> addInstr $ A.SDiv False e1op e2op [] (TypeUInt _) -> addInstr $ A.UDiv False e1op e2op [] TypeFloat -> addInstr $ A.FDiv A.NoFastMathFlags e1op e2op [] TypeDouble -> addInstr $ A.FDiv A.NoFastMathFlags e1op e2op [] (TypePtr _) -> throwError $ "Divide '/' operator not defined on pointers" (TypeFunc _ _) -> throwError $ "Divide '/' operator not defined on function pointers" (TypeName _) -> undefined TypeVoid -> undefined return $ A.LocalReference (toLLVMType t) (A.Name label) Modulo -> do e1op <- genExprArgument e1 >>= flip castOperand t e2op <- genExprArgument e2 >>= flip castOperand t label <- case t of (TypeInt _) -> addInstr $ A.SRem e1op e2op [] (TypeUInt _) -> addInstr $ A.URem e1op e2op [] TypeFloat -> addInstr $ A.FRem A.NoFastMathFlags e1op e2op [] TypeDouble -> addInstr $ A.FRem A.NoFastMathFlags e1op e2op [] (TypePtr _) -> throwError $ "Modulo '%' operator not defined on pointers" (TypeFunc _ _) -> throwError $ "Modulo '%' operator not defined on function pointers" (TypeName _) -> undefined TypeVoid -> undefined return $ A.LocalReference (toLLVMType t) (A.Name label) Equal -> do sharedType <- commonTypeM (fromJust (exTypeOf e1)) (fromJust (exTypeOf e2)) e1op <- genExprArgument e1 >>= flip castOperand sharedType e2op <- genExprArgument e2 >>= flip castOperand sharedType label <- case sharedType of (TypeInt _) -> addInstr $ A.ICmp A.IP.EQ e1op e2op [] (TypeUInt _) -> addInstr $ A.ICmp A.IP.EQ e1op e2op [] TypeFloat -> addInstr $ A.FCmp A.FPP.OEQ e1op e2op [] TypeDouble -> addInstr $ A.FCmp A.FPP.OEQ e1op e2op [] (TypePtr _) -> addInstr $ A.ICmp A.IP.EQ e1op e2op [] (TypeFunc _ _) -> addInstr $ A.ICmp A.IP.EQ e1op e2op [] (TypeName _) -> undefined TypeVoid -> undefined return $ A.LocalReference (A.IntegerType 1) (A.Name label) Greater -> do sharedType <- commonTypeM (fromJust (exTypeOf e1)) (fromJust (exTypeOf e2)) e1op <- genExprArgument e1 >>= flip castOperand sharedType e2op <- genExprArgument e2 >>= flip castOperand sharedType label <- case sharedType of (TypeInt _) -> addInstr $ A.ICmp A.IP.SGT e1op e2op [] (TypeUInt _) -> addInstr $ A.ICmp A.IP.UGT e1op e2op [] TypeFloat -> addInstr $ A.FCmp A.FPP.OGT e1op e2op [] TypeDouble -> addInstr $ A.FCmp A.FPP.OGT e1op e2op [] (TypePtr _) -> addInstr $ A.ICmp A.IP.UGT e1op e2op [] (TypeFunc _ _) -> addInstr $ A.ICmp A.IP.UGT e1op e2op [] (TypeName _) -> undefined TypeVoid -> undefined return $ A.LocalReference (A.IntegerType 1) (A.Name label) Less -> do sharedType <- commonTypeM (fromJust (exTypeOf e1)) (fromJust (exTypeOf e2)) e1op <- genExprArgument e1 >>= flip castOperand sharedType e2op <- genExprArgument e2 >>= flip castOperand sharedType label <- case sharedType of (TypeInt _) -> addInstr $ A.ICmp A.IP.SLT e1op e2op [] (TypeUInt _) -> addInstr $ A.ICmp A.IP.ULT e1op e2op [] TypeFloat -> addInstr $ A.FCmp A.FPP.OLT e1op e2op [] TypeDouble -> addInstr $ A.FCmp A.FPP.OLT e1op e2op [] (TypePtr _) -> addInstr $ A.ICmp A.IP.ULT e1op e2op [] (TypeFunc _ _) -> addInstr $ A.ICmp A.IP.ULT e1op e2op [] (TypeName _) -> undefined TypeVoid -> undefined return $ A.LocalReference (A.IntegerType 1) (A.Name label) GEqual -> do sharedType <- commonTypeM (fromJust (exTypeOf e1)) (fromJust (exTypeOf e2)) e1op <- genExprArgument e1 >>= flip castOperand sharedType e2op <- genExprArgument e2 >>= flip castOperand sharedType label <- case sharedType of (TypeInt _) -> addInstr $ A.ICmp A.IP.SGE e1op e2op [] (TypeUInt _) -> addInstr $ A.ICmp A.IP.UGE e1op e2op [] TypeFloat -> addInstr $ A.FCmp A.FPP.OGE e1op e2op [] TypeDouble -> addInstr $ A.FCmp A.FPP.OGE e1op e2op [] (TypePtr _) -> addInstr $ A.ICmp A.IP.UGE e1op e2op [] (TypeFunc _ _) -> addInstr $ A.ICmp A.IP.UGE e1op e2op [] (TypeName _) -> undefined TypeVoid -> undefined return $ A.LocalReference (A.IntegerType 1) (A.Name label) LEqual -> do sharedType <- commonTypeM (fromJust (exTypeOf e1)) (fromJust (exTypeOf e2)) e1op <- genExprArgument e1 >>= flip castOperand sharedType e2op <- genExprArgument e2 >>= flip castOperand sharedType label <- case sharedType of (TypeInt _) -> addInstr $ A.ICmp A.IP.SLE e1op e2op [] (TypeUInt _) -> addInstr $ A.ICmp A.IP.ULE e1op e2op [] TypeFloat -> addInstr $ A.FCmp A.FPP.OLE e1op e2op [] TypeDouble -> addInstr $ A.FCmp A.FPP.OLE e1op e2op [] (TypePtr _) -> addInstr $ A.ICmp A.IP.ULE e1op e2op [] (TypeFunc _ _) -> addInstr $ A.ICmp A.IP.ULE e1op e2op [] (TypeName _) -> undefined TypeVoid -> undefined return $ A.LocalReference (A.IntegerType 1) (A.Name label) BoolAnd -> do firstbb <- liftM (fromJust . currentBlock) get (A.LocalReference (A.IntegerType 1) (A.Name label1)) <- genExprArgument e1 >>= castToBool (A.Do origterm) <- getTerminator bb2 <- newBlock (A.LocalReference (A.IntegerType 1) (A.Name label2)) <- genExprArgument e2 >>= castToBool bb3 <- newBlock changeBlock firstbb setTerminator $ A.CondBr (A.LocalReference A.i1 (A.Name label1)) (A.Name bb2) (A.Name bb3) [] changeBlock bb2 setTerminator $ A.Br (A.Name bb3) [] changeBlock bb3 setTerminator origterm reslabel <- addInstr $ A.Phi A.i1 [(A.ConstantOperand (A.C.Int 1 0), A.Name firstbb), (A.LocalReference A.i1 (A.Name label2), A.Name bb2)] [] return $ A.LocalReference A.i1 (A.Name reslabel) -- BoolOr -> do -- e1op' <- castToBool e1op -- e2op' <- castToBool e2op -- label <- addInstr $ A.Or e1op' e2op' [] -- return $ A.LocalReference (A.IntegerType 1) (A.Name label) _ -> throwError $ "Binary operator " ++ pshow bo ++ " not implemented" genExpression (ExUnOp uo e1 (Just t)) = do e1op <- genExprArgument e1 case uo of Negate -> do label <- case t of (TypeInt s) -> addInstr $ A.Sub False False (A.ConstantOperand (A.C.Int (fromIntegral s) 0)) e1op [] (TypeUInt s) -> addInstr $ A.Sub False False (A.ConstantOperand (A.C.Int (fromIntegral s) 0)) e1op [] TypeFloat -> addInstr $ A.FSub A.NoFastMathFlags (A.ConstantOperand (A.C.Float (A.F.Single 0))) e1op [] TypeDouble -> addInstr $ A.FSub A.NoFastMathFlags (A.ConstantOperand (A.C.Float (A.F.Double 0))) e1op [] (TypePtr _) -> throwError $ "Negate '-' operator not defined on a pointer" (TypeName _) -> undefined (TypeFunc _ _) -> throwError $ "Negate '-' operator not defined on a function pointer" TypeVoid -> undefined return $ A.LocalReference (toLLVMType t) (A.Name label) Not -> do label <- case t of (TypeInt s) -> addInstr $ A.ICmp A.IP.EQ (A.ConstantOperand (A.C.Int (fromIntegral s) 0)) e1op [] (TypeUInt s) -> addInstr $ A.ICmp A.IP.EQ (A.ConstantOperand (A.C.Int (fromIntegral s) 0)) e1op [] TypeFloat -> addInstr $ A.FCmp A.FPP.OEQ (A.ConstantOperand (A.C.Float (A.F.Single 0))) e1op [] TypeDouble -> addInstr $ A.FCmp A.FPP.OEQ (A.ConstantOperand (A.C.Float (A.F.Double 0))) e1op [] (TypePtr _) -> addInstr $ A.ICmp A.IP.EQ (A.ConstantOperand (A.C.Null (toLLVMType t))) e1op [] (TypeName _) -> undefined (TypeFunc _ _) -> throwError $ "Not '!' operator not defined on a function pointer" TypeVoid -> undefined return $ A.LocalReference (toLLVMType t) (A.Name label) Dereference -> do label <- case t of (TypePtr _) -> addInstr $ A.Load False e1op Nothing 0 [] _ -> throwError $ "Dereference '*' operator only defined on pointers" return $ A.LocalReference (toLLVMType t) (A.Name label) _ -> throwError $ "Unary operator " ++ pshow uo ++ " not implemented" genExpression ex = throwError $ "Expression '" ++ pshow ex ++ "' not implemented" genExprArgument :: Expression -> CGMonad A.Operand genExprArgument expr = case expr of (ExLit lit (Just t)) -> literalToOperand lit t _ -> genExpression expr literalToOperand :: Literal -> Type -> CGMonad A.Operand literalToOperand (LitInt i) (TypeInt sz) = return $ A.ConstantOperand (A.C.Int (fromIntegral sz) i) literalToOperand (LitFloat f) TypeFloat = return $ A.ConstantOperand (A.C.Float (A.F.Single (realToFrac f))) literalToOperand (LitFloat f) TypeDouble = return $ A.ConstantOperand (A.C.Float (A.F.Double f)) literalToOperand (LitVar n) t = do oper <- variableOperand n oper' <- castOperand oper t return oper' literalToOperand (LitString s) (TypePtr (TypeInt 8)) = do (ty, name) <- addStringLiteral s label <- addInstr $ A.GetElementPtr True (A.ConstantOperand $ A.C.GlobalReference ty (A.Name name)) [A.ConstantOperand $ A.C.Int 64 0, A.ConstantOperand $ A.C.Int 32 0] [] return $ A.LocalReference (A.ptr A.i8) (A.Name label) literalToOperand (LitString _) _ = undefined literalToOperand (LitCall n args) _ = do ((TypeFunc rt ats), lname) <- lookupGlobalFunction n let processArgs :: [Expression] -> [Type] -> CGMonad [A.Operand] processArgs [] [] = return [] processArgs [] _ = undefined processArgs _ [] = undefined processArgs (ex:exs) (t:ts) = do first <- genExpression ex >>= flip castOperand t rest <- processArgs exs ts return $ first : rest rargs <- processArgs args ats let argpairs = map (\a -> (a,[])) rargs foper = A.ConstantOperand $ A.C.GlobalReference (A.FunctionType (toLLVMType rt) (map toLLVMType ats) False) (A.Name lname) label <- addInstr $ A.Call Nothing A.CC.C [] (Right foper) argpairs [] [] return $ A.LocalReference (toLLVMType rt) (A.Name label) literalToOperand lit _ = throwError $ "Literal '" ++ pshow lit ++ "' not implemented" castOperand :: A.Operand -> Type -> CGMonad A.Operand castOperand orig@(A.ConstantOperand (A.C.Int s1 val)) t2@(TypeInt s2) | fromIntegral s1 == s2 = return orig | fromIntegral s1 < s2 = return $ A.ConstantOperand (A.C.Int (fromIntegral s2) val) | fromIntegral s1 > s2 = throwError $ "Integer " ++ show val ++ " too large for type '" ++ pshow t2 ++ "'" castOperand (A.ConstantOperand (A.C.Int _ val)) TypeFloat = do return $ A.ConstantOperand (A.C.Float (A.F.Single (fromIntegral val))) castOperand (A.ConstantOperand (A.C.Int _ val)) TypeDouble = do return $ A.ConstantOperand (A.C.Float (A.F.Double (fromIntegral val))) castOperand orig@(A.ConstantOperand (A.C.Float (A.F.Single _))) TypeFloat = do return orig castOperand orig@(A.ConstantOperand (A.C.Float (A.F.Double _))) TypeDouble = do return orig castOperand (A.ConstantOperand (A.C.Float (A.F.Single f))) TypeDouble = do return $ A.ConstantOperand (A.C.Float (A.F.Double (realToFrac f))) castOperand orig@(A.LocalReference (A.IntegerType s1) _) t2@(TypeInt s2) | fromIntegral s1 == s2 = return orig | fromIntegral s1 < s2 = do label <- addInstr $ A.SExt orig (toLLVMType t2) [] return $ A.LocalReference (toLLVMType t2) (A.Name label) | fromIntegral s1 > s2 = throwError $ "Cannot implicitly cast '" ++ pshow (TypeInt (fromIntegral s1)) ++ "' to '" ++ pshow t2 ++ "'" castOperand orig@(A.LocalReference (A.IntegerType s1) _) t2@(TypeUInt s2) | fromIntegral s1 == s2 = return orig | fromIntegral s1 < s2 = do label <- addInstr $ A.ZExt orig (toLLVMType t2) [] return $ A.LocalReference (toLLVMType t2) (A.Name label) | fromIntegral s1 > s2 = throwError $ "Cannot implicitly cast '" ++ pshow (TypeUInt (fromIntegral s1)) ++ "' to '" ++ pshow t2 ++ "'" castOperand orig@(A.ConstantOperand (A.C.GlobalReference (A.IntegerType s1) _)) t2@(TypeInt s2) | fromIntegral s1 == s2 = return orig | fromIntegral s1 < s2 = do label <- addInstr $ A.SExt orig (toLLVMType t2) [] return $ A.LocalReference (toLLVMType t2) (A.Name label) | fromIntegral s1 > s2 = throwError $ "Cannot implicitly cast '" ++ pshow (TypeInt (fromIntegral s1)) ++ "' to '" ++ pshow t2 ++ "'" castOperand orig@(A.ConstantOperand (A.C.GlobalReference (A.IntegerType s1) _)) t2@(TypeUInt s2) | fromIntegral s1 == s2 = return orig | fromIntegral s1 < s2 = do label <- addInstr $ A.ZExt orig (toLLVMType t2) [] return $ A.LocalReference (toLLVMType t2) (A.Name label) | fromIntegral s1 > s2 = throwError $ "Cannot implicitly cast '" ++ pshow (TypeUInt (fromIntegral s1)) ++ "' to '" ++ pshow t2 ++ "'" castOperand orig@(A.LocalReference t _) TypeFloat | t == toLLVMType TypeFloat = do return orig castOperand orig@(A.LocalReference t _) TypeDouble | t == toLLVMType TypeDouble = do return orig castOperand orig@(A.LocalReference t _) TypeDouble | t == toLLVMType TypeFloat = do label <- addInstr $ A.FPExt orig (toLLVMType TypeDouble) [] return $ A.LocalReference (toLLVMType TypeDouble) (A.Name label) castOperand orig@(A.LocalReference (A.PointerType t1 _) _) (TypePtr t2) | toLLVMType t2 == t1 = return orig | otherwise = throwError $ "Cannot implicitly cast between pointer to '" ++ show t1 ++ "' and '" ++ pshow t2 ++ "'" castOperand orig@(A.ConstantOperand (A.C.GlobalReference (A.PointerType t1 _) _)) (TypePtr t2) | toLLVMType t2 == t1 = return orig | otherwise = throwError $ "Cannot implicitly cast between pointer to '" ++ show t1 ++ "' and '" ++ pshow t2 ++ "'" castOperand orig@(A.LocalReference (A.PointerType (A.FunctionType rt1 at1 False) _) _) t2@(TypeFunc rt2 at2) | toLLVMType rt2 == rt1 && all (uncurry (==)) (zip at1 (map toLLVMType at2)) = return orig | otherwise = throwError $ "Cannot implicitly cast between '" ++ show orig ++ "' and '" ++ pshow t2 ++ "'" castOperand orig t2 = throwError $ "Cast from '" ++ show orig ++ "' to type '" ++ pshow t2 ++ "' not implemented" castToBool :: A.Operand -> CGMonad A.Operand castToBool orig@(A.LocalReference (A.IntegerType 1) _) = return orig castToBool orig@(A.LocalReference (A.IntegerType s1) _) = do label <- addInstr $ A.ICmp A.IP.NE orig (A.ConstantOperand (A.C.Int s1 0)) [] return $ A.LocalReference (A.IntegerType 1) (A.Name label) castToBool (A.ConstantOperand (A.C.Int _ val)) = return $ A.ConstantOperand (A.C.Int 1 (if val == 0 then 1 else 0)) castToBool _ = undefined commonType :: Type -> Type -> Maybe Type commonType (TypePtr t1) (TypePtr t2) | t1 == t2 = Just $ TypePtr t1 commonType (TypePtr _) _ = Nothing commonType _ (TypePtr _) = Nothing commonType (TypeInt s1) (TypeInt s2) = Just $ TypeInt (max s1 s2) commonType (TypeUInt s1) (TypeUInt s2) = Just $ TypeUInt (max s1 s2) commonType TypeFloat (TypeInt _) = Just TypeFloat commonType (TypeInt _) TypeFloat = Just TypeFloat commonType TypeDouble (TypeInt _) = Just TypeDouble commonType (TypeInt _) TypeDouble = Just TypeDouble commonType TypeFloat TypeFloat = Just TypeFloat commonType TypeDouble TypeDouble = Just TypeDouble commonType TypeFloat TypeDouble = Just TypeDouble commonType TypeDouble TypeFloat = Just TypeDouble commonType t@(TypeFunc rt1 at1) (TypeFunc rt2 at2) | rt1 == rt2 && at1 == at2 = Just t | otherwise = Nothing commonType _ _ = Nothing commonTypeM :: Type -> Type -> CGMonad Type commonTypeM t1 t2 = maybe err return $ commonType t1 t2 where err = throwError $ "Cannot implicitly find common type of '" ++ pshow t1 ++ "' and '" ++ pshow t2 ++ "'" cleanupTrampolines :: LLName -> CGMonad () cleanupTrampolines toskip = do state $ \s -> ((), s {allBlocks = go (allBlocks s)}) where go :: Map.Map LLName A.BasicBlock -> Map.Map LLName A.BasicBlock go bbs = folder bbs (Map.toList bbs) where folder :: Map.Map LLName A.BasicBlock -> [(LLName, A.BasicBlock)] -> Map.Map LLName A.BasicBlock folder whole [] = whole folder whole ((name, (A.BasicBlock (A.Name name2) [] (A.Do (A.Br (A.Name dst) [])))) : _) | name /= name2 = error "INTERNAL ERROR: name /= name2" | name /= toskip = let res = eliminate name dst $ Map.delete name whole in folder res (Map.toList res) folder whole (_:rest) = folder whole rest eliminate :: LLName -> LLName -> Map.Map LLName A.BasicBlock -> Map.Map LLName A.BasicBlock eliminate name dst bbs = Map.fromList $ map (\(n,bb) -> (n,goBB bb)) $ Map.toList bbs where goBB :: A.BasicBlock -> A.BasicBlock goBB (A.BasicBlock nm instrs (A.Name n A.:= term)) = A.BasicBlock nm instrs (A.Name n A.:= (goT term)) goBB (A.BasicBlock _ _ (A.UnName _ A.:= _)) = undefined goBB (A.BasicBlock nm instrs (A.Do term)) = A.BasicBlock nm instrs (A.Do (goT term)) goT :: A.Terminator -> A.Terminator goT (A.CondBr cond d1 d2 []) = A.CondBr cond (changeName name dst d1) (changeName name dst d2) [] goT (A.Br d []) = A.Br (changeName name dst d) [] goT (A.Switch op d1 ds []) = A.Switch op (changeName name dst d1) (map (\(c,n) -> (c, changeName name dst n)) ds) [] goT (A.IndirectBr {}) = undefined goT (A.Invoke {}) = undefined goT bb = bb changeName :: LLName -> LLName -> A.Name -> A.Name changeName from to (A.Name x) | x == from = A.Name to | otherwise = A.Name x changeName _ _ (A.UnName _) = undefined bbIsReferenced :: LLName -> CGMonad Bool bbIsReferenced bb = do bbs <- liftM allBlocks get return $ any checkBlock bbs where checkBlock :: A.BasicBlock -> Bool checkBlock (A.BasicBlock name instrs (_ A.:= term)) = checkBlock (A.BasicBlock name instrs (A.Do term)) checkBlock (A.BasicBlock _ _ (A.Do term)) = case term of (A.Ret _ _) -> False (A.CondBr _ (A.Name d1) (A.Name d2) _) -> d1 == bb || d2 == bb (A.Br (A.Name d) _) -> d == bb (A.Unreachable _) -> False _ -> undefined toLLVMType :: Type -> A.Type toLLVMType (TypeInt s) = A.IntegerType $ fromIntegral s toLLVMType (TypeUInt s) = A.IntegerType $ fromIntegral s toLLVMType TypeFloat = A.float toLLVMType TypeDouble = A.double toLLVMType (TypePtr t) = A.ptr $ toLLVMType t toLLVMType (TypeName _) = undefined toLLVMType (TypeFunc r a) = A.ptr $ A.FunctionType (toLLVMType r) (map toLLVMType a) False toLLVMType TypeVoid = A.VoidType initializerFor :: Type -> A.C.Constant initializerFor (TypeInt s) = A.C.Int (fromIntegral s) 0 initializerFor (TypeUInt s) = A.C.Int (fromIntegral s) 0 initializerFor TypeFloat = A.C.Float (A.F.Single 0) initializerFor TypeDouble = A.C.Float (A.F.Double 0) initializerFor _ = undefined