summaryrefslogtreecommitdiff
path: root/src/ForwardAD/DualNumbers.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/ForwardAD/DualNumbers.hs')
-rw-r--r--src/ForwardAD/DualNumbers.hs7
1 files changed, 3 insertions, 4 deletions
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index 3587378..f2ded6e 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -144,11 +144,10 @@ dfwdDN = \case
| Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n)
-> EShape ext (dfwdDN e)
EOp _ op e -> dop op (dfwdDN e)
- ECustom _ s t _ du _ e1 e2 ->
- -- TODO: we need a bit of codegen here that projects the primals out from the dual number result of e1. Note that a non-differentiating code transformation does not eliminate the need for this, because then the need just shifts to free variable adaptor code.
- ELet ext (_ e1) $
+ ECustom _ _ _ _ pr _ _ e1 e2 ->
+ ELet ext (dfwdDN e1) $
ELet ext (weakenExpr WSink (dfwdDN e2)) $
- weakenExpr (WCopy (WCopy WClosed)) du
+ weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr)
EError t s -> EError (dn t) s
EWith{} -> err_accum