diff options
Diffstat (limited to 'src/ForwardAD/DualNumbers.hs')
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 7 |
1 files changed, 3 insertions, 4 deletions
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index 3587378..f2ded6e 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -144,11 +144,10 @@ dfwdDN = \case | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n) -> EShape ext (dfwdDN e) EOp _ op e -> dop op (dfwdDN e) - ECustom _ s t _ du _ e1 e2 -> - -- TODO: we need a bit of codegen here that projects the primals out from the dual number result of e1. Note that a non-differentiating code transformation does not eliminate the need for this, because then the need just shifts to free variable adaptor code. - ELet ext (_ e1) $ + ECustom _ _ _ _ pr _ _ e1 e2 -> + ELet ext (dfwdDN e1) $ ELet ext (weakenExpr WSink (dfwdDN e2)) $ - weakenExpr (WCopy (WCopy WClosed)) du + weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr) EError t s -> EError (dn t) s EWith{} -> err_accum |