{-# LANGUAGE LambdaCase, TupleSections #-} module Optimiser(optimise) where import Data.Function (on) import Data.List import qualified Data.Map.Strict as Map import Data.Maybe import qualified Data.Set as Set import AST (Name) import Intermediate import Util optimise :: IRProgram -> IRProgram optimise (IRProgram bbs gfds datas) = let optf = foldl (.) id . reverse $ [ mergeBlocks, mergeRets , map propAssigns, globalPropAssigns , deadBBElim gfds, deadStoreElim, mergeRets , map propAssigns , deadBBElim gfds, deadStoreElim, mergeRets , tailCallIntro, deadBBElim gfds ] progoptf = foldl (.) id . reverse $ [ dedupDatas ] in progoptf $ IRProgram (optf bbs) gfds datas mergeBlocks :: [BB] -> [BB] mergeBlocks [] = [] mergeBlocks allbbs@(BB startb _ _ : _) = let bbmap = Map.fromList [(bidOf bb, bb) | bb <- allbbs] cfg = Map.fromList [(bidOf bb, outEdges bb) | bb <- allbbs] revcfg = oppositeGraph cfg resbbs = go bbmap cfg revcfg Set.empty (map bidOf allbbs) in uncurry (++) (partition ((== startb) . bidOf) resbbs) where go bbmap _ _ _ [] = Map.elems bbmap go bbmap cfg revcfg seen (curid:rest) | curid `Set.member` seen = go bbmap cfg revcfg seen rest | otherwise = let topid = walkBack cfg revcfg curid (ids, bb') = walkForward bbmap cfg revcfg topid bbmap' = Map.insert topid bb' (foldr Map.delete bbmap ids) seen' = seen <> Set.fromList ids in go bbmap' cfg revcfg seen' rest walkBack cfg revcfg curid = fromMaybe curid $ do [v] <- Map.lookup curid revcfg [_] <- Map.lookup v cfg return (walkBack cfg revcfg v) walkForward bbmap cfg revcfg curid = fromMaybe ([curid], bbmap Map.! curid) $ do [v] <- Map.lookup curid cfg [_] <- Map.lookup v revcfg let (ids, BB _ inss term) = walkForward bbmap cfg revcfg v return (curid : ids, BB curid (inssOf (bbmap Map.! curid) ++ inss) term) mergeRets :: [BB] -> [BB] mergeRets bbs = let rets = Map.fromList [(bid, ret) | BB bid [] ret@(IRet _) <- bbs] in [case bb of BB bid inss (IJmp target) | Just ret <- Map.lookup target rets -> BB bid inss ret _ -> bb | bb <- bbs] propAssigns :: BB -> BB propAssigns (BB bid inss term) = let (state, inss') = mapFoldl propagateI Map.empty inss term' = propagateT state term in BB bid inss' term' where propagateI mp (d@(RTemp i), IAssign r) = let r' = propR mp r in (Map.insert i r' mp, (d, IAssign r')) propagateI mp (d, IAssign r) = (mp, (d, IAssign (propR mp r))) propagateI mp ins@(_, IParam _) = (mp, ins) propagateI mp ins@(_, IClosure _) = (mp, ins) propagateI mp ins@(_, IData _) = (mp, ins) propagateI mp (d, ICallC r rs) = (Map.empty, (d, ICallC (propR mp r) (map (propR mp) rs))) propagateI mp (d, IAllocClo n rs) = (mp, (d, IAllocClo n (map (propR mp) rs))) propagateI mp (d, IDiscard r) = (mp, (d, IDiscard (propR mp r))) propagateI mp (d, IPush rs) = (mp, (d, IPush (map (propR mp) rs))) propagateI mp (d, IPop rs) = (foldr Map.delete mp (onlyTemporaries rs), (d, IPop rs)) propagateT mp (IBr r a b) = IBr (propR mp r) a b propagateT _ t@(IJmp _) = t propagateT mp (IRet r) = IRet (propR mp r) propagateT mp (ITailC r rs) = ITailC (propR mp r) (map (propR mp) rs) propagateT _ t@IExit = t propagateT _ t@IUnknown = t propR mp ref@(RTemp i) = fromMaybe ref (Map.lookup i mp) propR _ ref = ref globalPropAssigns :: [BB] -> [BB] globalPropAssigns bbs = let asgmap = map ((,) <$> fst . head <*> map snd) . groupBy ((==) `on` fst) . sortOn fst $ [pair | BB _ inss _ <- bbs, pair <- inss] replacements = concatMap (\(dest, inss) -> case inss of [IAssign ref@(RConst _)] -> [(dest, ref)] [IAssign ref@(RSClo _)] -> [(dest, ref)] _ -> []) asgmap replMap = Map.fromList replacements replace r = case Map.lookup r replMap of { Just r2 -> r2 ; Nothing -> r } -- Explicitly do not replace the assignment itself; that will be -- handled by deadStoreElim in flip map bbs $ \(BB bid inss term) -> let inss' = flip map inss $ \case (d, IAssign r) -> (d, IAssign (replace r)) (d, ICallC r rs) -> (d, ICallC (replace r) (map replace rs)) (d, IAllocClo n rs) -> (d, IAllocClo n (map replace rs)) (d, IDiscard r) -> (d, IDiscard (replace r)) (d, IPush rs) -> (d, IPush (map replace rs)) ins@(_, IParam _) -> ins ins@(_, IClosure _) -> ins ins@(_, IData _) -> ins -- Cannot replace in an IPop, because its arguments are output parameters ins@(_, IPop _) -> ins in BB bid inss' term deadBBElim :: Map.Map Name GlobFuncDef -> [BB] -> [BB] deadBBElim gfds bbs = let callable = 0 : [bid | GlobFuncDef bid _ _ <- Map.elems gfds] jumpable = concatMap outEdges bbs reachable = Set.fromList (jumpable ++ callable) in filter (\bb -> bidOf bb `Set.member` reachable) bbs deadStoreElim :: [BB] -> [BB] deadStoreElim bbs = [BB bid (filter (not . shouldRemove) inss) term | BB bid inss term <- bbs] where readtemps = Set.fromList (concatMap bbReadTemps bbs) alltemps = readtemps <> Set.fromList (concatMap bbWrittenTemps bbs) elim = alltemps Set.\\ readtemps shouldRemove :: Instruction -> Bool shouldRemove (RNone, IDiscard RNone) = True shouldRemove (RNone, IDiscard (RConst _)) = True shouldRemove (RNone, IDiscard (RSClo _)) = True shouldRemove (RNone, IDiscard (RTemp i)) = i `Set.member` elim shouldRemove (RTemp i, ins) = pureIC ins && i `Set.member` elim shouldRemove _ = False pureIC :: InsCode -> Bool pureIC (IAssign _) = True pureIC (IParam _) = True pureIC (IClosure _) = True pureIC (IData _) = True pureIC (IAllocClo _ _) = True pureIC (ICallC _ _) = False pureIC (IDiscard _) = False pureIC (IPush _) = False pureIC (IPop _) = False tailCallIntro :: [BB] -> [BB] tailCallIntro bbs = map introduce bbs where readInBB = map (Set.fromList . bbReadTemps) bbs readBefore = init $ scanl (<>) Set.empty readInBB readAfter = tail $ scanr (<>) Set.empty readInBB readInOthers = Map.fromList [(bid, before <> after) | (BB bid _ _, before, after) <- zip3 bbs readBefore readAfter] introduce orig@(BB _ [] _) = orig introduce orig@(BB bid inss@(_:_) term) = case (last inss, term) of ((RTemp i1, ICallC cl as), IRet (RTemp i2)) | i1 == i2 , i1 `Set.notMember` (readInOthers Map.! bid) , i1 `notElem` onlyTemporaries (concatMap (allRefs . snd) (init inss)) -> BB bid (init inss) (ITailC cl as) _ -> orig dedupDatas :: IRProgram -> IRProgram dedupDatas (IRProgram origbbs gfds datatbl) = IRProgram (map goBB origbbs) gfds values where values = uniq (sort datatbl) valueIdx = Map.fromList (zip values [0..]) goBB (BB bid inss term) = BB bid (map goI inss) term goI (ref, IData i) = (ref, IData (valueIdx Map.! (datatbl !! i))) goI ins = ins mapFoldl :: (s -> a -> (s, b)) -> s -> [a] -> (s, [b]) mapFoldl f s = fmap reverse . foldl' (\(s', yet) x -> fmap (: yet) (f s' x)) (s, [])