module Optimiser(optimise) where import Data.List import Data.Maybe import qualified Data.Map.Strict as Map import Debug.Trace import Defs import Intermediate import ReplaceRefs import Utils type Optimisation = IRProgram -> IRProgram type FuncOptimisation = IRFunc -> IRFunc optimise :: IRProgram -> Error IRProgram optimise prog = let master = foldl1 (.) (reverse optimisations) {-. trace "-- OPT PASS --"-} reslist = iterate master prog pairs = zip reslist (tail reslist) in Right $ fst $ fromJust $ find (uncurry (==)) pairs where optimisations = map funcopt $ -- [chainJumps, removeUnusedBlocks] [chainJumps, mergeTerminators, looseJumps, removeUnusedBlocks, identityOps, constantPropagate, removeNops, movPush, evaluateInstructions, evaluateTerminators] funcopt :: FuncOptimisation -> Optimisation funcopt fo (IRProgram vars funcs) = IRProgram vars (map fo funcs) chainJumps :: FuncOptimisation chainJumps (IRFunc rt name al bbs sid) = IRFunc rt name al bbs' sid where bbs' = snd $ last $ takeWhile fst $ iterate (mergeChain . snd) (True, bbs) mergeChain :: [BB] -> (Bool, [BB]) mergeChain [] = (False, []) mergeChain bbs2 = case findIndex isSuitable bbs2 of Nothing -> (False, bbs2) Just idx -> let (BB bid1 inss1 (IJmp target), rest) = (bbs2 !! idx, take idx bbs2 ++ drop (idx+1) bbs2) [BB _ inss2 term2] = filter (\(BB bid _ _) -> bid == target) rest merged = BB bid1 (inss1 ++ inss2) term2 in (True, merged : rest) where hasJmpTo :: Id -> BB -> Bool hasJmpTo i (BB _ _ (IJmp i')) = i == i' hasJmpTo _ _ = False isSuitable :: BB -> Bool isSuitable (BB _ _ (IJmp target)) = sum (map (fromEnum . hasJmpTo target) bbs2) == 1 isSuitable _ = False mergeTerminators :: FuncOptimisation mergeTerminators (IRFunc rt name al bbs sid) = IRFunc rt name al bbs' sid where bbs' = flip map bbs $ \bb@(BB bid inss term) -> case term of IJmp i -> case find ((== i) . fst) singles of Just (_, t) -> BB bid inss t Nothing -> bb _ -> bb singles = map (\(BB i _ t) -> (i, t)) $ filter (\(BB _ inss _) -> null inss) bbs looseJumps :: FuncOptimisation looseJumps (IRFunc rt name al bbs sid) = IRFunc rt name al bbs' sid where bbs' = flip map bbs $ \bb@(BB bid inss term) -> case term of IJmp i -> BB bid inss (IJmp (translate i)) IJcc ct r1 r2 i j -> BB bid inss (IJcc ct r1 r2 (translate i) (translate j)) _ -> bb translate i = fromMaybe i $ Map.lookup i transmap transmap = Map.fromList $ catMaybes $ flip map bbs $ \bb -> case bb of BB bid [] (IJmp i) -> Just (bid, i) _ -> Nothing removeUnusedBlocks :: FuncOptimisation removeUnusedBlocks (IRFunc rt name al bbs sid) = IRFunc rt name al bbs' sid where bbs' = filter isReachable bbs isReachable :: BB -> Bool isReachable (BB bid _ _) | bid == sid = True | otherwise = isJust $ flip find bbs $ \(BB _ _ term) -> case term of IJcc _ _ _ i1 i2 -> i1 == bid || i2 == bid IJmp i -> i == bid _ -> False identityOps :: FuncOptimisation identityOps (IRFunc rt name al bbs sid) = IRFunc rt name al (map go bbs) sid where go :: BB -> BB go (BB bid inss term) = BB bid (catMaybes $ map goI inss) term goI :: IRIns -> Maybe IRIns goI (IAri AAdd _ (Constant _ 0)) = Nothing goI (IAri ASub _ (Constant _ 0)) = Nothing goI (IAri AMul _ (Constant _ 1)) = Nothing goI (IAri ADiv _ (Constant _ 1)) = Nothing goI i = Just i constantPropagate :: FuncOptimisation constantPropagate (IRFunc rt name al bbs sid) = IRFunc rt name al bbs' sid where alltemps = findAllTemps' bbs consttemps = catMaybes $ flip map alltemps $ \ref -> let locs = findMutations' bbs ref loc = head locs ins = insAt bbs loc isIMov (IMov _ _) = True isIMov _ = False in {-trace ("Muts of " ++ show ref ++ ": " ++ show locs ++ ": " ++ show (map (insAt bbs) locs)) $-} if length locs == 1 && isIMov ins then Just (loc, ins) else Nothing bbs' = case consttemps of [] -> bbs ((loc, IMov ref value) : _) -> replaceRefsBBList ref value (nopifyInsAt bbs loc) _ -> undefined removeNops :: FuncOptimisation removeNops (IRFunc rt name al bbs sid) = IRFunc rt name al (map go bbs) sid where go (BB bid inss term) = BB bid (filter (not . isNop) inss) term isNop INop = True isNop _ = False movPush :: FuncOptimisation movPush (IRFunc rt name al bbs sid) = IRFunc rt name al bbs' sid where bbs' = map goBB bbs goBB :: BB -> BB goBB (BB bid inss term) = BB bid (go inss) term go :: [IRIns] -> [IRIns] go [] = [] go (ins@(IMov d _) : rest) | isJust (find (== d) (findAllRefsInss rest)) = push ins rest go (ins : rest) = ins : go rest push :: IRIns -> [IRIns] -> [IRIns] push mov [] = [mov] push mov@(IMov d s) (ins@(IMov d' s') : rest) | d' == d = if s' == d then push mov rest else push ins rest | otherwise = replaceRefsIns d s ins : push mov rest push mov@(IMov d s) (ins@(IResize d' s') : rest) | d' == d = if s' == d then push mov rest else push ins rest | otherwise = replaceRefsIns d s ins : push mov rest push mov@(IMov d s) (ins@(ILoad d' _) : rest) | d' == d = mov : ins : go rest | otherwise = replaceRefsIns d s ins : push mov rest push mov@(IMov d s) (ins@(IAri at d' s') : rest) | d' == d = case (s, s') of (Constant sza a, Constant szb b) | sza == szb -> push (IMov d (Constant sza $ evaluateArith at a b)) rest | otherwise -> error $ "Inconsistent sizes in " ++ show mov ++ "; " ++ show ins _ -> mov : ins : go rest | otherwise = replaceRefsIns d s ins : push mov rest push mov@(IMov d s) (ins@(ICallr d' _ _) : rest) | d' == d = mov : ins : go rest | otherwise = replaceRefsIns d s ins : push mov rest push mov@(IMov d s) (ins@(IStore _ _) : rest) = replaceRefsIns d s ins : push mov rest push mov@(IMov d s) (ins@(ICall _ _) : rest) = replaceRefsIns d s ins : push mov rest push mov (ins@INop : rest) = ins : push mov rest push _ _ = undefined evaluateInstructions :: FuncOptimisation evaluateInstructions (IRFunc rt name al bbs sid) = IRFunc rt name al (map goBB bbs) sid where goBB :: BB -> BB goBB (BB bid inss term) = BB bid (map goI inss) term goI :: IRIns -> IRIns goI (IResize ref (Constant _ v)) = IMov ref $ Constant (refSize ref) $ truncValue (refSize ref) v goI ins = ins truncValue :: Size -> Value -> Value truncValue sz v = fromIntegral $ (fromIntegral v :: Integer) `mod` (2 ^ (8 * sz)) evaluateTerminators :: FuncOptimisation evaluateTerminators (IRFunc rt name al bbs sid) = IRFunc rt name al bbs' sid where bbs' = map (\(BB bid inss term) -> BB bid inss (go term)) bbs go :: IRTerm -> IRTerm go term@(IJcc ct (Constant sza a) (Constant szb b) i1 i2) | sza /= szb = error $ "Inconsistent sizes in " ++ show term | evaluateCmp ct a b = IJmp i1 | otherwise = IJmp i2 go term = term insAt :: [BB] -> (Int, Int) -> IRIns insAt bbs (i, j) = let (BB _ inss _) = bbs !! i in inss !! j nopifyInsAt :: [BB] -> (Int, Int) -> [BB] nopifyInsAt bbs (i, j) = let (pre, BB bid inss term : post) = splitAt i bbs (ipre, _ : ipost) = splitAt j inss in pre ++ BB bid (ipre ++ INop : ipost) term : post findMutations :: BB -> Ref -> [Int] findMutations (BB _ inss _) ref = catMaybes $ flip map (zip inss [0..]) $ \(ins, idx) -> case ins of (IMov r _) | r == ref -> Just idx (IAri _ r _) | r == ref -> Just idx (ICallr r _ _) | r == ref -> Just idx _ -> Nothing findMutations' :: [BB] -> Ref -> [(Int, Int)] findMutations' bbs ref = [(i, j) | (bb, i) <- zip bbs [0..], j <- findMutations bb ref] findAllRefs :: BB -> [Ref] findAllRefs (BB _ inss _) = findAllRefsInss inss findAllRefsInss :: [IRIns] -> [Ref] findAllRefsInss inss = uniq $ sort $ concatMap go inss where go (IMov a b) = [a, b] go (IStore a b) = [a, b] go (ILoad a b) = [a, b] go (IAri _ a b) = [a, b] go (ICall _ al) = al go (ICallr a _ al) = a : al go (IResize a b) = [a, b] go INop = [] -- findAllRefs' :: [BB] -> [Ref] -- findAllRefs' = uniq . sort . concatMap findAllRefs findAllTemps :: BB -> [Ref] findAllTemps bb = flip filter (findAllRefs bb) $ \ref -> case ref of (Temp _ _) -> True _ -> False findAllTemps' :: [BB] -> [Ref] findAllTemps' = concatMap findAllTemps