aboutsummaryrefslogtreecommitdiff
path: root/LifetimeAnalysis.hs
blob: aaf0c30b52816c40b3f80ed686f10b1561376f44 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
{-# LANGUAGE ScopedTypeVariables #-}

module LifetimeAnalysis(fullLifetimeAnalysis, lifetimeAnalysis, Access(..), unAccess) where

import Data.List
import Data.Maybe
import qualified Data.Map.Strict as Map
import Debug.Trace

import Utils


data Access a = Write a | Read a
  deriving (Show, Eq)

unAccess :: Access a -> a
unAccess (Write x) = x
unAccess (Read x) = x

type BB a = ([[Access a]], [Int])

data DUIO a = DUIO {defs :: [a], uses :: [a], ins :: [a], outs :: [a]}
  deriving Eq

lifetimeAnalysis :: (Eq a, Ord a) => a -> [BB a] -> [[Bool]]
lifetimeAnalysis target bbs = map (map (target `elem`)) $ fullLifetimeAnalysis bbs

fullLifetimeAnalysis :: (Eq a, Ord a) => [BB a] -> [[[a]]]
fullLifetimeAnalysis bbs =
    let duios = map fst $ eqFixpoint analysisIterator $ flip map bbs $
                    \bb@(_, nexts) -> let (d,u) = defUseAnalysis bb
                                      in (DUIO d u [] [], nexts)
    in map markLive $ zip bbs duios

markLive :: forall a. (Eq a, Ord a) => (BB a, DUIO a) -> [[a]]
markLive ((inaccs, _), duio) = init $ go fullaccs 0 (ins duio)
  where
    fullaccs = inaccs ++ [map Read (outs duio)]
    allvars = nub $ concatMap (map unAccess) fullaccs
    lastreads = Map.fromList $ map (\v -> (v, lastReadOf v)) allvars

    lastReadOf v = fromMaybe (-1) $ fmap ((length fullaccs - 1) -) $
                        findIndex (Read v `elem`) (reverse fullaccs)

    go :: (Eq a, Ord a) => [[Access a]] -> Int -> [a] -> [[a]]
    go [] _ _ = []
    go (acl : rest) i lives =
        let (ws, rs) = partitionAccess acl
            newlives = union rs $ flip filter (union ws lives) $ \v -> case Map.lookup v lastreads of
                            Nothing -> False
                            Just j -> j > i
        in lives : go rest (i+1) newlives

analysisIterator :: (Eq a, Ord a) => [(DUIO a, [Int])] -> [(DUIO a, [Int])]
analysisIterator toplist = map updateIns $ map updateOuts (zip toplist [0..])
  where
    updateIns (duio, nexts) = (duio {ins = sort $ union (uses duio) (outs duio \\ defs duio)}, nexts)
    updateOuts ((duio, nexts), idx) = (duio {outs = sort $ foldl union [] (insAfter idx)}, nexts)

    insAfter = map (ins . fst . (toplist !!)) . snd . (toplist !!)

defUseAnalysis :: Eq a => BB a -> ([a], [a])
defUseAnalysis (inss, _) = foldl foldfunc ([], []) inss
  where
    foldfunc (d, u) accs =
        let (ws, rs) = partitionAccess accs
            newds = filter (not . (`elem` u)) ws
            newus = filter (not . (`elem` d)) rs
        in (union d newds, union u newus)

partitionAccess :: [Access a] -> ([a], [a])
partitionAccess [] = ([], [])
partitionAccess (Write x : rest) = let (w, r) = partitionAccess rest in (x : w, r)
partitionAccess (Read x : rest) = let (w, r) = partitionAccess rest in (w, x : r)

eqFixpoint :: Eq a => (a -> a) -> a -> a
eqFixpoint f x = let y = f x in if y == x then x else eqFixpoint f y