summaryrefslogtreecommitdiff
path: root/codegen.hs
diff options
context:
space:
mode:
Diffstat (limited to 'codegen.hs')
-rw-r--r--codegen.hs295
1 files changed, 229 insertions, 66 deletions
diff --git a/codegen.hs b/codegen.hs
index f2c35b4..f314c63 100644
--- a/codegen.hs
+++ b/codegen.hs
@@ -1,5 +1,7 @@
-module Codegen(module Codegen, A.Module) where
+module Codegen(codegen, A.Module) where
+import Control.Monad
+import Data.Maybe
import qualified Data.Map.Strict as Map
-- import qualified LLVM.General.AST.Type as A
-- import qualified LLVM.General.AST.Global as A
@@ -8,8 +10,10 @@ import qualified Data.Map.Strict as Map
-- import qualified LLVM.General.AST.Name as A
-- import qualified LLVM.General.AST.Instruction as A
import qualified LLVM.General.AST as A
+import Debug.Trace
import AST
+import PShow
type Error a = Either String a
@@ -52,52 +56,207 @@ preprocess prog@(Program decls) = mapProgram' filtered mapper
generateDefs :: Program -> Error [A.Definition]
generateDefs prog = do
checkUndefinedTypes prog
- checkUndefinedVars prog
- fail "TODO"
+ checked <- typeCheck prog
+ collected <- collectVarDecls checked
+ void $ trace "Collected:" $ return []
+ void $ trace (pshow collected) $ return []
+ void $ fail "TODO"
return []
checkUndefinedTypes :: Program -> Error ()
checkUndefinedTypes prog = fmap (const ()) $ mapProgram prog $ defaultPM {typeHandler = check}
- where
- check :: Type -> Error Type
- check (TypeName n) = Left $ "Undefined type name '" ++ n ++ "'"
- check t = Right t
-
--- checkUndefinedVars :: Program -> Error ()
--- checkUndefinedVars prog = do
-
-
--- mapTypes' :: Program -> (Type -> Type) -> Program
--- mapTypes' prog f = (\(Right res) -> res) $ mapTypes prog (return . f)
-
--- mapTypes :: Program -> (Type -> Error Type) -> Error Program
--- mapTypes (Program decls) f = Program <$> sequence (map goD decls)
--- where
--- handler :: Type -> Error Type
--- handler (TypePtr t) = f t >>= f . TypePtr
--- handler t = f t
-
--- goD :: Declaration -> Error Declaration
--- goD (DecFunction t n a b) = do
--- rt <- handler t
--- ra <- sequence $ map (\(at,an) -> (\art -> (art,an)) <$> handler at) a
--- rb <- goB b
--- return $ DecFunction rt n ra rb
--- goD (DecVariable t n v) = (\rt -> DecVariable rt n v) <$> handler t
--- goD (DecTypedef t n) = (\rt -> DecTypedef rt n) <$> handler t
-
--- goB :: Block -> Error Block
--- goB (Block stmts) = Block <$> sequence (map goS stmts)
-
--- goS :: Statement -> Error Statement
--- goS (StBlock bl) = StBlock <$> goB bl
--- goS (StVarDeclaration t n e) = (\rt -> StVarDeclaration rt n e) <$> handler t
--- goS (StIf c t e) = do
--- rt <- goS t
--- re <- goS e
--- return $ StIf c rt re
--- goS (StWhile c b) = StWhile c <$> goS b
--- goS s = return s
+ where
+ check :: MapperHandler Type
+ check (TypeName n) = Left $ "Undefined type name '" ++ n ++ "'"
+ check t = Right t
+
+
+typeCheck :: Program -> Error Program
+typeCheck (Program decls) = Program <$> mapM (goD topLevelNames) decls
+ where
+ topLevelNames :: Map.Map Name Type
+ topLevelNames = foldr (uncurry Map.insert) Map.empty pairs
+ where pairs = map ((,) <$> nameOf <*> typeOf) $ filter isVarDecl decls
+
+ functionTypes :: Map.Map Name (Type,[Type])
+ functionTypes = foldr (uncurry Map.insert) Map.empty pairs
+ where pairs = map ((,) <$> nameOf <*> getTypes) $ filter isFunctionDecl decls
+ getTypes (DecFunction rt _ args _) = (rt, map fst args)
+ getTypes _ = undefined
+
+ isVarDecl (DecVariable {}) = True
+ isVarDecl _ = False
+
+ isFunctionDecl (DecFunction {}) = True
+ isFunctionDecl _ = False
+
+ goD :: Map.Map Name Type -> Declaration -> Error Declaration
+ goD names (DecFunction frt name args body) = do
+ newbody <- goB frt (foldr (\(t,n) m -> Map.insert n t m) names args) body
+ return $ DecFunction frt name args newbody
+ goD _ dec = return dec
+
+ goB :: Type -- function return type
+ -> Map.Map Name Type -> Block -> Error Block
+ goB frt names (Block stmts) = Block . snd <$> foldl foldfunc (return (names, [])) stmts
+ where
+ foldfunc :: Error (Map.Map Name Type, [Statement]) -> Statement -> Error (Map.Map Name Type, [Statement])
+ foldfunc ep st = do
+ (names', lst) <- ep
+ (newnames', newst) <- goS frt names' st
+ return (newnames', lst ++ [newst]) -- TODO: fix slow tail-append
+
+ goS :: Type -- function return type
+ -> Map.Map Name Type -> Statement -> Error (Map.Map Name Type, Statement)
+ goS _ names st@(StVarDeclaration t n Nothing) = return (Map.insert n t names, st)
+ goS frt names (StVarDeclaration t n (Just e)) = do
+ (newnames, _) <- goS frt names (StVarDeclaration t n Nothing)
+ goS frt newnames (StAssignment n e)
+ goS _ names (StAssignment n e) = maybe (Left $ "Undefined variable '" ++ n ++ "'") go (Map.lookup n names)
+ where go dsttype = do
+ re <- goE names e
+ let (Just extype) = exTypeOf re
+ if canConvert extype dsttype
+ then return (names, StAssignment n re)
+ else Left $ "Cannot convert type '" ++ pshow extype ++ "' to '"
+ ++ pshow dsttype ++ "' in assignment to variable '" ++ n ++ "'"
+ goS _ names st@StEmpty = return (names, st)
+ goS frt names (StBlock bl) = do
+ newbl <- goB frt names bl
+ return (names, StBlock newbl)
+ goS _ names (StExpr e) = do
+ re <- goE names e
+ return (names, StExpr re)
+ goS frt names (StIf e s1 s2) = do
+ re <- goE names e
+ (_, rs1) <- goS frt names s1
+ (_, rs2) <- goS frt names s2
+ return (names, StIf re rs1 rs2)
+ goS frt names (StWhile e s) = do
+ re <- goE names e
+ (_, rs) <- goS frt names s
+ return (names, StWhile re rs)
+ goS frt names (StReturn e) = do
+ re <- goE names e
+ let (Just extype) = exTypeOf re
+ if canConvert extype frt
+ then return (names, StReturn re)
+ else Left $ "Cannot convert type '" ++ pshow extype ++ "' to '"
+ ++ pshow frt ++ "' in return statement"
+
+ -- Postcondition: the expression (if any) has a type annotation.
+ goE :: Map.Map Name Type -> Expression -> Error Expression
+ goE _ (ExLit l@(LitInt i) _) = return $ ExLit l $ Just (smallestIntType i)
+ goE _ (ExLit l@(LitString _) _) = return $ ExLit l $ Just (TypePtr (TypeInt 8))
+ goE names (ExLit l@(LitVar n) _) = maybe (Left $ "Undefined variable '" ++ n ++ "'") (return . ExLit l . Just)
+ (Map.lookup n names)
+ goE names (ExLit l@(LitCall n args) _) = do
+ ft <- maybe (Left $ "Unknown function '" ++ n ++ "'") return $ Map.lookup n functionTypes
+ rargs <- mapM (goE names) args
+ when (length rargs /= length (snd ft))
+ $ Left ("Expected " ++ show (length (snd ft)) ++ "arguments to "
+ ++ "function '" ++ n ++ "', but got " ++ show (length rargs))
+ >> return ()
+ flip mapM_ rargs $
+ \a -> let argtype = fromJust (exTypeOf a)
+ in if canConvert argtype (fst ft)
+ then return a
+ else Left $ "Cannot convert type '" ++ pshow argtype ++ "' to '" ++ pshow (fst ft)
+ ++ "' in call of function '" ++ pshow n ++ "'"
+ return $ ExLit l (Just (fst ft))
+ goE names (ExBinOp bo e1 e2 _) = do
+ re1 <- goE names e1
+ re2 <- goE names e2
+ maybe (Left $ "Cannot use operator '" ++ pshow bo ++ "' with argument types '"
+ ++ pshow (fromJust $ exTypeOf re1) ++ "' and '" ++ pshow (fromJust $ exTypeOf re2) ++ "'")
+ (return . ExBinOp bo re1 re2 . Just)
+ $ typeCompatibleBO bo (fromJust $ exTypeOf re1) (fromJust $ exTypeOf re2)
+ goE names (ExUnOp uo e _) = do
+ re <- goE names e
+ maybe (Left $ "Cannot use operator '" ++ pshow uo ++ "' with argument type '" ++ pshow (fromJust $ exTypeOf re))
+ (return . ExUnOp uo re . Just)
+ $ typeCompatibleUO uo (fromJust $ exTypeOf re)
+
+
+collectVarDecls :: Program -> Error Program
+collectVarDecls prog = mapProgram prog $ defaultPM {blockHandler = goBlock}
+ where
+ goBlock :: MapperHandler Block
+ goBlock (Block stmts) =
+ let isVarDecl (StVarDeclaration {}) = True
+ isVarDecl _ = False
+
+ removeDecls [] = []
+ removeDecls ((StVarDeclaration _ n (Just ex)):rest) = StAssignment n ex : removeDecls rest
+ removeDecls ((StVarDeclaration _ _ Nothing):rest) = removeDecls rest
+ removeDecls (st:rest) = st : removeDecls rest
+
+ onlyDecl (StVarDeclaration t n _) = StVarDeclaration t n Nothing
+ onlyDecl _ = undefined
+
+ vdecls = map onlyDecl $ filter isVarDecl stmts
+ in return $ Block $ vdecls ++ removeDecls stmts
+
+
+canConvert :: Type -> Type -> Bool
+canConvert x y | x == y = True
+canConvert (TypeInt f) (TypeInt t) = f <= t
+canConvert (TypeUInt f) (TypeUInt t) = f <= t
+canConvert TypeFloat TypeDouble = True
+canConvert _ _ = False
+
+arithBO, compareBO, logicBO, complogBO :: [BinaryOperator]
+arithBO = [Plus, Minus, Times, Divide, Modulo]
+compareBO = [Equal, Unequal, Greater, Less, GEqual, LEqual]
+logicBO = [BoolAnd, BoolOr]
+complogBO = compareBO ++ logicBO
+
+typeCompatibleBO :: BinaryOperator -> Type -> Type -> Maybe Type
+typeCompatibleBO Minus (TypePtr t1) (TypePtr t2) | t1 == t2 = Just $ TypeInt 1
+typeCompatibleBO bo (TypePtr t1) (TypePtr t2) | t1 == t2 && bo `elem` complogBO = Just $ TypeInt 1
+typeCompatibleBO _ (TypePtr _) _ = Nothing
+typeCompatibleBO _ _ (TypePtr _) = Nothing
+
+typeCompatibleBO bo (TypeInt s1) (TypeInt s2) | bo `elem` arithBO = Just $ TypeInt (max s1 s2)
+typeCompatibleBO bo (TypeInt _) (TypeInt _) | bo `elem` complogBO = Just $ TypeInt 1
+
+typeCompatibleBO bo (TypeUInt s1) (TypeUInt s2) | bo `elem` arithBO = Just $ TypeUInt (max s1 s2)
+typeCompatibleBO bo (TypeUInt _) (TypeUInt _) | bo `elem` complogBO = Just $ TypeInt 1
+
+typeCompatibleBO bo t1 t2 | bo `elem` complogBO && t1 == t2 = Just $ TypeInt 1
+
+typeCompatibleBO bo TypeFloat (TypeInt s) | s <= 24 = Just $ if bo `elem` arithBO then TypeFloat else TypeInt 1
+typeCompatibleBO bo (TypeInt s) TypeFloat | s <= 24 = Just $ if bo `elem` arithBO then TypeFloat else TypeInt 1
+typeCompatibleBO bo TypeDouble (TypeInt s) | s <= 53 = Just $ if bo `elem` arithBO then TypeDouble else TypeInt 1
+typeCompatibleBO bo (TypeInt s) TypeDouble | s <= 53 = Just $ if bo `elem` arithBO then TypeDouble else TypeInt 1
+typeCompatibleBO bo TypeFloat TypeDouble = Just $ if bo `elem` arithBO then TypeFloat else TypeInt 1
+typeCompatibleBO bo TypeDouble TypeFloat = Just $ if bo `elem` arithBO then TypeFloat else TypeInt 1
+
+typeCompatibleBO _ _ _ = Nothing
+
+typeCompatibleUO :: UnaryOperator -> Type -> Maybe Type
+typeCompatibleUO Not _ = Just $ TypeInt 1
+typeCompatibleUO Address t = Just $ TypePtr t
+typeCompatibleUO uo t@(TypeInt _) | uo `elem` [Negate, Invert] = Just t
+typeCompatibleUO uo t@(TypeUInt _) | uo `elem` [Negate, Invert] = Just t
+typeCompatibleUO Negate TypeFloat = Just TypeFloat
+typeCompatibleUO Negate TypeDouble = Just TypeDouble
+typeCompatibleUO Dereference t@(TypePtr _) = Just t
+typeCompatibleUO _ _ = Nothing
+
+smallestIntType :: Integer -> Type
+smallestIntType i
+ | i >= -2^7 && i < 2^7 = TypeInt 8
+ | i >= -2^15 && i < 2^15 = TypeInt 16
+ | i >= -2^31 && i < 2^31 = TypeInt 32
+ | otherwise = TypeInt 64
+
+-- smallestUIntType :: Integer -> Type
+-- smallestUIntType i
+-- | i >= 0 && i < 2^8 = TypeUInt 8
+-- | i >= 0 && i < 2^16 = TypeUInt 16
+-- | i >= 0 && i < 2^32 = TypeUInt 32
+-- | otherwise = TypeUInt 64
type MapperHandler a = a -> Error a
@@ -135,14 +294,14 @@ defaultPM' = ProgramMapper' id id id id id id id id id
mapProgram' :: Program -> ProgramMapper' -> Program
mapProgram' prog mapper = (\(Right r) -> r) $ mapProgram prog $ ProgramMapper
{declarationHandler = return . declarationHandler' mapper
- ,blockHandler = return . blockHandler' mapper
- ,typeHandler = return . typeHandler' mapper
- ,literalHandler = return . literalHandler' mapper
- ,binOpHandler = return . binOpHandler' mapper
- ,unOpHandler = return . unOpHandler' mapper
- ,expressionHandler = return . expressionHandler' mapper
- ,statementHandler = return . statementHandler' mapper
- ,nameHandler = return . nameHandler' mapper}
+ ,blockHandler = return . blockHandler' mapper
+ ,typeHandler = return . typeHandler' mapper
+ ,literalHandler = return . literalHandler' mapper
+ ,binOpHandler = return . binOpHandler' mapper
+ ,unOpHandler = return . unOpHandler' mapper
+ ,expressionHandler = return . expressionHandler' mapper
+ ,statementHandler = return . statementHandler' mapper
+ ,nameHandler = return . nameHandler' mapper}
mapProgram :: Program -> ProgramMapper -> Error Program
mapProgram prog mapper = goP prog
@@ -157,10 +316,10 @@ mapProgram prog mapper = goP prog
h_s = statementHandler mapper
h_n = nameHandler mapper
- goP :: Program -> Error Program
+ goP :: MapperHandler Program
goP (Program decls) = Program <$> sequence (map (\d -> goD d >>= h_d) decls)
- goD :: Declaration -> Error Declaration
+ goD :: MapperHandler Declaration
goD (DecFunction t n a b) = do
rt <- goT t
rn <- goN n
@@ -177,30 +336,32 @@ mapProgram prog mapper = goP prog
rn <- goN n
h_d $ DecTypedef rt rn
- goT :: Type -> Error Type
+ goT :: MapperHandler Type
goT (TypePtr t) = goT t >>= (h_t . TypePtr)
goT (TypeName n) = goN n >>= (h_t . TypeName)
goT t = h_t t
- goN :: Name -> Error Name
+ goN :: MapperHandler Name
goN = h_n
- goB :: Block -> Error Block
+ goB :: MapperHandler Block
goB (Block sts) = (Block <$> sequence (map goS sts)) >>= h_b
- goE :: Expression -> Error Expression
- goE (ExLit l) = goL l >>= (h_e . ExLit)
- goE (ExBinOp bo e1 e2) = do
+ goE :: MapperHandler Expression
+ goE (ExLit l mt) = do
+ rl <- goL l
+ h_e $ ExLit rl mt
+ goE (ExBinOp bo e1 e2 mt) = do
rbo <- goBO bo
re1 <- goE e1
re2 <- goE e2
- h_e $ ExBinOp rbo re1 re2
- goE (ExUnOp uo e) = do
+ h_e $ ExBinOp rbo re1 re2 mt
+ goE (ExUnOp uo e mt) = do
ruo <- goUO uo
re <- goE e
- h_e $ ExUnOp ruo re
+ h_e $ ExUnOp ruo re mt
- goS :: Statement -> Error Statement
+ goS :: MapperHandler Statement
goS StEmpty = h_s StEmpty
goS (StBlock b) = goB b >>= (h_s . StBlock)
goS (StExpr e) = goE e >>= (h_s . StExpr)
@@ -224,15 +385,17 @@ mapProgram prog mapper = goP prog
h_s $ StWhile re rs
goS (StReturn e) = goE e >>= (h_s . StReturn)
- goL :: Literal -> Error Literal
+ goL :: MapperHandler Literal
+ goL l@(LitString _) = h_l l
+ goL l@(LitInt _) = h_l l
goL (LitVar n) = goN n >>= (h_l . LitVar)
goL (LitCall n a) = do
rn <- goN n
ra <- sequence $ map goE a
h_l $ LitCall rn ra
- goBO :: BinaryOperator -> Error BinaryOperator
+ goBO :: MapperHandler BinaryOperator
goBO = h_bo
- goUO :: UnaryOperator -> Error UnaryOperator
+ goUO :: MapperHandler UnaryOperator
goUO = h_uo