summaryrefslogtreecommitdiff
path: root/ASTOptimiser.hs
blob: d577f4ca96ce4c74ee1d1d0aa76e6872f73755f7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
module ASTOptimiser (
    optimiseAST
) where

import qualified Data.Map.Strict as Map
import Data.Maybe

import Debug.Trace

import AST


optimiseAST :: Program -> Program
optimiseAST = performInlining


performInlining :: Program -> Program
performInlining (Program values) =
    Program (snd (mapState inline initState values))

data State
    -- A variable maps to Nothing if its value is unknown
    = State { sEnv :: [Map.Map Name (Maybe (Value, MetaInfo))] }
  deriving (Show)

newtype MetaInfo
    = MetaInfo { miWeight :: Int }
  deriving (Show)

initState :: State
initState = State [Map.empty]

-- Maximum weight of function to be considered for inlining
paramInlineMaxWeight :: Int
paramInlineMaxWeight = 10  -- TODO: very arbitrary

-- Maximum weight of argument to function to be inlined; the larger this
-- is, the more work we allow to be duplicated over multiple usage sites in
-- the inlined function.
-- TODO: perhaps make this dependent on how many times the value is used in
-- the function to be inlined?
paramInlineArgMaxWeight :: Int
paramInlineArgMaxWeight = 1

withScope :: State -> (State -> (State, a)) -> (State, a)
withScope state f =
    let (st', ret) = f (state { sEnv = Map.empty : sEnv state })
    in (st' { sEnv = tail (sEnv st') }, ret)

defineName :: State -> Name -> Maybe Value -> State
defineName state name mvalue =
    let desc = (\v -> (v, MetaInfo (computeWeight v))) <$> mvalue
        env' = Map.insert name desc (head (sEnv state)) : tail (sEnv state)
    in state { sEnv = env' }

lookupName :: State -> Name -> Maybe (Value, MetaInfo)
lookupName state name =
    case catMaybes (map (Map.lookup name) (sEnv state)) of
        -- A found 'Nothing' should still result in Nothing being returned
        (Just desc : _) -> Just desc
        _ -> Nothing

inline :: State -> Value -> (State, Value)
inline state origValue =
    case origValue of
        VDefine name value ->
            let (state', value') = inline state value
            in (defineName state' name (Just value'), VDefine name value')
        VList [] -> (state, origValue)
        VList (vhead:vtail)
            | all ((<= paramInlineArgMaxWeight) . computeWeight) vtail ->
                trace ("\x1B[1;31minline: argument weight test passed\x1B[0m: " ++ show origValue) $
                betaReduce . VList . (inlineReplace state vhead :) <$>
                    mapState inline state vtail
            | otherwise ->
                trace ("\x1B[1;31minline: argument weight test FAILED\x1B[0m: " ++ show origValue) $
                (state, betaReduce origValue)
        VLambda as value ->
            withScope state $ \state1 ->
                let state1' = foldl (\s n -> defineName s n Nothing) state1 as
                in VLambda as <$> inline state1' value
        VLambdaRec r as value ->
            withScope state $ \state1 ->
                -- Also mark r as unknown, since we don't want to inline recursion
                let state1' = foldl (\s n -> defineName s n Nothing) state1 (r : as)
                in VLambdaRec r as <$> inline state1' value
        VLet [] body ->
            VLet [] <$> inline state body
        VLet ((name, value) : pairs) body ->
            withScope state $ \state1 ->
                let (state1', value') = inline state1 value
                    state1'' = defineName state1' name (Just value')
                    (state1''', VLet pairs' body') = inline state1'' (VLet pairs body)
                in (state1''', VLet ((name, value') : pairs') body')
        VNum _ -> (state, origValue)
        VString _ -> (state, origValue)
        VName _ -> (state, origValue)
        VQuoted _ -> (state, origValue)
        VDeclare _ -> (state, origValue)
        VBuiltin _ -> (state, origValue)
        VEllipsis -> (state, origValue)

inlineReplace :: State -> Value -> Value
inlineReplace state origValue@(VName name) = 
    case lookupName state name of
        Just (value, meta) | miWeight meta <= paramInlineMaxWeight ->
            trace ("\x1B[1;31minline: function weight small\x1B[0m (" ++ show (miWeight meta) ++ "): " ++ name) $
            value
        Just (_, meta) ->
            trace ("\x1B[1;31minline: function weight large\x1B[0m (" ++ show (miWeight meta) ++"): " ++ name) $
            origValue
        Nothing ->
            trace ("\x1B[1;31minline: variable not found\x1B[0m: " ++ name) $
            origValue
inlineReplace _ origValue = origValue

betaReduce :: Value -> Value
betaReduce (VList (VLambda names body : values))
    | length values == length names =
        THIS IS INCORRECT
        -- TODO: THIS IS INCORRECT if the replaced value is a name that is
        -- shadowed by another variable in the function body.
        replaceNames (Map.fromList (zip names values)) body
betaReduce value = value
                

-- TODO: Very arbitrary; should perhaps be tuned
computeWeight :: Value -> Int
computeWeight (VList vs) = 1 + sum (map computeWeight vs)
computeWeight (VNum _) = 1
computeWeight (VString _) = 1
computeWeight (VName _) = 1
computeWeight (VQuoted _) = 1
computeWeight (VDeclare _) = 0
computeWeight (VDefine _ v) = computeWeight v
computeWeight (VLambda _ v) = 2 + computeWeight v
computeWeight (VLambdaRec _ _ v) = 2 + computeWeight v
computeWeight (VLet ds v) = sum (map computeWeight (v : map snd ds))
computeWeight (VBuiltin _) = 1
computeWeight VEllipsis = 0

mapState :: (State -> a -> (State, b)) -> State -> [a] -> (State, [b])
mapState _ state [] = (state, [])
mapState f state (x:xs) = let (state', y) = f state x
                          in fmap (y :) (mapState f state' xs)