diff options
Diffstat (limited to 'src/Haskell/Rewrite.hs')
-rw-r--r-- | src/Haskell/Rewrite.hs | 54 |
1 files changed, 25 insertions, 29 deletions
diff --git a/src/Haskell/Rewrite.hs b/src/Haskell/Rewrite.hs index bcec6f7..9812231 100644 --- a/src/Haskell/Rewrite.hs +++ b/src/Haskell/Rewrite.hs @@ -15,12 +15,12 @@ import Util rewrite :: Name -> Expr -> Expr -> Expr rewrite name repl = \case - App e as -> App (rewrite name repl e) (map (rewrite name repl) as) - Ref n -> if n == name then repl else Ref n - Num k -> Num k - Tup es -> Tup (map (rewrite name repl) es) - Lam ns e -> if name `elem` ns then Lam ns e else Lam ns (rewrite name repl e) - Case e as -> Case (rewrite name repl e) (map caseArm as) + App e as ty -> App (rewrite name repl e) (map (rewrite name repl) as) ty + Ref n ty -> if n == name then repl else Ref n ty + Num k ty -> Num k ty + Tup es ty -> Tup (map (rewrite name repl) es) ty + Lam ns e ty -> if name `elem` ns then Lam ns e ty else Lam ns (rewrite name repl e) ty + Case e as ty -> Case (rewrite name repl e) (map caseArm as) ty where caseArm (p, e') = if name `elem` boundVars p then (p, e') else (p, rewrite name repl e') @@ -32,18 +32,21 @@ boundVars (PatCon n ps) = nub $ [n] ++ concatMap boundVars ps boundVars (PatTup ps) = nub $ concatMap boundVars ps betared :: Expr -> Expr -betared (App (Lam (n:as) bd) (arg:args)) = - App (Lam as (rewrite n arg bd)) args +betared (App (Lam (n:as) bd lty) (arg:args) ty) = + App (Lam as + (rewrite n arg bd) + (Just $ fromJust $ typeApply (fromJust lty) (typeOf arg))) + args ty betared e = recurse id betared e etared :: Expr -> Expr -etared (Lam (n:as) (App e es@(_:_))) - | last es == Ref n = - Lam as (App e (init es)) +etared (Lam (n:as) (App e es@(_:_) aty) ty) + | Ref n' _ <- last es, n == n' = + Lam as (App e (init es) aty) ty etared e = recurse id etared e casered :: Bool -> Expr -> Expr -casered ambig orig@(Case subj arms) = +casered ambig orig@(Case subj arms _) = case catMaybes [(,rhs) <$> unify p subj | (p, rhs) <- arms] of [] -> recurse id (casered ambig) orig ((mp, rhs):rest) -> @@ -55,12 +58,12 @@ casered ambig orig@(Case subj arms) = casered ambig e = recurse id (casered ambig) e etacase :: Expr -> Expr -etacase (Case subj arms) | all (uncurry eqPE) arms = subj +etacase (Case subj arms _) | all (uncurry eqPE) arms = subj etacase e = recurse id etacase e casecase :: Expr -> Expr -casecase (Case (Case subj arms1) arms2) = - Case subj [(p, Case e arms2) | (p, e) <- arms1] +casecase (Case (Case subj arms1 _) arms2 ty) = + Case subj [(p, Case e arms2 ty) | (p, e) <- arms1] ty casecase e = recurse id casecase e autoSimp :: Expr -> Expr @@ -73,15 +76,16 @@ eqPE pat expr = case unify pat expr of Nothing -> False Just mp -> all (uncurry isIdmap) (Map.assocs mp) where - isIdmap n = (== Ref n) + isIdmap n (Ref n' _) = n == n' + isIdmap _ _ = False unify :: Pat -> Expr -> Maybe (Map.Map Name Expr) unify PatAny _ = Just Map.empty unify (PatVar n) e = Just (Map.singleton n e) -unify (PatCon n ps) (App n' es) - | n' == Ref n, length ps == length es = +unify (PatCon n ps) (App (Ref n' _) es _) + | n == n', length ps == length es = foldM (\m (p, e) -> unify p e >>= reconcile m) Map.empty (zip ps es) -unify (PatTup ps) (Tup es) +unify (PatTup ps) (Tup es _) | length ps == length es = foldM (\m (p, e) -> unify p e >>= reconcile m) Map.empty (zip ps es) unify _ _ = Nothing @@ -94,14 +98,6 @@ reconcile m1 m2 = foldM func m1 (Map.assocs m2) | otherwise -> Nothing normalise :: Expr -> Expr -normalise (App e []) = normalise e -normalise (Lam [] e) = normalise e +normalise (App e [] _) = normalise e +normalise (Lam [] e _) = normalise e normalise e = recurse id normalise e - -recurse :: (Pat -> Pat) -> (Expr -> Expr) -> Expr -> Expr -recurse _ f (App e as) = App (f e) (map f as) -recurse _ _ (Ref n) = Ref n -recurse _ _ (Num k) = Num k -recurse _ f (Tup es) = Tup (map f es) -recurse _ f (Lam ns e) = Lam ns (f e) -recurse fp f (Case e as) = Case (f e) (map (\(p, e') -> (fp p, f e')) as) |