aboutsummaryrefslogtreecommitdiff
path: root/src/Haskell/Rewrite.hs
blob: d53eb6fa2f7f106e758f3ec8b9e0667fb55dd979 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
module Haskell.Rewrite
    (rewrite
    ,betared, etared, casered
    ,etacase, casecase
    ,normalise) where

import Control.Monad
import Data.List
import Data.Maybe
import qualified Data.Map.Strict as Map
import Haskell.AST


rewrite :: Name -> Expr -> Expr -> Expr
rewrite name repl = \case
    App e as -> App (rewrite name repl e) (map (rewrite name repl) as)
    Ref n -> if n == name then repl else Ref n
    Num k -> Num k
    Tup es -> Tup (map (rewrite name repl) es)
    Lam ns e -> if name `elem` ns then Lam ns e else Lam ns (rewrite name repl e)
    Case e as -> Case (rewrite name repl e) (map caseArm as)
  where
    caseArm (p, e') =
        if name `elem` boundVars p then (p, e') else (p, rewrite name repl e')

boundVars :: Pat -> [Name]
boundVars PatAny = []
boundVars (PatVar n) = [n]
boundVars (PatCon n ps) = nub $ [n] ++ concatMap boundVars ps
boundVars (PatTup ps) = nub $ concatMap boundVars ps

betared :: Expr -> Expr
betared (App (Lam (n:as) bd) (arg:args)) =
    App (Lam as (rewrite n arg bd)) args
betared e = recurse id betared e

etared :: Expr -> Expr
etared (Lam (n:as) (App e es@(_:_)))
    | last es == Ref n =
        Lam as (App e (init es))
etared e = recurse id etared e

casered :: Bool -> Expr -> Expr
casered ambig orig@(Case subj arms) =
    case catMaybes [(,rhs) <$> unify p subj | (p, rhs) <- arms] of
        [] -> recurse id (casered ambig) orig
        ((mp, rhs):rest) ->
            let res = foldl' (\e' (n, rhs') -> rewrite n rhs' e') rhs (Map.assocs mp)
            in case (ambig, rest) of
                   (True, _) -> res
                   (False, []) -> res
                   (False, _) -> recurse id (casered ambig) orig
casered ambig e = recurse id (casered ambig) e

etacase :: Expr -> Expr
etacase (Case subj arms) | all (uncurry eqPE) arms = subj
etacase e = recurse id etacase e

casecase :: Expr -> Expr
casecase (Case (Case subj arms1) arms2) =
    Case subj [(p, Case e arms2) | (p, e) <- arms1]
casecase e = recurse id casecase e

eqPE :: Pat -> Expr -> Bool
eqPE pat expr = case unify pat expr of
    Nothing -> False
    Just mp -> all (uncurry isIdmap) (Map.assocs mp)
  where
    isIdmap n = (== Ref n)

unify :: Pat -> Expr -> Maybe (Map.Map Name Expr)
unify PatAny _ = Just Map.empty
unify (PatVar n) e = Just (Map.singleton n e)
unify (PatCon n ps) (App n' es)
    | n' == Ref n, length ps == length es =
        foldM (\m (p, e) -> unify p e >>= reconcile m) Map.empty (zip ps es)
unify (PatTup ps) (Tup es)
    | length ps == length es =
        foldM (\m (p, e) -> unify p e >>= reconcile m) Map.empty (zip ps es)
unify _ _ = Nothing

reconcile :: Map.Map Name Expr -> Map.Map Name Expr -> Maybe (Map.Map Name Expr)
reconcile m1 m2 = foldM func m1 (Map.assocs m2)
  where func m (k, v) = case Map.lookup k m of
            Nothing -> Just (Map.insert k v m)
            Just v' | v == v'   -> Just m
                    | otherwise -> Nothing

normalise :: Expr -> Expr
normalise (App e []) = normalise e
normalise (Lam [] e) = normalise e
normalise e = recurse id normalise e

recurse :: (Pat -> Pat) -> (Expr -> Expr) -> Expr -> Expr
recurse _  f (App e as) = App (f e) (map f as)
recurse _  _ (Ref n) = Ref n
recurse _  _ (Num k) = Num k
recurse _  f (Tup es) = Tup (map f es)
recurse _  f (Lam ns e) = Lam ns (f e)
recurse fp f (Case e as) = Case (f e) (map (\(p, e') -> (fp p, f e')) as)