From 92d244786ee551ebba842567e07660efe478deab Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sat, 23 May 2020 14:36:39 +0200 Subject: Significantly improve rewrite correctness It's still not entirely correct, though. Case in point: conservative rewriting on 'expr' in 'reverse-ad.txt' gives the correct result (a non-zero partial derivative on both A and B), while iterating 'rewall; auto' only yields a partial derivative on A, ignoring B. I don't know how this happens. --- reverse-ad.txt | 38 ++++++++++++++++++++ src/Haskell/AST.hs | 63 +++++++++++++++++++-------------- src/Haskell/Rewrite.hs | 95 ++++++++++++++++++++++++++++++++++++++++++-------- 3 files changed, 154 insertions(+), 42 deletions(-) create mode 100644 reverse-ad.txt diff --git a/reverse-ad.txt b/reverse-ad.txt new file mode 100644 index 0000000..13438e7 --- /dev/null +++ b/reverse-ad.txt @@ -0,0 +1,38 @@ +unit x = Cons x Nil; +append l1 l2 = case l1 of { + Nil -> l2; + Cons x l1' -> Cons x (append l1' l2) +}; + +gradient e = gradient' (Num 1) e output; + +gradient' adj e r = case e of { + Var x -> r (unit (x, unit adj)); + Num a -> r Nil; + Add e1 e2 -> gradient' adj e1 + (\m1 -> gradient' adj e2 + (\m2 -> r (combine m1 m2))); + Mul e1 e2 -> gradient' (Mul adj e2) e1 + (\m1 -> gradient' (Mul adj e1) e2 + (\m2 -> r (combine m1 m2))) +}; + +combine m1 m2 = case m1 of { + Nil -> m2; + Cons (x, l) m1' -> combine m1' (insert x l m2) +}; +insert x l m = case m of { + Nil -> unit (x, l); + Cons (y, l') m' -> case eq x y of { + True -> Cons (y, append l l') m'; + False -> Cons (y, l') (insert x l m') + } +}; + +eq x y = case x of { + A -> case y of { A -> True; _ -> False }; + B -> case y of { B -> True; _ -> False }; + C -> case y of { C -> True; _ -> False } +}; + +expr = gradient (Mul (Add (Var A) (Num 2)) (Mul (Num 3) (Var B))); diff --git a/src/Haskell/AST.hs b/src/Haskell/AST.hs index 2238b6d..6d25153 100644 --- a/src/Haskell/AST.hs +++ b/src/Haskell/AST.hs @@ -108,33 +108,42 @@ instance Pretty Inst where pretty (Inst n t ds) = Node ("instance " ++ n ++ " " ++ pprintOneline t ++ " where") [Bracket "{" "}" ";" (map pretty ds)] -class AllRefs a where - allRefs :: a -> [Name] - -instance AllRefs AST where - allRefs (AST tops) = nub $ concatMap allRefs tops - -instance AllRefs Toplevel where - allRefs (TopDef def) = allRefs def - allRefs (TopDecl _) = [] - allRefs (TopData _) = [] - allRefs (TopClass _) = [] - allRefs (TopInst inst) = allRefs inst - -instance AllRefs Def where - allRefs (Def _ e) = allRefs e - -instance AllRefs Expr where - allRefs (App e es) = nub $ concatMap allRefs (e : es) - allRefs (Ref n) = [n] - allVars (Con _) = [] - allRefs (Num _) = [] - allRefs (Tup es) = nub $ concatMap allRefs es - allRefs (Lam ns e) = allRefs e \\ ns - allRefs (Case e pairs) = nub $ allRefs e ++ concatMap (allRefs . snd) pairs - -instance AllRefs Inst where - allRefs (Inst _ _ ds) = nub $ concatMap allRefs ds +-- This excludes constructor names, since those are not variables. This _does_ +-- include bound variables; if you don't want that, use freeVariables. +class AllVars a where + allVars :: a -> Set.Set Name + +instance AllVars AST where + allVars (AST tops) = Set.unions (map allVars tops) + +instance AllVars Toplevel where + allVars (TopDef def) = allVars def + allVars (TopDecl _) = mempty + allVars (TopData _) = mempty + allVars (TopClass _) = mempty + allVars (TopInst inst) = allVars inst + +instance AllVars Def where + allVars (Def n e) = Set.insert n (allVars e) + +instance AllVars Inst where + allVars (Inst _ _ ds) = Set.unions (map allVars ds) + +instance AllVars Expr where + allVars (App e es) = Set.unions (map allVars (e : es)) + allVars (Ref n) = Set.singleton n + allVars (Con _) = mempty + allVars (Num _) = mempty + allVars (Tup es) = Set.unions (map allVars es) + allVars (Lam ns e) = Set.fromList ns <> allVars e + allVars (Case e pairs) = + allVars e <> Set.unions [allVars p <> allVars e' | (p, e') <- pairs] + +instance AllVars Pat where + allVars PatAny = mempty + allVars (PatVar n) = Set.singleton n + allVars (PatCon _ ps) = Set.unions (map allVars ps) + allVars (PatTup ps) = Set.unions (map allVars ps) boundVars :: Pat -> Set.Set Name diff --git a/src/Haskell/Rewrite.hs b/src/Haskell/Rewrite.hs index 498ccae..2e3f6c5 100644 --- a/src/Haskell/Rewrite.hs +++ b/src/Haskell/Rewrite.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE TupleSections #-} module Haskell.Rewrite (rewrite ,betared, etared, casered @@ -9,27 +10,91 @@ import Control.Monad import Data.List import Data.Maybe import qualified Data.Map.Strict as Map +import qualified Data.Set as Set + import Haskell.AST import Util rewrite :: Name -> Expr -> Expr -> Expr -rewrite name repl = \case - App e as -> App (rewrite name repl e) (map (rewrite name repl) as) - Ref n -> if n == name then repl else Ref n - Num k -> Num k - Tup es -> Tup (map (rewrite name repl) es) - Lam ns e -> if name `elem` ns then Lam ns e else Lam ns (rewrite name repl e) - Case e as -> Case (rewrite name repl e) (map caseArm as) +rewrite target repl topexpr = fst (rewrite' mempty topexpr) where - caseArm (p, e') = - if name `elem` boundVars p then (p, e') else (p, rewrite name repl e') - -boundVars :: Pat -> [Name] -boundVars PatAny = [] -boundVars (PatVar n) = [n] -boundVars (PatCon n ps) = nub $ [n] ++ concatMap boundVars ps -boundVars (PatTup ps) = nub $ concatMap boundVars ps + -- When moving into the subexpression E under a binding for some variable + -- 'x', if 'x' is free in 'repl', we need to alpha-rename 'x' to something + -- that is: + -- - Not free in 'repl'; + -- - Not free in E; + -- - Not equal to a name 'y' bound in E if the subexpression under the + -- binder for 'y' has an occurrence of 'x'. + -- If we strengthen the third requirement to prohibit all bound variables + -- in E, the second and third together mean "all variables in E". This is + -- what we will use. + frees :: Set.Set Name + frees = freeVariables repl + + -- Returns rewritten expression, and whether any actual rewrites were + -- performed. This allows preventing unnecessary alpha-renames. + rewrite' :: Map.Map Name Name -> Expr -> (Expr, Bool) + rewrite' mapping = \case + App e as -> let (e' : as', bs) = unzip (map (rewrite' mapping) (e : as)) + in (App e' as', or bs) + Ref n -> case Map.lookup n mapping of + Just n' -> (Ref n', False) -- renamed variable cannot be target + Nothing | n == target -> (repl, True) + | otherwise -> (Ref n, False) + Con n -> (Con n, False) + Num k -> (Num k, False) + Tup es -> let (es', bs) = unzip (map (rewrite' mapping) es) + in (Tup es', or bs) + Lam ns e + | target `elem` ns -> (Lam ns e, False) + | otherwise -> + let forbidden = frees <> allVars e <> Set.fromList ns + -- Note that Map.<> is left-preferring + mapping' = freshenL ns frees forbidden <> mapping + ns' = map (rename mapping') ns + (e', b) = rewrite' mapping' e + in if b then (Lam ns' e', True) else (Lam ns e, False) + Case e as -> + let (scrutinee, b1) = rewrite' mapping e + (arms, bs) = + unzip [if target `elem` ns + then ((p, e'), False) + else let forbidden = frees <> allVars e' <> Set.fromList ns + mapping' = freshenL ns frees forbidden <> mapping + (rhs, b) = rewrite' mapping' e' + in if b + then ((renamePat mapping' p, rhs), True) + else ((p, e'), False) + | (p, e') <- as + , let ns = Set.toList (boundVars p)] + in (Case scrutinee arms, b1 || or bs) + + -- Freshens all variables found in 'vars' to something that is not in + -- 'bnd', returning a replace mapping. Later names override earlier ones + -- in the mapping. + freshenL ns vars bnd = + foldl (\mp n -> + if n `Set.member` vars + then Map.insert n (freshName bnd n) mp + else mp) + mempty ns + + -- Finds a fresh name for 'n' that is not in 'bnd'. + freshName bnd n = + head [n ++ "_R_" ++ show i + | i <- [1::Int ..] + , let n' = n ++ show i + , n' `Set.notMember` bnd] + + renamePat mapping = \case + PatAny -> PatAny + PatVar n -> PatVar (rename mapping n) + PatCon n ps -> PatCon (rename mapping n) (map (renamePat mapping) ps) + PatTup ps -> PatTup (map (renamePat mapping) ps) + + rename :: Map.Map Name Name -> Name -> Name + rename mapping n = fromMaybe n (Map.lookup n mapping) betared :: Expr -> Expr betared (App (Lam (n:as) bd) (arg:args)) = -- cgit v1.2.3-70-g09d2