diff options
Diffstat (limited to 'src/ForwardAD')
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 22 |
1 files changed, 20 insertions, 2 deletions
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index 7d47e6d..9ed04bb 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -41,6 +41,12 @@ scalTyCase STI32 _ k2 = k2 scalTyCase STI64 _ k2 = k2 scalTyCase STBool _ k2 = k2 +floatingDual :: ScalIsFloating t ~ True + => SScalTy t + -> ((Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t), ScalIsNumeric t ~ True) => r) -> r +floatingDual STF32 k = k +floatingDual STF64 k = k + -- | Argument does not need to be duplicable. dop :: forall a b env. SOp a b -> Ex env (DN a) -> Ex env (DN b) dop = \case @@ -68,6 +74,14 @@ dop = \case OIf -> EOp ext OIf ORound64 -> \arg -> EOp ext ORound64 (EFst ext arg) OToFl64 -> \arg -> EPair ext (EOp ext OToFl64 arg) (EConst ext STF64 0.0) + ORecip t -> floatingDual t $ unFloat (\(x, dx) -> + EPair ext (recip' t x) + (mul t (neg t (recip' t (mul t x x))) dx)) + OExp t -> floatingDual t $ unFloat (\(x, dx) -> + EPair ext (EOp ext (OExp t) x) (EOp ext (OExp t) dx)) + OLog t -> floatingDual t $ unFloat (\(x, dx) -> + EPair ext (EOp ext (OLog t) x) + (mul t (recip' t x) dx)) where add :: ScalIsNumeric t ~ True => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) @@ -81,6 +95,10 @@ dop = \case => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) neg t = EOp ext (ONeg t) + recip' :: ScalIsFloating t ~ True + => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) + recip' t = EOp ext (ORecip t) + unFloat :: DN a ~ TPair a a => (forall env'. (Ex env' a, Ex env' a) -> Ex env' (DN b)) -> Ex env (DN a) -> Ex env (DN b) @@ -99,10 +117,10 @@ dop = \case (EFst ext (ESnd ext var), ESnd ext (ESnd ext var)) zeroScalarConst :: ScalIsNumeric t ~ True => SScalTy t -> Ex env (TScal t) -zeroScalarConst STF32 = EConst ext STF32 0.0 -zeroScalarConst STF64 = EConst ext STF64 0.0 zeroScalarConst STI32 = EConst ext STI32 0 zeroScalarConst STI64 = EConst ext STI64 0 +zeroScalarConst STF32 = EConst ext STF32 0.0 +zeroScalarConst STF64 = EConst ext STF64 0.0 dfwdDN :: Ex env t -> Ex (DNE env) (DN t) dfwdDN = \case |