aboutsummaryrefslogtreecommitdiff
path: root/src/Haskell/Rewrite.hs
blob: 2e3f6c55ff39b8047a85f208a1a562507ece773e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
{-# LANGUAGE TupleSections #-}
module Haskell.Rewrite
    (rewrite
    ,betared, etared, casered
    ,etacase, casecase
    ,autoSimp
    ,normalise) where

import Control.Monad
import Data.List
import Data.Maybe
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set

import Haskell.AST
import Util


rewrite :: Name -> Expr -> Expr -> Expr
rewrite target repl topexpr = fst (rewrite' mempty topexpr)
  where
    -- When moving into the subexpression E under a binding for some variable
    -- 'x', if 'x' is free in 'repl', we need to alpha-rename 'x' to something
    -- that is:
    -- - Not free in 'repl';
    -- - Not free in E;
    -- - Not equal to a name 'y' bound in E if the subexpression under the
    --   binder for 'y' has an occurrence of 'x'.
    -- If we strengthen the third requirement to prohibit all bound variables
    -- in E, the second and third together mean "all variables in E". This is
    -- what we will use.
    frees :: Set.Set Name
    frees = freeVariables repl

    -- Returns rewritten expression, and whether any actual rewrites were
    -- performed. This allows preventing unnecessary alpha-renames.
    rewrite' :: Map.Map Name Name -> Expr -> (Expr, Bool)
    rewrite' mapping = \case
        App e as -> let (e' : as', bs) = unzip (map (rewrite' mapping) (e : as))
                    in (App e' as', or bs)
        Ref n -> case Map.lookup n mapping of
                     Just n' -> (Ref n', False)  -- renamed variable cannot be target
                     Nothing | n == target -> (repl, True)
                             | otherwise   -> (Ref n, False)
        Con n -> (Con n, False)
        Num k -> (Num k, False)
        Tup es -> let (es', bs) = unzip (map (rewrite' mapping) es)
                  in (Tup es', or bs)
        Lam ns e
          | target `elem` ns -> (Lam ns e, False)
          | otherwise ->
              let forbidden = frees <> allVars e <> Set.fromList ns
                  -- Note that Map.<> is left-preferring
                  mapping' = freshenL ns frees forbidden <> mapping
                  ns' = map (rename mapping') ns
                  (e', b) = rewrite' mapping' e
              in if b then (Lam ns' e', True) else (Lam ns e, False)
        Case e as ->
            let (scrutinee, b1) = rewrite' mapping e
                (arms, bs) =
                    unzip [if target `elem` ns
                               then ((p, e'), False)
                               else let forbidden = frees <> allVars e' <> Set.fromList ns
                                        mapping' = freshenL ns frees forbidden <> mapping
                                        (rhs, b) = rewrite' mapping' e'
                                    in if b
                                           then ((renamePat mapping' p, rhs), True)
                                           else ((p, e'), False)
                          | (p, e') <- as
                          , let ns = Set.toList (boundVars p)]
            in (Case scrutinee arms, b1 || or bs)

    -- Freshens all variables found in 'vars' to something that is not in
    -- 'bnd', returning a replace mapping. Later names override earlier ones
    -- in the mapping.
    freshenL ns vars bnd =
        foldl (\mp n ->
                  if n `Set.member` vars
                      then Map.insert n (freshName bnd n) mp
                      else mp)
              mempty ns

    -- Finds a fresh name for 'n' that is not in 'bnd'.
    freshName bnd n =
        head [n ++ "_R_" ++ show i
             | i <- [1::Int ..]
             , let n' = n ++ show i
             , n' `Set.notMember` bnd]

    renamePat mapping = \case
        PatAny -> PatAny
        PatVar n -> PatVar (rename mapping n)
        PatCon n ps -> PatCon (rename mapping n) (map (renamePat mapping) ps)
        PatTup ps -> PatTup (map (renamePat mapping) ps)

    rename :: Map.Map Name Name -> Name -> Name
    rename mapping n = fromMaybe n (Map.lookup n mapping)

betared :: Expr -> Expr
betared (App (Lam (n:as) bd) (arg:args)) =
    App (rewrite n arg (Lam as bd)) args
betared e = recurse id betared e

etared :: Expr -> Expr
etared (Lam (n:as) (App e es@(_:_)))
    | last es == Ref n =
        Lam as (App e (init es))
etared e = recurse id etared e

casered :: Bool -> Expr -> Expr
casered ambig orig@(Case subj arms) =
    case catMaybes [(,rhs) <$> unify p (casered ambig subj) | (p, rhs) <- arms] of
        [] -> recurse id (casered ambig) orig
        ((mp, rhs):rest) ->
            let res = foldl' (\e' (n, rhs') -> rewrite n rhs' e') rhs (Map.assocs mp)
            in case (ambig, rest) of
                   (True, _) -> res
                   (False, []) -> res
                   (False, _) -> recurse id (casered ambig) orig
casered ambig e = recurse id (casered ambig) e

etacase :: Expr -> Expr
etacase (Case subj arms) | all (uncurry eqPE) arms = subj
etacase e = recurse id etacase e

casecase :: Expr -> Expr
casecase (Case (Case subj arms1) arms2) =
    Case subj [(p, Case e arms2) | (p, e) <- arms1]
casecase e = recurse id casecase e

autoSimp :: Expr -> Expr
autoSimp expr =
    let steps = [betared, casered False, etared, etacase, casecase]
    in fixpoint (normalise . foldl1 (.) (intersperse normalise steps)) expr

eqPE :: Pat -> Expr -> Bool
eqPE pat expr = case unify pat expr of
    Nothing -> False
    Just mp -> all (uncurry isIdmap) (Map.assocs mp)
  where
    isIdmap n = (== Ref n)

unify :: Pat -> Expr -> Maybe (Map.Map Name Expr)
unify PatAny _ = Just Map.empty
unify (PatVar n) e = Just (Map.singleton n e)
unify (PatCon n ps) (App n' es)
    | n' == Con n, length ps == length es =
        foldM (\m (p, e) -> unify p e >>= reconcile m) Map.empty (zip ps es)
unify (PatCon n []) (Con n')
    | n == n' = Just Map.empty
unify (PatTup ps) (Tup es)
    | length ps == length es =
        foldM (\m (p, e) -> unify p e >>= reconcile m) Map.empty (zip ps es)
unify _ _ = Nothing

reconcile :: Map.Map Name Expr -> Map.Map Name Expr -> Maybe (Map.Map Name Expr)
reconcile m1 m2 = foldM func m1 (Map.assocs m2)
  where func m (k, v) = case Map.lookup k m of
            Nothing -> Just (Map.insert k v m)
            Just v' | v == v'   -> Just m
                    | otherwise -> Nothing

normalise :: Expr -> Expr
normalise (App e []) = normalise e
normalise (Lam [] e) = normalise e
normalise e = recurse id normalise e

recurse :: (Pat -> Pat) -> (Expr -> Expr) -> Expr -> Expr
recurse _  f (App e as) = App (f e) (map f as)
recurse _  _ (Ref n) = Ref n
recurse _  _ (Con n) = Con n
recurse _  _ (Num k) = Num k
recurse _  f (Tup es) = Tup (map f es)
recurse _  f (Lam ns e) = Lam ns (f e)
recurse fp f (Case e as) = Case (f e) (map (\(p, e') -> (fp p, f e')) as)