aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Haskell/AST.hs19
-rw-r--r--src/Main.hs10
2 files changed, 25 insertions, 4 deletions
diff --git a/src/Haskell/AST.hs b/src/Haskell/AST.hs
index 072fd97..2238b6d 100644
--- a/src/Haskell/AST.hs
+++ b/src/Haskell/AST.hs
@@ -1,6 +1,7 @@
module Haskell.AST where
import Data.List
+import qualified Data.Set as Set
import Pretty
@@ -134,3 +135,21 @@ instance AllRefs Expr where
instance AllRefs Inst where
allRefs (Inst _ _ ds) = nub $ concatMap allRefs ds
+
+
+boundVars :: Pat -> Set.Set Name
+boundVars PatAny = mempty
+boundVars (PatVar n) = Set.singleton n
+boundVars (PatCon _ ps) = Set.unions (map boundVars ps)
+boundVars (PatTup ps) = Set.unions (map boundVars ps)
+
+freeVariables :: Expr -> Set.Set Name
+freeVariables (App e es) = freeVariables e <> Set.unions (map freeVariables es)
+freeVariables (Ref n) = Set.singleton n
+freeVariables (Con _) = mempty
+freeVariables (Num _) = mempty
+freeVariables (Tup es) = Set.unions (map freeVariables es)
+freeVariables (Lam ns e) = freeVariables e Set.\\ Set.fromList ns
+freeVariables (Case e pairs) =
+ freeVariables e <> Set.unions [freeVariables e' Set.\\ boundVars p
+ | (p, e') <- pairs]
diff --git a/src/Main.hs b/src/Main.hs
index a6d80b3..f406f1b 100644
--- a/src/Main.hs
+++ b/src/Main.hs
@@ -3,6 +3,11 @@ module Main where
import Control.Monad
import Data.List
import Data.Maybe
+import qualified Data.Set as Set
+import System.Environment
+import System.Exit
+import System.IO
+
import Haskell.AST
import Haskell.Env
import Haskell.Env.Cmd
@@ -10,9 +15,6 @@ import Haskell.Env.Context
import Haskell.Rewrite
import Haskell.Parser
import Pretty
-import System.Environment
-import System.Exit
-import System.IO
import Util
@@ -115,7 +117,7 @@ applyUserCmd appstate = \case
Just focus ->
let nextStep ctx done =
let orig = fromRight (get (topEnv ctx) focus)
- in case filter (isRight . get (topEnv ctx)) $ allRefs orig \\ done of
+ in case filter (isRight . get (topEnv ctx)) $ Set.toList (freeVariables orig) \\ done of
[] -> Right ctx
(name:_) -> case apply ctx (Action (CRewrite name) focus) of
Left err -> Left ("Error rewriting '" ++ name ++ "': " ++ err)