aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--reverse-ad.txt38
-rw-r--r--src/Haskell/AST.hs63
-rw-r--r--src/Haskell/Rewrite.hs95
3 files changed, 154 insertions, 42 deletions
diff --git a/reverse-ad.txt b/reverse-ad.txt
new file mode 100644
index 0000000..13438e7
--- /dev/null
+++ b/reverse-ad.txt
@@ -0,0 +1,38 @@
+unit x = Cons x Nil;
+append l1 l2 = case l1 of {
+ Nil -> l2;
+ Cons x l1' -> Cons x (append l1' l2)
+};
+
+gradient e = gradient' (Num 1) e output;
+
+gradient' adj e r = case e of {
+ Var x -> r (unit (x, unit adj));
+ Num a -> r Nil;
+ Add e1 e2 -> gradient' adj e1
+ (\m1 -> gradient' adj e2
+ (\m2 -> r (combine m1 m2)));
+ Mul e1 e2 -> gradient' (Mul adj e2) e1
+ (\m1 -> gradient' (Mul adj e1) e2
+ (\m2 -> r (combine m1 m2)))
+};
+
+combine m1 m2 = case m1 of {
+ Nil -> m2;
+ Cons (x, l) m1' -> combine m1' (insert x l m2)
+};
+insert x l m = case m of {
+ Nil -> unit (x, l);
+ Cons (y, l') m' -> case eq x y of {
+ True -> Cons (y, append l l') m';
+ False -> Cons (y, l') (insert x l m')
+ }
+};
+
+eq x y = case x of {
+ A -> case y of { A -> True; _ -> False };
+ B -> case y of { B -> True; _ -> False };
+ C -> case y of { C -> True; _ -> False }
+};
+
+expr = gradient (Mul (Add (Var A) (Num 2)) (Mul (Num 3) (Var B)));
diff --git a/src/Haskell/AST.hs b/src/Haskell/AST.hs
index 2238b6d..6d25153 100644
--- a/src/Haskell/AST.hs
+++ b/src/Haskell/AST.hs
@@ -108,33 +108,42 @@ instance Pretty Inst where
pretty (Inst n t ds) = Node ("instance " ++ n ++ " " ++ pprintOneline t ++ " where") [Bracket "{" "}" ";" (map pretty ds)]
-class AllRefs a where
- allRefs :: a -> [Name]
-
-instance AllRefs AST where
- allRefs (AST tops) = nub $ concatMap allRefs tops
-
-instance AllRefs Toplevel where
- allRefs (TopDef def) = allRefs def
- allRefs (TopDecl _) = []
- allRefs (TopData _) = []
- allRefs (TopClass _) = []
- allRefs (TopInst inst) = allRefs inst
-
-instance AllRefs Def where
- allRefs (Def _ e) = allRefs e
-
-instance AllRefs Expr where
- allRefs (App e es) = nub $ concatMap allRefs (e : es)
- allRefs (Ref n) = [n]
- allVars (Con _) = []
- allRefs (Num _) = []
- allRefs (Tup es) = nub $ concatMap allRefs es
- allRefs (Lam ns e) = allRefs e \\ ns
- allRefs (Case e pairs) = nub $ allRefs e ++ concatMap (allRefs . snd) pairs
-
-instance AllRefs Inst where
- allRefs (Inst _ _ ds) = nub $ concatMap allRefs ds
+-- This excludes constructor names, since those are not variables. This _does_
+-- include bound variables; if you don't want that, use freeVariables.
+class AllVars a where
+ allVars :: a -> Set.Set Name
+
+instance AllVars AST where
+ allVars (AST tops) = Set.unions (map allVars tops)
+
+instance AllVars Toplevel where
+ allVars (TopDef def) = allVars def
+ allVars (TopDecl _) = mempty
+ allVars (TopData _) = mempty
+ allVars (TopClass _) = mempty
+ allVars (TopInst inst) = allVars inst
+
+instance AllVars Def where
+ allVars (Def n e) = Set.insert n (allVars e)
+
+instance AllVars Inst where
+ allVars (Inst _ _ ds) = Set.unions (map allVars ds)
+
+instance AllVars Expr where
+ allVars (App e es) = Set.unions (map allVars (e : es))
+ allVars (Ref n) = Set.singleton n
+ allVars (Con _) = mempty
+ allVars (Num _) = mempty
+ allVars (Tup es) = Set.unions (map allVars es)
+ allVars (Lam ns e) = Set.fromList ns <> allVars e
+ allVars (Case e pairs) =
+ allVars e <> Set.unions [allVars p <> allVars e' | (p, e') <- pairs]
+
+instance AllVars Pat where
+ allVars PatAny = mempty
+ allVars (PatVar n) = Set.singleton n
+ allVars (PatCon _ ps) = Set.unions (map allVars ps)
+ allVars (PatTup ps) = Set.unions (map allVars ps)
boundVars :: Pat -> Set.Set Name
diff --git a/src/Haskell/Rewrite.hs b/src/Haskell/Rewrite.hs
index 498ccae..2e3f6c5 100644
--- a/src/Haskell/Rewrite.hs
+++ b/src/Haskell/Rewrite.hs
@@ -1,3 +1,4 @@
+{-# LANGUAGE TupleSections #-}
module Haskell.Rewrite
(rewrite
,betared, etared, casered
@@ -9,27 +10,91 @@ import Control.Monad
import Data.List
import Data.Maybe
import qualified Data.Map.Strict as Map
+import qualified Data.Set as Set
+
import Haskell.AST
import Util
rewrite :: Name -> Expr -> Expr -> Expr
-rewrite name repl = \case
- App e as -> App (rewrite name repl e) (map (rewrite name repl) as)
- Ref n -> if n == name then repl else Ref n
- Num k -> Num k
- Tup es -> Tup (map (rewrite name repl) es)
- Lam ns e -> if name `elem` ns then Lam ns e else Lam ns (rewrite name repl e)
- Case e as -> Case (rewrite name repl e) (map caseArm as)
+rewrite target repl topexpr = fst (rewrite' mempty topexpr)
where
- caseArm (p, e') =
- if name `elem` boundVars p then (p, e') else (p, rewrite name repl e')
-
-boundVars :: Pat -> [Name]
-boundVars PatAny = []
-boundVars (PatVar n) = [n]
-boundVars (PatCon n ps) = nub $ [n] ++ concatMap boundVars ps
-boundVars (PatTup ps) = nub $ concatMap boundVars ps
+ -- When moving into the subexpression E under a binding for some variable
+ -- 'x', if 'x' is free in 'repl', we need to alpha-rename 'x' to something
+ -- that is:
+ -- - Not free in 'repl';
+ -- - Not free in E;
+ -- - Not equal to a name 'y' bound in E if the subexpression under the
+ -- binder for 'y' has an occurrence of 'x'.
+ -- If we strengthen the third requirement to prohibit all bound variables
+ -- in E, the second and third together mean "all variables in E". This is
+ -- what we will use.
+ frees :: Set.Set Name
+ frees = freeVariables repl
+
+ -- Returns rewritten expression, and whether any actual rewrites were
+ -- performed. This allows preventing unnecessary alpha-renames.
+ rewrite' :: Map.Map Name Name -> Expr -> (Expr, Bool)
+ rewrite' mapping = \case
+ App e as -> let (e' : as', bs) = unzip (map (rewrite' mapping) (e : as))
+ in (App e' as', or bs)
+ Ref n -> case Map.lookup n mapping of
+ Just n' -> (Ref n', False) -- renamed variable cannot be target
+ Nothing | n == target -> (repl, True)
+ | otherwise -> (Ref n, False)
+ Con n -> (Con n, False)
+ Num k -> (Num k, False)
+ Tup es -> let (es', bs) = unzip (map (rewrite' mapping) es)
+ in (Tup es', or bs)
+ Lam ns e
+ | target `elem` ns -> (Lam ns e, False)
+ | otherwise ->
+ let forbidden = frees <> allVars e <> Set.fromList ns
+ -- Note that Map.<> is left-preferring
+ mapping' = freshenL ns frees forbidden <> mapping
+ ns' = map (rename mapping') ns
+ (e', b) = rewrite' mapping' e
+ in if b then (Lam ns' e', True) else (Lam ns e, False)
+ Case e as ->
+ let (scrutinee, b1) = rewrite' mapping e
+ (arms, bs) =
+ unzip [if target `elem` ns
+ then ((p, e'), False)
+ else let forbidden = frees <> allVars e' <> Set.fromList ns
+ mapping' = freshenL ns frees forbidden <> mapping
+ (rhs, b) = rewrite' mapping' e'
+ in if b
+ then ((renamePat mapping' p, rhs), True)
+ else ((p, e'), False)
+ | (p, e') <- as
+ , let ns = Set.toList (boundVars p)]
+ in (Case scrutinee arms, b1 || or bs)
+
+ -- Freshens all variables found in 'vars' to something that is not in
+ -- 'bnd', returning a replace mapping. Later names override earlier ones
+ -- in the mapping.
+ freshenL ns vars bnd =
+ foldl (\mp n ->
+ if n `Set.member` vars
+ then Map.insert n (freshName bnd n) mp
+ else mp)
+ mempty ns
+
+ -- Finds a fresh name for 'n' that is not in 'bnd'.
+ freshName bnd n =
+ head [n ++ "_R_" ++ show i
+ | i <- [1::Int ..]
+ , let n' = n ++ show i
+ , n' `Set.notMember` bnd]
+
+ renamePat mapping = \case
+ PatAny -> PatAny
+ PatVar n -> PatVar (rename mapping n)
+ PatCon n ps -> PatCon (rename mapping n) (map (renamePat mapping) ps)
+ PatTup ps -> PatTup (map (renamePat mapping) ps)
+
+ rename :: Map.Map Name Name -> Name -> Name
+ rename mapping n = fromMaybe n (Map.lookup n mapping)
betared :: Expr -> Expr
betared (App (Lam (n:as) bd) (arg:args)) =