{-# LANGUAGE GeneralizedNewtypeDeriving, TupleSections #-} module Codegen(codegen) where import Control.Monad.State.Strict import Control.Monad.Except import Data.Maybe import qualified Data.Map.Strict as Map import qualified LLVM.General.AST.Type as A import qualified LLVM.General.AST.Global as A.G import qualified LLVM.General.AST.Constant as A.C -- import qualified LLVM.General.AST.Operand as A -- import qualified LLVM.General.AST.Name as A -- import qualified LLVM.General.AST.Instruction as A import qualified LLVM.General.AST as A import Debug.Trace import AST import PShow type Error a = Either String a type LLName = String data GenState = GenState {currentBlock :: Maybe LLName ,allBlocks :: Map.Map LLName A.BasicBlock ,currentFunction :: Declaration ,nextId :: Integer ,definitions :: [A.Definition] ,variables :: Map.Map Name (Type, LLName) ,globalVariables :: Map.Map Name (Type, LLName)} deriving (Show) initialGenState :: GenState initialGenState = GenState {currentBlock = Nothing ,allBlocks = Map.empty ,currentFunction = undefined ,nextId = 1 ,definitions = [] ,variables = Map.empty ,globalVariables = Map.empty} newtype CGMonad a = CGMonad {unMon :: ExceptT String (State GenState) a} deriving (Functor, Applicative, Monad, MonadState GenState, MonadError String) runCGMonad :: CGMonad a -> Error (a, GenState) runCGMonad m = let (e, s) = runState (runExceptT (unMon m)) initialGenState in either Left (\x -> Right (x, s)) e getUniqueId :: CGMonad Integer getUniqueId = state $ \s -> (nextId s, s {nextId = nextId s + 1}) getNewName :: String -> CGMonad String getNewName base = fmap ((base++) . show) getUniqueId newBlock :: CGMonad LLName newBlock = do name <- getNewName "bb" state $ \s -> (name, s { currentBlock = Just name, allBlocks = Map.insert name (A.BasicBlock (A.Name name) [] undefined) $ allBlocks s }) newBlockJump :: LLName -> CGMonad LLName newBlockJump next = do bb <- newBlock setTerminator $ A.Br (A.Name next) [] return bb changeBlock :: LLName -> CGMonad () changeBlock name = state $ \s -> ((), s {currentBlock = Just name}) addInstr :: A.Instruction -> CGMonad LLName addInstr instr = do name <- getNewName "t" addNamedInstr $ A.Name name A.:= instr addNamedInstr :: A.Named A.Instruction -> CGMonad LLName addNamedInstr instr@(A.Name name A.:= _) = do let append (A.BasicBlock n il t) = A.BasicBlock n (il ++ [instr]) t state $ \s -> (name, s {allBlocks = Map.adjust append (fromJust (currentBlock s)) (allBlocks s)}) addNamedInstr _ = undefined -- addNamedInstrList :: [A.Named A.Instruction] -> CGMonad LLName -- addNamedInstrList l = mapM addNamedInstr l >>= return . last setTerminator :: A.Terminator -> CGMonad () setTerminator term = do let replace (A.BasicBlock n il _) = A.BasicBlock n il (A.Do term) state $ \s -> ((), s {allBlocks = Map.adjust replace (fromJust (currentBlock s)) (allBlocks s)}) setCurrentFunction :: Declaration -> CGMonad () setCurrentFunction dec = do state $ \s -> ((), s {currentFunction = dec}) setVar :: Name -> LLName -> Type -> CGMonad () setVar name label t = do state $ \s -> ((), s {variables = Map.insert name (t, label) $ variables s}) setGlobalVar :: Name -> LLName -> Type -> CGMonad () setGlobalVar name label t = do state $ \s -> ((), s {globalVariables = Map.insert name (t, label) $ globalVariables s}) lookupVar :: Name -> CGMonad (Type, LLName) lookupVar name = liftM (fromJust . Map.lookup name . variables) get lookupGlobalVar :: Name -> CGMonad (Type, LLName) lookupGlobalVar name = liftM (fromJust . Map.lookup name . globalVariables) get variableStoreOperand :: Name -> CGMonad A.Operand variableStoreOperand name = get >>= (maybe getGlobal getLocal . Map.lookup name . variables) where getLocal :: (Type, LLName) -> CGMonad A.Operand getLocal (t, nm) = return $ A.LocalReference (toLLVMType t) (A.Name nm) getGlobal :: CGMonad A.Operand getGlobal = do (t, nm) <- lookupGlobalVar name return $ A.ConstantOperand $ A.C.GlobalReference (toLLVMType t) (A.Name nm) variableOperand :: Name -> CGMonad A.Operand variableOperand name = get >>= (maybe getGlobal getLocal . Map.lookup name . variables) where getLocal :: (Type, LLName) -> CGMonad A.Operand getLocal (t, nm) = do let loadoper = A.LocalReference (A.ptr (toLLVMType t)) (A.Name nm) label <- addInstr $ A.Load False loadoper Nothing 0 [] return $ A.LocalReference (toLLVMType t) (A.Name label) getGlobal :: CGMonad A.Operand getGlobal = do (t, nm) <- lookupGlobalVar name let loadoper = A.ConstantOperand $ A.C.GlobalReference (A.ptr (toLLVMType t)) (A.Name nm) label <- addInstr $ A.Load False loadoper Nothing 0 [] return $ A.LocalReference (toLLVMType t) (A.Name label) -- namedName :: A.Named a -> LLName -- namedName (A.Name name A.:= _) = name -- namedName _ = undefined codegen :: Program -- Program to compile -> String -- Module name -> String -- File name of source -> Error A.Module codegen prog name fname = do (defs, st) <- runCGMonad $ do defs <- generateDefs prog return defs traceShow st $ return () return $ A.defaultModule { A.moduleName = name, A.moduleSourceFileName = fname, A.moduleDefinitions = defs } generateDefs :: Program -> CGMonad [A.Definition] generateDefs prog = do vardecls <- genGlobalVars prog fundecls <- genFunctions prog return $ vardecls ++ fundecls genGlobalVars :: Program -> CGMonad [A.Definition] genGlobalVars (Program decs) = mapM gen $ filter isDecVariable decs where gen :: Declaration -> CGMonad A.Definition gen (DecVariable t n Nothing) = do setGlobalVar n n t return $ A.GlobalDefinition $ A.globalVariableDefaults { A.G.name = A.Name n, A.G.type' = toLLVMType t, A.G.initializer = Just $ initializerFor t } gen (DecVariable _ _ (Just _)) = throwError $ "Initialised global variables not supported yet" gen _ = undefined genFunctions :: Program -> CGMonad [A.Definition] genFunctions (Program decs) = mapM gen $ filter isDecFunction decs where gen :: Declaration -> CGMonad A.Definition gen dec@(DecFunction rettype name args body) = do setCurrentFunction dec firstbb <- genBlock' body cleanupTrampolines blockmap <- liftM allBlocks get let bbs' = map snd $ filter (\x -> fst x /= firstbb) $ Map.toList blockmap bbs = fromJust (Map.lookup firstbb blockmap) : bbs' return $ A.GlobalDefinition $ A.functionDefaults { A.G.returnType = toLLVMType rettype, A.G.name = A.Name name, A.G.parameters = ([A.Parameter (toLLVMType t) (A.Name n) [] | (t,n) <- args], False), A.G.basicBlocks = bbs } gen _ = undefined genBlock' :: Block -> CGMonad LLName genBlock' bl = do termbb <- newBlock setTerminator $ A.Unreachable [] genBlock bl termbb genBlock :: Block -> LLName -- name of BasicBlock following this Block -> CGMonad LLName -- name of first BasicBlock genBlock (Block []) following = genBlock (Block [StEmpty]) following genBlock (Block [stmt]) following = do firstbb <- genSingle stmt following return firstbb genBlock (Block (stmt:rest)) following = do interbb <- newBlock firstbb <- genSingle stmt interbb restbb <- genBlock (Block rest) following changeBlock interbb setTerminator $ A.Br (A.Name restbb) [] return firstbb genSingle :: Statement -> LLName -- name of BasicBlock following this statement -> CGMonad LLName -- name of first BasicBlock genSingle StEmpty following = newBlockJump following genSingle (StBlock block) following = genBlock block following genSingle (StExpr expr) following = do bb <- newBlockJump following void $ genExpression expr return bb genSingle (StVarDeclaration t n Nothing) following = do bb <- newBlockJump following label <- addInstr $ A.Alloca (toLLVMType t) Nothing 0 [] setVar n label t return bb genSingle (StVarDeclaration _ _ (Just _)) _ = undefined genSingle (StAssignment name expr) following = do bb <- newBlockJump following oper <- genExpression expr (dsttype, _) <- lookupVar name oper' <- castOperand oper dsttype ref <- variableStoreOperand name void $ addInstr $ A.Store False ref oper' Nothing 0 [] return bb genSingle (StReturn expr) _ = do bb <- newBlock oper <- genExpression expr rettype <- liftM (typeOf . currentFunction) get oper' <- castOperand oper rettype setTerminator $ A.Ret (Just oper') [] return bb genSingle _ _ = undefined genExpression :: Expression -> CGMonad A.Operand genExpression (ExLit lit (Just t)) = literalToOperand lit t -- genExpression (ExLit (LitInt i) (Just t@(TypeInt sz))) = do -- aname <- getNewName "t" -- void $ addNamedInstr $ A.Name aname A.:= A.Alloca (toLLVMType t) Nothing 0 [] -- void $ addInstr $ A.Store False (A.LocalReference (toLLVMType t) (A.Name aname)) -- (A.ConstantOperand (A.C.Int (fromIntegral sz) i)) Nothing 0 [] -- return aname genExpression (ExBinOp bo e1 e2 (Just t)) = do e1op <- genExprArgument e1 e2op <- genExprArgument e2 case bo of Plus -> do e1op' <- castOperand e1op t e2op' <- castOperand e2op t label <- case t of (TypeInt _) -> addInstr $ A.Add False False e1op' e2op' [] (TypeUInt _) -> addInstr $ A.Add False False e1op' e2op' [] TypeFloat -> addInstr $ A.FAdd A.NoFastMathFlags e1op' e2op' [] TypeDouble -> addInstr $ A.FAdd A.NoFastMathFlags e1op' e2op' [] (TypePtr _) -> addInstr $ A.Add False False e1op' e2op' [] (TypeName _) -> undefined return $ A.LocalReference (toLLVMType t) (A.Name label) Minus -> do e1op' <- castOperand e1op t e2op' <- castOperand e2op t label <- case t of (TypeInt _) -> addInstr $ A.Sub False False e1op' e2op' [] (TypeUInt _) -> addInstr $ A.Sub False False e1op' e2op' [] TypeFloat -> addInstr $ A.FSub A.NoFastMathFlags e1op' e2op' [] TypeDouble -> addInstr $ A.FSub A.NoFastMathFlags e1op' e2op' [] (TypePtr _) -> addInstr $ A.Sub False False e1op' e2op' [] (TypeName _) -> undefined return $ A.LocalReference (toLLVMType t) (A.Name label) _ -> throwError $ "Binary operator " ++ pshow bo ++ " not implemented" genExpression ex = throwError $ "Expression '" ++ pshow ex ++ "' not implemented" genExprArgument :: Expression -> CGMonad A.Operand genExprArgument expr = case expr of (ExLit lit (Just t)) -> literalToOperand lit t _ -> genExpression expr literalToOperand :: Literal -> Type -> CGMonad A.Operand literalToOperand (LitInt i) (TypeInt sz) = return $ A.ConstantOperand (A.C.Int (fromIntegral sz) i) literalToOperand (LitVar n) t = do oper <- variableOperand n oper' <- castOperand oper t return oper' literalToOperand lit _ = throwError $ "Literal '" ++ pshow lit ++ "' not implemented" castOperand :: A.Operand -> Type -> CGMonad A.Operand castOperand orig@(A.LocalReference (A.IntegerType s1) _) t2@(TypeInt s2) | fromIntegral s1 == s2 = return orig | fromIntegral s1 < s2 = do label <- addInstr $ A.SExt orig (toLLVMType t2) [] return $ A.LocalReference (toLLVMType t2) (A.Name label) | fromIntegral s1 > s2 = throwError $ "Cannot implicitly cast '" ++ pshow (TypeInt (fromIntegral s1)) ++ "' to '" ++ pshow t2 ++ "'" castOperand orig@(A.ConstantOperand (A.C.Int s1 val)) t2@(TypeInt s2) | fromIntegral s1 == s2 = return orig | fromIntegral s1 < s2 = return $ A.ConstantOperand (A.C.Int (fromIntegral s2) val) | fromIntegral s1 > s2 = throwError $ "Integer " ++ show val ++ " too large for type '" ++ pshow t2 ++ "'" castOperand orig@(A.ConstantOperand (A.C.GlobalReference (A.IntegerType s1) _)) t2@(TypeInt s2) | fromIntegral s1 == s2 = return orig | fromIntegral s1 < s2 = do label <- addInstr $ A.SExt orig (toLLVMType t2) [] return $ A.LocalReference (toLLVMType t2) (A.Name label) | fromIntegral s1 > s2 = throwError $ "Cannot implicitly cast '" ++ pshow (TypeInt (fromIntegral s1)) ++ "' to '" ++ pshow t2 ++ "'" castOperand orig t2 = throwError $ "Cast from '" ++ show orig ++ "' to type '" ++ pshow t2 ++ "' not implemented" cleanupTrampolines :: CGMonad () cleanupTrampolines = do st <- get let newblocks = go (allBlocks st) put $ st {allBlocks = newblocks} where go :: Map.Map LLName A.BasicBlock -> Map.Map LLName A.BasicBlock go bbs = folder bbs (Map.toList bbs) where folder :: Map.Map LLName A.BasicBlock -> [(LLName, A.BasicBlock)] -> Map.Map LLName A.BasicBlock folder whole [] = whole folder whole ((name, (A.BasicBlock (A.Name name2) [] (A.Do (A.Br (A.Name dst) [])))) : _) | name /= name2 = error "INTERNAL ERROR: name /= name2" | otherwise = let res = eliminate name dst $ Map.delete name whole in folder res (Map.toList res) folder whole (_:rest) = folder whole rest eliminate :: LLName -> LLName -> Map.Map LLName A.BasicBlock -> Map.Map LLName A.BasicBlock eliminate name dst bbs = Map.fromList $ map (\(n,bb) -> (n,goBB bb)) $ Map.toList bbs where goBB :: A.BasicBlock -> A.BasicBlock goBB (A.BasicBlock nm instrs (A.Name n A.:= term)) = A.BasicBlock nm instrs (A.Name n A.:= (goT term)) goBB (A.BasicBlock _ _ (A.UnName _ A.:= _)) = undefined goBB (A.BasicBlock nm instrs (A.Do term)) = A.BasicBlock nm instrs (A.Do (goT term)) goT :: A.Terminator -> A.Terminator goT (A.CondBr cond d1 d2 []) = A.CondBr cond (changeName name dst d1) (changeName name dst d2) [] goT (A.Br d []) = A.Br (changeName name dst d) [] goT (A.Switch op d1 ds []) = A.Switch op (changeName name dst d1) (map (\(c,n) -> (c, changeName name dst n)) ds) [] goT (A.IndirectBr {}) = undefined goT (A.Invoke {}) = undefined goT bb = bb changeName :: LLName -> LLName -> A.Name -> A.Name changeName from to (A.Name x) | x == from = A.Name to | otherwise = A.Name x changeName _ _ (A.UnName _) = undefined toLLVMType :: Type -> A.Type toLLVMType (TypeInt s) = A.IntegerType $ fromIntegral s toLLVMType (TypeUInt s) = A.IntegerType $ fromIntegral s toLLVMType TypeFloat = A.float toLLVMType TypeDouble = A.double toLLVMType (TypePtr t) = A.ptr $ toLLVMType t toLLVMType (TypeName _) = undefined initializerFor :: Type -> A.C.Constant initializerFor (TypeInt s) = A.C.Int (fromIntegral s) 0 initializerFor (TypeUInt s) = A.C.Int (fromIntegral s) 0 initializerFor _ = undefined isDecVariable :: Declaration -> Bool isDecVariable (DecVariable {}) = True isDecVariable _ = False isDecFunction :: Declaration -> Bool isDecFunction (DecFunction {}) = True isDecFunction _ = False