module ASTOptimiser ( optimiseAST ) where import qualified Data.Map.Strict as Map import Data.Maybe import Debug.Trace import AST optimiseAST :: Program -> Program optimiseAST = performInlining performInlining :: Program -> Program performInlining (Program values) = Program (snd (mapState inline initState values)) data State -- A variable maps to Nothing if its value is unknown = State { sEnv :: [Map.Map Name (Maybe (Value, MetaInfo))] } deriving (Show) newtype MetaInfo = MetaInfo { miWeight :: Int } deriving (Show) initState :: State initState = State [Map.empty] -- Maximum weight of function to be considered for inlining paramInlineMaxWeight :: Int paramInlineMaxWeight = 10 -- TODO: very arbitrary -- Maximum weight of argument to function to be inlined; the larger this -- is, the more work we allow to be duplicated over multiple usage sites in -- the inlined function. -- TODO: perhaps make this dependent on how many times the value is used in -- the function to be inlined? paramInlineArgMaxWeight :: Int paramInlineArgMaxWeight = 1 withScope :: State -> (State -> (State, a)) -> (State, a) withScope state f = let (st', ret) = f (state { sEnv = Map.empty : sEnv state }) in (st' { sEnv = tail (sEnv st') }, ret) defineName :: State -> Name -> Maybe Value -> State defineName state name mvalue = let desc = (\v -> (v, MetaInfo (computeWeight v))) <$> mvalue env' = Map.insert name desc (head (sEnv state)) : tail (sEnv state) in state { sEnv = env' } lookupName :: State -> Name -> Maybe (Value, MetaInfo) lookupName state name = case catMaybes (map (Map.lookup name) (sEnv state)) of -- A found 'Nothing' should still result in Nothing being returned (Just desc : _) -> Just desc _ -> Nothing inline :: State -> Value -> (State, Value) inline state origValue = case origValue of VDefine name value -> let (state', value') = inline state value in (defineName state' name (Just value'), VDefine name value') VList [] -> (state, origValue) VList (vhead:vtail) | all ((<= paramInlineArgMaxWeight) . computeWeight) vtail -> trace ("\x1B[1;31minline: argument weight test passed\x1B[0m: " ++ show origValue) $ betaReduce . VList . (inlineReplace state vhead :) <$> mapState inline state vtail | otherwise -> trace ("\x1B[1;31minline: argument weight test FAILED\x1B[0m: " ++ show origValue) $ (state, betaReduce origValue) VLambda as value -> withScope state $ \state1 -> let state1' = foldl (\s n -> defineName s n Nothing) state1 as in VLambda as <$> inline state1' value VLambdaRec r as value -> withScope state $ \state1 -> -- Also mark r as unknown, since we don't want to inline recursion let state1' = foldl (\s n -> defineName s n Nothing) state1 (r : as) in VLambdaRec r as <$> inline state1' value VLet [] body -> VLet [] <$> inline state body VLet ((name, value) : pairs) body -> withScope state $ \state1 -> let (state1', value') = inline state1 value state1'' = defineName state1' name (Just value') (state1''', VLet pairs' body') = inline state1'' (VLet pairs body) in (state1''', VLet ((name, value') : pairs') body') VNum _ -> (state, origValue) VString _ -> (state, origValue) VName _ -> (state, origValue) VQuoted _ -> (state, origValue) VDeclare _ -> (state, origValue) VBuiltin _ -> (state, origValue) VEllipsis -> (state, origValue) inlineReplace :: State -> Value -> Value inlineReplace state origValue@(VName name) = case lookupName state name of Just (value, meta) | miWeight meta <= paramInlineMaxWeight -> trace ("\x1B[1;31minline: function weight small\x1B[0m (" ++ show (miWeight meta) ++ "): " ++ name) $ value Just (_, meta) -> trace ("\x1B[1;31minline: function weight large\x1B[0m (" ++ show (miWeight meta) ++"): " ++ name) $ origValue Nothing -> trace ("\x1B[1;31minline: variable not found\x1B[0m: " ++ name) $ origValue inlineReplace _ origValue = origValue betaReduce :: Value -> Value betaReduce (VList (VLambda names body : values)) | length values == length names = THIS IS INCORRECT -- TODO: THIS IS INCORRECT if the replaced value is a name that is -- shadowed by another variable in the function body. replaceNames (Map.fromList (zip names values)) body betaReduce value = value -- TODO: Very arbitrary; should perhaps be tuned computeWeight :: Value -> Int computeWeight (VList vs) = 1 + sum (map computeWeight vs) computeWeight (VNum _) = 1 computeWeight (VString _) = 1 computeWeight (VName _) = 1 computeWeight (VQuoted _) = 1 computeWeight (VDeclare _) = 0 computeWeight (VDefine _ v) = computeWeight v computeWeight (VLambda _ v) = 2 + computeWeight v computeWeight (VLambdaRec _ _ v) = 2 + computeWeight v computeWeight (VLet ds v) = sum (map computeWeight (v : map snd ds)) computeWeight (VBuiltin _) = 1 computeWeight VEllipsis = 0 mapState :: (State -> a -> (State, b)) -> State -> [a] -> (State, [b]) mapState _ state [] = (state, []) mapState f state (x:xs) = let (state', y) = f state x in fmap (y :) (mapState f state' xs)