1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
|
{-# LANGUAGE TupleSections #-}
module Optimiser(optimise) where
import Data.List
import qualified Data.Map.Strict as Map
import Data.Maybe
import qualified Data.Set as Set
import AST (Name)
import Intermediate
optimise :: IRProgram -> IRProgram
optimise (IRProgram bbs gfds datas) =
let optf = foldl (.) id
[ tailCallIntro
, deadBBElim gfds, mergeRets
, deadStoreElim, deadBBElim gfds
, map propAssigns
, mergeRets, mergeBlocks]
in IRProgram (optf bbs) gfds datas
mergeBlocks :: [BB] -> [BB]
mergeBlocks [] = []
mergeBlocks allbbs@(BB startb _ _ : _) =
uncurry (++) (partition ((== startb) . bidOf) (go allbbs (length allbbs)))
where
go [] _ = []
go bbs 0 = bbs
go (bb@(BB bid inss term) : bbs) n = case partition (hasJumpTo bid . termOf) bbs of
([], _) -> go (bbs ++ [bb]) (n - 1)
([BB bid' inss' _], rest) -> go (BB bid' (inss' ++ inss) term : rest) n
_ -> go (bbs ++ [bb]) (n - 1)
hasJumpTo bid (IJmp a) = a == bid
hasJumpTo _ _ = False
mergeRets :: [BB] -> [BB]
mergeRets bbs =
let rets = Map.fromList [(bid, ret) | BB bid [] ret@(IRet _) <- bbs]
in [case bb of
BB bid inss (IJmp target) | Just ret <- Map.lookup target rets ->
BB bid inss ret
_ ->
bb
| bb <- bbs]
propAssigns :: BB -> BB
propAssigns (BB bid inss term) =
let (state, inss') = mapFoldl propagateI Map.empty inss
term' = propagateT state term
in BB bid inss' term'
where
propagateI mp (d@(RTemp i), IAssign r) = let r' = propR mp r
in (Map.insert i r' mp, (d, IAssign r'))
propagateI mp (d, IAssign r) = (mp, (d, IAssign (propR mp r)))
propagateI mp ins@(_, IParam _) = (mp, ins)
propagateI mp ins@(_, IClosure _) = (mp, ins)
propagateI mp ins@(_, IData _) = (mp, ins)
propagateI mp (d, ICallC r rs) = (Map.empty, (d, ICallC (propR mp r) (map (propR mp) rs)))
propagateI mp (d, IAllocClo n rs) = (mp, (d, IAllocClo n (map (propR mp) rs)))
propagateI mp (d, IDiscard r) = (mp, (d, IDiscard (propR mp r)))
propagateT mp (IBr r a b) = IBr (propR mp r) a b
propagateT _ t@(IJmp _) = t
propagateT mp (IRet r) = IRet (propR mp r)
propagateT mp (ITailC r rs) = ITailC (propR mp r) (map (propR mp) rs)
propagateT _ t@IExit = t
propagateT _ t@IUnknown = t
propR mp ref@(RTemp i) = fromMaybe ref (Map.lookup i mp)
propR _ ref = ref
deadBBElim :: Map.Map Name GlobFuncDef -> [BB] -> [BB]
deadBBElim gfds bbs =
let callable = 0 : [bid | GlobFuncDef bid _ _ <- Map.elems gfds]
jumpable = concatMap outEdges bbs
reachable = Set.fromList (jumpable ++ callable)
in filter (\bb -> bidOf bb `Set.member` reachable) bbs
deadStoreElim :: [BB] -> [BB]
deadStoreElim bbs = [BB bid (filter (not . shouldRemove) inss) term | BB bid inss term <- bbs]
where
readtemps = Set.fromList (concatMap readTempsBB bbs)
alltemps = readtemps <> Set.fromList (concatMap writtenTempsBB bbs)
elim = alltemps Set.\\ readtemps
shouldRemove :: Instruction -> Bool
shouldRemove (RNone, IDiscard RNone) = True
shouldRemove (RNone, IDiscard (RTemp i)) = i `Set.member` elim
shouldRemove (RTemp i, ins) = pureIC ins && i `Set.member` elim
shouldRemove _ = False
pureIC :: InsCode -> Bool
pureIC (IAssign _) = True
pureIC (IParam _) = True
pureIC (IClosure _) = True
pureIC (IData _) = True
pureIC (IAllocClo _ _) = True
pureIC (ICallC _ _) = False
pureIC (IDiscard _) = False
tailCallIntro :: [BB] -> [BB]
tailCallIntro bbs = map introduce bbs
where
readInBB = map (Set.fromList . readTempsBB) bbs
readBefore = init $ scanl (<>) Set.empty readInBB
readAfter = tail $ scanr (<>) Set.empty readInBB
readInOthers = Map.fromList [(bid, before <> after)
| (BB bid _ _, before, after) <- zip3 bbs readBefore readAfter]
introduce orig@(BB _ [] _) = orig
introduce orig@(BB bid inss@(_:_) term) =
case (last inss, term) of
((RTemp i1, ICallC cl as), IRet (RTemp i2))
| i1 == i2
, i1 `Set.notMember` (readInOthers Map.! bid)
, i1 `notElem` concatMap (readTempsIC . snd) (init inss) ->
BB bid (init inss) (ITailC cl as)
_ -> orig
outEdges :: BB -> [Int]
outEdges (BB _ _ term) = outEdgesT term
outEdgesT :: Terminator -> [Int]
outEdgesT (IBr _ a b) = [a, b]
outEdgesT (IJmp a) = [a]
outEdgesT (IRet _) = []
outEdgesT (ITailC _ _) = []
outEdgesT IExit = []
outEdgesT IUnknown = []
readTempsBB :: BB -> [Int]
readTempsBB (BB _ inss term) = concatMap (readTempsIC . snd) inss ++ readTempsT term
writtenTempsBB :: BB -> [Int]
writtenTempsBB (BB _ inss _) = concatMap (readTempsR . fst) inss
readTempsIC :: InsCode -> [Int]
readTempsIC (IAssign r) = readTempsR r
readTempsIC (IParam _) = []
readTempsIC (IClosure _) = []
readTempsIC (IData _) = []
readTempsIC (ICallC r rs) = readTempsR r ++ concatMap readTempsR rs
readTempsIC (IAllocClo _ rs) = concatMap readTempsR rs
readTempsIC (IDiscard _) = []
readTempsT :: Terminator -> [Int]
readTempsT (IBr r _ _) = readTempsR r
readTempsT (IJmp _) = []
readTempsT (IRet r) = readTempsR r
readTempsT (ITailC r rs) = readTempsR r ++ concatMap readTempsR rs
readTempsT IExit = []
readTempsT IUnknown = []
readTempsR :: Ref -> [Int]
readTempsR (RConst _) = []
readTempsR (RTemp i) = [i]
readTempsR (RSClo _) = []
readTempsR RNone = []
mapFoldl :: (s -> a -> (s, b)) -> s -> [a] -> (s, [b])
mapFoldl f s = fmap reverse . foldl' (\(s', yet) x -> fmap (: yet) (f s' x)) (s, [])
|