aboutsummaryrefslogtreecommitdiff
path: root/src/Numeric/ADDual
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2025-02-23 21:44:23 +0100
committerTom Smeding <t.j.smeding@uu.nl>2025-02-23 21:44:23 +0100
commit5f7a81acc7f75415d62dac86c5b50c848ab15341 (patch)
tree641ed54ce426ed8a1d9a5da12a9cde512b32bedc /src/Numeric/ADDual
parenta17bd53598ee5266fc3a1c45f8f4bb4798dc495e (diff)
Optimise by backpropagating scalar computation in C
Diffstat (limited to 'src/Numeric/ADDual')
-rw-r--r--src/Numeric/ADDual/Internal.hs59
1 files changed, 37 insertions, 22 deletions
diff --git a/src/Numeric/ADDual/Internal.hs b/src/Numeric/ADDual/Internal.hs
index 1ea3132..5dd84aa 100644
--- a/src/Numeric/ADDual/Internal.hs
+++ b/src/Numeric/ADDual/Internal.hs
@@ -1,5 +1,6 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveTraversable #-}
+{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE RankNTypes #-}
@@ -11,9 +12,13 @@ import Control.Monad (when)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State.Strict
import Data.IORef
+import Data.Int
import Data.Proxy
+import Data.Typeable
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Storable.Mutable as VSM
+import Foreign.C.Types
+import Foreign.ForeignPtr
import Foreign.Ptr
import Foreign.Storable
import GHC.Stack
@@ -27,9 +32,14 @@ debug :: Bool
debug = toEnum 0
+foreign import ccall unsafe "ad_dual_hs_backpropagate_double"
+ c_backpropagate_double :: Ptr CDouble -> Int64 -> Int64 -> Ptr () -> IO ()
+
+
-- TODO: full vjp (just some more Traversable mess)
+-- TODO: if non-scalar output types are allowed, ensure that all its scalar components are WHNF evaluated before we backpropagate
{-# NOINLINE gradient' #-}
-gradient' :: forall a f. (Traversable f, Num a, Storable a)
+gradient' :: forall a f. (Traversable f, Num a, Storable a, Typeable a)
=> HasCallStack
=> Show a -- TODO: remove
=> (forall s. Taping s a => f (Dual s a) -> Dual s a)
@@ -38,8 +48,9 @@ gradient' f inp topctg = unsafePerformIO $ do
when debug $ hPutStrLn stderr "Preparing input"
let (inp', starti) = runState (traverse (\x -> state (\i -> (Dual x i, i + 1))) inp) 0
idref <- newIORef starti
- vec1 <- VSM.new (max 128 (2 * starti))
- taperef <- newIORef (MLog idref (Chunk 0 vec1) SLNil)
+ -- The first chunk starts after the input IDs.
+ vec1 <- VSM.unsafeNew 128
+ taperef <- newIORef (MLog idref (Chunk starti vec1) SLNil)
when debug $ hPutStrLn stderr "Running function"
let !(Dual result outi) = withDict @(Taping () a) taperef $ f @() inp'
@@ -63,19 +74,23 @@ gradient' f inp topctg = unsafePerformIO $ do
backpropagate (i-1) chunk tape
| otherwise = case tape of
SLNil -> return () -- reached end of tape we should loop over
- tape'@Snoc{} `Snoc` chunk' -> backpropagate i chunk' tape'
- -- When we reach the last chunk, modify it so that its
- -- starting index is after the inputs.
- SLNil `Snoc` Chunk _ vec' ->
- backpropagate i (Chunk starti (VSM.slice starti (VSM.length vec' - starti) vec')) SLNil
-
- -- Ensure that if there are no more chunks in the tape tail, the starting
- -- index of the first chunk is adjusted so that backpropagate stops in time.
- case tapeTail of
- SLNil -> backpropagate outi (let Chunk _ vec = lastChunk
- in Chunk starti (VSM.slice starti (VSM.length vec - starti) vec))
- SLNil
- Snoc{} -> backpropagate outi lastChunk tapeTail
+ tape' `Snoc` chunk' -> backpropagate i chunk' tape'
+
+ backpropagate_via_c :: Ptr CDouble -> Int -> Chunk Double -> Snoclist (Chunk Double) -> IO ()
+ backpropagate_via_c accums_ptr i (Chunk ci0 vec) tape = do
+ let (vec_fptr, _) = VSM.unsafeToForeignPtr0 vec
+ withForeignPtr vec_fptr $ \vec_ptr ->
+ c_backpropagate_double accums_ptr (fromIntegral ci0) (fromIntegral i) (castPtr @(Contrib Double) @() vec_ptr)
+ case tape of
+ SLNil -> return ()
+ tape' `Snoc` chunk' -> backpropagate_via_c accums_ptr (ci0 - 1) chunk' tape'
+
+ case (eqT @a @Double, sizeOf (undefined :: Int)) of
+ (Just Refl, 8) -> do
+ let (accums_fptr, _) = VSM.unsafeToForeignPtr0 accums
+ withForeignPtr accums_fptr $ \accums_ptr ->
+ backpropagate_via_c (castPtr @Double @CDouble accums_ptr) outi lastChunk tapeTail
+ _ -> backpropagate outi lastChunk tapeTail
when debug $ do
accums' <- VS.freeze accums
@@ -93,8 +108,8 @@ gradient' f inp topctg = unsafePerformIO $ do
data Snoclist a = SLNil | Snoc !(Snoclist a) !a
deriving (Show, Eq, Ord, Functor, Foldable, Traversable)
-data Contrib a = Contrib {-# UNPACK #-} !Int a -- ^ ID == -1 -> no contribution
- {-# UNPACK #-} !Int a -- ^ idem
+data Contrib a = Contrib {-# UNPACK #-} !Int !a -- ^ ID == -1 -> no contribution
+ {-# UNPACK #-} !Int !a -- ^ idem
deriving (Show)
instance Storable a => Storable (Contrib a) where
@@ -147,7 +162,7 @@ instance (Num a, Storable a, Taping s a) => Num (Dual s a) where
Dual x i1 + Dual y i2 = mkDual (x + y) i1 1 i2 1
Dual x i1 - Dual y i2 = mkDual (x - y) i1 1 i2 (-1)
Dual x i1 * Dual y i2 = mkDual (x * y) i1 y i2 x
- negate (Dual x i1) = mkDual (negate x) i1 (-1) (-1) 0
+ negate (Dual x i1) = mkDual (negate x) i1 (-1) (-1) 0
abs (Dual x i1) = mkDual (abs x) i1 (x * signum x) (-1) 0
signum (Dual x _) = Dual (signum x) (-1)
fromInteger n = Dual (fromInteger n) (-1)
@@ -181,6 +196,8 @@ mkDual res i1 dx i2 dy = Dual res (writeTapeUnsafe @a (Proxy @s) i1 dx i2 dy)
data WriteTapeAction a = WTANewvec (VSM.IOVector (Contrib a))
| WTAOldTape (Snoclist (Chunk a))
+-- This NOINLINE really doesn't seem to matter for performance, so let's be safe
+{-# NOINLINE writeTapeUnsafe #-}
writeTapeUnsafe :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> Int
writeTapeUnsafe _ i1 dx i2 dy = unsafePerformIO $ writeTapeIO (Proxy @s) i1 dx i2 dy
@@ -200,7 +217,7 @@ writeTapeIO _ i1 dx i2 dy = do
-- check if we'd fit in the next chunk (overwhelmingly likely)
| let newlen = 3 * n `div` 2
, idx < n + newlen -> do
- newvec <- VSM.new newlen
+ newvec <- VSM.unsafeNew newlen
action <- atomicModifyIORef' (getTape @s) $ \(MLog idref' chunk@(Chunk ci0' vec') tape) ->
if | ci0 == ci0' ->
-- Likely (certain when single-threaded): no race condition,
@@ -236,5 +253,3 @@ writeTapeIO _ i1 dx i2 dy = do
-- there's a tremendous amount of competition, let's just try again
| otherwise -> writeTapeIO (Proxy @s) i1 dx i2 dy
-
-