module RegAlloc(regalloc, Allocation(..)) where import Data.Function import Data.List import Data.Maybe import qualified Data.Map.Strict as Map import Debug.Trace import Utils -- Follows the Linear Scan Register Allocation algorithm specified in: -- https://www.cs.purdue.edu/homes/suresh/502-Fall2008/papers/linear-scan.pdf -- Test case: -- regalloc [((1, 3), "a"), ((2, 2), "x"), ((2, 10), "b"), ((3, 7), "c"), ((4, 5), "d")] ["ra", "rb"] data Allocation a = AllocReg a | AllocMem deriving Show type Interval = (Int, Int) data State a = State {stActive :: [Int], stAlloc :: [Allocation a], stFreeRegs :: [a]} deriving Show regalloc :: (Show a, Show b, Ord a, Ord b) => [(Interval, b)] -- [(live interval, name of variable)] -> [a] -- the available registers -> [(b, b)] -- pairs to be allocated to the same register if possible -> Map.Map b (Allocation a) -- allocation map regalloc vars' regs aliaspairs = let foldfunc = \st' ((int, name), index) -> let st = expireOldIntervals st' int wanted = findWantedAliases aliaspairs name wantedregs = uniq $ sort $ catMaybes $ flip map wanted $ \n -> case findIndex (== n) intnames of Nothing -> Nothing Just idx | idx >= length (stAlloc st) -> Nothing | otherwise -> case stAlloc st !! idx of AllocMem -> Nothing AllocReg r -> Just r in if length (stActive st) == length regs then spillAtInterval st index else let -- ([regchoice], fr) = splitAt 1 (stFreeRegs st) (regchoice, fr) = case find (`elem` wantedregs) (stFreeRegs st) of Nothing -> (head (stFreeRegs st), tail (stFreeRegs st)) Just wr -> trace ("Pair-allocated " ++ show name ++ " in " ++ show wr) $ (wr, stFreeRegs st \\ [wr]) allocrev = stAlloc st ++ [AllocReg regchoice] active = sortBy (compare `on` snd . (ints !!)) $ index : stActive st in State active allocrev fr in Map.fromList $ zip intnames $ stAlloc $ foldl foldfunc (State [] [] regs) (zip vars [0..]) where vars = sortBy (compare `on` fst . fst) vars' (ints, intnames) = (map fst vars, map snd vars) expireOldIntervals :: State a -> Interval -> State a expireOldIntervals st (intstart, _) = let (dropped, active) = span ((< intstart) . snd . (ints !!)) (stActive st) fr = selectAllocRegs (map (stAlloc st !!) dropped) ++ stFreeRegs st in State active (stAlloc st) fr spillAtInterval :: State a -> Int -> State a spillAtInterval st index = let spill = last (stActive st) in if snd (ints !! spill) > snd (ints !! index) then let alloc = setAt spill AllocMem (stAlloc st) ++ [stAlloc st !! spill] active = sortBy (compare `on` snd . (ints !!)) $ index : (stActive st \\ [spill]) in State active alloc (stFreeRegs st) else State (stActive st) (stAlloc st ++ [AllocMem]) (stFreeRegs st) findWantedAliases :: (Ord b) => [(b, b)] -> b -> [b] findWantedAliases pairs x = uniq $ sort $ map snd (filter ((== x) . fst) pairs) ++ map fst (filter ((== x) . snd) pairs) selectAllocRegs :: [Allocation a] -> [a] selectAllocRegs allocs = catMaybes $ flip map allocs $ \alloc -> case alloc of (AllocReg r) -> Just r AllocMem -> Nothing setAt :: Int -> a -> [a] -> [a] setAt idx v l = let (pre, _ : post) = splitAt idx l in pre ++ v : post