{-# LANGUAGE TupleSections #-} module Haskell.Rewrite (rewrite ,betared, etared, casered ,etacase, casecase ,autoSimp ,normalise) where import Control.Monad import Data.List import Data.Maybe import qualified Data.Map.Strict as Map import qualified Data.Set as Set import Haskell.AST import Util rewrite :: Name -> Expr -> Expr -> Expr rewrite target repl topexpr = fst (rewrite' mempty topexpr) where -- When moving into the subexpression E under a binding for some variable -- 'x', if 'x' is free in 'repl', we need to alpha-rename 'x' to something -- that is: -- - Not free in 'repl'; -- - Not free in E; -- - Not equal to a name 'y' bound in E if the subexpression under the -- binder for 'y' has an occurrence of 'x'. -- If we strengthen the third requirement to prohibit all bound variables -- in E, the second and third together mean "all variables in E". This is -- what we will use. frees :: Set.Set Name frees = freeVariables repl -- Returns rewritten expression, and whether any actual rewrites were -- performed. This allows preventing unnecessary alpha-renames. rewrite' :: Map.Map Name Name -> Expr -> (Expr, Bool) rewrite' mapping = \case App e as -> let (e' : as', bs) = unzip (map (rewrite' mapping) (e : as)) in (App e' as', or bs) Ref n -> case Map.lookup n mapping of Just n' -> (Ref n', False) -- renamed variable cannot be target Nothing | n == target -> (repl, True) | otherwise -> (Ref n, False) Con n -> (Con n, False) Num k -> (Num k, False) Tup es -> let (es', bs) = unzip (map (rewrite' mapping) es) in (Tup es', or bs) Lam ns e | target `elem` ns -> (Lam ns e, False) | otherwise -> let forbidden = frees <> allVars e <> Set.fromList ns -- Note that Map.<> is left-preferring mapping' = freshenL ns frees forbidden <> mapping ns' = map (rename mapping') ns (e', b) = rewrite' mapping' e in if b then (Lam ns' e', True) else (Lam ns e, False) Case e as -> let (scrutinee, b1) = rewrite' mapping e (arms, bs) = unzip [if target `elem` ns then ((p, e'), False) else let forbidden = frees <> allVars e' <> Set.fromList ns mapping' = freshenL ns frees forbidden <> mapping (rhs, b) = rewrite' mapping' e' in if b then ((renamePat mapping' p, rhs), True) else ((p, e'), False) | (p, e') <- as , let ns = Set.toList (boundVars p)] in (Case scrutinee arms, b1 || or bs) -- Freshens all variables found in 'vars' to something that is not in -- 'bnd', returning a replace mapping. Later names override earlier ones -- in the mapping. freshenL ns vars bnd = foldl (\mp n -> if n `Set.member` vars then Map.insert n (freshName bnd n) mp else mp) mempty ns -- Finds a fresh name for 'n' that is not in 'bnd'. freshName bnd n = head [n ++ "_R_" ++ show i | i <- [1::Int ..] , let n' = n ++ show i , n' `Set.notMember` bnd] renamePat mapping = \case PatAny -> PatAny PatVar n -> PatVar (rename mapping n) PatCon n ps -> PatCon (rename mapping n) (map (renamePat mapping) ps) PatTup ps -> PatTup (map (renamePat mapping) ps) rename :: Map.Map Name Name -> Name -> Name rename mapping n = fromMaybe n (Map.lookup n mapping) betared :: Expr -> Expr betared (App (Lam (n:as) bd) (arg:args)) = App (rewrite n arg (Lam as bd)) args betared e = recurse id betared e etared :: Expr -> Expr etared (Lam (n:as) (App e es@(_:_))) | last es == Ref n = Lam as (App e (init es)) etared e = recurse id etared e casered :: Bool -> Expr -> Expr casered ambig orig@(Case subj arms) = case catMaybes [(,rhs) <$> unify p (casered ambig subj) | (p, rhs) <- arms] of [] -> recurse id (casered ambig) orig ((mp, rhs):rest) -> let res = foldl' (\e' (n, rhs') -> rewrite n rhs' e') rhs (Map.assocs mp) in case (ambig, rest) of (True, _) -> res (False, []) -> res (False, _) -> recurse id (casered ambig) orig casered ambig e = recurse id (casered ambig) e etacase :: Expr -> Expr etacase (Case subj arms) | all (uncurry eqPE) arms = subj etacase e = recurse id etacase e casecase :: Expr -> Expr casecase (Case (Case subj arms1) arms2) = Case subj [(p, Case e arms2) | (p, e) <- arms1] casecase e = recurse id casecase e autoSimp :: Expr -> Expr autoSimp expr = let steps = [betared, casered False, etared, etacase, casecase] in fixpoint (normalise . foldl1 (.) (intersperse normalise steps)) expr eqPE :: Pat -> Expr -> Bool eqPE pat expr = case unify pat expr of Nothing -> False Just mp -> all (uncurry isIdmap) (Map.assocs mp) where isIdmap n = (== Ref n) unify :: Pat -> Expr -> Maybe (Map.Map Name Expr) unify PatAny _ = Just Map.empty unify (PatVar n) e = Just (Map.singleton n e) unify (PatCon n ps) (App n' es) | n' == Con n, length ps == length es = foldM (\m (p, e) -> unify p e >>= reconcile m) Map.empty (zip ps es) unify (PatCon n []) (Con n') | n == n' = Just Map.empty unify (PatTup ps) (Tup es) | length ps == length es = foldM (\m (p, e) -> unify p e >>= reconcile m) Map.empty (zip ps es) unify _ _ = Nothing reconcile :: Map.Map Name Expr -> Map.Map Name Expr -> Maybe (Map.Map Name Expr) reconcile m1 m2 = foldM func m1 (Map.assocs m2) where func m (k, v) = case Map.lookup k m of Nothing -> Just (Map.insert k v m) Just v' | v == v' -> Just m | otherwise -> Nothing normalise :: Expr -> Expr normalise (App e []) = normalise e normalise (Lam [] e) = normalise e normalise e = recurse id normalise e recurse :: (Pat -> Pat) -> (Expr -> Expr) -> Expr -> Expr recurse _ f (App e as) = App (f e) (map f as) recurse _ _ (Ref n) = Ref n recurse _ _ (Con n) = Con n recurse _ _ (Num k) = Num k recurse _ f (Tup es) = Tup (map f es) recurse _ f (Lam ns e) = Lam ns (f e) recurse fp f (Case e as) = Case (f e) (map (\(p, e') -> (fp p, f e')) as)