aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-02-20 10:11:57 +0100
committerTom Smeding <tom@tomsmeding.com>2025-02-20 10:11:57 +0100
commitfe3132304b6c25e5bebc9fb327e3ea5d6018be7a (patch)
treeedbf62755eab8103c3f39ee7dfe2b0006692e857
parent011bda94ea9ab0bdb43751d8d19963beb5a887a0 (diff)
Attempt at a benchmark (crashes)
-rw-r--r--ad-dual.cabal15
-rw-r--r--bench/Main.hs57
-rw-r--r--src/Numeric/ADDual.hs209
-rw-r--r--src/Numeric/ADDual/Internal.hs226
4 files changed, 306 insertions, 201 deletions
diff --git a/ad-dual.cabal b/ad-dual.cabal
index 09697f0..5d3ca39 100644
--- a/ad-dual.cabal
+++ b/ad-dual.cabal
@@ -8,6 +8,7 @@ build-type: Simple
library
exposed-modules:
Numeric.ADDual
+ Numeric.ADDual.Internal
other-modules:
build-depends:
base >= 4.14.3,
@@ -28,3 +29,17 @@ test-suite test
hs-source-dirs: test
default-language: Haskell2010
ghc-options: -Wall
+
+benchmark bench
+ type: exitcode-stdio-1.0
+ main-is: Main.hs
+ build-depends:
+ base,
+ ad-dual,
+ ad,
+ criterion,
+ deepseq,
+ vector
+ hs-source-dirs: bench
+ default-language: Haskell2010
+ ghc-options: -Wall
diff --git a/bench/Main.hs b/bench/Main.hs
new file mode 100644
index 0000000..cb5e829
--- /dev/null
+++ b/bench/Main.hs
@@ -0,0 +1,57 @@
+{-# LANGUAGE DeriveTraversable #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE DeriveGeneric #-}
+module Main where
+
+import Control.DeepSeq
+import Criterion
+import Criterion.Main
+import qualified Data.Vector as V
+import GHC.Generics (Generic)
+
+import qualified Numeric.ADDual as ADD
+
+
+type Matrix s = V.Vector s
+
+data FNeural a = FNeural [(Matrix a, V.Vector a)] (V.Vector a)
+ deriving (Show, Functor, Foldable, Traversable, Generic)
+
+instance NFData a => NFData (FNeural a)
+
+fneural :: (Floating a, Ord a) => FNeural a -> a
+fneural (FNeural layers input) =
+ let dotp v1 v2 = V.sum (V.zipWith (*) v1 v2)
+
+ mat @. vec =
+ let n = V.length vec
+ m = V.length mat `div` n
+ in V.fromListN m $ map (\i -> dotp (V.slice (n*i) n mat) vec) [0 .. m-1]
+ (+.) = V.zipWith (+)
+
+ relu x = if x >= 0.0 then x else 0.0
+ safeSoftmax vec = let m = V.maximum vec
+ factor = V.sum (V.map (\z -> exp (z - m)) vec)
+ in V.map (\z -> exp (z - m) / factor) vec
+ forward [] x = safeSoftmax x
+ forward ((weights, bias) : lys) x =
+ let x' = V.map relu ((weights @. x) +. bias)
+ in forward lys x'
+ in V.sum $ forward layers input
+
+makeNeuralInput :: FNeural Double
+makeNeuralInput =
+ let genMatrix nin nout =
+ V.fromListN (nin*nout) [sin (fromIntegral @Int (i+j))
+ | i <- [0..nout-1], j <- [0..nin-1]]
+ genVector nout = V.fromListN nout [sin (0.41 * fromIntegral @Int i) | i <- [0..nout-1]]
+ -- 50 inputs; 2 hidden layers (100; 50); final softmax, then sum the outputs.
+ nIn = 50; n1 = 100; n2 = 50
+ in FNeural [(genMatrix nIn n1, genVector n1)
+ ,(genMatrix n1 n2, genVector n2)]
+ (genVector nIn)
+
+main :: IO ()
+main = defaultMain
+ [env (pure makeNeuralInput) $ \input ->
+ bench "neural" $ nf (\inp -> ADD.gradient' @Double fneural inp 1.0) input]
diff --git a/src/Numeric/ADDual.hs b/src/Numeric/ADDual.hs
index d9c5b74..2b69025 100644
--- a/src/Numeric/ADDual.hs
+++ b/src/Numeric/ADDual.hs
@@ -1,201 +1,8 @@
-{-# LANGUAGE DeriveTraversable #-}
-{-# LANGUAGE MultiParamTypeClasses #-}
-{-# LANGUAGE MultiWayIf #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TypeApplications #-}
-module Numeric.ADDual where
-
-import Control.Monad (when)
-import Control.Monad.Trans.Class (lift)
-import Control.Monad.Trans.State.Strict
-import Data.IORef
-import Data.Proxy
-import qualified Data.Vector.Storable as VS
-import qualified Data.Vector.Storable.Mutable as VSM
-import Foreign.Ptr
-import Foreign.Storable
-import GHC.Exts (withDict)
-import System.IO.Unsafe
-
--- import System.IO (hPutStrLn, stderr)
-
--- import Numeric.ADDual.IDGen
-
-
--- TODO: full vjp (just some more Traversable mess)
-{-# NOINLINE gradient' #-}
-gradient' :: forall a f. (Traversable f, Num a, Storable a)
- -- => Show a -- TODO: remove
- => (forall s. Taping s a => f (Dual s a) -> Dual s a)
- -> f a -> a -> (a, f a)
-gradient' f inp topctg = unsafePerformIO $ do
- -- hPutStrLn stderr "Preparing input"
- let (inp', starti) = runState (traverse (\x -> state (\i -> (Dual x i, i + 1))) inp) 0
- idref <- newIORef starti
- vec1 <- VSM.new (max 128 (2 * starti))
- taperef <- newIORef (MLog idref (Chunk 0 vec1) SLNil)
-
- -- hPutStrLn stderr "Running function"
- let Dual result outi = withDict @(Taping () a) taperef $ f @() inp'
- -- hPutStrLn stderr $ "result = " ++ show result ++ "; outi = " ++ show outi
- MLog _ lastChunk tapeTail <- readIORef taperef
-
- -- do tapestr <- showTape (tapeTail `Snoc` lastChunk)
- -- hPutStrLn stderr $ "tape = " ++ tapestr ""
-
- -- hPutStrLn stderr "Backpropagating"
- accums <- VSM.new (outi+1)
- VSM.write accums outi topctg
-
- let backpropagate i chunk@(Chunk ci0 vec) tape
- | i >= ci0 = do
- ctg <- VSM.read accums i
- Contrib i1 dx i2 dy <- VSM.read vec (i - ci0)
- when (i1 /= -1) $ VSM.modify accums (+ ctg*dx) i1
- when (i2 /= -1) $ VSM.modify accums (+ ctg*dy) i2
- backpropagate (i-1) chunk tape
- | otherwise = case tape of
- SLNil -> return () -- reached end of tape we should loop over
- tape'@Snoc{} `Snoc` chunk' -> backpropagate i chunk' tape'
- -- When we reach the last chunk, modify it so that its
- -- starting index is after the inputs.
- SLNil `Snoc` Chunk _ vec' ->
- backpropagate i (Chunk starti (VSM.slice starti (VSM.length vec' - starti) vec')) SLNil
-
- -- Ensure that if there are no more chunks in the tape tail, the starting
- -- index of the first chunk is adjusted so that backpropagate stops in time.
- case tapeTail of
- SLNil -> backpropagate outi (let Chunk _ vec = lastChunk
- in Chunk starti (VSM.slice starti (VSM.length vec - starti) vec))
- SLNil
- Snoc{} -> backpropagate outi lastChunk tapeTail
-
- -- do accums' <- VS.freeze accums
- -- hPutStrLn stderr $ "accums = " ++ show accums'
-
- -- hPutStrLn stderr "Reconstructing gradient"
- let readDeriv = do i <- get
- d <- lift $ VSM.read accums i
- put (i+1)
- return d
- grad <- evalStateT (traverse (\_ -> readDeriv) inp) 0
-
- return (result, grad)
-
-data Snoclist a = SLNil | Snoc !(Snoclist a) !a
- deriving (Show, Eq, Ord, Functor, Foldable, Traversable)
-
-data Contrib a = Contrib {-# UNPACK #-} !Int a -- ^ ID == -1 -> no contribution
- {-# UNPACK #-} !Int a -- ^ idem
- deriving (Show)
-
-instance Storable a => Storable (Contrib a) where
- sizeOf _ = 2 * (sizeOf (undefined :: Int) + sizeOf (undefined :: a))
- alignment _ = alignment (undefined :: Int)
- peek ptr = Contrib <$> peek (castPtr ptr)
- <*> peekByteOff (castPtr ptr) (sizeOf (undefined :: Int))
- <*> peekByteOff (castPtr ptr) (sizeOf (undefined :: Int) + sizeOf (undefined :: a))
- <*> peekByteOff (castPtr ptr) (2 * sizeOf (undefined :: Int) + sizeOf (undefined :: a))
- poke ptr (Contrib i1 dx i2 dy) = do
- poke (castPtr ptr) i1
- pokeByteOff (castPtr ptr) (sizeOf (undefined :: Int)) dx
- pokeByteOff (castPtr ptr) (sizeOf (undefined :: Int) + sizeOf (undefined :: a)) i2
- pokeByteOff (castPtr ptr) (2 * sizeOf (undefined :: Int) + sizeOf (undefined :: a)) dy
-
-data Chunk a = Chunk {-# UNPACK #-} !Int -- ^ First ID in this chunk
- {-# UNPACK #-} !(VSM.IOVector (Contrib a))
-
-data MLog s a = MLog !(IORef Int) -- ^ next ID to generate
- {-# UNPACK #-} !(Chunk a) -- ^ current running chunk
- !(Snoclist (Chunk a)) -- ^ tape
-
-showChunk :: (Storable a, Show a) => Chunk a -> IO ShowS
-showChunk (Chunk ci0 vec) = do
- vec' <- VS.freeze vec
- return (showString ("Chunk " ++ show ci0 ++ " ") . shows vec')
-
-showTape :: (Storable a, Show a) => Snoclist (Chunk a) -> IO ShowS
-showTape SLNil = return (showString "SLNil")
-showTape (tape `Snoc` chunk) = do
- s1 <- showTape tape
- s2 <- showChunk chunk
- return (s1 . showString " `Snoc` " . s2)
-
--- | This class does not have any instances defined, on purpose. You'll get one
--- magically when you differentiate.
-class Taping s a where
- getTape :: IORef (MLog s a)
-
-data Dual s a = Dual !a
- {-# UNPACK #-} !Int -- ^ -1 if this is a constant
-
-instance Eq a => Eq (Dual s a) where
- Dual x _ == Dual y _ = x == y
-
-instance Ord a => Ord (Dual s a) where
- compare (Dual x _) (Dual y _) = compare x y
-
-instance (Num a, Storable a, Taping s a) => Num (Dual s a) where
- Dual x i1 + Dual y i2 = Dual (x + y) (writeTape @a (Proxy @s) i1 1 i2 1)
- Dual x i1 - Dual y i2 = Dual (x - y) (writeTape @a (Proxy @s) i1 1 i2 (-1))
- Dual x i1 * Dual y i2 = Dual (x * y) (writeTape (Proxy @s) i1 y i2 x)
- negate (Dual x i1) = Dual (negate x) (writeTape @a (Proxy @s) i1 (-1) (-1) 0)
- abs (Dual x i1) = Dual (abs x) (writeTape (Proxy @s) i1 (x * signum x) (-1) 0)
- signum (Dual x _) = Dual (signum x) (-1)
- fromInteger n = Dual (fromInteger n) (-1)
-
-data WriteTapeAction a = WTANewvec (VSM.IOVector (Contrib a))
- | WTAOldTape (Snoclist (Chunk a))
-
-writeTape :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> Int
-writeTape _ i1 dx i2 dy = unsafePerformIO $ writeTapeIO (Proxy @s) i1 dx i2 dy
-
-writeTapeIO :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> IO Int
-writeTapeIO _ i1 dx i2 dy = do
- MLog idref (Chunk ci0 vec) _ <- readIORef (getTape @s)
- let n = VSM.length vec
- i <- atomicModifyIORef' idref (\i -> (i + 1, i))
- let idx = i - ci0
-
- if | idx < n -> do
- VSM.write vec idx (Contrib i1 dx i2 dy)
- return i
-
- -- check if we'd fit in the next chunk (overwhelmingly likely)
- | let newlen = 3 * n `div` 2
- , idx < n + newlen -> do
- newvec <- VSM.new newlen
- action <- atomicModifyIORef' (getTape @s) $ \(MLog idref' chunk@(Chunk ci0' vec') tape) ->
- if | ci0 == ci0' ->
- -- Likely (certain when single-threaded): no race condition,
- -- we get the chance to put the new chunk in place.
- (MLog idref' (Chunk (ci0 + n) newvec) (tape `Snoc` chunk), WTANewvec newvec)
- | i < ci0' + VSM.length vec' ->
- -- Race condition; need to write to appropriate position in this vector.
- (MLog idref' chunk tape, WTANewvec vec')
- | i < ci0' ->
- -- Very unlikely; need to write to old chunk in tape.
- (MLog idref' chunk tape, WTAOldTape tape)
- | otherwise ->
- -- We got an ID so far in the future that it doesn't even fit
- -- in the next chunk. But that can't happen, because we're
- -- only in this branch if the ID would have fit in the next
- -- chunk in the first place!
- error "writeTape: impossible"
- case action of
- WTANewvec vec' -> VSM.write vec' (idx - n) (Contrib i1 dx i2 dy)
- WTAOldTape tape ->
- let go SLNil = error "writeTape: no appropriate tape chunk?"
- go (tape' `Snoc` Chunk ci0' vec')
- -- The first comparison here is technically unnecessary, but
- -- I'm not courageous enough to remove it.
- | ci0' <= i, i < ci0' + VSM.length vec' =
- VSM.write vec' (i - ci0') (Contrib i1 dx i2 dy)
- | otherwise =
- go tape'
- in go tape
- return i
-
- -- there's a tremendous amount of competition, let's just try again
- | otherwise -> writeTapeIO (Proxy @s) i1 dx i2 dy
+module Numeric.ADDual (
+ gradient',
+ Dual,
+ Taping,
+ constant,
+) where
+
+import Numeric.ADDual.Internal
diff --git a/src/Numeric/ADDual/Internal.hs b/src/Numeric/ADDual/Internal.hs
new file mode 100644
index 0000000..8228694
--- /dev/null
+++ b/src/Numeric/ADDual/Internal.hs
@@ -0,0 +1,226 @@
+{-# LANGUAGE DeriveTraversable #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE MultiWayIf #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+module Numeric.ADDual.Internal where
+
+import Control.Monad (when)
+import Control.Monad.Trans.Class (lift)
+import Control.Monad.Trans.State.Strict
+import Data.IORef
+import Data.Proxy
+import qualified Data.Vector.Storable as VS
+import qualified Data.Vector.Storable.Mutable as VSM
+import Foreign.Ptr
+import Foreign.Storable
+import GHC.Exts (withDict)
+import System.IO.Unsafe
+
+-- import System.IO (hPutStrLn, stderr)
+
+-- import Numeric.ADDual.IDGen
+
+
+-- TODO: full vjp (just some more Traversable mess)
+{-# NOINLINE gradient' #-}
+gradient' :: forall a f. (Traversable f, Num a, Storable a)
+ -- => Show a -- TODO: remove
+ => (forall s. Taping s a => f (Dual s a) -> Dual s a)
+ -> f a -> a -> (a, f a)
+gradient' f inp topctg = unsafePerformIO $ do
+ -- hPutStrLn stderr "Preparing input"
+ let (inp', starti) = runState (traverse (\x -> state (\i -> (Dual x i, i + 1))) inp) 0
+ idref <- newIORef starti
+ vec1 <- VSM.new (max 128 (2 * starti))
+ taperef <- newIORef (MLog idref (Chunk 0 vec1) SLNil)
+
+ -- hPutStrLn stderr "Running function"
+ let Dual result outi = withDict @(Taping () a) taperef $ f @() inp'
+ -- hPutStrLn stderr $ "result = " ++ show result ++ "; outi = " ++ show outi
+ MLog _ lastChunk tapeTail <- readIORef taperef
+
+ -- do tapestr <- showTape (tapeTail `Snoc` lastChunk)
+ -- hPutStrLn stderr $ "tape = " ++ tapestr ""
+
+ -- hPutStrLn stderr "Backpropagating"
+ accums <- VSM.new (outi+1)
+ VSM.write accums outi topctg
+
+ let backpropagate i chunk@(Chunk ci0 vec) tape
+ | i >= ci0 = do
+ ctg <- VSM.read accums i
+ Contrib i1 dx i2 dy <- VSM.read vec (i - ci0)
+ when (i1 /= -1) $ VSM.modify accums (+ ctg*dx) i1
+ when (i2 /= -1) $ VSM.modify accums (+ ctg*dy) i2
+ backpropagate (i-1) chunk tape
+ | otherwise = case tape of
+ SLNil -> return () -- reached end of tape we should loop over
+ tape'@Snoc{} `Snoc` chunk' -> backpropagate i chunk' tape'
+ -- When we reach the last chunk, modify it so that its
+ -- starting index is after the inputs.
+ SLNil `Snoc` Chunk _ vec' ->
+ backpropagate i (Chunk starti (VSM.slice starti (VSM.length vec' - starti) vec')) SLNil
+
+ -- Ensure that if there are no more chunks in the tape tail, the starting
+ -- index of the first chunk is adjusted so that backpropagate stops in time.
+ case tapeTail of
+ SLNil -> backpropagate outi (let Chunk _ vec = lastChunk
+ in Chunk starti (VSM.slice starti (VSM.length vec - starti) vec))
+ SLNil
+ Snoc{} -> backpropagate outi lastChunk tapeTail
+
+ -- do accums' <- VS.freeze accums
+ -- hPutStrLn stderr $ "accums = " ++ show accums'
+
+ -- hPutStrLn stderr "Reconstructing gradient"
+ let readDeriv = do i <- get
+ d <- lift $ VSM.read accums i
+ put (i+1)
+ return d
+ grad <- evalStateT (traverse (\_ -> readDeriv) inp) 0
+
+ return (result, grad)
+
+data Snoclist a = SLNil | Snoc !(Snoclist a) !a
+ deriving (Show, Eq, Ord, Functor, Foldable, Traversable)
+
+data Contrib a = Contrib {-# UNPACK #-} !Int a -- ^ ID == -1 -> no contribution
+ {-# UNPACK #-} !Int a -- ^ idem
+ deriving (Show)
+
+instance Storable a => Storable (Contrib a) where
+ sizeOf _ = 2 * (sizeOf (undefined :: Int) + sizeOf (undefined :: a))
+ alignment _ = alignment (undefined :: Int)
+ peek ptr = Contrib <$> peek (castPtr ptr)
+ <*> peekByteOff (castPtr ptr) (sizeOf (undefined :: Int))
+ <*> peekByteOff (castPtr ptr) (sizeOf (undefined :: Int) + sizeOf (undefined :: a))
+ <*> peekByteOff (castPtr ptr) (2 * sizeOf (undefined :: Int) + sizeOf (undefined :: a))
+ poke ptr (Contrib i1 dx i2 dy) = do
+ poke (castPtr ptr) i1
+ pokeByteOff (castPtr ptr) (sizeOf (undefined :: Int)) dx
+ pokeByteOff (castPtr ptr) (sizeOf (undefined :: Int) + sizeOf (undefined :: a)) i2
+ pokeByteOff (castPtr ptr) (2 * sizeOf (undefined :: Int) + sizeOf (undefined :: a)) dy
+
+data Chunk a = Chunk {-# UNPACK #-} !Int -- ^ First ID in this chunk
+ {-# UNPACK #-} !(VSM.IOVector (Contrib a))
+
+data MLog s a = MLog !(IORef Int) -- ^ next ID to generate
+ {-# UNPACK #-} !(Chunk a) -- ^ current running chunk
+ !(Snoclist (Chunk a)) -- ^ tape
+
+showChunk :: (Storable a, Show a) => Chunk a -> IO ShowS
+showChunk (Chunk ci0 vec) = do
+ vec' <- VS.freeze vec
+ return (showString ("Chunk " ++ show ci0 ++ " ") . shows vec')
+
+showTape :: (Storable a, Show a) => Snoclist (Chunk a) -> IO ShowS
+showTape SLNil = return (showString "SLNil")
+showTape (tape `Snoc` chunk) = do
+ s1 <- showTape tape
+ s2 <- showChunk chunk
+ return (s1 . showString " `Snoc` " . s2)
+
+-- | This class does not have any instances defined, on purpose. You'll get one
+-- magically when you differentiate.
+class Taping s a where
+ getTape :: IORef (MLog s a)
+
+data Dual s a = Dual !a
+ {-# UNPACK #-} !Int -- ^ -1 if this is a constant
+
+instance Eq a => Eq (Dual s a) where
+ Dual x _ == Dual y _ = x == y
+
+instance Ord a => Ord (Dual s a) where
+ compare (Dual x _) (Dual y _) = compare x y
+
+instance (Num a, Storable a, Taping s a) => Num (Dual s a) where
+ Dual x i1 + Dual y i2 = Dual (x + y) (writeTape @a (Proxy @s) i1 1 i2 1)
+ Dual x i1 - Dual y i2 = Dual (x - y) (writeTape @a (Proxy @s) i1 1 i2 (-1))
+ Dual x i1 * Dual y i2 = Dual (x * y) (writeTape (Proxy @s) i1 y i2 x)
+ negate (Dual x i1) = Dual (negate x) (writeTape @a (Proxy @s) i1 (-1) (-1) 0)
+ abs (Dual x i1) = Dual (abs x) (writeTape (Proxy @s) i1 (x * signum x) (-1) 0)
+ signum (Dual x _) = Dual (signum x) (-1)
+ fromInteger n = Dual (fromInteger n) (-1)
+
+instance (Fractional a, Storable a, Taping s a) => Fractional (Dual s a) where
+ Dual x i1 / Dual y i2 = Dual (x / y) (writeTape (Proxy @s) i1 (recip y) i2 (-x/(y*y)))
+ recip (Dual x i1) = Dual (recip x) (writeTape (Proxy @s) i1 (-1/(x*x)) (-1) 0)
+ fromRational r = Dual (fromRational r) (-1)
+
+instance (Floating a, Storable a, Taping s a) => Floating (Dual s a) where
+ pi = Dual pi (-1)
+ exp (Dual x i1) = Dual (exp x) (writeTape (Proxy @s) i1 (exp x) (-1) 0)
+ log (Dual x i1) = Dual (log x) (writeTape (Proxy @s) i1 (recip x) (-1) 0)
+ sqrt (Dual x i1) = Dual (sqrt x) (writeTape (Proxy @s) i1 (recip (2*sqrt x)) (-1) 0)
+ -- d/dx (x ^ y) = d/dx (e ^ (y ln x)) = e ^ (y ln x) * d/dx (y ln x) = e ^ (y ln x) * y/x
+ -- d/dy (x ^ y) = d/dy (e ^ (y ln x)) = e ^ (y ln x) * d/dy (y ln x) = e ^ (y ln x) * ln x
+ Dual x i1 ** Dual y i2 =
+ let z = x ** y
+ in Dual z (writeTape (Proxy @s) i1 (z * y/x) i2 (z * log x))
+ logBase = undefined ; sin = undefined ; cos = undefined ; tan = undefined
+ asin = undefined ; acos = undefined ; atan = undefined ; sinh = undefined
+ cosh = undefined ; tanh = undefined ; asinh = undefined ; acosh = undefined
+ atanh = undefined
+
+constant :: a -> Dual s a
+constant x = Dual x (-1)
+
+data WriteTapeAction a = WTANewvec (VSM.IOVector (Contrib a))
+ | WTAOldTape (Snoclist (Chunk a))
+
+writeTape :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> Int
+writeTape _ i1 dx i2 dy = unsafePerformIO $ writeTapeIO (Proxy @s) i1 dx i2 dy
+
+writeTapeIO :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> IO Int
+writeTapeIO _ i1 dx i2 dy = do
+ MLog idref (Chunk ci0 vec) _ <- readIORef (getTape @s)
+ let n = VSM.length vec
+ i <- atomicModifyIORef' idref (\i -> (i + 1, i))
+ let idx = i - ci0
+
+ if | idx < n -> do
+ VSM.write vec idx (Contrib i1 dx i2 dy)
+ return i
+
+ -- check if we'd fit in the next chunk (overwhelmingly likely)
+ | let newlen = 3 * n `div` 2
+ , idx < n + newlen -> do
+ newvec <- VSM.new newlen
+ action <- atomicModifyIORef' (getTape @s) $ \(MLog idref' chunk@(Chunk ci0' vec') tape) ->
+ if | ci0 == ci0' ->
+ -- Likely (certain when single-threaded): no race condition,
+ -- we get the chance to put the new chunk in place.
+ (MLog idref' (Chunk (ci0 + n) newvec) (tape `Snoc` chunk), WTANewvec newvec)
+ | i < ci0' + VSM.length vec' ->
+ -- Race condition; need to write to appropriate position in this vector.
+ (MLog idref' chunk tape, WTANewvec vec')
+ | i < ci0' ->
+ -- Very unlikely; need to write to old chunk in tape.
+ (MLog idref' chunk tape, WTAOldTape tape)
+ | otherwise ->
+ -- We got an ID so far in the future that it doesn't even fit
+ -- in the next chunk. But that can't happen, because we're
+ -- only in this branch if the ID would have fit in the next
+ -- chunk in the first place!
+ error "writeTape: impossible"
+ case action of
+ WTANewvec vec' -> VSM.write vec' (idx - n) (Contrib i1 dx i2 dy)
+ WTAOldTape tape ->
+ let go SLNil = error "writeTape: no appropriate tape chunk?"
+ go (tape' `Snoc` Chunk ci0' vec')
+ -- The first comparison here is technically unnecessary, but
+ -- I'm not courageous enough to remove it.
+ | ci0' <= i, i < ci0' + VSM.length vec' =
+ VSM.write vec' (i - ci0') (Contrib i1 dx i2 dy)
+ | otherwise =
+ go tape'
+ in go tape
+ return i
+
+ -- there's a tremendous amount of competition, let's just try again
+ | otherwise -> writeTapeIO (Proxy @s) i1 dx i2 dy
+
+