aboutsummaryrefslogtreecommitdiff
path: root/Optimiser.hs
blob: 6e6227c12ff78c9b609b02e41ae517daa508a234 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
module Optimiser(optimise) where

import Data.List
import Data.Maybe
import qualified Data.Map.Strict as Map
import Debug.Trace

import Defs
import Intermediate
import ReplaceRefs
import Utils


type Optimisation = IRProgram -> IRProgram
type FuncOptimisation = IRFunc -> IRFunc

optimise :: IRProgram -> Error IRProgram
optimise prog =
    let master = foldl1 (.) (reverse optimisations) {-. trace "-- OPT PASS --"-}
        reslist = iterate master prog
        pairs = zip reslist (tail reslist)
    in Right $ fst $ fromJust $ find (uncurry (==)) pairs
  where
    optimisations = map funcopt $
        -- [chainJumps, removeUnusedBlocks]
        [chainJumps, mergeTerminators, looseJumps, removeUnusedBlocks, identityOps,
         constantPropagate, removeNops, movPush, evaluateInstructions, evaluateTerminators]


funcopt :: FuncOptimisation -> Optimisation
funcopt fo (IRProgram vars funcs) = IRProgram vars (map fo funcs)


chainJumps :: FuncOptimisation
chainJumps (IRFunc rt name al bbs sid) = IRFunc rt name al bbs' sid
  where
    bbs' = snd $ last $ takeWhile fst $ iterate (mergeChain . snd) (True, bbs)
    
    mergeChain :: [BB] -> (Bool, [BB])
    mergeChain [] = (False, [])
    mergeChain bbs2 = case findIndex isSuitable bbs2 of
        Nothing -> (False, bbs2)
        Just idx ->
            let (BB bid1 inss1 (IJmp target), rest) =
                    (bbs2 !! idx, take idx bbs2 ++ drop (idx+1) bbs2)
                [BB _ inss2 term2] = filter (\(BB bid _ _) -> bid == target) rest
                merged = BB bid1 (inss1 ++ inss2) term2
            in (True, merged : rest)
      where
        hasJmpTo :: Id -> BB -> Bool
        hasJmpTo i (BB _ _ (IJmp i')) = i == i'
        hasJmpTo _ _ = False

        isSuitable :: BB -> Bool
        isSuitable (BB _ _ (IJmp target)) = sum (map (fromEnum . hasJmpTo target) bbs2) == 1
        isSuitable _ = False

mergeTerminators :: FuncOptimisation
mergeTerminators (IRFunc rt name al bbs sid) = IRFunc rt name al bbs' sid
  where
    bbs' = flip map bbs $ \bb@(BB bid inss term) -> case term of
        IJmp i -> case find ((== i) . fst) singles of
            Just (_, t) -> BB bid inss t
            Nothing -> bb
        _ -> bb

    singles = map (\(BB i _ t) -> (i, t)) $ filter (\(BB _ inss _) -> null inss) bbs

looseJumps :: FuncOptimisation
looseJumps (IRFunc rt name al bbs sid) = IRFunc rt name al bbs' sid
  where
    bbs' = flip map bbs $ \bb@(BB bid inss term) -> case term of
        IJmp i -> BB bid inss (IJmp (translate i))
        IJcc ct r1 r2 i j -> BB bid inss (IJcc ct r1 r2 (translate i) (translate j))
        _ -> bb

    translate i = fromMaybe i $ Map.lookup i transmap

    transmap = Map.fromList $ catMaybes $ flip map bbs $ \bb -> case bb of
        BB bid [] (IJmp i) -> Just (bid, i)
        _ -> Nothing

removeUnusedBlocks :: FuncOptimisation
removeUnusedBlocks (IRFunc rt name al bbs sid) = IRFunc rt name al bbs' sid
  where
    bbs' = filter isReachable bbs

    isReachable :: BB -> Bool
    isReachable (BB bid _ _)
        | bid == sid = True
        | otherwise = isJust $ flip find bbs $ \(BB _ _ term) -> case term of
            IJcc _ _ _ i1 i2 -> i1 == bid || i2 == bid
            IJmp i -> i == bid
            _ -> False

identityOps :: FuncOptimisation
identityOps (IRFunc rt name al bbs sid) = IRFunc rt name al (map go bbs) sid
  where
    go :: BB -> BB
    go (BB bid inss term) = BB bid (catMaybes $ map goI inss) term

    goI :: IRIns -> Maybe IRIns
    goI (IAri AAdd _ (Constant _ 0)) = Nothing
    goI (IAri ASub _ (Constant _ 0)) = Nothing
    goI (IAri AMul _ (Constant _ 1)) = Nothing
    goI (IAri ADiv _ (Constant _ 1)) = Nothing
    goI i = Just i

constantPropagate :: FuncOptimisation
constantPropagate (IRFunc rt name al bbs sid) = IRFunc rt name al bbs' sid
  where
    alltemps = findAllTemps' bbs
    consttemps = catMaybes $ flip map alltemps $ \ref ->
        let locs = findMutations' bbs ref
            loc = head locs
            ins = insAt bbs loc

            isIMov (IMov _ _) = True
            isIMov _ = False
        in {-trace ("Muts of " ++ show ref ++ ": " ++ show locs ++ ": " ++
                    show (map (insAt bbs) locs)) $-}
           if length locs == 1 && isIMov ins
               then Just (loc, ins)
               else Nothing

    bbs' = case consttemps of
                [] -> bbs
                ((loc, IMov ref value) : _) ->
                    replaceRefsBBList ref value (nopifyInsAt bbs loc)
                _ -> undefined

removeNops :: FuncOptimisation
removeNops (IRFunc rt name al bbs sid) =
    IRFunc rt name al (map go bbs) sid
  where
    go (BB bid inss term) = BB bid (filter (not . isNop) inss) term
    isNop INop = True
    isNop _ = False

movPush :: FuncOptimisation
movPush (IRFunc rt name al bbs sid) = IRFunc rt name al bbs' sid
  where
    bbs' = map goBB bbs

    goBB :: BB -> BB
    goBB (BB bid inss term) = BB bid (go inss) term

    go :: [IRIns] -> [IRIns]
    go [] = []
    go (ins@(IMov d _) : rest) | isJust (find (== d) (findAllRefsInss rest)) = push ins rest
    go (ins : rest) = ins : go rest

    push :: IRIns -> [IRIns] -> [IRIns]
    push mov [] = [mov]
    push mov@(IMov d s) (ins@(IMov d' s') : rest)
        | d' == d = if s' == d then push mov rest else push ins rest
        | otherwise = replaceRefsIns d s ins : push mov rest
    push mov@(IMov d s) (ins@(IResize d' s') : rest)
        | d' == d = if s' == d then push mov rest else push ins rest
        | otherwise = replaceRefsIns d s ins : push mov rest
    push mov@(IMov d s) (ins@(ILoad d' _) : rest)
        | d' == d = mov : ins : go rest
        | otherwise = replaceRefsIns d s ins : push mov rest
    push mov@(IMov d s) (ins@(IAri at d' s') : rest)
        | d' == d = case (s, s') of
            (Constant sza a, Constant szb b)
                | sza == szb -> push (IMov d (Constant sza $ evaluateArith at a b)) rest
                | otherwise -> error $ "Inconsistent sizes in " ++ show mov ++ "; " ++ show ins
            _ -> mov : ins : go rest
        | otherwise = replaceRefsIns d s ins : push mov rest
    push mov@(IMov d s) (ins@(ICallr d' _ _) : rest)
        | d' == d = mov : ins : go rest
        | otherwise = replaceRefsIns d s ins : push mov rest
    push mov@(IMov d s) (ins@(IStore _ _) : rest) = replaceRefsIns d s ins : push mov rest
    push mov@(IMov d s) (ins@(ICall _ _) : rest) = replaceRefsIns d s ins : push mov rest
    push mov (ins@INop : rest) = ins : push mov rest
    push _ _ = undefined

evaluateInstructions :: FuncOptimisation
evaluateInstructions (IRFunc rt name al bbs sid) = IRFunc rt name al (map goBB bbs) sid
  where
    goBB :: BB -> BB
    goBB (BB bid inss term) = BB bid (map goI inss) term

    goI :: IRIns -> IRIns
    goI (IResize ref (Constant _ v)) = IMov ref $ Constant (refSize ref) $ truncValue (refSize ref) v
    goI ins = ins

    truncValue :: Size -> Value -> Value
    truncValue sz v = fromIntegral $ (fromIntegral v :: Integer) `mod` (2 ^ (8 * sz))

evaluateTerminators :: FuncOptimisation
evaluateTerminators (IRFunc rt name al bbs sid) = IRFunc rt name al bbs' sid
  where
    bbs' = map (\(BB bid inss term) -> BB bid inss (go term)) bbs

    go :: IRTerm -> IRTerm
    go term@(IJcc ct (Constant sza a) (Constant szb b) i1 i2)
        | sza /= szb = error $ "Inconsistent sizes in " ++ show term
        | evaluateCmp ct a b = IJmp i1
        | otherwise = IJmp i2
    go term = term


insAt :: [BB] -> (Int, Int) -> IRIns
insAt bbs (i, j) =
    let (BB _ inss _) = bbs !! i
    in inss !! j

nopifyInsAt :: [BB] -> (Int, Int) -> [BB]
nopifyInsAt bbs (i, j) =
    let (pre, BB bid inss term : post) = splitAt i bbs
        (ipre, _ : ipost) = splitAt j inss
    in pre ++ BB bid (ipre ++ INop : ipost) term : post

findMutations :: BB -> Ref -> [Int]
findMutations (BB _ inss _) ref =
    catMaybes $ flip map (zip inss [0..]) $ \(ins, idx) -> case ins of
        (IMov r _) | r == ref -> Just idx
        (IAri _ r _) | r == ref -> Just idx
        (ICallr r _ _) | r == ref -> Just idx
        _ -> Nothing

findMutations' :: [BB] -> Ref -> [(Int, Int)]
findMutations' bbs ref =
    [(i, j) | (bb, i) <- zip bbs [0..], j <- findMutations bb ref]

findAllRefs :: BB -> [Ref]
findAllRefs (BB _ inss _) = findAllRefsInss inss

findAllRefsInss :: [IRIns] -> [Ref]
findAllRefsInss inss = uniq $ sort $ concatMap go inss
  where
    go (IMov a b) = [a, b]
    go (IStore a b) = [a, b]
    go (ILoad a b) = [a, b]
    go (IAri _ a b) = [a, b]
    go (ICall _ al) = al
    go (ICallr a _ al) = a : al
    go (IResize a b) = [a, b]
    go INop = []

-- findAllRefs' :: [BB] -> [Ref]
-- findAllRefs' = uniq . sort . concatMap findAllRefs

findAllTemps :: BB -> [Ref]
findAllTemps bb = flip filter (findAllRefs bb) $ \ref -> case ref of
    (Temp _ _) -> True
    _ -> False

findAllTemps' :: [BB] -> [Ref]
findAllTemps' = concatMap findAllTemps