diff options
author | Tom Smeding <tom.smeding@gmail.com> | 2020-05-23 14:36:39 +0200 |
---|---|---|
committer | Tom Smeding <tom.smeding@gmail.com> | 2020-05-23 14:37:43 +0200 |
commit | 92d244786ee551ebba842567e07660efe478deab (patch) | |
tree | 30f3c363a4ded3168b3ae177f9cc884afe30cc12 /reverse-ad.txt | |
parent | 18ea7b6804e09b1ae604b7fb9eadd542677f172d (diff) |
Significantly improve rewrite correctness
It's still not entirely correct, though. Case in point: conservative
rewriting on 'expr' in 'reverse-ad.txt' gives the correct result (a
non-zero partial derivative on both A and B), while iterating
'rewall; auto' only yields a partial derivative on A, ignoring B.
I don't know how this happens.
Diffstat (limited to 'reverse-ad.txt')
-rw-r--r-- | reverse-ad.txt | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/reverse-ad.txt b/reverse-ad.txt new file mode 100644 index 0000000..13438e7 --- /dev/null +++ b/reverse-ad.txt @@ -0,0 +1,38 @@ +unit x = Cons x Nil; +append l1 l2 = case l1 of { + Nil -> l2; + Cons x l1' -> Cons x (append l1' l2) +}; + +gradient e = gradient' (Num 1) e output; + +gradient' adj e r = case e of { + Var x -> r (unit (x, unit adj)); + Num a -> r Nil; + Add e1 e2 -> gradient' adj e1 + (\m1 -> gradient' adj e2 + (\m2 -> r (combine m1 m2))); + Mul e1 e2 -> gradient' (Mul adj e2) e1 + (\m1 -> gradient' (Mul adj e1) e2 + (\m2 -> r (combine m1 m2))) +}; + +combine m1 m2 = case m1 of { + Nil -> m2; + Cons (x, l) m1' -> combine m1' (insert x l m2) +}; +insert x l m = case m of { + Nil -> unit (x, l); + Cons (y, l') m' -> case eq x y of { + True -> Cons (y, append l l') m'; + False -> Cons (y, l') (insert x l m') + } +}; + +eq x y = case x of { + A -> case y of { A -> True; _ -> False }; + B -> case y of { B -> True; _ -> False }; + C -> case y of { C -> True; _ -> False } +}; + +expr = gradient (Mul (Add (Var A) (Num 2)) (Mul (Num 3) (Var B))); |