diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2025-02-23 21:44:23 +0100 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-02-23 21:44:23 +0100 |
commit | 5f7a81acc7f75415d62dac86c5b50c848ab15341 (patch) | |
tree | 641ed54ce426ed8a1d9a5da12a9cde512b32bedc /src/Numeric/ADDual | |
parent | a17bd53598ee5266fc3a1c45f8f4bb4798dc495e (diff) |
Optimise by backpropagating scalar computation in C
Diffstat (limited to 'src/Numeric/ADDual')
-rw-r--r-- | src/Numeric/ADDual/Internal.hs | 59 |
1 files changed, 37 insertions, 22 deletions
diff --git a/src/Numeric/ADDual/Internal.hs b/src/Numeric/ADDual/Internal.hs index 1ea3132..5dd84aa 100644 --- a/src/Numeric/ADDual/Internal.hs +++ b/src/Numeric/ADDual/Internal.hs @@ -1,5 +1,6 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE RankNTypes #-} @@ -11,9 +12,13 @@ import Control.Monad (when) import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.State.Strict import Data.IORef +import Data.Int import Data.Proxy +import Data.Typeable import qualified Data.Vector.Storable as VS import qualified Data.Vector.Storable.Mutable as VSM +import Foreign.C.Types +import Foreign.ForeignPtr import Foreign.Ptr import Foreign.Storable import GHC.Stack @@ -27,9 +32,14 @@ debug :: Bool debug = toEnum 0 +foreign import ccall unsafe "ad_dual_hs_backpropagate_double" + c_backpropagate_double :: Ptr CDouble -> Int64 -> Int64 -> Ptr () -> IO () + + -- TODO: full vjp (just some more Traversable mess) +-- TODO: if non-scalar output types are allowed, ensure that all its scalar components are WHNF evaluated before we backpropagate {-# NOINLINE gradient' #-} -gradient' :: forall a f. (Traversable f, Num a, Storable a) +gradient' :: forall a f. (Traversable f, Num a, Storable a, Typeable a) => HasCallStack => Show a -- TODO: remove => (forall s. Taping s a => f (Dual s a) -> Dual s a) @@ -38,8 +48,9 @@ gradient' f inp topctg = unsafePerformIO $ do when debug $ hPutStrLn stderr "Preparing input" let (inp', starti) = runState (traverse (\x -> state (\i -> (Dual x i, i + 1))) inp) 0 idref <- newIORef starti - vec1 <- VSM.new (max 128 (2 * starti)) - taperef <- newIORef (MLog idref (Chunk 0 vec1) SLNil) + -- The first chunk starts after the input IDs. + vec1 <- VSM.unsafeNew 128 + taperef <- newIORef (MLog idref (Chunk starti vec1) SLNil) when debug $ hPutStrLn stderr "Running function" let !(Dual result outi) = withDict @(Taping () a) taperef $ f @() inp' @@ -63,19 +74,23 @@ gradient' f inp topctg = unsafePerformIO $ do backpropagate (i-1) chunk tape | otherwise = case tape of SLNil -> return () -- reached end of tape we should loop over - tape'@Snoc{} `Snoc` chunk' -> backpropagate i chunk' tape' - -- When we reach the last chunk, modify it so that its - -- starting index is after the inputs. - SLNil `Snoc` Chunk _ vec' -> - backpropagate i (Chunk starti (VSM.slice starti (VSM.length vec' - starti) vec')) SLNil - - -- Ensure that if there are no more chunks in the tape tail, the starting - -- index of the first chunk is adjusted so that backpropagate stops in time. - case tapeTail of - SLNil -> backpropagate outi (let Chunk _ vec = lastChunk - in Chunk starti (VSM.slice starti (VSM.length vec - starti) vec)) - SLNil - Snoc{} -> backpropagate outi lastChunk tapeTail + tape' `Snoc` chunk' -> backpropagate i chunk' tape' + + backpropagate_via_c :: Ptr CDouble -> Int -> Chunk Double -> Snoclist (Chunk Double) -> IO () + backpropagate_via_c accums_ptr i (Chunk ci0 vec) tape = do + let (vec_fptr, _) = VSM.unsafeToForeignPtr0 vec + withForeignPtr vec_fptr $ \vec_ptr -> + c_backpropagate_double accums_ptr (fromIntegral ci0) (fromIntegral i) (castPtr @(Contrib Double) @() vec_ptr) + case tape of + SLNil -> return () + tape' `Snoc` chunk' -> backpropagate_via_c accums_ptr (ci0 - 1) chunk' tape' + + case (eqT @a @Double, sizeOf (undefined :: Int)) of + (Just Refl, 8) -> do + let (accums_fptr, _) = VSM.unsafeToForeignPtr0 accums + withForeignPtr accums_fptr $ \accums_ptr -> + backpropagate_via_c (castPtr @Double @CDouble accums_ptr) outi lastChunk tapeTail + _ -> backpropagate outi lastChunk tapeTail when debug $ do accums' <- VS.freeze accums @@ -93,8 +108,8 @@ gradient' f inp topctg = unsafePerformIO $ do data Snoclist a = SLNil | Snoc !(Snoclist a) !a deriving (Show, Eq, Ord, Functor, Foldable, Traversable) -data Contrib a = Contrib {-# UNPACK #-} !Int a -- ^ ID == -1 -> no contribution - {-# UNPACK #-} !Int a -- ^ idem +data Contrib a = Contrib {-# UNPACK #-} !Int !a -- ^ ID == -1 -> no contribution + {-# UNPACK #-} !Int !a -- ^ idem deriving (Show) instance Storable a => Storable (Contrib a) where @@ -147,7 +162,7 @@ instance (Num a, Storable a, Taping s a) => Num (Dual s a) where Dual x i1 + Dual y i2 = mkDual (x + y) i1 1 i2 1 Dual x i1 - Dual y i2 = mkDual (x - y) i1 1 i2 (-1) Dual x i1 * Dual y i2 = mkDual (x * y) i1 y i2 x - negate (Dual x i1) = mkDual (negate x) i1 (-1) (-1) 0 + negate (Dual x i1) = mkDual (negate x) i1 (-1) (-1) 0 abs (Dual x i1) = mkDual (abs x) i1 (x * signum x) (-1) 0 signum (Dual x _) = Dual (signum x) (-1) fromInteger n = Dual (fromInteger n) (-1) @@ -181,6 +196,8 @@ mkDual res i1 dx i2 dy = Dual res (writeTapeUnsafe @a (Proxy @s) i1 dx i2 dy) data WriteTapeAction a = WTANewvec (VSM.IOVector (Contrib a)) | WTAOldTape (Snoclist (Chunk a)) +-- This NOINLINE really doesn't seem to matter for performance, so let's be safe +{-# NOINLINE writeTapeUnsafe #-} writeTapeUnsafe :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> Int writeTapeUnsafe _ i1 dx i2 dy = unsafePerformIO $ writeTapeIO (Proxy @s) i1 dx i2 dy @@ -200,7 +217,7 @@ writeTapeIO _ i1 dx i2 dy = do -- check if we'd fit in the next chunk (overwhelmingly likely) | let newlen = 3 * n `div` 2 , idx < n + newlen -> do - newvec <- VSM.new newlen + newvec <- VSM.unsafeNew newlen action <- atomicModifyIORef' (getTape @s) $ \(MLog idref' chunk@(Chunk ci0' vec') tape) -> if | ci0 == ci0' -> -- Likely (certain when single-threaded): no race condition, @@ -236,5 +253,3 @@ writeTapeIO _ i1 dx i2 dy = do -- there's a tremendous amount of competition, let's just try again | otherwise -> writeTapeIO (Proxy @s) i1 dx i2 dy - - |