summaryrefslogtreecommitdiff
path: root/Language/C/Print.hs
blob: e075b0e7d31cfa8393b57c9a5f7dc3a933c134a3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
{-# LANGUAGE LambdaCase #-}
module Language.C.Print where

import Data.List (intersperse)

import Language.C


-- precedence -> tail -> string
type PrintS = Int -> String -> String

-- | The resulting program will need <stdlib.h> and <stdint.h>.
printProgram :: Program -> PrintS
printProgram (Program defs) = intercalates "\n" (map printFunDef defs)

printFunDef :: FunDef -> PrintS
printFunDef (FunDef rt n as (StExpr ss rete)) =
    printType rt % printString " " % printName n
    % printString "("
    % intercalates ", " [printType t % printString " " % printName an | (t, an) <- as]
    % printString ") {\n  "
    % addIndent 2 (intercalates "\n" (map printStmt ss))
    % printString "\n  return (" % printExpr rete % printString ");\n}"
printFunDef (ProcDef n as ss) =
    printString "void " % printName n
    % printString "("
    % intercalates ", " [printType t % printString " " % printName an | (t, an) <- as]
    % printString ") " % printBlock ss

printName :: Name -> PrintS
printName (Name s) = printString s

printType :: Type -> PrintS
printType = printString . showType
  where
    showType :: Type -> String
    showType = \case
        TInt b -> "int" ++ printBits b ++ "_t"
        TUInt b -> "uint" ++ printBits b ++ "_t"
        TFloat -> "float"
        TDouble -> "double"
        TPtr t -> showType t ++ "*"
      where
        printBits :: Bits -> String
        printBits B8 = "8"
        printBits B16 = "16"
        printBits B32 = "32"
        printBits B64 = "64"

printBlock :: [Stmt] -> PrintS
printBlock ss =
    printString "{\n  "
    % addIndent 2 (intercalates "\n" (map printStmt ss))
    % printString "\n}"

printStmt :: Stmt -> PrintS
printStmt (SDecl ty name rhs) =
    printType ty % printString " "
    % printName name
    % maybe (printString "") (\e -> printString " = " % printExpr e) rhs
    % printString ";"
printStmt (SAsg name e) =
    printName name % printString " = " % printExpr e % printString ";"
printStmt (SStore name idx val) =
    printName name % printString "[" % printExpr idx % printString "] = " % printExpr val % printString ";"
printStmt (SCall name args) =
    printName name % printString "(" % intercalates ", " (map printExpr args) % printString ");"
printStmt (SFor ty name lo hi body) =
    printString "for ("
    % printType ty % printString " " % printName name % printString " = " % printExpr lo
    % printString "; "
    % printName name % printString " < (" % printExpr hi % printString ")"
    % printString "; "
    % printName name % printString "++) "
    % printBlock body
printStmt (SIf e b1 b2) =
    printString "if (" % printExpr e % printString ") "
    % printBlock b1 % printString " else " % printBlock b2

printExpr :: Expr -> PrintS
printExpr (EOp e1 op e2) =
    printString "(" % printExpr e1 % printString (") " ++ op ++ " (")
    % printExpr e2 % printString ")"
printExpr (ENot e) =
    printString "!(" % printExpr e % printString ")"
printExpr (ELit s) = printString ("(" ++ s ++ ")")
printExpr (EVar name) = printName name
printExpr (ECall name args) =
    printName name % printString "("
    % intercalates ", " (map printExpr args) % printString ")"
printExpr (EIndex name e) = printName name % printString "[" % printExpr e % printString "]"
printExpr (EPtrTo e) = printString "&(" % printExpr e % printString ")"
printExpr (ESizeOf t) = printString "(sizeof (" % printType t % printString "))"


addIndent :: Int -> PrintS -> PrintS
addIndent plusd f d = f (d + plusd)

withIndent :: Int -> PrintS -> PrintS
withIndent d f _ = f d

intercalates :: String -> [PrintS] -> PrintS
intercalates sep l = foldr (%) (\_ -> id) $ intersperse (printString sep) l

(%) :: PrintS -> PrintS -> PrintS
f % g = \d -> f d . g d

printString :: String -> PrintS
printString "" _ rest = rest
printString ('\n' : s@('\n' : _)) d rest = '\n' : printString s d rest
printString ('\n' : s) d rest = '\n' : replicate d ' ' ++ printString s d rest
printString s d rest =
    let (pre, post) = span (/= '\n') s
    in pre ++ printString post d rest