module Optimiser(optimise) where import Data.Either import Data.List import Data.Maybe import qualified Data.Map.Strict as Map import Debug.Trace import Defs import Intermediate import Pretty import ReplaceRefs import Utils type Optimisation = IRProgram -> IRProgram type FuncOptimisation = IRFunc -> IRFunc optimise :: IRProgram -> Error IRProgram optimise prog = let optlist = [trace "-- OPT PASS --" , \p -> trace (pretty p) p] ++ optimisations reslist = scanl (flip ($)) prog $ cycle optlist passreslist = map fst $ filter (\(_, i) -> i `mod` length optlist == 0) $ zip reslist [0..] applyFinalOpts p = foldl (flip ($)) p finaloptimisations in if True then return $ applyFinalOpts $ fst $ fromJust $ find (uncurry (==)) $ zip passreslist (tail passreslist) else return $ reslist !! 5 where optimisations = map funcopt [chainJumps, mergeTerminators, looseJumps, removeUnusedBlocks, removeDuplicateBlocks, identityOps, constantPropagate, movPush, arithPush False, removeUnusedInstructions, evaluateInstructions, evaluateTerminators, flipJccs] finaloptimisations = map funcopt [arithPush True] funcopt :: FuncOptimisation -> Optimisation funcopt fo (IRProgram vars funcs) = IRProgram vars (map fo funcs) chainJumps :: FuncOptimisation chainJumps (IRFunc rt name al bbs sid) = IRFunc rt name al bbs' sid where bbs' = snd $ last $ takeWhile fst $ iterate (mergeChain . snd) (True, bbs) mergeChain :: [BB] -> (Bool, [BB]) mergeChain [] = (False, []) mergeChain bbs2 = case findIndex isSuitable bbs2 of Nothing -> (False, bbs2) Just idx -> let (BB bid1 inss1 (IJmp target), rest) = (bbs2 !! idx, take idx bbs2 ++ drop (idx+1) bbs2) [BB _ inss2 term2] = filter (\(BB bid _ _) -> bid == target) rest merged = BB bid1 (inss1 ++ inss2) term2 in (True, merged : rest) where hasJmpTo :: Id -> BB -> Bool hasJmpTo i (BB _ _ (IJmp i')) = i == i' hasJmpTo _ _ = False isSuitable :: BB -> Bool isSuitable (BB _ _ (IJmp target)) = sum (map (fromEnum . hasJmpTo target) bbs2) == 1 isSuitable _ = False mergeTerminators :: FuncOptimisation mergeTerminators (IRFunc rt name al bbs sid) = IRFunc rt name al bbs' sid where bbs' = flip map bbs $ \bb@(BB bid inss term) -> case term of IJmp i -> case find ((== i) . fst) singles of Just (_, t) -> BB bid inss t Nothing -> bb _ -> bb singles = map (\(BB i _ t) -> (i, t)) $ filter (\(BB _ inss _) -> null inss) bbs looseJumps :: FuncOptimisation looseJumps (IRFunc rt name al bbs sid) = IRFunc rt name al bbs' sid where bbs' = flip map bbs $ \bb@(BB bid inss term) -> case term of IJmp i -> BB bid inss (IJmp (translate i)) IJcc ct r1 r2 i j -> BB bid inss (IJcc ct r1 r2 (translate i) (translate j)) _ -> bb translate i = fromMaybe i $ Map.lookup i transmap transmap = Map.fromList $ catMaybes $ flip map bbs $ \bb -> case bb of BB bid [] (IJmp i) -> Just (bid, i) _ -> Nothing removeUnusedBlocks :: FuncOptimisation removeUnusedBlocks (IRFunc rt name al bbs sid) = IRFunc rt name al bbs' sid where bbs' = filter isReachable bbs isReachable :: BB -> Bool isReachable (BB bid _ _) | bid == sid = True | otherwise = isJust $ flip find bbs $ \(BB _ _ term) -> case term of IJcc _ _ _ i1 i2 -> i1 == bid || i2 == bid IJmp i -> i == bid _ -> False removeDuplicateBlocks :: FuncOptimisation removeDuplicateBlocks (IRFunc rt name al bbs sid) = IRFunc rt name al bbs' sid where bbs' = let (bbspre, repls) = foldr foldfunc ([], []) bbs in foldl (\l (from, to) -> replaceBBIds from to l) bbspre repls foldfunc bb@(BB bid inss term) (l, repls) = case find (\(BB _ inss' term') -> inss == inss' && term == term') l of Nothing -> (bb : l, repls) Just (BB bid' _ _) -> (l, (bid, bid') : repls) replaceBBIds :: Id -> Id -> [BB] -> [BB] replaceBBIds from to = map $ \(BB bid inss term) -> BB bid inss $ case term of IJcc ct r1 r2 i1 i2 -> IJcc ct r1 r2 (trans from to i1) (trans from to i2) IJmp i -> IJmp (trans from to i) IRet -> IRet IRetr r -> IRetr r ITermNone -> undefined trans :: (Eq a) => a -> a -> a -> a trans a b c | a == c = b | otherwise = c identityOps :: FuncOptimisation identityOps (IRFunc rt name al bbs sid) = IRFunc rt name al (map go bbs) sid where go :: BB -> BB go (BB bid inss term) = BB bid (catMaybes $ map goI inss) term goI :: IRIns -> Maybe IRIns goI (IAri AAdd d s (Constant _ 0)) = Just $ IMov d s goI (IAri AAdd d (Constant _ 0) s) = Just $ IMov d s goI (IAri ASub d s (Constant _ 0)) = Just $ IMov d s goI (IAri AMul d s (Constant _ 1)) = Just $ IMov d s goI (IAri AMul d (Constant _ 1) s) = Just $ IMov d s goI (IAri ADiv d s (Constant _ 1)) = Just $ IMov d s goI (IMov d s) | d == s = Nothing goI i = Just i constantPropagate :: FuncOptimisation constantPropagate (IRFunc rt name al bbs sid) = IRFunc rt name al bbs' sid where alltemps = findAllTemps' bbs consttemps = catMaybes $ flip map alltemps $ \ref -> let locs = findMutations' bbs ref loc = head locs ins = insAt bbs loc in if length locs == 1 && isIMov ins then Just (loc, ins) else Nothing bbs' = case consttemps of [] -> bbs ((loc, IMov ref value) : _) -> replaceRefsBBList ref value (nopifyInsAt bbs loc) _ -> undefined movPush :: FuncOptimisation movPush (IRFunc rt name al bbs sid) = IRFunc rt name al (map goBB bbs) sid where goBB :: BB -> BB goBB (BB bid inss term) = let inss' = go inss term term' = if null inss' then term else pushT (last inss) term in BB bid inss' term' go :: [IRIns] -> IRTerm -> [IRIns] go [] _ = [] go (ins@(IMov d _) : rest) term | isJust (find (== d) (findAllRefsInss rest ++ findAllRefsTerm term)) = push ins rest term go (ins : rest) term = ins : go rest term push :: IRIns -> [IRIns] -> IRTerm -> [IRIns] push mov [] _ = [mov] push (IMov d s) l _ | d == s = l push mov@(IMov d s) (ins@(IMov d' s') : rest) term | d' == d = if d' == s' then push mov rest term else push ins rest term | d' == s = mov : push (IMov d' (replaceRef d s s')) rest term | otherwise = IMov d' (replaceRef d s s') : push mov rest term push mov@(IMov d s) (IResize d' s' : rest) term | d' == d = IResize d' (replaceRef d s s') : go rest term | d' == s = mov : IResize d' (replaceRef d s s') : go rest term | otherwise = IResize d' (replaceRef d s s') : push mov rest term push mov@(IMov d s) (ILoad d' s' : rest) term | d' == d = ILoad d' (replaceRef d s s') : go rest term | d' == s = mov : ILoad d' (replaceRef d s s') : go rest term | otherwise = ILoad d' (replaceRef d s s') : push mov rest term push mov@(IMov d s) (IAri at d' s1' s2' : rest) term | d' == d = IAri at d' (replaceRef d s s1') (replaceRef d s s2') : go rest term | d' == s = mov : IAri at d' (replaceRef d s s1') (replaceRef d s s2') : go rest term | otherwise = IAri at d' (replaceRef d s s1') (replaceRef d s s2') : push mov rest term -- I don't trust going past calls because globals might change. Might be able to -- catch that case, but that will go wrong when more stuff gets added. -- push mov@(IMov d s) (ins@(ICallr d' _ _) : rest) term -- | d' == d = mov : ins : go rest term -- | otherwise = replaceRefsIns d s ins : push mov rest term -- push mov@(IMov d s) (ins@(ICall _ _) : rest) term = replaceRefsIns d s ins : push mov rest term push mov@(IMov d s) (ins@(IStore _ _) : rest) term = replaceRefsIns d s ins : push mov rest term push mov (INop : rest) term = push mov rest term push mov l term = mov : go l term pushT :: IRIns -> IRTerm -> IRTerm pushT (IMov d s) term = replaceRefsTerm d s term pushT _ term = term arithPush :: Bool -> FuncOptimisation arithPush ariari (IRFunc rt name al bbs sid) = IRFunc rt name al (map goBB bbs) sid where goBB :: BB -> BB goBB (BB bid inss term) = let (inss', [Right term']) = span isLeft $ go (map Left inss ++ [Right term]) in BB bid (map (fromLeft undefined) inss') term' go :: [Either IRIns IRTerm] -> [Either IRIns IRTerm] go [] = [] go (Left ari@(IAri _ _ _ _) : rest) = Left ari : go (propagate ari rest) go (ins : rest) = ins : go rest propagate :: IRIns -> [Either IRIns IRTerm] -> [Either IRIns IRTerm] propagate _ [] = [] propagate ari@(IAri at d s1 s2) (Left ins@(IMov md ms) : rest) | d == ms = Left (IAri at md s1 s2) : (if d /= md then propagate ari rest else rest) | d /= md && md /= s1 && md /= s2 = Left ins : propagate ari rest | otherwise = Left ins : rest propagate ari@(IAri _ d _ _) (Left ins@(IStore md ms) : rest) | null (intersect [d] [md,ms]) = Left ins : propagate ari rest | otherwise = Left ins : rest propagate ari@(IAri _ d s1 s2) (Left ins@(ILoad md ms) : rest) | null (intersect [d] [md,ms] ++ intersect [s1,s2] [md]) = Left ins : propagate ari rest | otherwise = Left ins : rest propagate ari@(IAri at d s1 s2) (Left ins@(IAri mat md ms1 ms2) : rest) | ariari && d /= md && (at, s1, s2) == (mat, ms1, ms2) = Left (IMov md d) : propagate ari rest | null (intersect [d] [md,ms1,ms2] ++ intersect [s1,s2] [md]) = Left ins : propagate ari rest | otherwise = Left ins : propagate ins rest -- I don't trust going past calls because globals might change. Might be able to -- catch that case, but that will go wrong when more stuff gets added. -- propagate ari@(IAri _ d s1 s2) (Left ins@(ICall _ mal) : rest) -- | null (intersect [d] mal) = Left ins : propagate ari rest -- | otherwise = Left ins : rest -- propagate ari@(IAri _ d s1 s2) (Left ins@(ICallr md _ mal) : rest) -- | null (intersect [d,s1,s2] (md : mal)) = Left ins : propagate ari rest -- | otherwise = Left ins : rest propagate ari@(IAri _ d s1 s2) (Left ins@(IResize md ms) : rest) | null (intersect [d] [md,ms] ++ intersect [s1,s2] [md]) = Left ins : propagate ari rest | otherwise = Left ins : rest propagate ari@(IAri _ _ _ _) (Left INop : rest) = propagate ari rest propagate (IAri at d s1 s2) (Right term@(IJcc ct r1 r2 i1 i2) : rest) | (r1 == d || r2 == d) && (isConstant r1 || isConstant r2) && at `elem` [AEq, ANeq, AGt, ALt, AGeq, ALeq] = let ct' = if isConstant r2 then ct else flipCmpType ct conref = if isConstant r2 then r2 else r1 (ct'', con) = case (ct', conref) of (CEq, Constant _ c) -> (CEq, if c `elem` [0, 1] then c else (-1)) (CNeq, Constant _ c) -> (CNeq, if c `elem` [0, 1] then c else (-1)) (CGt, Constant _ c) | c < 0 -> (CNeq, (-1)) | c == 0 -> (CEq, 1) | otherwise -> (CEq, (-1)) (CLt, Constant _ c) | c > 1 -> (CNeq, (-1)) | c == 1 -> (CEq, 0) | otherwise -> (CEq, (-1)) (CGeq, Constant _ c) | c <= 0 -> (CNeq, (-1)) | c == 1 -> (CEq, 1) | otherwise -> (CEq, (-1)) (CLeq, Constant _ c) | c >= 1 -> (CNeq, (-1)) | c == 0 -> (CEq, 0) | otherwise -> (CEq, (-1)) _ -> undefined resterm = case (ct'', con) of (CEq, 0) -> IJcc (invertCmpType (arithTypeToCmpType at)) s1 s2 i1 i2 (CEq, 1) -> IJcc (arithTypeToCmpType at) s1 s2 i1 i2 (CEq, _) -> IJmp i2 (CNeq, 0) -> IJcc (arithTypeToCmpType at) s1 s2 i1 i2 (CNeq, 1) -> IJcc (invertCmpType (arithTypeToCmpType at)) s1 s2 i1 i2 (CNeq, _) -> IJmp i1 _ -> undefined in Right resterm : rest | otherwise = Right term : rest propagate _ l = l flipCmpType :: CmpType -> CmpType flipCmpType CEq = CEq flipCmpType CNeq = CNeq flipCmpType CGt = CLt flipCmpType CLt = CGt flipCmpType CGeq = CLeq flipCmpType CLeq = CGeq invertCmpType :: CmpType -> CmpType invertCmpType CEq = CNeq invertCmpType CNeq = CEq invertCmpType CGt = CLeq invertCmpType CLt = CGeq invertCmpType CGeq = CLt invertCmpType CLeq = CGt arithTypeToCmpType :: ArithType -> CmpType arithTypeToCmpType AEq = CEq arithTypeToCmpType ANeq = CNeq arithTypeToCmpType AGt = CGt arithTypeToCmpType ALt = CLt arithTypeToCmpType AGeq = CGeq arithTypeToCmpType ALeq = CLeq arithTypeToCmpType _ = undefined removeUnusedInstructions :: FuncOptimisation removeUnusedInstructions (IRFunc rt name al bbs sid) = IRFunc rt name al (map goBB bbs) sid where goBB :: BB -> BB goBB (BB bid inss term) = BB bid (catMaybes $ map goI inss) term goI :: IRIns -> Maybe IRIns goI ins@(IMov d _) = pureInstruction d ins goI ins@(IStore _ _) = Just ins goI ins@(ILoad d _) = pureInstruction d ins goI ins@(IAri _ d _ _) = pureInstruction d ins goI ins@(ICall _ _) = Just ins goI ins@(ICallr _ _ _) = Just ins goI ins@(IResize d _) = pureInstruction d ins goI INop = Nothing pureInstruction :: Ref -> IRIns -> Maybe IRIns pureInstruction d ins = if length (findMentions' bbs d) == 1 then Nothing else Just ins evaluateInstructions :: FuncOptimisation evaluateInstructions (IRFunc rt name al bbs sid) = IRFunc rt name al (map goBB bbs) sid where goBB :: BB -> BB goBB (BB bid inss term) = BB bid (map goI inss) term goI :: IRIns -> IRIns goI (IAri at ref (Constant _ v1) (Constant _ v2)) = IMov ref $ Constant (refSize ref) $ truncValue (refSize ref) $ evaluateArith at v1 v2 goI (IResize ref (Constant _ v)) = IMov ref $ Constant (refSize ref) $ truncValue (refSize ref) v goI ins = ins truncValue :: Size -> Value -> Value truncValue sz v = fromIntegral $ (fromIntegral v :: Integer) `mod` (2 ^ (8 * sz)) evaluateTerminators :: FuncOptimisation evaluateTerminators (IRFunc rt name al bbs sid) = IRFunc rt name al bbs' sid where bbs' = map (\(BB bid inss term) -> BB bid inss (go term)) bbs go :: IRTerm -> IRTerm go term@(IJcc ct (Constant sza a) (Constant szb b) i1 i2) | sza /= szb = error $ "Inconsistent sizes in " ++ show term | evaluateCmp ct a b = IJmp i1 | otherwise = IJmp i2 go term = term flipJccs :: FuncOptimisation flipJccs (IRFunc rt name al bbs sid) = IRFunc rt name al (map goBB bbs) sid where goBB :: BB -> BB goBB (BB bid inss term) = BB bid inss (goT term) goT :: IRTerm -> IRTerm goT (IJcc ct r1@(Constant _ _) r2 i1 i2) = IJcc (flipCmpType ct) r2 r1 i1 i2 goT term = term insAt :: [BB] -> (Int, Int) -> IRIns insAt bbs (i, j) = let (BB _ inss _) = bbs !! i in inss !! j nopifyInsAt :: [BB] -> (Int, Int) -> [BB] nopifyInsAt bbs (i, j) = let (pre, BB bid inss term : post) = splitAt i bbs (ipre, _ : ipost) = splitAt j inss in pre ++ BB bid (ipre ++ INop : ipost) term : post findMutations :: BB -> Ref -> [Int] findMutations (BB _ inss _) ref = catMaybes $ flip map (zip inss [0..]) $ \(ins, idx) -> case ins of (IMov r _) | r == ref -> Just idx (IAri _ r _ _) | r == ref -> Just idx (ICallr r _ _) | r == ref -> Just idx (IResize r _) | r == ref -> Just idx _ -> Nothing findMutations' :: [BB] -> Ref -> [(Int, Int)] findMutations' bbs ref = [(i, j) | (bb, i) <- zip bbs [0..], j <- findMutations bb ref] findMentions :: BB -> Ref -> [Int] findMentions (BB _ inss term) ref = insres ++ termres where insres = catMaybes $ flip map (zip inss [0..]) $ \(ins, idx) -> if ref `elem` findAllRefsIns ins then Just idx else Nothing termres = if ref `elem` findAllRefsTerm term then [length inss] else [] findMentions' :: [BB] -> Ref -> [(Int, Int)] findMentions' bbs ref = [(i, j) | (bb, i) <- zip bbs [0..], j <- findMentions bb ref] findAllRefs :: BB -> [Ref] findAllRefs (BB _ inss _) = findAllRefsInss inss findAllRefsInss :: [IRIns] -> [Ref] findAllRefsInss inss = uniq $ sort $ concatMap findAllRefsIns inss findAllRefsIns :: IRIns -> [Ref] findAllRefsIns (IMov a b) = [a, b] findAllRefsIns (IStore a b) = [a, b] findAllRefsIns (ILoad a b) = [a, b] findAllRefsIns (IAri _ a b c) = [a, b, c] findAllRefsIns (ICall _ al) = al findAllRefsIns (ICallr a _ al) = a : al findAllRefsIns (IResize a b) = [a, b] findAllRefsIns INop = [] findAllRefsTerm :: IRTerm -> [Ref] findAllRefsTerm (IJcc _ a b _ _) = [a, b] findAllRefsTerm (IJmp _) = [] findAllRefsTerm IRet = [] findAllRefsTerm (IRetr a) = [a] findAllRefsTerm ITermNone = undefined -- findAllRefs' :: [BB] -> [Ref] -- findAllRefs' = uniq . sort . concatMap findAllRefs findAllTemps :: BB -> [Ref] findAllTemps bb = flip filter (findAllRefs bb) $ \ref -> case ref of (Temp _ _) -> True _ -> False findAllTemps' :: [BB] -> [Ref] findAllTemps' = concatMap findAllTemps