aboutsummaryrefslogtreecommitdiff
path: root/reverse-ad.txt
diff options
context:
space:
mode:
Diffstat (limited to 'reverse-ad.txt')
-rw-r--r--reverse-ad.txt38
1 files changed, 38 insertions, 0 deletions
diff --git a/reverse-ad.txt b/reverse-ad.txt
new file mode 100644
index 0000000..13438e7
--- /dev/null
+++ b/reverse-ad.txt
@@ -0,0 +1,38 @@
+unit x = Cons x Nil;
+append l1 l2 = case l1 of {
+ Nil -> l2;
+ Cons x l1' -> Cons x (append l1' l2)
+};
+
+gradient e = gradient' (Num 1) e output;
+
+gradient' adj e r = case e of {
+ Var x -> r (unit (x, unit adj));
+ Num a -> r Nil;
+ Add e1 e2 -> gradient' adj e1
+ (\m1 -> gradient' adj e2
+ (\m2 -> r (combine m1 m2)));
+ Mul e1 e2 -> gradient' (Mul adj e2) e1
+ (\m1 -> gradient' (Mul adj e1) e2
+ (\m2 -> r (combine m1 m2)))
+};
+
+combine m1 m2 = case m1 of {
+ Nil -> m2;
+ Cons (x, l) m1' -> combine m1' (insert x l m2)
+};
+insert x l m = case m of {
+ Nil -> unit (x, l);
+ Cons (y, l') m' -> case eq x y of {
+ True -> Cons (y, append l l') m';
+ False -> Cons (y, l') (insert x l m')
+ }
+};
+
+eq x y = case x of {
+ A -> case y of { A -> True; _ -> False };
+ B -> case y of { B -> True; _ -> False };
+ C -> case y of { C -> True; _ -> False }
+};
+
+expr = gradient (Mul (Add (Var A) (Num 2)) (Mul (Num 3) (Var B)));