aboutsummaryrefslogtreecommitdiff
path: root/src/Haskell/Rewrite.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Haskell/Rewrite.hs')
-rw-r--r--src/Haskell/Rewrite.hs100
1 files changed, 100 insertions, 0 deletions
diff --git a/src/Haskell/Rewrite.hs b/src/Haskell/Rewrite.hs
new file mode 100644
index 0000000..d53eb6f
--- /dev/null
+++ b/src/Haskell/Rewrite.hs
@@ -0,0 +1,100 @@
+module Haskell.Rewrite
+ (rewrite
+ ,betared, etared, casered
+ ,etacase, casecase
+ ,normalise) where
+
+import Control.Monad
+import Data.List
+import Data.Maybe
+import qualified Data.Map.Strict as Map
+import Haskell.AST
+
+
+rewrite :: Name -> Expr -> Expr -> Expr
+rewrite name repl = \case
+ App e as -> App (rewrite name repl e) (map (rewrite name repl) as)
+ Ref n -> if n == name then repl else Ref n
+ Num k -> Num k
+ Tup es -> Tup (map (rewrite name repl) es)
+ Lam ns e -> if name `elem` ns then Lam ns e else Lam ns (rewrite name repl e)
+ Case e as -> Case (rewrite name repl e) (map caseArm as)
+ where
+ caseArm (p, e') =
+ if name `elem` boundVars p then (p, e') else (p, rewrite name repl e')
+
+boundVars :: Pat -> [Name]
+boundVars PatAny = []
+boundVars (PatVar n) = [n]
+boundVars (PatCon n ps) = nub $ [n] ++ concatMap boundVars ps
+boundVars (PatTup ps) = nub $ concatMap boundVars ps
+
+betared :: Expr -> Expr
+betared (App (Lam (n:as) bd) (arg:args)) =
+ App (Lam as (rewrite n arg bd)) args
+betared e = recurse id betared e
+
+etared :: Expr -> Expr
+etared (Lam (n:as) (App e es@(_:_)))
+ | last es == Ref n =
+ Lam as (App e (init es))
+etared e = recurse id etared e
+
+casered :: Bool -> Expr -> Expr
+casered ambig orig@(Case subj arms) =
+ case catMaybes [(,rhs) <$> unify p subj | (p, rhs) <- arms] of
+ [] -> recurse id (casered ambig) orig
+ ((mp, rhs):rest) ->
+ let res = foldl' (\e' (n, rhs') -> rewrite n rhs' e') rhs (Map.assocs mp)
+ in case (ambig, rest) of
+ (True, _) -> res
+ (False, []) -> res
+ (False, _) -> recurse id (casered ambig) orig
+casered ambig e = recurse id (casered ambig) e
+
+etacase :: Expr -> Expr
+etacase (Case subj arms) | all (uncurry eqPE) arms = subj
+etacase e = recurse id etacase e
+
+casecase :: Expr -> Expr
+casecase (Case (Case subj arms1) arms2) =
+ Case subj [(p, Case e arms2) | (p, e) <- arms1]
+casecase e = recurse id casecase e
+
+eqPE :: Pat -> Expr -> Bool
+eqPE pat expr = case unify pat expr of
+ Nothing -> False
+ Just mp -> all (uncurry isIdmap) (Map.assocs mp)
+ where
+ isIdmap n = (== Ref n)
+
+unify :: Pat -> Expr -> Maybe (Map.Map Name Expr)
+unify PatAny _ = Just Map.empty
+unify (PatVar n) e = Just (Map.singleton n e)
+unify (PatCon n ps) (App n' es)
+ | n' == Ref n, length ps == length es =
+ foldM (\m (p, e) -> unify p e >>= reconcile m) Map.empty (zip ps es)
+unify (PatTup ps) (Tup es)
+ | length ps == length es =
+ foldM (\m (p, e) -> unify p e >>= reconcile m) Map.empty (zip ps es)
+unify _ _ = Nothing
+
+reconcile :: Map.Map Name Expr -> Map.Map Name Expr -> Maybe (Map.Map Name Expr)
+reconcile m1 m2 = foldM func m1 (Map.assocs m2)
+ where func m (k, v) = case Map.lookup k m of
+ Nothing -> Just (Map.insert k v m)
+ Just v' | v == v' -> Just m
+ | otherwise -> Nothing
+
+normalise :: Expr -> Expr
+normalise (App e []) = normalise e
+normalise (Lam [] e) = normalise e
+normalise e = recurse id normalise e
+
+recurse :: (Pat -> Pat) -> (Expr -> Expr) -> Expr -> Expr
+recurse _ f (App e as) = App (f e) (map f as)
+recurse _ _ (Ref n) = Ref n
+recurse _ _ (Num k) = Num k
+recurse _ f (Tup es) = Tup (map f es)
+recurse _ f (Lam ns e) = Lam ns (f e)
+recurse fp f (Case e as) = Case (f e) (map (\(p, e') -> (fp p, f e')) as)