aboutsummaryrefslogtreecommitdiff
path: root/RegAlloc.hs
blob: 3a41aac8688b076db379fd9c07b15ebd36d93e8a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
module RegAlloc(regalloc, Allocation(..)) where

import Data.Function
import Data.List
import Data.Maybe
import qualified Data.Map.Strict as Map
import Debug.Trace

import Utils


-- Follows the Linear Scan Register Allocation algorithm specified in:
-- https://www.cs.purdue.edu/homes/suresh/502-Fall2008/papers/linear-scan.pdf

-- Test case:
-- regalloc [((1, 3), "a"), ((2, 2), "x"), ((2, 10), "b"), ((3, 7), "c"), ((4, 5), "d")] ["ra", "rb"]


data Allocation a = AllocReg a | AllocMem
  deriving Show

type Interval = (Int, Int)

data State a = State {stActive :: [Int], stAlloc :: [Allocation a], stFreeRegs :: [a]}
  deriving Show

regalloc :: (Show a, Show b, Ord a, Ord b)
         => [(Interval, b)]           -- [(live interval, name of variable)]
         -> [a]                       -- the available registers
         -> [(b, b)]                  -- pairs to be allocated to the same register if possible
         -> Map.Map b (Allocation a)  -- allocation map
regalloc vars' regs aliaspairs =
    let foldfunc = \st' ((int, name), index) ->
            let st = expireOldIntervals st' int
                wanted = findWantedAliases aliaspairs name
                wantedregs = uniq $ sort $ catMaybes $ flip map wanted $ \n ->
                                case findIndex (== n) intnames of
                                    Nothing -> Nothing
                                    Just idx | idx >= length (stAlloc st) -> Nothing
                                             | otherwise -> case stAlloc st !! idx of
                                        AllocMem -> Nothing
                                        AllocReg r -> Just r
            in if length (stActive st) == length regs
                   then spillAtInterval st index
                   else let (regchoice, fr) = case find (`elem` wantedregs) (stFreeRegs st) of
                                Nothing -> (head (stFreeRegs st), tail (stFreeRegs st))
                                Just wr -> trace ("Pair-allocated " ++ show name ++ " in " ++ show wr) $
                                           (wr, stFreeRegs st \\ [wr])
                            allocrev = stAlloc st ++ [AllocReg regchoice]
                            active = sortBy (compare `on` snd . (ints !!)) $ index : stActive st
                        in State active allocrev fr
    in Map.fromList $ zip intnames $ stAlloc $
               foldl foldfunc (State [] [] regs) (zip vars [0..])
  where
    vars = sortBy (compare `on` fst . fst) vars'
    (ints, intnames) = (map fst vars, map snd vars)

    expireOldIntervals :: State a -> Interval -> State a
    expireOldIntervals st (intstart, _) =
        let (dropped, active) = span ((< intstart) . snd . (ints !!)) (stActive st)
            fr = selectAllocRegs (map (stAlloc st !!) dropped) ++ stFreeRegs st
        in State active (stAlloc st) fr

    spillAtInterval :: State a -> Int -> State a
    spillAtInterval st index =
        let spill = last (stActive st)
        in if snd (ints !! spill) > snd (ints !! index)
               then let alloc = setAt spill AllocMem (stAlloc st) ++ [stAlloc st !! spill]
                        active = sortBy (compare `on` snd . (ints !!)) $
                                     index : (stActive st \\ [spill])
                    in State active alloc (stFreeRegs st)
               else State (stActive st) (stAlloc st ++ [AllocMem]) (stFreeRegs st)

    findWantedAliases :: (Ord b) => [(b, b)] -> b -> [b]
    findWantedAliases pairs x =
        uniq $ sort $ map snd (filter ((== x) . fst) pairs) ++ map fst (filter ((== x) . snd) pairs)

selectAllocRegs :: [Allocation a] -> [a]
selectAllocRegs allocs = catMaybes $ flip map allocs $ \alloc -> case alloc of
        (AllocReg r) -> Just r
        AllocMem -> Nothing

setAt :: Int -> a -> [a] -> [a]
setAt idx v l =
    let (pre, _ : post) = splitAt idx l
    in pre ++ v : post