diff options
| -rw-r--r-- | ad-dual.cabal | 15 | ||||
| -rw-r--r-- | bench/Main.hs | 57 | ||||
| -rw-r--r-- | src/Numeric/ADDual.hs | 207 | ||||
| -rw-r--r-- | src/Numeric/ADDual/Internal.hs | 226 | 
4 files changed, 305 insertions, 200 deletions
diff --git a/ad-dual.cabal b/ad-dual.cabal index 09697f0..5d3ca39 100644 --- a/ad-dual.cabal +++ b/ad-dual.cabal @@ -8,6 +8,7 @@ build-type:      Simple  library    exposed-modules:      Numeric.ADDual +    Numeric.ADDual.Internal    other-modules:    build-depends:      base >= 4.14.3, @@ -28,3 +29,17 @@ test-suite test    hs-source-dirs: test    default-language: Haskell2010    ghc-options: -Wall + +benchmark bench +  type: exitcode-stdio-1.0 +  main-is: Main.hs +  build-depends: +    base, +    ad-dual, +    ad, +    criterion, +    deepseq, +    vector +  hs-source-dirs: bench +  default-language: Haskell2010 +  ghc-options: -Wall diff --git a/bench/Main.hs b/bench/Main.hs new file mode 100644 index 0000000..cb5e829 --- /dev/null +++ b/bench/Main.hs @@ -0,0 +1,57 @@ +{-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE DeriveGeneric #-} +module Main where + +import Control.DeepSeq +import Criterion +import Criterion.Main +import qualified Data.Vector as V +import GHC.Generics (Generic) + +import qualified Numeric.ADDual as ADD + + +type Matrix s = V.Vector s + +data FNeural a = FNeural [(Matrix a, V.Vector a)] (V.Vector a) +  deriving (Show, Functor, Foldable, Traversable, Generic) + +instance NFData a => NFData (FNeural a) + +fneural :: (Floating a, Ord a) => FNeural a -> a +fneural (FNeural layers input) = +  let dotp v1 v2 = V.sum (V.zipWith (*) v1 v2) + +      mat @. vec = +        let n = V.length vec +            m = V.length mat `div` n +        in V.fromListN m $ map (\i -> dotp (V.slice (n*i) n mat) vec) [0 .. m-1] +      (+.) = V.zipWith (+) + +      relu x = if x >= 0.0 then x else 0.0 +      safeSoftmax vec = let m = V.maximum vec +                            factor = V.sum (V.map (\z -> exp (z - m)) vec) +                        in V.map (\z -> exp (z - m) / factor) vec +      forward [] x = safeSoftmax x +      forward ((weights, bias) : lys) x = +        let x' = V.map relu ((weights @. x) +. bias) +        in forward lys x' +  in V.sum $ forward layers input + +makeNeuralInput :: FNeural Double +makeNeuralInput = +  let genMatrix nin nout = +        V.fromListN (nin*nout) [sin (fromIntegral @Int (i+j)) +                               | i <- [0..nout-1], j <- [0..nin-1]] +      genVector nout = V.fromListN nout [sin (0.41 * fromIntegral @Int i) | i <- [0..nout-1]] +      -- 50 inputs; 2 hidden layers (100; 50); final softmax, then sum the outputs. +      nIn = 50; n1 = 100; n2 = 50 +  in FNeural [(genMatrix nIn n1, genVector n1) +             ,(genMatrix n1 n2, genVector n2)] +             (genVector nIn) + +main :: IO () +main = defaultMain +  [env (pure makeNeuralInput) $ \input -> +      bench "neural" $ nf (\inp -> ADD.gradient' @Double fneural inp 1.0) input] diff --git a/src/Numeric/ADDual.hs b/src/Numeric/ADDual.hs index d9c5b74..2b69025 100644 --- a/src/Numeric/ADDual.hs +++ b/src/Numeric/ADDual.hs @@ -1,201 +1,8 @@ -{-# LANGUAGE DeriveTraversable #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE MultiWayIf #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -module Numeric.ADDual where +module Numeric.ADDual ( +  gradient', +  Dual, +  Taping, +  constant, +) where -import Control.Monad (when) -import Control.Monad.Trans.Class (lift) -import Control.Monad.Trans.State.Strict -import Data.IORef -import Data.Proxy -import qualified Data.Vector.Storable as VS -import qualified Data.Vector.Storable.Mutable as VSM -import Foreign.Ptr -import Foreign.Storable -import GHC.Exts (withDict) -import System.IO.Unsafe - --- import System.IO (hPutStrLn, stderr) - --- import Numeric.ADDual.IDGen - - --- TODO: full vjp (just some more Traversable mess) -{-# NOINLINE gradient' #-} -gradient' :: forall a f. (Traversable f, Num a, Storable a) -          -- => Show a  -- TODO: remove -          => (forall s. Taping s a => f (Dual s a) -> Dual s a) -          -> f a -> a -> (a, f a) -gradient' f inp topctg = unsafePerformIO $ do -  -- hPutStrLn stderr "Preparing input" -  let (inp', starti) = runState (traverse (\x -> state (\i -> (Dual x i, i + 1))) inp) 0 -  idref <- newIORef starti -  vec1 <- VSM.new (max 128 (2 * starti)) -  taperef <- newIORef (MLog idref (Chunk 0 vec1) SLNil) - -  -- hPutStrLn stderr "Running function" -  let Dual result outi = withDict @(Taping () a) taperef $ f @() inp' -  -- hPutStrLn stderr $ "result = " ++ show result ++ "; outi = " ++ show outi -  MLog _ lastChunk tapeTail <- readIORef taperef - -  -- do tapestr <- showTape (tapeTail `Snoc` lastChunk) -  --    hPutStrLn stderr $ "tape = " ++ tapestr "" - -  -- hPutStrLn stderr "Backpropagating" -  accums <- VSM.new (outi+1) -  VSM.write accums outi topctg - -  let backpropagate i chunk@(Chunk ci0 vec) tape -        | i >= ci0 = do -            ctg <- VSM.read accums i -            Contrib i1 dx i2 dy <- VSM.read vec (i - ci0) -            when (i1 /= -1) $ VSM.modify accums (+ ctg*dx) i1 -            when (i2 /= -1) $ VSM.modify accums (+ ctg*dy) i2 -            backpropagate (i-1) chunk tape -        | otherwise = case tape of -                        SLNil -> return ()  -- reached end of tape we should loop over -                        tape'@Snoc{} `Snoc` chunk' -> backpropagate i chunk' tape' -                        -- When we reach the last chunk, modify it so that its -                        -- starting index is after the inputs. -                        SLNil `Snoc` Chunk _ vec' -> -                          backpropagate i (Chunk starti (VSM.slice starti (VSM.length vec' - starti) vec')) SLNil - -  -- Ensure that if there are no more chunks in the tape tail, the starting -  -- index of the first chunk is adjusted so that backpropagate stops in time. -  case tapeTail of -    SLNil -> backpropagate outi (let Chunk _ vec = lastChunk -                                 in Chunk starti (VSM.slice starti (VSM.length vec - starti) vec)) -                                SLNil -    Snoc{} -> backpropagate outi lastChunk tapeTail - -  -- do accums' <- VS.freeze accums -  --    hPutStrLn stderr $ "accums = " ++ show accums' - -  -- hPutStrLn stderr "Reconstructing gradient" -  let readDeriv = do i <- get -                     d <- lift $ VSM.read accums i -                     put (i+1) -                     return d -  grad <- evalStateT (traverse (\_ -> readDeriv) inp) 0 - -  return (result, grad) - -data Snoclist a = SLNil | Snoc !(Snoclist a) !a -  deriving (Show, Eq, Ord, Functor, Foldable, Traversable) - -data Contrib a = Contrib {-# UNPACK #-} !Int a  -- ^ ID == -1 -> no contribution -                         {-# UNPACK #-} !Int a  -- ^ idem -  deriving (Show) - -instance Storable a => Storable (Contrib a) where -  sizeOf _ = 2 * (sizeOf (undefined :: Int) + sizeOf (undefined :: a)) -  alignment _ = alignment (undefined :: Int) -  peek ptr = Contrib <$> peek (castPtr ptr) -                     <*> peekByteOff (castPtr ptr) (sizeOf (undefined :: Int)) -                     <*> peekByteOff (castPtr ptr) (sizeOf (undefined :: Int) + sizeOf (undefined :: a)) -                     <*> peekByteOff (castPtr ptr) (2 * sizeOf (undefined :: Int) + sizeOf (undefined :: a)) -  poke ptr (Contrib i1 dx i2 dy) = do -    poke (castPtr ptr) i1 -    pokeByteOff (castPtr ptr) (sizeOf (undefined :: Int)) dx -    pokeByteOff (castPtr ptr) (sizeOf (undefined :: Int) + sizeOf (undefined :: a)) i2 -    pokeByteOff (castPtr ptr) (2 * sizeOf (undefined :: Int) + sizeOf (undefined :: a)) dy - -data Chunk a = Chunk {-# UNPACK #-} !Int  -- ^ First ID in this chunk -                     {-# UNPACK #-} !(VSM.IOVector (Contrib a)) - -data MLog s a = MLog !(IORef Int)  -- ^ next ID to generate -                     {-# UNPACK #-} !(Chunk a)  -- ^ current running chunk -                     !(Snoclist (Chunk a))  -- ^ tape - -showChunk :: (Storable a, Show a) => Chunk a -> IO ShowS -showChunk (Chunk ci0 vec) = do -  vec' <- VS.freeze vec -  return (showString ("Chunk " ++ show ci0 ++ " ") . shows vec') - -showTape :: (Storable a, Show a) => Snoclist (Chunk a) -> IO ShowS -showTape SLNil = return (showString "SLNil") -showTape (tape `Snoc` chunk) = do -  s1 <- showTape tape -  s2 <- showChunk chunk -  return (s1 . showString " `Snoc` " . s2) - --- | This class does not have any instances defined, on purpose. You'll get one --- magically when you differentiate. -class Taping s a where -  getTape :: IORef (MLog s a) - -data Dual s a = Dual !a -                     {-# UNPACK #-} !Int  -- ^ -1 if this is a constant - -instance Eq a => Eq (Dual s a) where -  Dual x _ == Dual y _ = x == y - -instance Ord a => Ord (Dual s a) where -  compare (Dual x _) (Dual y _) = compare x y - -instance (Num a, Storable a, Taping s a) => Num (Dual s a) where -  Dual x i1 + Dual y i2 = Dual (x + y)    (writeTape @a (Proxy @s) i1 1 i2 1) -  Dual x i1 - Dual y i2 = Dual (x - y)    (writeTape @a (Proxy @s) i1 1 i2 (-1)) -  Dual x i1 * Dual y i2 = Dual (x * y)    (writeTape    (Proxy @s) i1 y i2 x) -  negate (Dual x i1)    = Dual (negate x) (writeTape @a (Proxy @s) i1 (-1) (-1) 0) -  abs (Dual x i1)       = Dual (abs x)    (writeTape    (Proxy @s) i1 (x * signum x) (-1) 0) -  signum (Dual x _) = Dual (signum x) (-1) -  fromInteger n = Dual (fromInteger n) (-1) - -data WriteTapeAction a = WTANewvec (VSM.IOVector (Contrib a)) -                       | WTAOldTape (Snoclist (Chunk a)) - -writeTape :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> Int -writeTape _ i1 dx i2 dy = unsafePerformIO $ writeTapeIO (Proxy @s) i1 dx i2 dy - -writeTapeIO :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> IO Int -writeTapeIO _ i1 dx i2 dy = do -  MLog idref (Chunk ci0 vec) _ <- readIORef (getTape @s) -  let n = VSM.length vec -  i <- atomicModifyIORef' idref (\i -> (i + 1, i)) -  let idx = i - ci0 - -  if | idx < n -> do -         VSM.write vec idx (Contrib i1 dx i2 dy) -         return i - -     -- check if we'd fit in the next chunk (overwhelmingly likely) -     | let newlen = 3 * n `div` 2 -     , idx < n + newlen -> do -         newvec <- VSM.new newlen -         action <- atomicModifyIORef' (getTape @s) $ \(MLog idref' chunk@(Chunk ci0' vec') tape) -> -           if | ci0 == ci0' -> -                 -- Likely (certain when single-threaded): no race condition, -                 -- we get the chance to put the new chunk in place. -                 (MLog idref' (Chunk (ci0 + n) newvec) (tape `Snoc` chunk), WTANewvec newvec) -              | i < ci0' + VSM.length vec' -> -                 -- Race condition; need to write to appropriate position in this vector. -                 (MLog idref' chunk tape, WTANewvec vec') -              | i < ci0' -> -                 -- Very unlikely; need to write to old chunk in tape. -                 (MLog idref' chunk tape, WTAOldTape tape) -              | otherwise -> -                 -- We got an ID so far in the future that it doesn't even fit -                 -- in the next chunk. But that can't happen, because we're -                 -- only in this branch if the ID would have fit in the next -                 -- chunk in the first place! -                 error "writeTape: impossible" -         case action of -           WTANewvec vec' -> VSM.write vec' (idx - n) (Contrib i1 dx i2 dy) -           WTAOldTape tape -> -             let go SLNil = error "writeTape: no appropriate tape chunk?" -                 go (tape' `Snoc` Chunk ci0' vec') -                   -- The first comparison here is technically unnecessary, but -                   -- I'm not courageous enough to remove it. -                   | ci0' <= i, i < ci0' + VSM.length vec' = -                       VSM.write vec' (i - ci0') (Contrib i1 dx i2 dy) -                   | otherwise = -                       go tape' -             in go tape -         return i - -     -- there's a tremendous amount of competition, let's just try again -     | otherwise -> writeTapeIO (Proxy @s) i1 dx i2 dy +import Numeric.ADDual.Internal diff --git a/src/Numeric/ADDual/Internal.hs b/src/Numeric/ADDual/Internal.hs new file mode 100644 index 0000000..8228694 --- /dev/null +++ b/src/Numeric/ADDual/Internal.hs @@ -0,0 +1,226 @@ +{-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +module Numeric.ADDual.Internal where + +import Control.Monad (when) +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.State.Strict +import Data.IORef +import Data.Proxy +import qualified Data.Vector.Storable as VS +import qualified Data.Vector.Storable.Mutable as VSM +import Foreign.Ptr +import Foreign.Storable +import GHC.Exts (withDict) +import System.IO.Unsafe + +-- import System.IO (hPutStrLn, stderr) + +-- import Numeric.ADDual.IDGen + + +-- TODO: full vjp (just some more Traversable mess) +{-# NOINLINE gradient' #-} +gradient' :: forall a f. (Traversable f, Num a, Storable a) +          -- => Show a  -- TODO: remove +          => (forall s. Taping s a => f (Dual s a) -> Dual s a) +          -> f a -> a -> (a, f a) +gradient' f inp topctg = unsafePerformIO $ do +  -- hPutStrLn stderr "Preparing input" +  let (inp', starti) = runState (traverse (\x -> state (\i -> (Dual x i, i + 1))) inp) 0 +  idref <- newIORef starti +  vec1 <- VSM.new (max 128 (2 * starti)) +  taperef <- newIORef (MLog idref (Chunk 0 vec1) SLNil) + +  -- hPutStrLn stderr "Running function" +  let Dual result outi = withDict @(Taping () a) taperef $ f @() inp' +  -- hPutStrLn stderr $ "result = " ++ show result ++ "; outi = " ++ show outi +  MLog _ lastChunk tapeTail <- readIORef taperef + +  -- do tapestr <- showTape (tapeTail `Snoc` lastChunk) +  --    hPutStrLn stderr $ "tape = " ++ tapestr "" + +  -- hPutStrLn stderr "Backpropagating" +  accums <- VSM.new (outi+1) +  VSM.write accums outi topctg + +  let backpropagate i chunk@(Chunk ci0 vec) tape +        | i >= ci0 = do +            ctg <- VSM.read accums i +            Contrib i1 dx i2 dy <- VSM.read vec (i - ci0) +            when (i1 /= -1) $ VSM.modify accums (+ ctg*dx) i1 +            when (i2 /= -1) $ VSM.modify accums (+ ctg*dy) i2 +            backpropagate (i-1) chunk tape +        | otherwise = case tape of +                        SLNil -> return ()  -- reached end of tape we should loop over +                        tape'@Snoc{} `Snoc` chunk' -> backpropagate i chunk' tape' +                        -- When we reach the last chunk, modify it so that its +                        -- starting index is after the inputs. +                        SLNil `Snoc` Chunk _ vec' -> +                          backpropagate i (Chunk starti (VSM.slice starti (VSM.length vec' - starti) vec')) SLNil + +  -- Ensure that if there are no more chunks in the tape tail, the starting +  -- index of the first chunk is adjusted so that backpropagate stops in time. +  case tapeTail of +    SLNil -> backpropagate outi (let Chunk _ vec = lastChunk +                                 in Chunk starti (VSM.slice starti (VSM.length vec - starti) vec)) +                                SLNil +    Snoc{} -> backpropagate outi lastChunk tapeTail + +  -- do accums' <- VS.freeze accums +  --    hPutStrLn stderr $ "accums = " ++ show accums' + +  -- hPutStrLn stderr "Reconstructing gradient" +  let readDeriv = do i <- get +                     d <- lift $ VSM.read accums i +                     put (i+1) +                     return d +  grad <- evalStateT (traverse (\_ -> readDeriv) inp) 0 + +  return (result, grad) + +data Snoclist a = SLNil | Snoc !(Snoclist a) !a +  deriving (Show, Eq, Ord, Functor, Foldable, Traversable) + +data Contrib a = Contrib {-# UNPACK #-} !Int a  -- ^ ID == -1 -> no contribution +                         {-# UNPACK #-} !Int a  -- ^ idem +  deriving (Show) + +instance Storable a => Storable (Contrib a) where +  sizeOf _ = 2 * (sizeOf (undefined :: Int) + sizeOf (undefined :: a)) +  alignment _ = alignment (undefined :: Int) +  peek ptr = Contrib <$> peek (castPtr ptr) +                     <*> peekByteOff (castPtr ptr) (sizeOf (undefined :: Int)) +                     <*> peekByteOff (castPtr ptr) (sizeOf (undefined :: Int) + sizeOf (undefined :: a)) +                     <*> peekByteOff (castPtr ptr) (2 * sizeOf (undefined :: Int) + sizeOf (undefined :: a)) +  poke ptr (Contrib i1 dx i2 dy) = do +    poke (castPtr ptr) i1 +    pokeByteOff (castPtr ptr) (sizeOf (undefined :: Int)) dx +    pokeByteOff (castPtr ptr) (sizeOf (undefined :: Int) + sizeOf (undefined :: a)) i2 +    pokeByteOff (castPtr ptr) (2 * sizeOf (undefined :: Int) + sizeOf (undefined :: a)) dy + +data Chunk a = Chunk {-# UNPACK #-} !Int  -- ^ First ID in this chunk +                     {-# UNPACK #-} !(VSM.IOVector (Contrib a)) + +data MLog s a = MLog !(IORef Int)  -- ^ next ID to generate +                     {-# UNPACK #-} !(Chunk a)  -- ^ current running chunk +                     !(Snoclist (Chunk a))  -- ^ tape + +showChunk :: (Storable a, Show a) => Chunk a -> IO ShowS +showChunk (Chunk ci0 vec) = do +  vec' <- VS.freeze vec +  return (showString ("Chunk " ++ show ci0 ++ " ") . shows vec') + +showTape :: (Storable a, Show a) => Snoclist (Chunk a) -> IO ShowS +showTape SLNil = return (showString "SLNil") +showTape (tape `Snoc` chunk) = do +  s1 <- showTape tape +  s2 <- showChunk chunk +  return (s1 . showString " `Snoc` " . s2) + +-- | This class does not have any instances defined, on purpose. You'll get one +-- magically when you differentiate. +class Taping s a where +  getTape :: IORef (MLog s a) + +data Dual s a = Dual !a +                     {-# UNPACK #-} !Int  -- ^ -1 if this is a constant + +instance Eq a => Eq (Dual s a) where +  Dual x _ == Dual y _ = x == y + +instance Ord a => Ord (Dual s a) where +  compare (Dual x _) (Dual y _) = compare x y + +instance (Num a, Storable a, Taping s a) => Num (Dual s a) where +  Dual x i1 + Dual y i2 = Dual (x + y)    (writeTape @a (Proxy @s) i1 1 i2 1) +  Dual x i1 - Dual y i2 = Dual (x - y)    (writeTape @a (Proxy @s) i1 1 i2 (-1)) +  Dual x i1 * Dual y i2 = Dual (x * y)    (writeTape    (Proxy @s) i1 y i2 x) +  negate (Dual x i1)    = Dual (negate x) (writeTape @a (Proxy @s) i1 (-1) (-1) 0) +  abs (Dual x i1)       = Dual (abs x)    (writeTape    (Proxy @s) i1 (x * signum x) (-1) 0) +  signum (Dual x _) = Dual (signum x) (-1) +  fromInteger n = Dual (fromInteger n) (-1) + +instance (Fractional a, Storable a, Taping s a) => Fractional (Dual s a) where +  Dual x i1 / Dual y i2 = Dual (x / y)   (writeTape (Proxy @s) i1 (recip y) i2 (-x/(y*y))) +  recip (Dual x i1)     = Dual (recip x) (writeTape (Proxy @s) i1 (-1/(x*x)) (-1) 0) +  fromRational r = Dual (fromRational r) (-1) + +instance (Floating a, Storable a, Taping s a) => Floating (Dual s a) where +  pi = Dual pi (-1) +  exp (Dual x i1) = Dual (exp x) (writeTape (Proxy @s) i1 (exp x) (-1) 0) +  log (Dual x i1) = Dual (log x) (writeTape (Proxy @s) i1 (recip x) (-1) 0) +  sqrt (Dual x i1) = Dual (sqrt x) (writeTape (Proxy @s) i1 (recip (2*sqrt x)) (-1) 0) +  -- d/dx (x ^ y) = d/dx (e ^ (y ln x)) = e ^ (y ln x) * d/dx (y ln x) = e ^ (y ln x) * y/x +  -- d/dy (x ^ y) = d/dy (e ^ (y ln x)) = e ^ (y ln x) * d/dy (y ln x) = e ^ (y ln x) * ln x +  Dual x i1 ** Dual y i2 = +    let z = x ** y +    in Dual z (writeTape (Proxy @s) i1 (z * y/x) i2 (z * log x)) +  logBase = undefined ; sin = undefined ; cos = undefined ; tan = undefined +  asin = undefined ; acos = undefined ; atan = undefined ; sinh = undefined +  cosh = undefined ; tanh = undefined ; asinh = undefined ; acosh = undefined +  atanh = undefined + +constant :: a -> Dual s a +constant x = Dual x (-1) + +data WriteTapeAction a = WTANewvec (VSM.IOVector (Contrib a)) +                       | WTAOldTape (Snoclist (Chunk a)) + +writeTape :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> Int +writeTape _ i1 dx i2 dy = unsafePerformIO $ writeTapeIO (Proxy @s) i1 dx i2 dy + +writeTapeIO :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> IO Int +writeTapeIO _ i1 dx i2 dy = do +  MLog idref (Chunk ci0 vec) _ <- readIORef (getTape @s) +  let n = VSM.length vec +  i <- atomicModifyIORef' idref (\i -> (i + 1, i)) +  let idx = i - ci0 + +  if | idx < n -> do +         VSM.write vec idx (Contrib i1 dx i2 dy) +         return i + +     -- check if we'd fit in the next chunk (overwhelmingly likely) +     | let newlen = 3 * n `div` 2 +     , idx < n + newlen -> do +         newvec <- VSM.new newlen +         action <- atomicModifyIORef' (getTape @s) $ \(MLog idref' chunk@(Chunk ci0' vec') tape) -> +           if | ci0 == ci0' -> +                 -- Likely (certain when single-threaded): no race condition, +                 -- we get the chance to put the new chunk in place. +                 (MLog idref' (Chunk (ci0 + n) newvec) (tape `Snoc` chunk), WTANewvec newvec) +              | i < ci0' + VSM.length vec' -> +                 -- Race condition; need to write to appropriate position in this vector. +                 (MLog idref' chunk tape, WTANewvec vec') +              | i < ci0' -> +                 -- Very unlikely; need to write to old chunk in tape. +                 (MLog idref' chunk tape, WTAOldTape tape) +              | otherwise -> +                 -- We got an ID so far in the future that it doesn't even fit +                 -- in the next chunk. But that can't happen, because we're +                 -- only in this branch if the ID would have fit in the next +                 -- chunk in the first place! +                 error "writeTape: impossible" +         case action of +           WTANewvec vec' -> VSM.write vec' (idx - n) (Contrib i1 dx i2 dy) +           WTAOldTape tape -> +             let go SLNil = error "writeTape: no appropriate tape chunk?" +                 go (tape' `Snoc` Chunk ci0' vec') +                   -- The first comparison here is technically unnecessary, but +                   -- I'm not courageous enough to remove it. +                   | ci0' <= i, i < ci0' + VSM.length vec' = +                       VSM.write vec' (i - ci0') (Contrib i1 dx i2 dy) +                   | otherwise = +                       go tape' +             in go tape +         return i + +     -- there's a tremendous amount of competition, let's just try again +     | otherwise -> writeTapeIO (Proxy @s) i1 dx i2 dy + +  | 
