summaryrefslogtreecommitdiff
path: root/Optimiser.hs
blob: 4cff2ea84c2b081002df6d6bd1870aa0e496b649 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
{-# LANGUAGE LambdaCase, TupleSections #-}
module Optimiser(optimise) where

import Data.Function (on)
import Data.List
import qualified Data.Map.Strict as Map
import Data.Maybe
import qualified Data.Set as Set

import AST (Name)
import Intermediate
import Util


optimise :: IRProgram -> IRProgram
optimise (IRProgram bbs gfds datas) =
    let optf = foldl (.) id . reverse $
            [ mergeBlocks, mergeRets
            , map propAssigns, globalPropAssigns
            , deadBBElim gfds, deadStoreElim, mergeRets
            , map propAssigns
            , deadBBElim gfds, deadStoreElim, mergeRets
            , tailCallIntro, deadBBElim gfds
            ]
        progoptf = foldl (.) id . reverse $
            [ dedupDatas ]
    in progoptf $ IRProgram (optf bbs) gfds datas

mergeBlocks :: [BB] -> [BB]
mergeBlocks [] = []
mergeBlocks allbbs@(BB startb _ _ : _) =
    let bbmap = Map.fromList [(bidOf bb, bb) | bb <- allbbs]
        cfg = Map.fromList [(bidOf bb, outEdges bb) | bb <- allbbs]
        revcfg = oppositeGraph cfg
        resbbs = go bbmap cfg revcfg Set.empty (map bidOf allbbs)
    in uncurry (++) (partition ((== startb) . bidOf) resbbs)
  where
    go bbmap _ _ _ [] = Map.elems bbmap
    go bbmap cfg revcfg seen (curid:rest)
        | curid `Set.member` seen = go bbmap cfg revcfg seen rest
        | otherwise =
            let topid = walkBack cfg revcfg curid
                (ids, bb') = walkForward bbmap cfg revcfg topid
                bbmap' = Map.insert topid bb' (foldr Map.delete bbmap ids)
                seen' = seen <> Set.fromList ids
            in go bbmap' cfg revcfg seen' rest

    walkBack cfg revcfg curid = fromMaybe curid $ do
        [v] <- Map.lookup curid revcfg
        [_] <- Map.lookup v cfg
        return (walkBack cfg revcfg v)

    walkForward bbmap cfg revcfg curid = fromMaybe ([curid], bbmap Map.! curid) $ do
        [v] <- Map.lookup curid cfg
        [_] <- Map.lookup v revcfg
        let (ids, BB _ inss term) = walkForward bbmap cfg revcfg v
        return (curid : ids, BB curid (inssOf (bbmap Map.! curid) ++ inss) term)

mergeRets :: [BB] -> [BB]
mergeRets bbs =
    let rets = Map.fromList [(bid, ret) | BB bid [] ret@(IRet _) <- bbs]
    in [case bb of
            BB bid inss (IJmp target) | Just ret <- Map.lookup target rets ->
                BB bid inss ret
            _ ->
                bb
       | bb <- bbs]

propAssigns :: BB -> BB
propAssigns (BB bid inss term) =
    let (state, inss') = mapFoldl propagateI Map.empty inss
        term' = propagateT state term
    in BB bid inss' term'
  where
    propagateI mp (d@(RTemp i), IAssign r) = let r' = propR mp r
                                             in (Map.insert i r' mp, (d, IAssign r'))
    propagateI mp (d, IAssign r)      = (mp, (d, IAssign (propR mp r)))
    propagateI mp ins@(_, IParam _)   = (mp, ins)
    propagateI mp ins@(_, IClosure _) = (mp, ins)
    propagateI mp ins@(_, IData _)    = (mp, ins)
    propagateI mp (d, ICallC r rs)    = (Map.empty, (d, ICallC (propR mp r) (map (propR mp) rs)))
    propagateI mp (d, IAllocClo n rs) = (mp, (d, IAllocClo n (map (propR mp) rs)))
    propagateI mp (d, IDiscard r)     = (mp, (d, IDiscard (propR mp r)))
    propagateI mp (d, IPush rs)       = (mp, (d, IPush (map (propR mp) rs)))
    propagateI mp (d, IPop rs)        = (foldr Map.delete mp (onlyTemporaries rs), (d, IPop rs))

    propagateT mp (IBr r a b)   = IBr (propR mp r) a b
    propagateT _  t@(IJmp _)    = t
    propagateT mp (IRet r)      = IRet (propR mp r)
    propagateT mp (ITailC r rs) = ITailC (propR mp r) (map (propR mp) rs)
    propagateT _  t@IExit       = t
    propagateT _  t@IUnknown    = t

    propR mp ref@(RTemp i) = fromMaybe ref (Map.lookup i mp)
    propR _ ref = ref

globalPropAssigns :: [BB] -> [BB]
globalPropAssigns bbs =
    let asgmap = map ((,) <$> fst . head <*> map snd)
                 . groupBy ((==) `on` fst)
                 . sortOn fst
                 $ [pair | BB _ inss _ <- bbs, pair <- inss]
        replacements = concatMap (\(dest, inss) -> case inss of
                                                       [IAssign ref@(RConst _)] -> [(dest, ref)]
                                                       [IAssign ref@(RSClo _)] -> [(dest, ref)]
                                                       _ -> [])
                                 asgmap
        replMap = Map.fromList replacements
        replace r = case Map.lookup r replMap of { Just r2 -> r2 ; Nothing -> r }
       -- Explicitly do not replace the assignment itself; that will be
       -- handled by deadStoreElim
    in flip map bbs $ \(BB bid inss term) ->
           let inss' = flip map inss $ \case
                           (d, IAssign r) -> (d, IAssign (replace r))
                           (d, ICallC r rs) -> (d, ICallC (replace r) (map replace rs))
                           (d, IAllocClo n rs) -> (d, IAllocClo n (map replace rs))
                           (d, IDiscard r) -> (d, IDiscard (replace r))
                           (d, IPush rs) -> (d, IPush (map replace rs))
                           ins@(_, IParam _) -> ins
                           ins@(_, IClosure _) -> ins
                           ins@(_, IData _) -> ins
                           -- Cannot replace in an IPop, because its arguments are output parameters
                           ins@(_, IPop _) -> ins
           in BB bid inss' term

deadBBElim :: Map.Map Name GlobFuncDef -> [BB] -> [BB]
deadBBElim gfds bbs =
    let callable = 0 : [bid | GlobFuncDef bid _ _ <- Map.elems gfds]
        jumpable = concatMap outEdges bbs
        reachable = Set.fromList (jumpable ++ callable)
    in filter (\bb -> bidOf bb `Set.member` reachable) bbs

deadStoreElim :: [BB] -> [BB]
deadStoreElim bbs = [BB bid (filter (not . shouldRemove) inss) term | BB bid inss term <- bbs]
  where
    readtemps = Set.fromList (concatMap bbReadTemps bbs)
    alltemps = readtemps <> Set.fromList (concatMap bbWrittenTemps bbs)
    elim = alltemps Set.\\ readtemps

    shouldRemove :: Instruction -> Bool
    shouldRemove (RNone, IDiscard RNone) = True
    shouldRemove (RNone, IDiscard (RConst _)) = True
    shouldRemove (RNone, IDiscard (RSClo _)) = True
    shouldRemove (RNone, IDiscard (RTemp i)) = i `Set.member` elim
    shouldRemove (RTemp i, ins) = pureIC ins && i `Set.member` elim
    shouldRemove _ = False

    pureIC :: InsCode -> Bool
    pureIC (IAssign _)     = True
    pureIC (IParam _)      = True
    pureIC (IClosure _)    = True
    pureIC (IData _)       = True
    pureIC (IAllocClo _ _) = True
    pureIC (ICallC _ _)    = False
    pureIC (IDiscard _)    = False
    pureIC (IPush _)       = False
    pureIC (IPop _)        = False

tailCallIntro :: [BB] -> [BB]
tailCallIntro bbs = map introduce bbs
  where
    readInBB = map (Set.fromList . bbReadTemps) bbs
    readBefore = init $ scanl (<>) Set.empty readInBB
    readAfter = tail $ scanr (<>) Set.empty readInBB
    readInOthers = Map.fromList [(bid, before <> after)
                                | (BB bid _ _, before, after) <- zip3 bbs readBefore readAfter]

    introduce orig@(BB _ [] _) = orig
    introduce orig@(BB bid inss@(_:_) term) =
        case (last inss, term) of
            ((RTemp i1, ICallC cl as), IRet (RTemp i2))
                | i1 == i2
                , i1 `Set.notMember` (readInOthers Map.! bid)
                , i1 `notElem` onlyTemporaries (concatMap (allRefs . snd) (init inss)) ->
                    BB bid (init inss) (ITailC cl as)
            _ -> orig

dedupDatas :: IRProgram -> IRProgram
dedupDatas (IRProgram origbbs gfds datatbl) = IRProgram (map goBB origbbs) gfds values
  where
    values = uniq (sort datatbl)
    valueIdx = Map.fromList (zip values [0..])

    goBB (BB bid inss term) = BB bid (map goI inss) term

    goI (ref, IData i) = (ref, IData (valueIdx Map.! (datatbl !! i)))
    goI ins = ins

mapFoldl :: (s -> a -> (s, b)) -> s -> [a] -> (s, [b])
mapFoldl f s = fmap reverse . foldl' (\(s', yet) x -> fmap (: yet) (f s' x)) (s, [])