From 4fcdb7118e0084f192753ea6c70394352a27d5ed Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Fri, 8 Nov 2024 20:29:29 +0100 Subject: Custom derivatives --- src/ForwardAD/DualNumbers.hs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) (limited to 'src/ForwardAD/DualNumbers.hs') diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index 3587378..f2ded6e 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -144,11 +144,10 @@ dfwdDN = \case | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n) -> EShape ext (dfwdDN e) EOp _ op e -> dop op (dfwdDN e) - ECustom _ s t _ du _ e1 e2 -> - -- TODO: we need a bit of codegen here that projects the primals out from the dual number result of e1. Note that a non-differentiating code transformation does not eliminate the need for this, because then the need just shifts to free variable adaptor code. - ELet ext (_ e1) $ + ECustom _ _ _ _ pr _ _ e1 e2 -> + ELet ext (dfwdDN e1) $ ELet ext (weakenExpr WSink (dfwdDN e2)) $ - weakenExpr (WCopy (WCopy WClosed)) du + weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr) EError t s -> EError (dn t) s EWith{} -> err_accum -- cgit v1.2.3-70-g09d2