module Haskell.Rewrite (rewrite ,betared, etared, casered ,etacase, casecase ,autoSimp ,normalise) where import Control.Monad import Data.List import Data.Maybe import qualified Data.Map.Strict as Map import Haskell.AST import Util rewrite :: Name -> Expr -> Expr -> Expr rewrite name repl = \case App e as -> App (rewrite name repl e) (map (rewrite name repl) as) Ref n -> if n == name then repl else Ref n Num k -> Num k Tup es -> Tup (map (rewrite name repl) es) Lam ns e -> if name `elem` ns then Lam ns e else Lam ns (rewrite name repl e) Case e as -> Case (rewrite name repl e) (map caseArm as) where caseArm (p, e') = if name `elem` boundVars p then (p, e') else (p, rewrite name repl e') boundVars :: Pat -> [Name] boundVars PatAny = [] boundVars (PatVar n) = [n] boundVars (PatCon n ps) = nub $ [n] ++ concatMap boundVars ps boundVars (PatTup ps) = nub $ concatMap boundVars ps betared :: Expr -> Expr betared (App (Lam (n:as) bd) (arg:args)) = App (rewrite n arg (Lam as bd)) args betared e = recurse id betared e etared :: Expr -> Expr etared (Lam (n:as) (App e es@(_:_))) | last es == Ref n = Lam as (App e (init es)) etared e = recurse id etared e casered :: Bool -> Expr -> Expr casered ambig orig@(Case subj arms) = case catMaybes [(,rhs) <$> unify p (casered ambig subj) | (p, rhs) <- arms] of [] -> recurse id (casered ambig) orig ((mp, rhs):rest) -> let res = foldl' (\e' (n, rhs') -> rewrite n rhs' e') rhs (Map.assocs mp) in case (ambig, rest) of (True, _) -> res (False, []) -> res (False, _) -> recurse id (casered ambig) orig casered ambig e = recurse id (casered ambig) e etacase :: Expr -> Expr etacase (Case subj arms) | all (uncurry eqPE) arms = subj etacase e = recurse id etacase e casecase :: Expr -> Expr casecase (Case (Case subj arms1) arms2) = Case subj [(p, Case e arms2) | (p, e) <- arms1] casecase e = recurse id casecase e autoSimp :: Expr -> Expr autoSimp expr = let steps = [betared, casered False, etared, etacase, casecase] in fixpoint (normalise . foldl1 (.) (intersperse normalise steps)) expr eqPE :: Pat -> Expr -> Bool eqPE pat expr = case unify pat expr of Nothing -> False Just mp -> all (uncurry isIdmap) (Map.assocs mp) where isIdmap n = (== Ref n) unify :: Pat -> Expr -> Maybe (Map.Map Name Expr) unify PatAny _ = Just Map.empty unify (PatVar n) e = Just (Map.singleton n e) unify (PatCon n ps) (App n' es) | n' == Con n, length ps == length es = foldM (\m (p, e) -> unify p e >>= reconcile m) Map.empty (zip ps es) unify (PatCon n []) (Con n') | n == n' = Just Map.empty unify (PatTup ps) (Tup es) | length ps == length es = foldM (\m (p, e) -> unify p e >>= reconcile m) Map.empty (zip ps es) unify _ _ = Nothing reconcile :: Map.Map Name Expr -> Map.Map Name Expr -> Maybe (Map.Map Name Expr) reconcile m1 m2 = foldM func m1 (Map.assocs m2) where func m (k, v) = case Map.lookup k m of Nothing -> Just (Map.insert k v m) Just v' | v == v' -> Just m | otherwise -> Nothing normalise :: Expr -> Expr normalise (App e []) = normalise e normalise (Lam [] e) = normalise e normalise e = recurse id normalise e recurse :: (Pat -> Pat) -> (Expr -> Expr) -> Expr -> Expr recurse _ f (App e as) = App (f e) (map f as) recurse _ _ (Ref n) = Ref n recurse _ _ (Con n) = Con n recurse _ _ (Num k) = Num k recurse _ f (Tup es) = Tup (map f es) recurse _ f (Lam ns e) = Lam ns (f e) recurse fp f (Case e as) = Case (f e) (map (\(p, e') -> (fp p, f e')) as)