diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2025-02-21 13:04:05 +0100 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-02-21 13:04:05 +0100 |
commit | 94a59b0d78ff16903f250989a6121d13dae23e2f (patch) | |
tree | f2e11f587fcb91eb9d50fa2d82c33b379e2d485f /src | |
parent | 4bd1890dccb45a90f10183a916f93f025a3f57d2 (diff) |
Nicer API for writing Dual instances
Diffstat (limited to 'src')
-rw-r--r-- | src/Numeric/ADDual/Internal.hs | 29 |
1 files changed, 16 insertions, 13 deletions
diff --git a/src/Numeric/ADDual/Internal.hs b/src/Numeric/ADDual/Internal.hs index 55c47c2..858e0db 100644 --- a/src/Numeric/ADDual/Internal.hs +++ b/src/Numeric/ADDual/Internal.hs @@ -143,29 +143,29 @@ instance Ord a => Ord (Dual s a) where compare (Dual x _) (Dual y _) = compare x y instance (Num a, Storable a, Taping s a) => Num (Dual s a) where - Dual x i1 + Dual y i2 = Dual (x + y) (writeTape @a (Proxy @s) i1 1 i2 1) - Dual x i1 - Dual y i2 = Dual (x - y) (writeTape @a (Proxy @s) i1 1 i2 (-1)) - Dual x i1 * Dual y i2 = Dual (x * y) (writeTape (Proxy @s) i1 y i2 x) - negate (Dual x i1) = Dual (negate x) (writeTape @a (Proxy @s) i1 (-1) (-1) 0) - abs (Dual x i1) = Dual (abs x) (writeTape (Proxy @s) i1 (x * signum x) (-1) 0) + Dual x i1 + Dual y i2 = mkDual (x + y) i1 1 i2 1 + Dual x i1 - Dual y i2 = mkDual (x - y) i1 1 i2 (-1) + Dual x i1 * Dual y i2 = mkDual (x * y) i1 y i2 x + negate (Dual x i1) = mkDual (negate x) i1 (-1) (-1) 0 + abs (Dual x i1) = mkDual (abs x) i1 (x * signum x) (-1) 0 signum (Dual x _) = Dual (signum x) (-1) fromInteger n = Dual (fromInteger n) (-1) instance (Fractional a, Storable a, Taping s a) => Fractional (Dual s a) where - Dual x i1 / Dual y i2 = Dual (x / y) (writeTape (Proxy @s) i1 (recip y) i2 (-x/(y*y))) - recip (Dual x i1) = Dual (recip x) (writeTape (Proxy @s) i1 (-1/(x*x)) (-1) 0) + Dual x i1 / Dual y i2 = mkDual (x / y) i1 (recip y) i2 (-x/(y*y)) + recip (Dual x i1) = mkDual (recip x) i1 (-1/(x*x)) (-1) 0 fromRational r = Dual (fromRational r) (-1) instance (Floating a, Storable a, Taping s a) => Floating (Dual s a) where pi = Dual pi (-1) - exp (Dual x i1) = Dual (exp x) (writeTape (Proxy @s) i1 (exp x) (-1) 0) - log (Dual x i1) = Dual (log x) (writeTape (Proxy @s) i1 (recip x) (-1) 0) - sqrt (Dual x i1) = Dual (sqrt x) (writeTape (Proxy @s) i1 (recip (2*sqrt x)) (-1) 0) + exp (Dual x i1) = mkDual (exp x) i1 (exp x) (-1) 0 + log (Dual x i1) = mkDual (log x) i1 (recip x) (-1) 0 + sqrt (Dual x i1) = mkDual (sqrt x) i1 (recip (2*sqrt x)) (-1) 0 -- d/dx (x ^ y) = d/dx (e ^ (y ln x)) = e ^ (y ln x) * d/dx (y ln x) = e ^ (y ln x) * y/x -- d/dy (x ^ y) = d/dy (e ^ (y ln x)) = e ^ (y ln x) * d/dy (y ln x) = e ^ (y ln x) * ln x Dual x i1 ** Dual y i2 = let z = x ** y - in Dual z (writeTape (Proxy @s) i1 (z * y/x) i2 (z * log x)) + in mkDual z i1 (z * y/x) i2 (z * log x) logBase = undefined ; sin = undefined ; cos = undefined ; tan = undefined asin = undefined ; acos = undefined ; atan = undefined ; sinh = undefined cosh = undefined ; tanh = undefined ; asinh = undefined ; acosh = undefined @@ -174,11 +174,14 @@ instance (Floating a, Storable a, Taping s a) => Floating (Dual s a) where constant :: a -> Dual s a constant x = Dual x (-1) +mkDual :: forall a s. (Num a, Storable a, Taping s a) => a -> Int -> a -> Int -> a -> Dual s a +mkDual res i1 dx i2 dy = Dual res (writeTapeUnsafe @a (Proxy @s) i1 dx i2 dy) + data WriteTapeAction a = WTANewvec (VSM.IOVector (Contrib a)) | WTAOldTape (Snoclist (Chunk a)) -writeTape :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> Int -writeTape _ i1 dx i2 dy = unsafePerformIO $ writeTapeIO (Proxy @s) i1 dx i2 dy +writeTapeUnsafe :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> Int +writeTapeUnsafe _ i1 dx i2 dy = unsafePerformIO $ writeTapeIO (Proxy @s) i1 dx i2 dy writeTapeIO :: forall a s proxy. (Num a, Storable a, Taping s a) => HasCallStack |