module Haskell.Rewrite (rewrite ,betared, etared, casered ,etacase, casecase ,autoSimp ,normalise) where import Control.Monad import Data.List import Data.Maybe import qualified Data.Map.Strict as Map import Haskell.AST import Util rewrite :: Name -> Expr -> Expr -> Expr rewrite name repl = \case App e as ty -> App (rewrite name repl e) (map (rewrite name repl) as) ty Ref n ty -> if n == name then repl else Ref n ty Num k ty -> Num k ty Tup es ty -> Tup (map (rewrite name repl) es) ty Lam ns e ty -> if name `elem` ns then Lam ns e ty else Lam ns (rewrite name repl e) ty Case e as ty -> Case (rewrite name repl e) (map caseArm as) ty where caseArm (p, e') = if name `elem` boundVars p then (p, e') else (p, rewrite name repl e') boundVars :: Pat -> [Name] boundVars PatAny = [] boundVars (PatVar n) = [n] boundVars (PatCon n ps) = nub $ [n] ++ concatMap boundVars ps boundVars (PatTup ps) = nub $ concatMap boundVars ps betared :: Expr -> Expr betared (App (Lam (n:as) bd lty) (arg:args) ty) = App (Lam as (rewrite n arg bd) (Just $ fromJust $ typeApply (fromJust lty) (typeOf arg))) args ty betared e = recurse id betared e etared :: Expr -> Expr etared (Lam (n:as) (App e es@(_:_) aty) ty) | Ref n' _ <- last es, n == n' = Lam as (App e (init es) aty) ty etared e = recurse id etared e casered :: Bool -> Expr -> Expr casered ambig orig@(Case subj arms _) = case catMaybes [(,rhs) <$> unify p subj | (p, rhs) <- arms] of [] -> recurse id (casered ambig) orig ((mp, rhs):rest) -> let res = foldl' (\e' (n, rhs') -> rewrite n rhs' e') rhs (Map.assocs mp) in case (ambig, rest) of (True, _) -> res (False, []) -> res (False, _) -> recurse id (casered ambig) orig casered ambig e = recurse id (casered ambig) e etacase :: Expr -> Expr etacase (Case subj arms _) | all (uncurry eqPE) arms = subj etacase e = recurse id etacase e casecase :: Expr -> Expr casecase (Case (Case subj arms1 _) arms2 ty) = Case subj [(p, Case e arms2 ty) | (p, e) <- arms1] ty casecase e = recurse id casecase e autoSimp :: Expr -> Expr autoSimp expr = let steps = [betared, casered False, etared, etacase, casecase] in fixpoint (normalise . foldl1 (.) (intersperse normalise steps)) expr eqPE :: Pat -> Expr -> Bool eqPE pat expr = case unify pat expr of Nothing -> False Just mp -> all (uncurry isIdmap) (Map.assocs mp) where isIdmap n (Ref n' _) = n == n' isIdmap _ _ = False unify :: Pat -> Expr -> Maybe (Map.Map Name Expr) unify PatAny _ = Just Map.empty unify (PatVar n) e = Just (Map.singleton n e) unify (PatCon n ps) (App (Ref n' _) es _) | n == n', length ps == length es = foldM (\m (p, e) -> unify p e >>= reconcile m) Map.empty (zip ps es) unify (PatTup ps) (Tup es _) | length ps == length es = foldM (\m (p, e) -> unify p e >>= reconcile m) Map.empty (zip ps es) unify _ _ = Nothing reconcile :: Map.Map Name Expr -> Map.Map Name Expr -> Maybe (Map.Map Name Expr) reconcile m1 m2 = foldM func m1 (Map.assocs m2) where func m (k, v) = case Map.lookup k m of Nothing -> Just (Map.insert k v m) Just v' | v == v' -> Just m | otherwise -> Nothing normalise :: Expr -> Expr normalise (App e [] _) = normalise e normalise (Lam [] e _) = normalise e normalise e = recurse id normalise e