aboutsummaryrefslogtreecommitdiff
path: root/reverse-ad.txt
blob: 13438e7dfb3d5a79e82e24adb51b50f9b473fdd7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
unit x = Cons x Nil;
append l1 l2 = case l1 of {
	Nil -> l2;
	Cons x l1' -> Cons x (append l1' l2)
};

gradient e = gradient' (Num 1) e output;

gradient' adj e r = case e of {
	Var x -> r (unit (x, unit adj));
	Num a -> r Nil;
	Add e1 e2 -> gradient' adj e1
						   (\m1 -> gradient' adj e2
								             (\m2 -> r (combine m1 m2)));
	Mul e1 e2 -> gradient' (Mul adj e2) e1
						   (\m1 -> gradient' (Mul adj e1) e2
											 (\m2 -> r (combine m1 m2)))
};

combine m1 m2 = case m1 of {
	Nil -> m2;
	Cons (x, l) m1' -> combine m1' (insert x l m2)
};
insert x l m = case m of {
	Nil -> unit (x, l);
	Cons (y, l') m' -> case eq x y of {
		True -> Cons (y, append l l') m';
		False -> Cons (y, l') (insert x l m')
	}
};

eq x y = case x of {
	A -> case y of { A -> True; _ -> False };
	B -> case y of { B -> True; _ -> False };
	C -> case y of { C -> True; _ -> False }
};

expr = gradient (Mul (Add (Var A) (Num 2)) (Mul (Num 3) (Var B)));