diff options
Diffstat (limited to 'reverse-ad.txt')
-rw-r--r-- | reverse-ad.txt | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/reverse-ad.txt b/reverse-ad.txt new file mode 100644 index 0000000..13438e7 --- /dev/null +++ b/reverse-ad.txt @@ -0,0 +1,38 @@ +unit x = Cons x Nil; +append l1 l2 = case l1 of { + Nil -> l2; + Cons x l1' -> Cons x (append l1' l2) +}; + +gradient e = gradient' (Num 1) e output; + +gradient' adj e r = case e of { + Var x -> r (unit (x, unit adj)); + Num a -> r Nil; + Add e1 e2 -> gradient' adj e1 + (\m1 -> gradient' adj e2 + (\m2 -> r (combine m1 m2))); + Mul e1 e2 -> gradient' (Mul adj e2) e1 + (\m1 -> gradient' (Mul adj e1) e2 + (\m2 -> r (combine m1 m2))) +}; + +combine m1 m2 = case m1 of { + Nil -> m2; + Cons (x, l) m1' -> combine m1' (insert x l m2) +}; +insert x l m = case m of { + Nil -> unit (x, l); + Cons (y, l') m' -> case eq x y of { + True -> Cons (y, append l l') m'; + False -> Cons (y, l') (insert x l m') + } +}; + +eq x y = case x of { + A -> case y of { A -> True; _ -> False }; + B -> case y of { B -> True; _ -> False }; + C -> case y of { C -> True; _ -> False } +}; + +expr = gradient (Mul (Add (Var A) (Num 2)) (Mul (Num 3) (Var B))); |