diff options
-rw-r--r-- | .gitignore | 7 | ||||
-rw-r--r-- | ad-dual.cabal | 3 | ||||
-rw-r--r-- | bench/Main.hs | 57 | ||||
-rw-r--r-- | cbits/backprop.c | 25 | ||||
-rw-r--r-- | src/Numeric/ADDual/Internal.hs | 59 |
5 files changed, 116 insertions, 35 deletions
@@ -1,2 +1,9 @@ dist-newstyle/ +.ccls-cache/ + cabal.project.local + +data.txt +plot.png + +test.prof diff --git a/ad-dual.cabal b/ad-dual.cabal index fe14d31..9880744 100644 --- a/ad-dual.cabal +++ b/ad-dual.cabal @@ -16,6 +16,8 @@ library Numeric.ADDual Numeric.ADDual.Internal other-modules: + c-sources: cbits/backprop.c + cc-options: -O3 -Wall -Wextra -std=c99 build-depends: transformers, vector @@ -54,6 +56,7 @@ benchmark bench ad-dual, ad-dual-examples, ad, + clock, criterion, deepseq, vector diff --git a/bench/Main.hs b/bench/Main.hs index 99c3f1d..1174a3a 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -1,8 +1,14 @@ {-# LANGUAGE TypeApplications #-} module Main where +import Control.DeepSeq +import Control.Exception (evaluate) +import Control.Monad (forM_) import Criterion import Criterion.Main +import qualified System.Clock as Clock +import System.Environment (getArgs) +import System.Mem (performGC) import qualified Numeric.AD as AD @@ -11,17 +17,42 @@ import Numeric.ADDual.Examples main :: IO () -main = defaultMain - [env (pure (makeNeuralInput 100)) $ \input -> - bgroup "neural-100" - [bench "dual" $ nf (\inp -> ADD.gradient' @Double fneural inp 1.0) input - ,bench "ad" $ nf (\inp -> AD.grad fneural inp) input] - ,env (pure (makeNeuralInput 500)) $ \input -> - bgroup "neural-500" - [bench "dual" $ nf (\inp -> ADD.gradient' @Double fneural inp 1.0) input - ,bench "ad" $ nf (\inp -> AD.grad fneural inp) input] - ,env (pure (makeNeuralInput 2000)) $ \input -> - bgroup "neural-2000" - [bench "dual" $ nf (\inp -> ADD.gradient' @Double fneural inp 1.0) input - ,bench "ad" $ nf (\inp -> AD.grad fneural inp) input] +main = do + args <- getArgs + case args of + ["--neural-graph"] -> mainNeuralGraph + _ -> mainCriterion + +mainCriterion :: IO () +mainCriterion = defaultMain + [benchNeural 100 + ,benchNeural 180 -- rather stably 2 GCs + ,benchNeural 500 + ,benchNeural 2000 ] + where + benchNeural :: Int -> Benchmark + benchNeural n = + env (pure (makeNeuralInput n)) $ \input -> + bgroup ("neural-" ++ show n) + [bench "dual" $ nf (\inp -> ADD.gradient' fneural inp 1.0) input + ,bench "ad" $ nf (\inp -> AD.grad fneural inp) input] + +mainNeuralGraph :: IO () +mainNeuralGraph = do + forM_ [10, 20 .. 300] $ \n -> do + let input = makeNeuralInput n + _ <- evaluate (force input) + + performGC + t1 <- Clock.getTime Clock.Monotonic + _ <- evaluate $ force (ADD.gradient' fneural input 1.0) + t2 <- Clock.getTime Clock.Monotonic + + performGC + t3 <- Clock.getTime Clock.Monotonic + _ <- evaluate $ force (AD.grad fneural input) + t4 <- Clock.getTime Clock.Monotonic + + let diff a b = fromIntegral (Clock.toNanoSecs (Clock.diffTimeSpec a b)) / 1e9 :: Double + putStrLn $ show n ++ " " ++ show (diff t1 t2) ++ " " ++ show (diff t3 t4) diff --git a/cbits/backprop.c b/cbits/backprop.c new file mode 100644 index 0000000..0ca62e3 --- /dev/null +++ b/cbits/backprop.c @@ -0,0 +1,25 @@ +// #include <stdio.h> +#include <stdint.h> +// #include <inttypes.h> + +struct Contrib { + int64_t i1; + double dx; + int64_t i2; + double dy; +}; + +void ad_dual_hs_backpropagate_double( + double *accums, + int64_t id_base, int64_t topid, const void *contribs_buf +) { + // fprintf(stderr, "< ci0=%" PRIi64 " topid=%" PRIi64 " >\n", id_base, topid); + const struct Contrib *contribs = (const struct Contrib*)contribs_buf; + + for (int64_t i = topid - id_base; i >= 0; i--) { + double d = accums[id_base + i]; + // fprintf(stderr, "ACC i=%" PRIi64 " d=%g C={%" PRIi64 ", %g, %" PRIi64 ", %g}\n", id_base + i, d, contribs[i].i1, contribs[i].dx, contribs[i].i2, contribs[i].dy); + if (contribs[i].i1 != -1) accums[contribs[i].i1] += d * contribs[i].dx; + if (contribs[i].i2 != -1) accums[contribs[i].i2] += d * contribs[i].dy; + } +} diff --git a/src/Numeric/ADDual/Internal.hs b/src/Numeric/ADDual/Internal.hs index 1ea3132..5dd84aa 100644 --- a/src/Numeric/ADDual/Internal.hs +++ b/src/Numeric/ADDual/Internal.hs @@ -1,5 +1,6 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE RankNTypes #-} @@ -11,9 +12,13 @@ import Control.Monad (when) import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.State.Strict import Data.IORef +import Data.Int import Data.Proxy +import Data.Typeable import qualified Data.Vector.Storable as VS import qualified Data.Vector.Storable.Mutable as VSM +import Foreign.C.Types +import Foreign.ForeignPtr import Foreign.Ptr import Foreign.Storable import GHC.Stack @@ -27,9 +32,14 @@ debug :: Bool debug = toEnum 0 +foreign import ccall unsafe "ad_dual_hs_backpropagate_double" + c_backpropagate_double :: Ptr CDouble -> Int64 -> Int64 -> Ptr () -> IO () + + -- TODO: full vjp (just some more Traversable mess) +-- TODO: if non-scalar output types are allowed, ensure that all its scalar components are WHNF evaluated before we backpropagate {-# NOINLINE gradient' #-} -gradient' :: forall a f. (Traversable f, Num a, Storable a) +gradient' :: forall a f. (Traversable f, Num a, Storable a, Typeable a) => HasCallStack => Show a -- TODO: remove => (forall s. Taping s a => f (Dual s a) -> Dual s a) @@ -38,8 +48,9 @@ gradient' f inp topctg = unsafePerformIO $ do when debug $ hPutStrLn stderr "Preparing input" let (inp', starti) = runState (traverse (\x -> state (\i -> (Dual x i, i + 1))) inp) 0 idref <- newIORef starti - vec1 <- VSM.new (max 128 (2 * starti)) - taperef <- newIORef (MLog idref (Chunk 0 vec1) SLNil) + -- The first chunk starts after the input IDs. + vec1 <- VSM.unsafeNew 128 + taperef <- newIORef (MLog idref (Chunk starti vec1) SLNil) when debug $ hPutStrLn stderr "Running function" let !(Dual result outi) = withDict @(Taping () a) taperef $ f @() inp' @@ -63,19 +74,23 @@ gradient' f inp topctg = unsafePerformIO $ do backpropagate (i-1) chunk tape | otherwise = case tape of SLNil -> return () -- reached end of tape we should loop over - tape'@Snoc{} `Snoc` chunk' -> backpropagate i chunk' tape' - -- When we reach the last chunk, modify it so that its - -- starting index is after the inputs. - SLNil `Snoc` Chunk _ vec' -> - backpropagate i (Chunk starti (VSM.slice starti (VSM.length vec' - starti) vec')) SLNil - - -- Ensure that if there are no more chunks in the tape tail, the starting - -- index of the first chunk is adjusted so that backpropagate stops in time. - case tapeTail of - SLNil -> backpropagate outi (let Chunk _ vec = lastChunk - in Chunk starti (VSM.slice starti (VSM.length vec - starti) vec)) - SLNil - Snoc{} -> backpropagate outi lastChunk tapeTail + tape' `Snoc` chunk' -> backpropagate i chunk' tape' + + backpropagate_via_c :: Ptr CDouble -> Int -> Chunk Double -> Snoclist (Chunk Double) -> IO () + backpropagate_via_c accums_ptr i (Chunk ci0 vec) tape = do + let (vec_fptr, _) = VSM.unsafeToForeignPtr0 vec + withForeignPtr vec_fptr $ \vec_ptr -> + c_backpropagate_double accums_ptr (fromIntegral ci0) (fromIntegral i) (castPtr @(Contrib Double) @() vec_ptr) + case tape of + SLNil -> return () + tape' `Snoc` chunk' -> backpropagate_via_c accums_ptr (ci0 - 1) chunk' tape' + + case (eqT @a @Double, sizeOf (undefined :: Int)) of + (Just Refl, 8) -> do + let (accums_fptr, _) = VSM.unsafeToForeignPtr0 accums + withForeignPtr accums_fptr $ \accums_ptr -> + backpropagate_via_c (castPtr @Double @CDouble accums_ptr) outi lastChunk tapeTail + _ -> backpropagate outi lastChunk tapeTail when debug $ do accums' <- VS.freeze accums @@ -93,8 +108,8 @@ gradient' f inp topctg = unsafePerformIO $ do data Snoclist a = SLNil | Snoc !(Snoclist a) !a deriving (Show, Eq, Ord, Functor, Foldable, Traversable) -data Contrib a = Contrib {-# UNPACK #-} !Int a -- ^ ID == -1 -> no contribution - {-# UNPACK #-} !Int a -- ^ idem +data Contrib a = Contrib {-# UNPACK #-} !Int !a -- ^ ID == -1 -> no contribution + {-# UNPACK #-} !Int !a -- ^ idem deriving (Show) instance Storable a => Storable (Contrib a) where @@ -147,7 +162,7 @@ instance (Num a, Storable a, Taping s a) => Num (Dual s a) where Dual x i1 + Dual y i2 = mkDual (x + y) i1 1 i2 1 Dual x i1 - Dual y i2 = mkDual (x - y) i1 1 i2 (-1) Dual x i1 * Dual y i2 = mkDual (x * y) i1 y i2 x - negate (Dual x i1) = mkDual (negate x) i1 (-1) (-1) 0 + negate (Dual x i1) = mkDual (negate x) i1 (-1) (-1) 0 abs (Dual x i1) = mkDual (abs x) i1 (x * signum x) (-1) 0 signum (Dual x _) = Dual (signum x) (-1) fromInteger n = Dual (fromInteger n) (-1) @@ -181,6 +196,8 @@ mkDual res i1 dx i2 dy = Dual res (writeTapeUnsafe @a (Proxy @s) i1 dx i2 dy) data WriteTapeAction a = WTANewvec (VSM.IOVector (Contrib a)) | WTAOldTape (Snoclist (Chunk a)) +-- This NOINLINE really doesn't seem to matter for performance, so let's be safe +{-# NOINLINE writeTapeUnsafe #-} writeTapeUnsafe :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> Int writeTapeUnsafe _ i1 dx i2 dy = unsafePerformIO $ writeTapeIO (Proxy @s) i1 dx i2 dy @@ -200,7 +217,7 @@ writeTapeIO _ i1 dx i2 dy = do -- check if we'd fit in the next chunk (overwhelmingly likely) | let newlen = 3 * n `div` 2 , idx < n + newlen -> do - newvec <- VSM.new newlen + newvec <- VSM.unsafeNew newlen action <- atomicModifyIORef' (getTape @s) $ \(MLog idref' chunk@(Chunk ci0' vec') tape) -> if | ci0 == ci0' -> -- Likely (certain when single-threaded): no race condition, @@ -236,5 +253,3 @@ writeTapeIO _ i1 dx i2 dy = do -- there's a tremendous amount of competition, let's just try again | otherwise -> writeTapeIO (Proxy @s) i1 dx i2 dy - - |