aboutsummaryrefslogtreecommitdiff
path: root/src/Haskell/Rewrite.hs
blob: c33f62ee8be5ba4eb3837971bbd60cf4e0cb404c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
module Haskell.Rewrite
    (rewrite
    ,betared, etared, casered
    ,etacase, casecase
    ,autoSimp
    ,normalise
    ,fixpoint) where

import Control.Monad
import Data.List
import Data.Maybe
import qualified Data.Map.Strict as Map
import Haskell.AST


rewrite :: Name -> Expr -> Expr -> Expr
rewrite name repl = \case
    App e as -> App (rewrite name repl e) (map (rewrite name repl) as)
    Ref n -> if n == name then repl else Ref n
    Num k -> Num k
    Tup es -> Tup (map (rewrite name repl) es)
    Lam ns e -> if name `elem` ns then Lam ns e else Lam ns (rewrite name repl e)
    Case e as -> Case (rewrite name repl e) (map caseArm as)
  where
    caseArm (p, e') =
        if name `elem` boundVars p then (p, e') else (p, rewrite name repl e')

boundVars :: Pat -> [Name]
boundVars PatAny = []
boundVars (PatVar n) = [n]
boundVars (PatCon n ps) = nub $ [n] ++ concatMap boundVars ps
boundVars (PatTup ps) = nub $ concatMap boundVars ps

betared :: Expr -> Expr
betared (App (Lam (n:as) bd) (arg:args)) =
    App (Lam as (rewrite n arg bd)) args
betared e = recurse id betared e

etared :: Expr -> Expr
etared (Lam (n:as) (App e es@(_:_)))
    | last es == Ref n =
        Lam as (App e (init es))
etared e = recurse id etared e

casered :: Bool -> Expr -> Expr
casered ambig orig@(Case subj arms) =
    case catMaybes [(,rhs) <$> unify p subj | (p, rhs) <- arms] of
        [] -> recurse id (casered ambig) orig
        ((mp, rhs):rest) ->
            let res = foldl' (\e' (n, rhs') -> rewrite n rhs' e') rhs (Map.assocs mp)
            in case (ambig, rest) of
                   (True, _) -> res
                   (False, []) -> res
                   (False, _) -> recurse id (casered ambig) orig
casered ambig e = recurse id (casered ambig) e

etacase :: Expr -> Expr
etacase (Case subj arms) | all (uncurry eqPE) arms = subj
etacase e = recurse id etacase e

casecase :: Expr -> Expr
casecase (Case (Case subj arms1) arms2) =
    Case subj [(p, Case e arms2) | (p, e) <- arms1]
casecase e = recurse id casecase e

autoSimp :: Expr -> Expr
autoSimp expr =
    let steps = [betared, casered False, etared, etacase, casecase]
    in fixpoint (normalise . foldl1 (.) (intersperse normalise steps)) expr

eqPE :: Pat -> Expr -> Bool
eqPE pat expr = case unify pat expr of
    Nothing -> False
    Just mp -> all (uncurry isIdmap) (Map.assocs mp)
  where
    isIdmap n = (== Ref n)

unify :: Pat -> Expr -> Maybe (Map.Map Name Expr)
unify PatAny _ = Just Map.empty
unify (PatVar n) e = Just (Map.singleton n e)
unify (PatCon n ps) (App n' es)
    | n' == Ref n, length ps == length es =
        foldM (\m (p, e) -> unify p e >>= reconcile m) Map.empty (zip ps es)
unify (PatTup ps) (Tup es)
    | length ps == length es =
        foldM (\m (p, e) -> unify p e >>= reconcile m) Map.empty (zip ps es)
unify _ _ = Nothing

reconcile :: Map.Map Name Expr -> Map.Map Name Expr -> Maybe (Map.Map Name Expr)
reconcile m1 m2 = foldM func m1 (Map.assocs m2)
  where func m (k, v) = case Map.lookup k m of
            Nothing -> Just (Map.insert k v m)
            Just v' | v == v'   -> Just m
                    | otherwise -> Nothing

normalise :: Expr -> Expr
normalise (App e []) = normalise e
normalise (Lam [] e) = normalise e
normalise e = recurse id normalise e

recurse :: (Pat -> Pat) -> (Expr -> Expr) -> Expr -> Expr
recurse _  f (App e as) = App (f e) (map f as)
recurse _  _ (Ref n) = Ref n
recurse _  _ (Num k) = Num k
recurse _  f (Tup es) = Tup (map f es)
recurse _  f (Lam ns e) = Lam ns (f e)
recurse fp f (Case e as) = Case (f e) (map (\(p, e') -> (fp p, f e')) as)

fixpoint :: Eq a => (a -> a) -> a -> a
fixpoint f initVal =
    let values = iterate f initVal
        pairs = zip values (tail values)
    in fst . head $ dropWhile (uncurry (/=)) pairs