aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore7
-rw-r--r--ad-dual.cabal3
-rw-r--r--bench/Main.hs57
-rw-r--r--cbits/backprop.c25
-rw-r--r--src/Numeric/ADDual/Internal.hs59
5 files changed, 116 insertions, 35 deletions
diff --git a/.gitignore b/.gitignore
index a3ac1fc..278e7e1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,9 @@
dist-newstyle/
+.ccls-cache/
+
cabal.project.local
+
+data.txt
+plot.png
+
+test.prof
diff --git a/ad-dual.cabal b/ad-dual.cabal
index fe14d31..9880744 100644
--- a/ad-dual.cabal
+++ b/ad-dual.cabal
@@ -16,6 +16,8 @@ library
Numeric.ADDual
Numeric.ADDual.Internal
other-modules:
+ c-sources: cbits/backprop.c
+ cc-options: -O3 -Wall -Wextra -std=c99
build-depends:
transformers,
vector
@@ -54,6 +56,7 @@ benchmark bench
ad-dual,
ad-dual-examples,
ad,
+ clock,
criterion,
deepseq,
vector
diff --git a/bench/Main.hs b/bench/Main.hs
index 99c3f1d..1174a3a 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -1,8 +1,14 @@
{-# LANGUAGE TypeApplications #-}
module Main where
+import Control.DeepSeq
+import Control.Exception (evaluate)
+import Control.Monad (forM_)
import Criterion
import Criterion.Main
+import qualified System.Clock as Clock
+import System.Environment (getArgs)
+import System.Mem (performGC)
import qualified Numeric.AD as AD
@@ -11,17 +17,42 @@ import Numeric.ADDual.Examples
main :: IO ()
-main = defaultMain
- [env (pure (makeNeuralInput 100)) $ \input ->
- bgroup "neural-100"
- [bench "dual" $ nf (\inp -> ADD.gradient' @Double fneural inp 1.0) input
- ,bench "ad" $ nf (\inp -> AD.grad fneural inp) input]
- ,env (pure (makeNeuralInput 500)) $ \input ->
- bgroup "neural-500"
- [bench "dual" $ nf (\inp -> ADD.gradient' @Double fneural inp 1.0) input
- ,bench "ad" $ nf (\inp -> AD.grad fneural inp) input]
- ,env (pure (makeNeuralInput 2000)) $ \input ->
- bgroup "neural-2000"
- [bench "dual" $ nf (\inp -> ADD.gradient' @Double fneural inp 1.0) input
- ,bench "ad" $ nf (\inp -> AD.grad fneural inp) input]
+main = do
+ args <- getArgs
+ case args of
+ ["--neural-graph"] -> mainNeuralGraph
+ _ -> mainCriterion
+
+mainCriterion :: IO ()
+mainCriterion = defaultMain
+ [benchNeural 100
+ ,benchNeural 180 -- rather stably 2 GCs
+ ,benchNeural 500
+ ,benchNeural 2000
]
+ where
+ benchNeural :: Int -> Benchmark
+ benchNeural n =
+ env (pure (makeNeuralInput n)) $ \input ->
+ bgroup ("neural-" ++ show n)
+ [bench "dual" $ nf (\inp -> ADD.gradient' fneural inp 1.0) input
+ ,bench "ad" $ nf (\inp -> AD.grad fneural inp) input]
+
+mainNeuralGraph :: IO ()
+mainNeuralGraph = do
+ forM_ [10, 20 .. 300] $ \n -> do
+ let input = makeNeuralInput n
+ _ <- evaluate (force input)
+
+ performGC
+ t1 <- Clock.getTime Clock.Monotonic
+ _ <- evaluate $ force (ADD.gradient' fneural input 1.0)
+ t2 <- Clock.getTime Clock.Monotonic
+
+ performGC
+ t3 <- Clock.getTime Clock.Monotonic
+ _ <- evaluate $ force (AD.grad fneural input)
+ t4 <- Clock.getTime Clock.Monotonic
+
+ let diff a b = fromIntegral (Clock.toNanoSecs (Clock.diffTimeSpec a b)) / 1e9 :: Double
+ putStrLn $ show n ++ " " ++ show (diff t1 t2) ++ " " ++ show (diff t3 t4)
diff --git a/cbits/backprop.c b/cbits/backprop.c
new file mode 100644
index 0000000..0ca62e3
--- /dev/null
+++ b/cbits/backprop.c
@@ -0,0 +1,25 @@
+// #include <stdio.h>
+#include <stdint.h>
+// #include <inttypes.h>
+
+struct Contrib {
+ int64_t i1;
+ double dx;
+ int64_t i2;
+ double dy;
+};
+
+void ad_dual_hs_backpropagate_double(
+ double *accums,
+ int64_t id_base, int64_t topid, const void *contribs_buf
+) {
+ // fprintf(stderr, "< ci0=%" PRIi64 " topid=%" PRIi64 " >\n", id_base, topid);
+ const struct Contrib *contribs = (const struct Contrib*)contribs_buf;
+
+ for (int64_t i = topid - id_base; i >= 0; i--) {
+ double d = accums[id_base + i];
+ // fprintf(stderr, "ACC i=%" PRIi64 " d=%g C={%" PRIi64 ", %g, %" PRIi64 ", %g}\n", id_base + i, d, contribs[i].i1, contribs[i].dx, contribs[i].i2, contribs[i].dy);
+ if (contribs[i].i1 != -1) accums[contribs[i].i1] += d * contribs[i].dx;
+ if (contribs[i].i2 != -1) accums[contribs[i].i2] += d * contribs[i].dy;
+ }
+}
diff --git a/src/Numeric/ADDual/Internal.hs b/src/Numeric/ADDual/Internal.hs
index 1ea3132..5dd84aa 100644
--- a/src/Numeric/ADDual/Internal.hs
+++ b/src/Numeric/ADDual/Internal.hs
@@ -1,5 +1,6 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveTraversable #-}
+{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE RankNTypes #-}
@@ -11,9 +12,13 @@ import Control.Monad (when)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State.Strict
import Data.IORef
+import Data.Int
import Data.Proxy
+import Data.Typeable
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Storable.Mutable as VSM
+import Foreign.C.Types
+import Foreign.ForeignPtr
import Foreign.Ptr
import Foreign.Storable
import GHC.Stack
@@ -27,9 +32,14 @@ debug :: Bool
debug = toEnum 0
+foreign import ccall unsafe "ad_dual_hs_backpropagate_double"
+ c_backpropagate_double :: Ptr CDouble -> Int64 -> Int64 -> Ptr () -> IO ()
+
+
-- TODO: full vjp (just some more Traversable mess)
+-- TODO: if non-scalar output types are allowed, ensure that all its scalar components are WHNF evaluated before we backpropagate
{-# NOINLINE gradient' #-}
-gradient' :: forall a f. (Traversable f, Num a, Storable a)
+gradient' :: forall a f. (Traversable f, Num a, Storable a, Typeable a)
=> HasCallStack
=> Show a -- TODO: remove
=> (forall s. Taping s a => f (Dual s a) -> Dual s a)
@@ -38,8 +48,9 @@ gradient' f inp topctg = unsafePerformIO $ do
when debug $ hPutStrLn stderr "Preparing input"
let (inp', starti) = runState (traverse (\x -> state (\i -> (Dual x i, i + 1))) inp) 0
idref <- newIORef starti
- vec1 <- VSM.new (max 128 (2 * starti))
- taperef <- newIORef (MLog idref (Chunk 0 vec1) SLNil)
+ -- The first chunk starts after the input IDs.
+ vec1 <- VSM.unsafeNew 128
+ taperef <- newIORef (MLog idref (Chunk starti vec1) SLNil)
when debug $ hPutStrLn stderr "Running function"
let !(Dual result outi) = withDict @(Taping () a) taperef $ f @() inp'
@@ -63,19 +74,23 @@ gradient' f inp topctg = unsafePerformIO $ do
backpropagate (i-1) chunk tape
| otherwise = case tape of
SLNil -> return () -- reached end of tape we should loop over
- tape'@Snoc{} `Snoc` chunk' -> backpropagate i chunk' tape'
- -- When we reach the last chunk, modify it so that its
- -- starting index is after the inputs.
- SLNil `Snoc` Chunk _ vec' ->
- backpropagate i (Chunk starti (VSM.slice starti (VSM.length vec' - starti) vec')) SLNil
-
- -- Ensure that if there are no more chunks in the tape tail, the starting
- -- index of the first chunk is adjusted so that backpropagate stops in time.
- case tapeTail of
- SLNil -> backpropagate outi (let Chunk _ vec = lastChunk
- in Chunk starti (VSM.slice starti (VSM.length vec - starti) vec))
- SLNil
- Snoc{} -> backpropagate outi lastChunk tapeTail
+ tape' `Snoc` chunk' -> backpropagate i chunk' tape'
+
+ backpropagate_via_c :: Ptr CDouble -> Int -> Chunk Double -> Snoclist (Chunk Double) -> IO ()
+ backpropagate_via_c accums_ptr i (Chunk ci0 vec) tape = do
+ let (vec_fptr, _) = VSM.unsafeToForeignPtr0 vec
+ withForeignPtr vec_fptr $ \vec_ptr ->
+ c_backpropagate_double accums_ptr (fromIntegral ci0) (fromIntegral i) (castPtr @(Contrib Double) @() vec_ptr)
+ case tape of
+ SLNil -> return ()
+ tape' `Snoc` chunk' -> backpropagate_via_c accums_ptr (ci0 - 1) chunk' tape'
+
+ case (eqT @a @Double, sizeOf (undefined :: Int)) of
+ (Just Refl, 8) -> do
+ let (accums_fptr, _) = VSM.unsafeToForeignPtr0 accums
+ withForeignPtr accums_fptr $ \accums_ptr ->
+ backpropagate_via_c (castPtr @Double @CDouble accums_ptr) outi lastChunk tapeTail
+ _ -> backpropagate outi lastChunk tapeTail
when debug $ do
accums' <- VS.freeze accums
@@ -93,8 +108,8 @@ gradient' f inp topctg = unsafePerformIO $ do
data Snoclist a = SLNil | Snoc !(Snoclist a) !a
deriving (Show, Eq, Ord, Functor, Foldable, Traversable)
-data Contrib a = Contrib {-# UNPACK #-} !Int a -- ^ ID == -1 -> no contribution
- {-# UNPACK #-} !Int a -- ^ idem
+data Contrib a = Contrib {-# UNPACK #-} !Int !a -- ^ ID == -1 -> no contribution
+ {-# UNPACK #-} !Int !a -- ^ idem
deriving (Show)
instance Storable a => Storable (Contrib a) where
@@ -147,7 +162,7 @@ instance (Num a, Storable a, Taping s a) => Num (Dual s a) where
Dual x i1 + Dual y i2 = mkDual (x + y) i1 1 i2 1
Dual x i1 - Dual y i2 = mkDual (x - y) i1 1 i2 (-1)
Dual x i1 * Dual y i2 = mkDual (x * y) i1 y i2 x
- negate (Dual x i1) = mkDual (negate x) i1 (-1) (-1) 0
+ negate (Dual x i1) = mkDual (negate x) i1 (-1) (-1) 0
abs (Dual x i1) = mkDual (abs x) i1 (x * signum x) (-1) 0
signum (Dual x _) = Dual (signum x) (-1)
fromInteger n = Dual (fromInteger n) (-1)
@@ -181,6 +196,8 @@ mkDual res i1 dx i2 dy = Dual res (writeTapeUnsafe @a (Proxy @s) i1 dx i2 dy)
data WriteTapeAction a = WTANewvec (VSM.IOVector (Contrib a))
| WTAOldTape (Snoclist (Chunk a))
+-- This NOINLINE really doesn't seem to matter for performance, so let's be safe
+{-# NOINLINE writeTapeUnsafe #-}
writeTapeUnsafe :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> Int
writeTapeUnsafe _ i1 dx i2 dy = unsafePerformIO $ writeTapeIO (Proxy @s) i1 dx i2 dy
@@ -200,7 +217,7 @@ writeTapeIO _ i1 dx i2 dy = do
-- check if we'd fit in the next chunk (overwhelmingly likely)
| let newlen = 3 * n `div` 2
, idx < n + newlen -> do
- newvec <- VSM.new newlen
+ newvec <- VSM.unsafeNew newlen
action <- atomicModifyIORef' (getTape @s) $ \(MLog idref' chunk@(Chunk ci0' vec') tape) ->
if | ci0 == ci0' ->
-- Likely (certain when single-threaded): no race condition,
@@ -236,5 +253,3 @@ writeTapeIO _ i1 dx i2 dy = do
-- there's a tremendous amount of competition, let's just try again
| otherwise -> writeTapeIO (Proxy @s) i1 dx i2 dy
-
-