diff options
| -rw-r--r-- | Makefile | 2 | ||||
| -rw-r--r-- | ast.hs | 47 | ||||
| -rw-r--r-- | codegen.hs | 281 | ||||
| -rw-r--r-- | main.hs | 3 | ||||
| -rw-r--r-- | parser.hs | 10 | 
5 files changed, 263 insertions, 80 deletions
@@ -1,5 +1,5 @@  GHC = ghc -GHCFLAGS = -Wall -O3 +GHCFLAGS = -Wall -Wno-type-defaults -O3  TARGET = main @@ -1,7 +1,4 @@ -module AST( -    Name, -    Program(..), Declaration(..), Block(..), Type(..), Literal(..), -    BinaryOperator(..), UnaryOperator(..), Expression(..), Statement(..)) where +module AST where  import Data.List @@ -52,10 +49,10 @@ data UnaryOperator      = Negate | Not | Invert | Dereference | Address    deriving (Show, Eq) -data Expression -    = ExLit Literal -    | ExBinOp BinaryOperator Expression Expression -    | ExUnOp UnaryOperator Expression +data Expression    -- (Maybe Type)'s are type annotations by the type checker +    = ExLit Literal (Maybe Type) +    | ExBinOp BinaryOperator Expression Expression (Maybe Type) +    | ExUnOp UnaryOperator Expression (Maybe Type)    deriving (Show)  data Statement @@ -70,9 +67,24 @@ data Statement    deriving (Show) -indent :: Int -> String -> String -indent sz str = intercalate "\n" $ map (prefix++) $ lines str -    where prefix = replicate sz ' ' +exLit_ :: Literal -> Expression +exLit_ l = ExLit l Nothing + +exBinOp_ :: BinaryOperator -> Expression -> Expression -> Expression +exBinOp_ bo a b = ExBinOp bo a b Nothing + +exUnOp_ :: UnaryOperator -> Expression -> Expression +exUnOp_ uo e = ExUnOp uo e Nothing + +exTypeOf :: Expression -> Maybe Type +exTypeOf (ExLit _ mt) = mt +exTypeOf (ExBinOp _ _ _ mt) = mt +exTypeOf (ExUnOp _ _ mt) = mt + +-- exSetType :: Type -> Expression -> Expression +-- exSetType t (ExLit l _) = ExLit l t +-- exSetType t (ExBinOp bo e1 e2 _) = ExBinOp bo e1 e2 t +-- exSetType t (ExUnOp bo e _) = ExUnOp bo e t  instance PShow Program where @@ -92,6 +104,10 @@ instance PShow Declaration where  instance PShow Block where      pshow (Block []) = "{}"      pshow (Block stmts) = concat ["{\n", indent 4 $ intercalate "\n" (map pshow stmts), "\n}"] +      where +        indent :: Int -> String -> String +        indent sz str = intercalate "\n" $ map (prefix++) $ lines str +            where prefix = replicate sz ' '  instance PShow Type where      pshow (TypeInt sz) = 'i' : pshow sz @@ -130,9 +146,12 @@ instance PShow UnaryOperator where      pshow Address = "&"  instance PShow Expression where -    pshow (ExLit lit) = pshow lit -    pshow (ExBinOp op a b) = concat [pshow a, " ", pshow op, " ", pshow b] -    pshow (ExUnOp op a) = concat [pshow op, pshow a] +    pshow (ExLit lit Nothing) = pshow lit +    pshow (ExLit lit (Just t)) = concat ["(", pshow lit, " :: ", pshow t, ")"] +    pshow (ExBinOp op a b Nothing) = concat ["(", pshow a, " ", pshow op, " ", pshow b, ")"] +    pshow (ExBinOp op a b (Just t)) = concat ["(", pshow a, " ", pshow op, " ", pshow b, " :: ", pshow t, ")"] +    pshow (ExUnOp op a Nothing) = concat [pshow op, pshow a] +    pshow (ExUnOp op a (Just t)) = concat ["(", pshow op, pshow a, " :: ", pshow t, ")"]  instance PShow Statement where      pshow StEmpty = ";" @@ -1,5 +1,7 @@ -module Codegen(module Codegen, A.Module) where +module Codegen(codegen, A.Module) where +import Control.Monad +import Data.Maybe  import qualified Data.Map.Strict as Map  -- import qualified LLVM.General.AST.Type as A  -- import qualified LLVM.General.AST.Global as A @@ -8,8 +10,10 @@ import qualified Data.Map.Strict as Map  -- import qualified LLVM.General.AST.Name as A  -- import qualified LLVM.General.AST.Instruction as A  import qualified LLVM.General.AST as A +import Debug.Trace  import AST +import PShow  type Error a = Either String a @@ -52,52 +56,207 @@ preprocess prog@(Program decls) = mapProgram' filtered mapper  generateDefs :: Program -> Error [A.Definition]  generateDefs prog = do      checkUndefinedTypes prog -    checkUndefinedVars prog -    fail "TODO" +    checked <- typeCheck prog +    collected <- collectVarDecls checked +    void $ trace "Collected:" $ return [] +    void $ trace (pshow collected) $ return [] +    void $ fail "TODO"      return []  checkUndefinedTypes :: Program -> Error ()  checkUndefinedTypes prog = fmap (const ()) $ mapProgram prog $ defaultPM {typeHandler = check} -    where -        check :: Type -> Error Type -        check (TypeName n) = Left $ "Undefined type name '" ++ n ++ "'" -        check t = Right t +  where +    check :: MapperHandler Type +    check (TypeName n) = Left $ "Undefined type name '" ++ n ++ "'" +    check t = Right t + + +typeCheck :: Program -> Error Program +typeCheck (Program decls) = Program <$> mapM (goD topLevelNames) decls +  where +    topLevelNames :: Map.Map Name Type +    topLevelNames = foldr (uncurry Map.insert) Map.empty pairs +        where pairs = map ((,) <$> nameOf <*> typeOf) $ filter isVarDecl decls + +    functionTypes :: Map.Map Name (Type,[Type]) +    functionTypes = foldr (uncurry Map.insert) Map.empty pairs +        where pairs = map ((,) <$> nameOf <*> getTypes) $ filter isFunctionDecl decls +              getTypes (DecFunction rt _ args _) = (rt, map fst args) +              getTypes _ = undefined + +    isVarDecl (DecVariable {}) = True +    isVarDecl _ = False + +    isFunctionDecl (DecFunction {}) = True +    isFunctionDecl _ = False + +    goD :: Map.Map Name Type -> Declaration -> Error Declaration +    goD names (DecFunction frt name args body) = do +        newbody <- goB frt (foldr (\(t,n) m -> Map.insert n t m) names args) body +        return $ DecFunction frt name args newbody +    goD _ dec = return dec + +    goB :: Type  -- function return type +        -> Map.Map Name Type -> Block -> Error Block +    goB frt names (Block stmts) = Block . snd <$> foldl foldfunc (return (names, [])) stmts +      where +        foldfunc :: Error (Map.Map Name Type, [Statement]) -> Statement -> Error (Map.Map Name Type, [Statement]) +        foldfunc ep st = do +            (names', lst) <- ep +            (newnames', newst) <- goS frt names' st +            return (newnames', lst ++ [newst])  -- TODO: fix slow tail-append + +    goS :: Type  -- function return type +        -> Map.Map Name Type -> Statement -> Error (Map.Map Name Type, Statement) +    goS _ names st@(StVarDeclaration t n Nothing) = return (Map.insert n t names, st) +    goS frt names (StVarDeclaration t n (Just e)) = do +        (newnames, _) <- goS frt names (StVarDeclaration t n Nothing) +        goS frt newnames (StAssignment n e) +    goS _ names (StAssignment n e) = maybe (Left $ "Undefined variable '" ++ n ++ "'") go (Map.lookup n names) +      where go dsttype = do +              re <- goE names e +              let (Just extype) = exTypeOf re +              if canConvert extype dsttype +                  then return (names, StAssignment n re) +                  else Left $ "Cannot convert type '" ++ pshow extype ++ "' to '" +                              ++ pshow dsttype ++ "' in assignment to variable '" ++ n ++ "'" +    goS _ names st@StEmpty = return (names, st) +    goS frt names (StBlock bl) = do +        newbl <- goB frt names bl +        return (names, StBlock newbl) +    goS _ names (StExpr e) = do +        re <- goE names e +        return (names, StExpr re) +    goS frt names (StIf e s1 s2) = do +        re <- goE names e +        (_, rs1) <- goS frt names s1 +        (_, rs2) <- goS frt names s2 +        return (names, StIf re rs1 rs2) +    goS frt names (StWhile e s) = do +        re <- goE names e +        (_, rs) <- goS frt names s +        return (names, StWhile re rs) +    goS frt names (StReturn e) = do +        re <- goE names e +        let (Just extype) = exTypeOf re +        if canConvert extype frt +            then return (names, StReturn re) +            else Left $ "Cannot convert type '" ++ pshow extype ++ "' to '" +                        ++ pshow frt ++ "' in return statement" + +    -- Postcondition: the expression (if any) has a type annotation. +    goE :: Map.Map Name Type -> Expression -> Error Expression +    goE _ (ExLit l@(LitInt i) _) = return $ ExLit l $ Just (smallestIntType i) +    goE _ (ExLit l@(LitString _) _) = return $ ExLit l $ Just (TypePtr (TypeInt 8)) +    goE names (ExLit l@(LitVar n) _) = maybe (Left $ "Undefined variable '" ++ n ++ "'") (return . ExLit l . Just) +                                             (Map.lookup n names) +    goE names (ExLit l@(LitCall n args) _) = do +        ft <- maybe (Left $ "Unknown function '" ++ n ++ "'") return $ Map.lookup n functionTypes +        rargs <- mapM (goE names) args +        when (length rargs /= length (snd ft)) +            $ Left ("Expected " ++ show (length (snd ft)) ++ "arguments to " +                    ++ "function '" ++ n ++ "', but got " ++ show (length rargs)) +              >> return () +        flip mapM_ rargs $ +            \a -> let argtype = fromJust (exTypeOf a) +                  in if canConvert argtype (fst ft) +                      then return a +                      else Left $ "Cannot convert type '" ++ pshow argtype ++ "' to '" ++ pshow (fst ft) +                                  ++ "' in call of function '" ++ pshow n ++ "'" +        return $ ExLit l (Just (fst ft)) +    goE names (ExBinOp bo e1 e2 _) = do +        re1 <- goE names e1 +        re2 <- goE names e2 +        maybe (Left $ "Cannot use operator '" ++ pshow bo ++ "' with argument types '" +                      ++ pshow (fromJust $ exTypeOf re1) ++ "' and '" ++ pshow (fromJust $ exTypeOf re2) ++ "'") +              (return . ExBinOp bo re1 re2 . Just) +              $ typeCompatibleBO bo (fromJust $ exTypeOf re1) (fromJust $ exTypeOf re2) +    goE names (ExUnOp uo e _) = do +        re <- goE names e +        maybe (Left $ "Cannot use operator '" ++ pshow uo ++ "' with argument type '" ++ pshow (fromJust $ exTypeOf re)) +              (return . ExUnOp uo re . Just) +              $ typeCompatibleUO uo (fromJust $ exTypeOf re) + + +collectVarDecls :: Program -> Error Program +collectVarDecls prog = mapProgram prog $ defaultPM {blockHandler = goBlock} +  where +    goBlock :: MapperHandler Block +    goBlock (Block stmts) = +        let isVarDecl (StVarDeclaration {}) = True +            isVarDecl _ = False + +            removeDecls [] = [] +            removeDecls ((StVarDeclaration _ n (Just ex)):rest) = StAssignment n ex : removeDecls rest +            removeDecls ((StVarDeclaration _ _ Nothing):rest) = removeDecls rest +            removeDecls (st:rest) = st : removeDecls rest + +            onlyDecl (StVarDeclaration t n _) = StVarDeclaration t n Nothing +            onlyDecl _ = undefined + +            vdecls = map onlyDecl $ filter isVarDecl stmts +        in return $ Block $ vdecls ++ removeDecls stmts + + +canConvert :: Type -> Type -> Bool +canConvert x y | x == y = True +canConvert (TypeInt f) (TypeInt t) = f <= t +canConvert (TypeUInt f) (TypeUInt t) = f <= t +canConvert TypeFloat TypeDouble = True +canConvert _ _ = False + +arithBO, compareBO, logicBO, complogBO :: [BinaryOperator] +arithBO = [Plus, Minus, Times, Divide, Modulo] +compareBO = [Equal, Unequal, Greater, Less, GEqual, LEqual] +logicBO = [BoolAnd, BoolOr] +complogBO = compareBO ++ logicBO + +typeCompatibleBO :: BinaryOperator -> Type -> Type -> Maybe Type +typeCompatibleBO Minus (TypePtr t1) (TypePtr t2) | t1 == t2 = Just $ TypeInt 1 +typeCompatibleBO bo (TypePtr t1) (TypePtr t2) | t1 == t2 && bo `elem` complogBO = Just $ TypeInt 1 +typeCompatibleBO _ (TypePtr _) _ = Nothing +typeCompatibleBO _ _ (TypePtr _) = Nothing + +typeCompatibleBO bo (TypeInt s1) (TypeInt s2) | bo `elem` arithBO = Just $ TypeInt (max s1 s2) +typeCompatibleBO bo (TypeInt _) (TypeInt _) | bo `elem` complogBO = Just $ TypeInt 1 --- checkUndefinedVars :: Program -> Error () --- checkUndefinedVars prog = do +typeCompatibleBO bo (TypeUInt s1) (TypeUInt s2) | bo `elem` arithBO = Just $ TypeUInt (max s1 s2) +typeCompatibleBO bo (TypeUInt _) (TypeUInt _) | bo `elem` complogBO = Just $ TypeInt 1 +typeCompatibleBO bo t1 t2 | bo `elem` complogBO && t1 == t2 = Just $ TypeInt 1 --- mapTypes' :: Program -> (Type -> Type) -> Program --- mapTypes' prog f = (\(Right res) -> res) $ mapTypes prog (return . f) +typeCompatibleBO bo TypeFloat (TypeInt s) | s <= 24  = Just $ if bo `elem` arithBO then TypeFloat else TypeInt 1 +typeCompatibleBO bo (TypeInt s) TypeFloat | s <= 24  = Just $ if bo `elem` arithBO then TypeFloat else TypeInt 1 +typeCompatibleBO bo TypeDouble (TypeInt s) | s <= 53 = Just $ if bo `elem` arithBO then TypeDouble else TypeInt 1 +typeCompatibleBO bo (TypeInt s) TypeDouble | s <= 53 = Just $ if bo `elem` arithBO then TypeDouble else TypeInt 1 +typeCompatibleBO bo TypeFloat TypeDouble = Just $ if bo `elem` arithBO then TypeFloat else TypeInt 1 +typeCompatibleBO bo TypeDouble TypeFloat = Just $ if bo `elem` arithBO then TypeFloat else TypeInt 1 --- mapTypes :: Program -> (Type -> Error Type) -> Error Program --- mapTypes (Program decls) f = Program <$> sequence (map goD decls) ---     where ---         handler :: Type -> Error Type ---         handler (TypePtr t) = f t >>= f . TypePtr ---         handler t = f t +typeCompatibleBO _ _ _ = Nothing ---         goD :: Declaration -> Error Declaration ---         goD (DecFunction t n a b) = do ---             rt <- handler t ---             ra <- sequence $ map (\(at,an) -> (\art -> (art,an)) <$> handler at) a ---             rb <- goB b ---             return $ DecFunction rt n ra rb ---         goD (DecVariable t n v) = (\rt -> DecVariable rt n v) <$> handler t ---         goD (DecTypedef t n) = (\rt -> DecTypedef rt n) <$> handler t +typeCompatibleUO :: UnaryOperator -> Type -> Maybe Type +typeCompatibleUO Not _ = Just $ TypeInt 1 +typeCompatibleUO Address t = Just $ TypePtr t +typeCompatibleUO uo t@(TypeInt _) | uo `elem` [Negate, Invert] = Just t +typeCompatibleUO uo t@(TypeUInt _) | uo `elem` [Negate, Invert] = Just t +typeCompatibleUO Negate TypeFloat = Just TypeFloat +typeCompatibleUO Negate TypeDouble = Just TypeDouble +typeCompatibleUO Dereference t@(TypePtr _) = Just t +typeCompatibleUO _ _ = Nothing ---         goB :: Block -> Error Block ---         goB (Block stmts) = Block <$> sequence (map goS stmts) +smallestIntType :: Integer -> Type +smallestIntType i +    | i >= -2^7 && i < 2^7 = TypeInt 8 +    | i >= -2^15 && i < 2^15 = TypeInt 16 +    | i >= -2^31 && i < 2^31 = TypeInt 32 +    | otherwise = TypeInt 64 ---         goS :: Statement -> Error Statement ---         goS (StBlock bl) = StBlock <$> goB bl ---         goS (StVarDeclaration t n e) = (\rt -> StVarDeclaration rt n e) <$> handler t ---         goS (StIf c t e) = do ---             rt <- goS t ---             re <- goS e ---             return $ StIf c rt re ---         goS (StWhile c b) = StWhile c <$> goS b ---         goS s = return s +-- smallestUIntType :: Integer -> Type +-- smallestUIntType i +--     | i >= 0 && i < 2^8 = TypeUInt 8 +--     | i >= 0 && i < 2^16 = TypeUInt 16 +--     | i >= 0 && i < 2^32 = TypeUInt 32 +--     | otherwise = TypeUInt 64  type MapperHandler a = a -> Error a @@ -135,14 +294,14 @@ defaultPM' = ProgramMapper' id id id id id id id id id  mapProgram' :: Program -> ProgramMapper' -> Program  mapProgram' prog mapper = (\(Right r) -> r) $ mapProgram prog $ ProgramMapper      {declarationHandler = return . declarationHandler' mapper -    ,blockHandler = return . blockHandler' mapper -    ,typeHandler = return . typeHandler' mapper -    ,literalHandler = return . literalHandler' mapper -    ,binOpHandler = return . binOpHandler' mapper -    ,unOpHandler = return . unOpHandler' mapper -    ,expressionHandler = return . expressionHandler' mapper -    ,statementHandler = return . statementHandler' mapper -    ,nameHandler = return . nameHandler' mapper} +    ,blockHandler       = return . blockHandler' mapper +    ,typeHandler        = return . typeHandler' mapper +    ,literalHandler     = return . literalHandler' mapper +    ,binOpHandler       = return . binOpHandler' mapper +    ,unOpHandler        = return . unOpHandler' mapper +    ,expressionHandler  = return . expressionHandler' mapper +    ,statementHandler   = return . statementHandler' mapper +    ,nameHandler        = return . nameHandler' mapper}  mapProgram :: Program -> ProgramMapper -> Error Program  mapProgram prog mapper = goP prog @@ -157,10 +316,10 @@ mapProgram prog mapper = goP prog          h_s = statementHandler mapper          h_n = nameHandler mapper -        goP :: Program -> Error Program +        goP :: MapperHandler Program          goP (Program decls) = Program <$> sequence (map (\d -> goD d >>= h_d) decls) -        goD :: Declaration -> Error Declaration +        goD :: MapperHandler Declaration          goD (DecFunction t n a b) = do              rt <- goT t              rn <- goN n @@ -177,30 +336,32 @@ mapProgram prog mapper = goP prog              rn <- goN n              h_d $ DecTypedef rt rn -        goT :: Type -> Error Type +        goT :: MapperHandler Type          goT (TypePtr t) = goT t >>= (h_t . TypePtr)          goT (TypeName n) = goN n >>= (h_t . TypeName)          goT t = h_t t -        goN :: Name -> Error Name +        goN :: MapperHandler Name          goN = h_n -        goB :: Block -> Error Block +        goB :: MapperHandler Block          goB (Block sts) = (Block <$> sequence (map goS sts)) >>= h_b -        goE :: Expression -> Error Expression -        goE (ExLit l) = goL l >>= (h_e . ExLit) -        goE (ExBinOp bo e1 e2) = do +        goE :: MapperHandler Expression +        goE (ExLit l mt) = do +            rl <- goL l +            h_e $ ExLit rl mt +        goE (ExBinOp bo e1 e2 mt) = do              rbo <- goBO bo              re1 <- goE e1              re2 <- goE e2 -            h_e $ ExBinOp rbo re1 re2 -        goE (ExUnOp uo e) = do +            h_e $ ExBinOp rbo re1 re2 mt +        goE (ExUnOp uo e mt) = do              ruo <- goUO uo              re <- goE e -            h_e $ ExUnOp ruo re +            h_e $ ExUnOp ruo re mt -        goS :: Statement -> Error Statement +        goS :: MapperHandler Statement          goS StEmpty = h_s StEmpty          goS (StBlock b) = goB b >>= (h_s . StBlock)          goS (StExpr e) = goE e >>= (h_s . StExpr) @@ -224,15 +385,17 @@ mapProgram prog mapper = goP prog              h_s $ StWhile re rs          goS (StReturn e) = goE e >>= (h_s . StReturn) -        goL :: Literal -> Error Literal +        goL :: MapperHandler Literal +        goL l@(LitString _) = h_l l +        goL l@(LitInt _) = h_l l          goL (LitVar n) = goN n >>= (h_l . LitVar)          goL (LitCall n a) = do              rn <- goN n              ra <- sequence $ map goE a              h_l $ LitCall rn ra -        goBO :: BinaryOperator -> Error BinaryOperator +        goBO :: MapperHandler BinaryOperator          goBO = h_bo -        goUO :: UnaryOperator -> Error UnaryOperator +        goUO :: MapperHandler UnaryOperator          goUO = h_uo @@ -33,6 +33,7 @@ main = do      when (isLeft parseResult) $ dieShow $ fromLeft parseResult      let ast = fromRight parseResult -    pprint ast +    -- print ast +    putStrLn $ pshow ast      either die print $ codegen ast "Module" fname @@ -85,14 +85,14 @@ exprTable =       [binary "&&" BoolAnd E.AssocLeft,        binary "||" BoolOr E.AssocLeft]]    where -    binary name op assoc = E.Infix (ExBinOp op <$ symbol name) assoc -    prefix name op = E.Prefix (ExUnOp op <$ symbol name) +    binary name op assoc = E.Infix (exBinOp_ op <$ symbol name) assoc +    prefix name op = E.Prefix (exUnOp_ op <$ symbol name)  pExpression :: Parser Expression  pExpression = E.buildExpressionParser exprTable pExLit  pExLit :: Parser Expression -pExLit = ExLit <$> pLiteral +pExLit = exLit_ <$> pLiteral  pLiteral :: Parser Literal  pLiteral = (LitInt <$> pInteger) <|> (LitString <$> pString) @@ -109,7 +109,7 @@ pLitCall = do  pStatement :: Parser Statement  pStatement = pStEmpty <|> pStIf <|> pStWhile <|> pStReturn <|> pStBlock -             <|> try pStAssignment <|> pStVarDeclaration <|> pStExpr +             <|> try pStAssignment <|> try pStVarDeclaration <|> pStExpr  pStEmpty :: Parser Statement  pStEmpty = symbol ";" >> return StEmpty @@ -165,7 +165,7 @@ pStReturn = do  primitiveTypes :: Map.Map String Type  primitiveTypes = Map.fromList -    [("i8", TypeInt 8), ("i16", TypeInt 16), ("i32", TypeInt 32), ("i64", TypeInt 64), +    [("i1", TypeInt 1), ("i8", TypeInt 8), ("i16", TypeInt 16), ("i32", TypeInt 32), ("i64", TypeInt 64),       ("u8", TypeUInt 8), ("u16", TypeUInt 16), ("u32", TypeUInt 32), ("u64", TypeUInt 64),       ("float", TypeFloat), ("double", TypeDouble)]  | 
