summaryrefslogtreecommitdiff
path: root/Optimiser.hs
blob: a6c5e103fc25fde0f3ebd2e69edc66f8ba7f277d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
{-# LANGUAGE TupleSections #-}
module Optimiser(optimise) where

import Data.List
import qualified Data.Map.Strict as Map
import Data.Maybe
import qualified Data.Set as Set

import AST (Name)
import Intermediate


optimise :: IRProgram -> IRProgram
optimise (IRProgram bbs gfds datas) =
    let optf = foldl (.) id
                     [ tailCallIntro
                     , deadBBElim gfds, mergeRets
                     , deadStoreElim, deadBBElim gfds
                     , map propAssigns
                     , mergeRets, mergeBlocks]
    in IRProgram (optf bbs) gfds datas

mergeBlocks :: [BB] -> [BB]
mergeBlocks [] = []
mergeBlocks allbbs@(BB startb _ _ : _) =
    uncurry (++) (partition ((== startb) . bidOf) (go allbbs (length allbbs)))
  where
    go [] _ = []
    go bbs 0 = bbs
    go (bb@(BB bid inss term) : bbs) n = case partition (hasJumpTo bid . termOf) bbs of
        ([], _) -> go (bbs ++ [bb]) (n - 1)
        ([BB bid' inss' _], rest) -> go (BB bid' (inss' ++ inss) term : rest) n
        _ -> go (bbs ++ [bb]) (n - 1)

    hasJumpTo bid (IJmp a) = a == bid
    hasJumpTo _ _ = False

mergeRets :: [BB] -> [BB]
mergeRets bbs =
    let rets = Map.fromList [(bid, ret) | BB bid [] ret@(IRet _) <- bbs]
    in [case bb of
            BB bid inss (IJmp target) | Just ret <- Map.lookup target rets ->
                BB bid inss ret
            _ ->
                bb
       | bb <- bbs]

propAssigns :: BB -> BB
propAssigns (BB bid inss term) =
    let (state, inss') = mapFoldl propagateI Map.empty inss
        term' = propagateT state term
    in BB bid inss' term'
  where
    propagateI mp (d@(RTemp i), IAssign r) = let r' = propR mp r
                                             in (Map.insert i r' mp, (d, IAssign r'))
    propagateI mp (d, IAssign r)      = (mp, (d, IAssign (propR mp r)))
    propagateI mp ins@(_, IParam _)   = (mp, ins)
    propagateI mp ins@(_, IClosure _) = (mp, ins)
    propagateI mp ins@(_, IData _)    = (mp, ins)
    propagateI mp (d, ICallC r rs)    = (Map.empty, (d, ICallC (propR mp r) (map (propR mp) rs)))
    propagateI mp (d, IAllocClo n rs) = (mp, (d, IAllocClo n (map (propR mp) rs)))
    propagateI mp (d, IDiscard r)     = (mp, (d, IDiscard (propR mp r)))

    propagateT mp (IBr r a b)   = IBr (propR mp r) a b
    propagateT _  t@(IJmp _)    = t
    propagateT mp (IRet r)      = IRet (propR mp r)
    propagateT mp (ITailC r rs) = ITailC (propR mp r) (map (propR mp) rs)
    propagateT _  t@IExit       = t
    propagateT _  t@IUnknown    = t

    propR mp ref@(RTemp i) = fromMaybe ref (Map.lookup i mp)
    propR _ ref = ref

deadBBElim :: Map.Map Name GlobFuncDef -> [BB] -> [BB]
deadBBElim gfds bbs =
    let callable = 0 : [bid | GlobFuncDef bid _ _ <- Map.elems gfds]
        jumpable = concatMap outEdges bbs
        reachable = Set.fromList (jumpable ++ callable)
    in filter (\bb -> bidOf bb `Set.member` reachable) bbs

deadStoreElim :: [BB] -> [BB]
deadStoreElim bbs = [BB bid (filter (not . shouldRemove) inss) term | BB bid inss term <- bbs]
  where
    readtemps = Set.fromList (concatMap readTempsBB bbs)
    alltemps = readtemps <> Set.fromList (concatMap writtenTempsBB bbs)
    elim = alltemps Set.\\ readtemps

    shouldRemove :: Instruction -> Bool
    shouldRemove (RNone, IDiscard RNone) = True
    shouldRemove (RNone, IDiscard (RTemp i)) = i `Set.member` elim
    shouldRemove (RTemp i, ins) = pureIC ins && i `Set.member` elim
    shouldRemove _ = False

    pureIC :: InsCode -> Bool
    pureIC (IAssign _)     = True
    pureIC (IParam _)      = True
    pureIC (IClosure _)    = True
    pureIC (IData _)       = True
    pureIC (IAllocClo _ _) = True
    pureIC (ICallC _ _)    = False
    pureIC (IDiscard _)    = False

tailCallIntro :: [BB] -> [BB]
tailCallIntro bbs = map introduce bbs
  where
    readInBB = map (Set.fromList . readTempsBB) bbs
    readBefore = init $ scanl (<>) Set.empty readInBB
    readAfter = tail $ scanr (<>) Set.empty readInBB
    readInOthers = Map.fromList [(bid, before <> after)
                                | (BB bid _ _, before, after) <- zip3 bbs readBefore readAfter]

    introduce orig@(BB _ [] _) = orig
    introduce orig@(BB bid inss@(_:_) term) =
        case (last inss, term) of
            ((RTemp i1, ICallC cl as), IRet (RTemp i2))
                | i1 == i2
                , i1 `Set.notMember` (readInOthers Map.! bid)
                , i1 `notElem` concatMap (readTempsIC . snd) (init inss) ->
                    BB bid (init inss) (ITailC cl as)
            _ -> orig

outEdges :: BB -> [Int]
outEdges (BB _ _ term) = outEdgesT term

outEdgesT :: Terminator -> [Int]
outEdgesT (IBr _ a b) = [a, b]
outEdgesT (IJmp a) = [a]
outEdgesT (IRet _) = []
outEdgesT (ITailC _ _) = []
outEdgesT IExit = []
outEdgesT IUnknown = []

readTempsBB :: BB -> [Int]
readTempsBB (BB _ inss term) = concatMap (readTempsIC . snd) inss ++ readTempsT term

writtenTempsBB :: BB -> [Int]
writtenTempsBB (BB _ inss _) = concatMap (readTempsR . fst) inss

readTempsIC :: InsCode -> [Int]
readTempsIC (IAssign r) = readTempsR r
readTempsIC (IParam _) = []
readTempsIC (IClosure _) = []
readTempsIC (IData _) = []
readTempsIC (ICallC r rs) = readTempsR r ++ concatMap readTempsR rs
readTempsIC (IAllocClo _ rs) = concatMap readTempsR rs
readTempsIC (IDiscard _) = []

readTempsT :: Terminator -> [Int]
readTempsT (IBr r _ _) = readTempsR r
readTempsT (IJmp _) = []
readTempsT (IRet r) = readTempsR r
readTempsT (ITailC r rs) = readTempsR r ++ concatMap readTempsR rs
readTempsT IExit = []
readTempsT IUnknown = []

readTempsR :: Ref -> [Int]
readTempsR (RConst _) = []
readTempsR (RTemp i) = [i]
readTempsR (RSClo _) = []
readTempsR RNone = []

mapFoldl :: (s -> a -> (s, b)) -> s -> [a] -> (s, [b])
mapFoldl f s = fmap reverse . foldl' (\(s', yet) x -> fmap (: yet) (f s' x)) (s, [])