aboutsummaryrefslogtreecommitdiff
path: root/src/Numeric/ADDual/Internal.hs
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2025-02-21 13:04:05 +0100
committerTom Smeding <t.j.smeding@uu.nl>2025-02-21 13:04:05 +0100
commit94a59b0d78ff16903f250989a6121d13dae23e2f (patch)
treef2e11f587fcb91eb9d50fa2d82c33b379e2d485f /src/Numeric/ADDual/Internal.hs
parent4bd1890dccb45a90f10183a916f93f025a3f57d2 (diff)
Nicer API for writing Dual instances
Diffstat (limited to 'src/Numeric/ADDual/Internal.hs')
-rw-r--r--src/Numeric/ADDual/Internal.hs29
1 files changed, 16 insertions, 13 deletions
diff --git a/src/Numeric/ADDual/Internal.hs b/src/Numeric/ADDual/Internal.hs
index 55c47c2..858e0db 100644
--- a/src/Numeric/ADDual/Internal.hs
+++ b/src/Numeric/ADDual/Internal.hs
@@ -143,29 +143,29 @@ instance Ord a => Ord (Dual s a) where
compare (Dual x _) (Dual y _) = compare x y
instance (Num a, Storable a, Taping s a) => Num (Dual s a) where
- Dual x i1 + Dual y i2 = Dual (x + y) (writeTape @a (Proxy @s) i1 1 i2 1)
- Dual x i1 - Dual y i2 = Dual (x - y) (writeTape @a (Proxy @s) i1 1 i2 (-1))
- Dual x i1 * Dual y i2 = Dual (x * y) (writeTape (Proxy @s) i1 y i2 x)
- negate (Dual x i1) = Dual (negate x) (writeTape @a (Proxy @s) i1 (-1) (-1) 0)
- abs (Dual x i1) = Dual (abs x) (writeTape (Proxy @s) i1 (x * signum x) (-1) 0)
+ Dual x i1 + Dual y i2 = mkDual (x + y) i1 1 i2 1
+ Dual x i1 - Dual y i2 = mkDual (x - y) i1 1 i2 (-1)
+ Dual x i1 * Dual y i2 = mkDual (x * y) i1 y i2 x
+ negate (Dual x i1) = mkDual (negate x) i1 (-1) (-1) 0
+ abs (Dual x i1) = mkDual (abs x) i1 (x * signum x) (-1) 0
signum (Dual x _) = Dual (signum x) (-1)
fromInteger n = Dual (fromInteger n) (-1)
instance (Fractional a, Storable a, Taping s a) => Fractional (Dual s a) where
- Dual x i1 / Dual y i2 = Dual (x / y) (writeTape (Proxy @s) i1 (recip y) i2 (-x/(y*y)))
- recip (Dual x i1) = Dual (recip x) (writeTape (Proxy @s) i1 (-1/(x*x)) (-1) 0)
+ Dual x i1 / Dual y i2 = mkDual (x / y) i1 (recip y) i2 (-x/(y*y))
+ recip (Dual x i1) = mkDual (recip x) i1 (-1/(x*x)) (-1) 0
fromRational r = Dual (fromRational r) (-1)
instance (Floating a, Storable a, Taping s a) => Floating (Dual s a) where
pi = Dual pi (-1)
- exp (Dual x i1) = Dual (exp x) (writeTape (Proxy @s) i1 (exp x) (-1) 0)
- log (Dual x i1) = Dual (log x) (writeTape (Proxy @s) i1 (recip x) (-1) 0)
- sqrt (Dual x i1) = Dual (sqrt x) (writeTape (Proxy @s) i1 (recip (2*sqrt x)) (-1) 0)
+ exp (Dual x i1) = mkDual (exp x) i1 (exp x) (-1) 0
+ log (Dual x i1) = mkDual (log x) i1 (recip x) (-1) 0
+ sqrt (Dual x i1) = mkDual (sqrt x) i1 (recip (2*sqrt x)) (-1) 0
-- d/dx (x ^ y) = d/dx (e ^ (y ln x)) = e ^ (y ln x) * d/dx (y ln x) = e ^ (y ln x) * y/x
-- d/dy (x ^ y) = d/dy (e ^ (y ln x)) = e ^ (y ln x) * d/dy (y ln x) = e ^ (y ln x) * ln x
Dual x i1 ** Dual y i2 =
let z = x ** y
- in Dual z (writeTape (Proxy @s) i1 (z * y/x) i2 (z * log x))
+ in mkDual z i1 (z * y/x) i2 (z * log x)
logBase = undefined ; sin = undefined ; cos = undefined ; tan = undefined
asin = undefined ; acos = undefined ; atan = undefined ; sinh = undefined
cosh = undefined ; tanh = undefined ; asinh = undefined ; acosh = undefined
@@ -174,11 +174,14 @@ instance (Floating a, Storable a, Taping s a) => Floating (Dual s a) where
constant :: a -> Dual s a
constant x = Dual x (-1)
+mkDual :: forall a s. (Num a, Storable a, Taping s a) => a -> Int -> a -> Int -> a -> Dual s a
+mkDual res i1 dx i2 dy = Dual res (writeTapeUnsafe @a (Proxy @s) i1 dx i2 dy)
+
data WriteTapeAction a = WTANewvec (VSM.IOVector (Contrib a))
| WTAOldTape (Snoclist (Chunk a))
-writeTape :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> Int
-writeTape _ i1 dx i2 dy = unsafePerformIO $ writeTapeIO (Proxy @s) i1 dx i2 dy
+writeTapeUnsafe :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> Int
+writeTapeUnsafe _ i1 dx i2 dy = unsafePerformIO $ writeTapeIO (Proxy @s) i1 dx i2 dy
writeTapeIO :: forall a s proxy. (Num a, Storable a, Taping s a)
=> HasCallStack