{-# LANGUAGE TupleSections #-} module Optimiser(optimise) where import Data.List import qualified Data.Map.Strict as Map import Data.Maybe import qualified Data.Set as Set import AST (Name) import Intermediate optimise :: IRProgram -> IRProgram optimise (IRProgram bbs gfds datas) = let optf = foldl (.) id [ tailCallIntro , deadBBElim gfds, mergeRets , deadStoreElim, deadBBElim gfds , map propAssigns , mergeRets, mergeBlocks] in IRProgram (optf bbs) gfds datas mergeBlocks :: [BB] -> [BB] mergeBlocks [] = [] mergeBlocks allbbs@(BB startb _ _ : _) = uncurry (++) (partition ((== startb) . bidOf) (go allbbs (length allbbs))) where go [] _ = [] go bbs 0 = bbs go (bb@(BB bid inss term) : bbs) n = case partition (hasJumpTo bid . termOf) bbs of ([], _) -> go (bbs ++ [bb]) (n - 1) ([BB bid' inss' _], rest) -> go (BB bid' (inss' ++ inss) term : rest) n _ -> go (bbs ++ [bb]) (n - 1) hasJumpTo bid (IJmp a) = a == bid hasJumpTo _ _ = False mergeRets :: [BB] -> [BB] mergeRets bbs = let rets = Map.fromList [(bid, ret) | BB bid [] ret@(IRet _) <- bbs] in [case bb of BB bid inss (IJmp target) | Just ret <- Map.lookup target rets -> BB bid inss ret _ -> bb | bb <- bbs] propAssigns :: BB -> BB propAssigns (BB bid inss term) = let (state, inss') = mapFoldl propagateI Map.empty inss term' = propagateT state term in BB bid inss' term' where propagateI mp (d@(RTemp i), IAssign r) = let r' = propR mp r in (Map.insert i r' mp, (d, IAssign r')) propagateI mp (d, IAssign r) = (mp, (d, IAssign (propR mp r))) propagateI mp ins@(_, IParam _) = (mp, ins) propagateI mp ins@(_, IClosure _) = (mp, ins) propagateI mp ins@(_, IData _) = (mp, ins) propagateI mp (d, ICallC r rs) = (Map.empty, (d, ICallC (propR mp r) (map (propR mp) rs))) propagateI mp (d, IAllocClo n rs) = (mp, (d, IAllocClo n (map (propR mp) rs))) propagateI mp (d, IDiscard r) = (mp, (d, IDiscard (propR mp r))) propagateT mp (IBr r a b) = IBr (propR mp r) a b propagateT _ t@(IJmp _) = t propagateT mp (IRet r) = IRet (propR mp r) propagateT mp (ITailC r rs) = ITailC (propR mp r) (map (propR mp) rs) propagateT _ t@IExit = t propagateT _ t@IUnknown = t propR mp ref@(RTemp i) = fromMaybe ref (Map.lookup i mp) propR _ ref = ref deadBBElim :: Map.Map Name GlobFuncDef -> [BB] -> [BB] deadBBElim gfds bbs = let callable = 0 : [bid | GlobFuncDef bid _ _ <- Map.elems gfds] jumpable = concatMap outEdges bbs reachable = Set.fromList (jumpable ++ callable) in filter (\bb -> bidOf bb `Set.member` reachable) bbs deadStoreElim :: [BB] -> [BB] deadStoreElim bbs = [BB bid (filter (not . shouldRemove) inss) term | BB bid inss term <- bbs] where readtemps = Set.fromList (concatMap readTempsBB bbs) alltemps = readtemps <> Set.fromList (concatMap writtenTempsBB bbs) elim = alltemps Set.\\ readtemps shouldRemove :: Instruction -> Bool shouldRemove (RNone, IDiscard RNone) = True shouldRemove (RNone, IDiscard (RTemp i)) = i `Set.member` elim shouldRemove (RTemp i, ins) = pureIC ins && i `Set.member` elim shouldRemove _ = False pureIC :: InsCode -> Bool pureIC (IAssign _) = True pureIC (IParam _) = True pureIC (IClosure _) = True pureIC (IData _) = True pureIC (IAllocClo _ _) = True pureIC (ICallC _ _) = False pureIC (IDiscard _) = False tailCallIntro :: [BB] -> [BB] tailCallIntro bbs = map introduce bbs where readInBB = map (Set.fromList . readTempsBB) bbs readBefore = init $ scanl (<>) Set.empty readInBB readAfter = tail $ scanr (<>) Set.empty readInBB readInOthers = Map.fromList [(bid, before <> after) | (BB bid _ _, before, after) <- zip3 bbs readBefore readAfter] introduce orig@(BB _ [] _) = orig introduce orig@(BB bid inss@(_:_) term) = case (last inss, term) of ((RTemp i1, ICallC cl as), IRet (RTemp i2)) | i1 == i2 , i1 `Set.notMember` (readInOthers Map.! bid) , i1 `notElem` concatMap (readTempsIC . snd) (init inss) -> BB bid (init inss) (ITailC cl as) _ -> orig outEdges :: BB -> [Int] outEdges (BB _ _ term) = outEdgesT term outEdgesT :: Terminator -> [Int] outEdgesT (IBr _ a b) = [a, b] outEdgesT (IJmp a) = [a] outEdgesT (IRet _) = [] outEdgesT (ITailC _ _) = [] outEdgesT IExit = [] outEdgesT IUnknown = [] readTempsBB :: BB -> [Int] readTempsBB (BB _ inss term) = concatMap (readTempsIC . snd) inss ++ readTempsT term writtenTempsBB :: BB -> [Int] writtenTempsBB (BB _ inss _) = concatMap (readTempsR . fst) inss readTempsIC :: InsCode -> [Int] readTempsIC (IAssign r) = readTempsR r readTempsIC (IParam _) = [] readTempsIC (IClosure _) = [] readTempsIC (IData _) = [] readTempsIC (ICallC r rs) = readTempsR r ++ concatMap readTempsR rs readTempsIC (IAllocClo _ rs) = concatMap readTempsR rs readTempsIC (IDiscard _) = [] readTempsT :: Terminator -> [Int] readTempsT (IBr r _ _) = readTempsR r readTempsT (IJmp _) = [] readTempsT (IRet r) = readTempsR r readTempsT (ITailC r rs) = readTempsR r ++ concatMap readTempsR rs readTempsT IExit = [] readTempsT IUnknown = [] readTempsR :: Ref -> [Int] readTempsR (RConst _) = [] readTempsR (RTemp i) = [i] readTempsR (RSClo _) = [] readTempsR RNone = [] mapFoldl :: (s -> a -> (s, b)) -> s -> [a] -> (s, [b]) mapFoldl f s = fmap reverse . foldl' (\(s', yet) x -> fmap (: yet) (f s' x)) (s, [])