1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
|
module Haskell.Rewrite
(rewrite
,betared, etared, casered
,etacase, casecase
,autoSimp
,normalise) where
import Control.Monad
import Data.List
import Data.Maybe
import qualified Data.Map.Strict as Map
import Haskell.AST
import Util
rewrite :: Name -> Expr -> Expr -> Expr
rewrite name repl = \case
App e as ty -> App (rewrite name repl e) (map (rewrite name repl) as) ty
Ref n ty -> if n == name then repl else Ref n ty
Num k ty -> Num k ty
Tup es ty -> Tup (map (rewrite name repl) es) ty
Lam ns e ty -> if name `elem` ns then Lam ns e ty else Lam ns (rewrite name repl e) ty
Case e as ty -> Case (rewrite name repl e) (map caseArm as) ty
where
caseArm (p, e') =
if name `elem` boundVars p then (p, e') else (p, rewrite name repl e')
boundVars :: Pat -> [Name]
boundVars PatAny = []
boundVars (PatVar n) = [n]
boundVars (PatCon n ps) = nub $ [n] ++ concatMap boundVars ps
boundVars (PatTup ps) = nub $ concatMap boundVars ps
betared :: Expr -> Expr
betared (App (Lam (n:as) bd lty) (arg:args) ty) =
App (Lam as
(rewrite n arg bd)
(Just $ fromJust $ typeApply (fromJust lty) (typeOf arg)))
args ty
betared e = recurse id betared e
etared :: Expr -> Expr
etared (Lam (n:as) (App e es@(_:_) aty) ty)
| Ref n' _ <- last es, n == n' =
Lam as (App e (init es) aty) ty
etared e = recurse id etared e
casered :: Bool -> Expr -> Expr
casered ambig orig@(Case subj arms _) =
case catMaybes [(,rhs) <$> unify p subj | (p, rhs) <- arms] of
[] -> recurse id (casered ambig) orig
((mp, rhs):rest) ->
let res = foldl' (\e' (n, rhs') -> rewrite n rhs' e') rhs (Map.assocs mp)
in case (ambig, rest) of
(True, _) -> res
(False, []) -> res
(False, _) -> recurse id (casered ambig) orig
casered ambig e = recurse id (casered ambig) e
etacase :: Expr -> Expr
etacase (Case subj arms _) | all (uncurry eqPE) arms = subj
etacase e = recurse id etacase e
casecase :: Expr -> Expr
casecase (Case (Case subj arms1 _) arms2 ty) =
Case subj [(p, Case e arms2 ty) | (p, e) <- arms1] ty
casecase e = recurse id casecase e
autoSimp :: Expr -> Expr
autoSimp expr =
let steps = [betared, casered False, etared, etacase, casecase]
in fixpoint (normalise . foldl1 (.) (intersperse normalise steps)) expr
eqPE :: Pat -> Expr -> Bool
eqPE pat expr = case unify pat expr of
Nothing -> False
Just mp -> all (uncurry isIdmap) (Map.assocs mp)
where
isIdmap n (Ref n' _) = n == n'
isIdmap _ _ = False
unify :: Pat -> Expr -> Maybe (Map.Map Name Expr)
unify PatAny _ = Just Map.empty
unify (PatVar n) e = Just (Map.singleton n e)
unify (PatCon n ps) (App (Ref n' _) es _)
| n == n', length ps == length es =
foldM (\m (p, e) -> unify p e >>= reconcile m) Map.empty (zip ps es)
unify (PatTup ps) (Tup es _)
| length ps == length es =
foldM (\m (p, e) -> unify p e >>= reconcile m) Map.empty (zip ps es)
unify _ _ = Nothing
reconcile :: Map.Map Name Expr -> Map.Map Name Expr -> Maybe (Map.Map Name Expr)
reconcile m1 m2 = foldM func m1 (Map.assocs m2)
where func m (k, v) = case Map.lookup k m of
Nothing -> Just (Map.insert k v m)
Just v' | v == v' -> Just m
| otherwise -> Nothing
normalise :: Expr -> Expr
normalise (App e [] _) = normalise e
normalise (Lam [] e _) = normalise e
normalise e = recurse id normalise e
|