summaryrefslogtreecommitdiff
path: root/src/ForwardAD/DualNumbers.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/ForwardAD/DualNumbers.hs')
-rw-r--r--src/ForwardAD/DualNumbers.hs22
1 files changed, 20 insertions, 2 deletions
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index 7d47e6d..9ed04bb 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -41,6 +41,12 @@ scalTyCase STI32 _ k2 = k2
scalTyCase STI64 _ k2 = k2
scalTyCase STBool _ k2 = k2
+floatingDual :: ScalIsFloating t ~ True
+ => SScalTy t
+ -> ((Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t), ScalIsNumeric t ~ True) => r) -> r
+floatingDual STF32 k = k
+floatingDual STF64 k = k
+
-- | Argument does not need to be duplicable.
dop :: forall a b env. SOp a b -> Ex env (DN a) -> Ex env (DN b)
dop = \case
@@ -68,6 +74,14 @@ dop = \case
OIf -> EOp ext OIf
ORound64 -> \arg -> EOp ext ORound64 (EFst ext arg)
OToFl64 -> \arg -> EPair ext (EOp ext OToFl64 arg) (EConst ext STF64 0.0)
+ ORecip t -> floatingDual t $ unFloat (\(x, dx) ->
+ EPair ext (recip' t x)
+ (mul t (neg t (recip' t (mul t x x))) dx))
+ OExp t -> floatingDual t $ unFloat (\(x, dx) ->
+ EPair ext (EOp ext (OExp t) x) (EOp ext (OExp t) dx))
+ OLog t -> floatingDual t $ unFloat (\(x, dx) ->
+ EPair ext (EOp ext (OLog t) x)
+ (mul t (recip' t x) dx))
where
add :: ScalIsNumeric t ~ True
=> SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
@@ -81,6 +95,10 @@ dop = \case
=> SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t)
neg t = EOp ext (ONeg t)
+ recip' :: ScalIsFloating t ~ True
+ => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t)
+ recip' t = EOp ext (ORecip t)
+
unFloat :: DN a ~ TPair a a
=> (forall env'. (Ex env' a, Ex env' a) -> Ex env' (DN b))
-> Ex env (DN a) -> Ex env (DN b)
@@ -99,10 +117,10 @@ dop = \case
(EFst ext (ESnd ext var), ESnd ext (ESnd ext var))
zeroScalarConst :: ScalIsNumeric t ~ True => SScalTy t -> Ex env (TScal t)
-zeroScalarConst STF32 = EConst ext STF32 0.0
-zeroScalarConst STF64 = EConst ext STF64 0.0
zeroScalarConst STI32 = EConst ext STI32 0
zeroScalarConst STI64 = EConst ext STI64 0
+zeroScalarConst STF32 = EConst ext STF32 0.0
+zeroScalarConst STF64 = EConst ext STF64 0.0
dfwdDN :: Ex env t -> Ex (DNE env) (DN t)
dfwdDN = \case