summaryrefslogtreecommitdiff
path: root/Language/C/Print.hs
blob: 3d7688cb08f010a7dc50171c17ae179efdc8cae1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
{-# LANGUAGE LambdaCase #-}
module Language.C.Print where

import Data.List (intersperse)

import Language.C


-- precedence -> tail -> string
type PrintS = Int -> String -> String

-- | The resulting program will need <stdlib.h> and <stdint.h>.
printProgram :: Program -> PrintS
printProgram (Program defs) = intercalates "\n" (map printFunDef defs)

printFunDef :: FunDef -> PrintS
printFunDef (FunDef attrs rt n as (StExpr ss rete)) =
    printString (if faStatic attrs then "static " else "")
    % printType rt % printString " " % printName n
    % printString "("
    % intercalates ", " [printType t % printString " " % printName an | (t, an) <- as]
    % printString ") {\n  "
    % addIndent 2 (intercalates "\n" (map printStmt ss))
    % printString "\n  return (" % printExpr rete % printString ");\n}\n"
printFunDef (ProcDef attrs n as ss) =
    printString (if faStatic attrs then "static " else "")
    % printString "void " % printName n
    % printString "("
    % intercalates ", " [printType t % printString " " % printName an | (t, an) <- as]
    % printString ") " % printBlock ss % printString "\n"

printName :: Name -> PrintS
printName (Name s) = printString s

printType :: Type -> PrintS
printType = printString . showType
  where
    showType :: Type -> String
    showType = \case
        TInt b -> "int" ++ printBits b ++ "_t"
        TUInt b -> "uint" ++ printBits b ++ "_t"
        TFloat -> "float"
        TDouble -> "double"
        TPtr t -> showType t ++ "*"
      where
        printBits :: Bits -> String
        printBits B8 = "8"
        printBits B16 = "16"
        printBits B32 = "32"
        printBits B64 = "64"

printBlock :: [Stmt] -> PrintS
printBlock ss = printBlock' (intercalates "\n" (map printStmt ss))

printBlock' :: PrintS -> PrintS
printBlock' body =
    getIndent $ \d ->
        printString ("{\n" ++ replicate (d + 2) ' ')
        % addIndent 2 body
        % printString "\n}"

printStmt :: Stmt -> PrintS
printStmt (SDecl ty name rhs) =
    printType ty % printString " "
    % printName name
    % maybe (printString "") (\e -> printString " = " % printExpr e) rhs
    % printString ";"
printStmt (SAsg name e) =
    printName name % printString " = " % printExpr e % printString ";"
printStmt (SStore name idx val) =
    printName name % printString "[" % printExpr idx % printString "] = " % printExpr val % printString ";"
printStmt (SCall name args) =
    printName name % printString "(" % intercalates ", " (map printExpr args) % printString ");"
printStmt (SFor ty name lo hi body) =
    printString "for ("
    % printType ty % printString " " % printName name % printString " = " % printExpr lo
    % printString "; "
    % printName name % printString " < (" % printExpr hi % printString ")"
    % printString "; "
    % printName name % printString "++) "
    % printBlock body
printStmt (SWhile [] cond body) =
    printString "while (" % printExpr cond % printString ") " % printBlock body
printStmt (SWhile pre cond post) =
    printString "while (true) "
    % printBlock' (intercalates "\n" (map printStmt pre)
                   % printString "\nif (" % printExpr cond % printString ") break;\n"
                   % intercalates "\n" (map printStmt post))
printStmt (SIf e b1 b2) =
    printString "if (" % printExpr e % printString ") "
    % printBlock b1 % printString " else " % printBlock b2

printExpr :: Expr -> PrintS
printExpr (EOp e1 op e2) =
    printString "(" % printExpr e1 % printString (") " ++ op ++ " (")
    % printExpr e2 % printString ")"
printExpr (ENot e) =
    printString "!(" % printExpr e % printString ")"
printExpr (ELit s) = printString ("(" ++ s ++ ")")
printExpr (EVar name) = printName name
printExpr (ECall name args) =
    printName name % printString "("
    % intercalates ", " (map printExpr args) % printString ")"
printExpr (EIndex name e) = printName name % printString "[" % printExpr e % printString "]"
printExpr (EPtrTo e) = printString "&(" % printExpr e % printString ")"
printExpr (ESizeOf t) = printString "(sizeof (" % printType t % printString "))"
printExpr (ECast t e) = printString "(" % printType t % printString ")(" % printExpr e % printString ")"


addIndent :: Int -> PrintS -> PrintS
addIndent plusd f d = f (d + plusd)

withIndent :: Int -> PrintS -> PrintS
withIndent d f _ = f d

getIndent :: (Int -> PrintS) -> PrintS
getIndent f d = f d d

intercalates :: String -> [PrintS] -> PrintS
intercalates sep l = foldr (%) (\_ -> id) $ intersperse (printString sep) l

(%) :: PrintS -> PrintS -> PrintS
f % g = \d -> f d . g d

printString :: String -> PrintS
printString "" _ rest = rest
printString ('\n' : s@('\n' : _)) d rest = '\n' : printString s d rest
printString ('\n' : s) d rest = '\n' : replicate d ' ' ++ printString s d rest
printString s d rest =
    let (pre, post) = span (/= '\n') s
    in pre ++ printString post d rest