aboutsummaryrefslogtreecommitdiff
path: root/src/Numeric/ADDual
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-02-20 10:11:57 +0100
committerTom Smeding <tom@tomsmeding.com>2025-02-20 10:11:57 +0100
commitfe3132304b6c25e5bebc9fb327e3ea5d6018be7a (patch)
treeedbf62755eab8103c3f39ee7dfe2b0006692e857 /src/Numeric/ADDual
parent011bda94ea9ab0bdb43751d8d19963beb5a887a0 (diff)
Attempt at a benchmark (crashes)
Diffstat (limited to 'src/Numeric/ADDual')
-rw-r--r--src/Numeric/ADDual/Internal.hs226
1 files changed, 226 insertions, 0 deletions
diff --git a/src/Numeric/ADDual/Internal.hs b/src/Numeric/ADDual/Internal.hs
new file mode 100644
index 0000000..8228694
--- /dev/null
+++ b/src/Numeric/ADDual/Internal.hs
@@ -0,0 +1,226 @@
+{-# LANGUAGE DeriveTraversable #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE MultiWayIf #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+module Numeric.ADDual.Internal where
+
+import Control.Monad (when)
+import Control.Monad.Trans.Class (lift)
+import Control.Monad.Trans.State.Strict
+import Data.IORef
+import Data.Proxy
+import qualified Data.Vector.Storable as VS
+import qualified Data.Vector.Storable.Mutable as VSM
+import Foreign.Ptr
+import Foreign.Storable
+import GHC.Exts (withDict)
+import System.IO.Unsafe
+
+-- import System.IO (hPutStrLn, stderr)
+
+-- import Numeric.ADDual.IDGen
+
+
+-- TODO: full vjp (just some more Traversable mess)
+{-# NOINLINE gradient' #-}
+gradient' :: forall a f. (Traversable f, Num a, Storable a)
+ -- => Show a -- TODO: remove
+ => (forall s. Taping s a => f (Dual s a) -> Dual s a)
+ -> f a -> a -> (a, f a)
+gradient' f inp topctg = unsafePerformIO $ do
+ -- hPutStrLn stderr "Preparing input"
+ let (inp', starti) = runState (traverse (\x -> state (\i -> (Dual x i, i + 1))) inp) 0
+ idref <- newIORef starti
+ vec1 <- VSM.new (max 128 (2 * starti))
+ taperef <- newIORef (MLog idref (Chunk 0 vec1) SLNil)
+
+ -- hPutStrLn stderr "Running function"
+ let Dual result outi = withDict @(Taping () a) taperef $ f @() inp'
+ -- hPutStrLn stderr $ "result = " ++ show result ++ "; outi = " ++ show outi
+ MLog _ lastChunk tapeTail <- readIORef taperef
+
+ -- do tapestr <- showTape (tapeTail `Snoc` lastChunk)
+ -- hPutStrLn stderr $ "tape = " ++ tapestr ""
+
+ -- hPutStrLn stderr "Backpropagating"
+ accums <- VSM.new (outi+1)
+ VSM.write accums outi topctg
+
+ let backpropagate i chunk@(Chunk ci0 vec) tape
+ | i >= ci0 = do
+ ctg <- VSM.read accums i
+ Contrib i1 dx i2 dy <- VSM.read vec (i - ci0)
+ when (i1 /= -1) $ VSM.modify accums (+ ctg*dx) i1
+ when (i2 /= -1) $ VSM.modify accums (+ ctg*dy) i2
+ backpropagate (i-1) chunk tape
+ | otherwise = case tape of
+ SLNil -> return () -- reached end of tape we should loop over
+ tape'@Snoc{} `Snoc` chunk' -> backpropagate i chunk' tape'
+ -- When we reach the last chunk, modify it so that its
+ -- starting index is after the inputs.
+ SLNil `Snoc` Chunk _ vec' ->
+ backpropagate i (Chunk starti (VSM.slice starti (VSM.length vec' - starti) vec')) SLNil
+
+ -- Ensure that if there are no more chunks in the tape tail, the starting
+ -- index of the first chunk is adjusted so that backpropagate stops in time.
+ case tapeTail of
+ SLNil -> backpropagate outi (let Chunk _ vec = lastChunk
+ in Chunk starti (VSM.slice starti (VSM.length vec - starti) vec))
+ SLNil
+ Snoc{} -> backpropagate outi lastChunk tapeTail
+
+ -- do accums' <- VS.freeze accums
+ -- hPutStrLn stderr $ "accums = " ++ show accums'
+
+ -- hPutStrLn stderr "Reconstructing gradient"
+ let readDeriv = do i <- get
+ d <- lift $ VSM.read accums i
+ put (i+1)
+ return d
+ grad <- evalStateT (traverse (\_ -> readDeriv) inp) 0
+
+ return (result, grad)
+
+data Snoclist a = SLNil | Snoc !(Snoclist a) !a
+ deriving (Show, Eq, Ord, Functor, Foldable, Traversable)
+
+data Contrib a = Contrib {-# UNPACK #-} !Int a -- ^ ID == -1 -> no contribution
+ {-# UNPACK #-} !Int a -- ^ idem
+ deriving (Show)
+
+instance Storable a => Storable (Contrib a) where
+ sizeOf _ = 2 * (sizeOf (undefined :: Int) + sizeOf (undefined :: a))
+ alignment _ = alignment (undefined :: Int)
+ peek ptr = Contrib <$> peek (castPtr ptr)
+ <*> peekByteOff (castPtr ptr) (sizeOf (undefined :: Int))
+ <*> peekByteOff (castPtr ptr) (sizeOf (undefined :: Int) + sizeOf (undefined :: a))
+ <*> peekByteOff (castPtr ptr) (2 * sizeOf (undefined :: Int) + sizeOf (undefined :: a))
+ poke ptr (Contrib i1 dx i2 dy) = do
+ poke (castPtr ptr) i1
+ pokeByteOff (castPtr ptr) (sizeOf (undefined :: Int)) dx
+ pokeByteOff (castPtr ptr) (sizeOf (undefined :: Int) + sizeOf (undefined :: a)) i2
+ pokeByteOff (castPtr ptr) (2 * sizeOf (undefined :: Int) + sizeOf (undefined :: a)) dy
+
+data Chunk a = Chunk {-# UNPACK #-} !Int -- ^ First ID in this chunk
+ {-# UNPACK #-} !(VSM.IOVector (Contrib a))
+
+data MLog s a = MLog !(IORef Int) -- ^ next ID to generate
+ {-# UNPACK #-} !(Chunk a) -- ^ current running chunk
+ !(Snoclist (Chunk a)) -- ^ tape
+
+showChunk :: (Storable a, Show a) => Chunk a -> IO ShowS
+showChunk (Chunk ci0 vec) = do
+ vec' <- VS.freeze vec
+ return (showString ("Chunk " ++ show ci0 ++ " ") . shows vec')
+
+showTape :: (Storable a, Show a) => Snoclist (Chunk a) -> IO ShowS
+showTape SLNil = return (showString "SLNil")
+showTape (tape `Snoc` chunk) = do
+ s1 <- showTape tape
+ s2 <- showChunk chunk
+ return (s1 . showString " `Snoc` " . s2)
+
+-- | This class does not have any instances defined, on purpose. You'll get one
+-- magically when you differentiate.
+class Taping s a where
+ getTape :: IORef (MLog s a)
+
+data Dual s a = Dual !a
+ {-# UNPACK #-} !Int -- ^ -1 if this is a constant
+
+instance Eq a => Eq (Dual s a) where
+ Dual x _ == Dual y _ = x == y
+
+instance Ord a => Ord (Dual s a) where
+ compare (Dual x _) (Dual y _) = compare x y
+
+instance (Num a, Storable a, Taping s a) => Num (Dual s a) where
+ Dual x i1 + Dual y i2 = Dual (x + y) (writeTape @a (Proxy @s) i1 1 i2 1)
+ Dual x i1 - Dual y i2 = Dual (x - y) (writeTape @a (Proxy @s) i1 1 i2 (-1))
+ Dual x i1 * Dual y i2 = Dual (x * y) (writeTape (Proxy @s) i1 y i2 x)
+ negate (Dual x i1) = Dual (negate x) (writeTape @a (Proxy @s) i1 (-1) (-1) 0)
+ abs (Dual x i1) = Dual (abs x) (writeTape (Proxy @s) i1 (x * signum x) (-1) 0)
+ signum (Dual x _) = Dual (signum x) (-1)
+ fromInteger n = Dual (fromInteger n) (-1)
+
+instance (Fractional a, Storable a, Taping s a) => Fractional (Dual s a) where
+ Dual x i1 / Dual y i2 = Dual (x / y) (writeTape (Proxy @s) i1 (recip y) i2 (-x/(y*y)))
+ recip (Dual x i1) = Dual (recip x) (writeTape (Proxy @s) i1 (-1/(x*x)) (-1) 0)
+ fromRational r = Dual (fromRational r) (-1)
+
+instance (Floating a, Storable a, Taping s a) => Floating (Dual s a) where
+ pi = Dual pi (-1)
+ exp (Dual x i1) = Dual (exp x) (writeTape (Proxy @s) i1 (exp x) (-1) 0)
+ log (Dual x i1) = Dual (log x) (writeTape (Proxy @s) i1 (recip x) (-1) 0)
+ sqrt (Dual x i1) = Dual (sqrt x) (writeTape (Proxy @s) i1 (recip (2*sqrt x)) (-1) 0)
+ -- d/dx (x ^ y) = d/dx (e ^ (y ln x)) = e ^ (y ln x) * d/dx (y ln x) = e ^ (y ln x) * y/x
+ -- d/dy (x ^ y) = d/dy (e ^ (y ln x)) = e ^ (y ln x) * d/dy (y ln x) = e ^ (y ln x) * ln x
+ Dual x i1 ** Dual y i2 =
+ let z = x ** y
+ in Dual z (writeTape (Proxy @s) i1 (z * y/x) i2 (z * log x))
+ logBase = undefined ; sin = undefined ; cos = undefined ; tan = undefined
+ asin = undefined ; acos = undefined ; atan = undefined ; sinh = undefined
+ cosh = undefined ; tanh = undefined ; asinh = undefined ; acosh = undefined
+ atanh = undefined
+
+constant :: a -> Dual s a
+constant x = Dual x (-1)
+
+data WriteTapeAction a = WTANewvec (VSM.IOVector (Contrib a))
+ | WTAOldTape (Snoclist (Chunk a))
+
+writeTape :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> Int
+writeTape _ i1 dx i2 dy = unsafePerformIO $ writeTapeIO (Proxy @s) i1 dx i2 dy
+
+writeTapeIO :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> IO Int
+writeTapeIO _ i1 dx i2 dy = do
+ MLog idref (Chunk ci0 vec) _ <- readIORef (getTape @s)
+ let n = VSM.length vec
+ i <- atomicModifyIORef' idref (\i -> (i + 1, i))
+ let idx = i - ci0
+
+ if | idx < n -> do
+ VSM.write vec idx (Contrib i1 dx i2 dy)
+ return i
+
+ -- check if we'd fit in the next chunk (overwhelmingly likely)
+ | let newlen = 3 * n `div` 2
+ , idx < n + newlen -> do
+ newvec <- VSM.new newlen
+ action <- atomicModifyIORef' (getTape @s) $ \(MLog idref' chunk@(Chunk ci0' vec') tape) ->
+ if | ci0 == ci0' ->
+ -- Likely (certain when single-threaded): no race condition,
+ -- we get the chance to put the new chunk in place.
+ (MLog idref' (Chunk (ci0 + n) newvec) (tape `Snoc` chunk), WTANewvec newvec)
+ | i < ci0' + VSM.length vec' ->
+ -- Race condition; need to write to appropriate position in this vector.
+ (MLog idref' chunk tape, WTANewvec vec')
+ | i < ci0' ->
+ -- Very unlikely; need to write to old chunk in tape.
+ (MLog idref' chunk tape, WTAOldTape tape)
+ | otherwise ->
+ -- We got an ID so far in the future that it doesn't even fit
+ -- in the next chunk. But that can't happen, because we're
+ -- only in this branch if the ID would have fit in the next
+ -- chunk in the first place!
+ error "writeTape: impossible"
+ case action of
+ WTANewvec vec' -> VSM.write vec' (idx - n) (Contrib i1 dx i2 dy)
+ WTAOldTape tape ->
+ let go SLNil = error "writeTape: no appropriate tape chunk?"
+ go (tape' `Snoc` Chunk ci0' vec')
+ -- The first comparison here is technically unnecessary, but
+ -- I'm not courageous enough to remove it.
+ | ci0' <= i, i < ci0' + VSM.length vec' =
+ VSM.write vec' (i - ci0') (Contrib i1 dx i2 dy)
+ | otherwise =
+ go tape'
+ in go tape
+ return i
+
+ -- there's a tremendous amount of competition, let's just try again
+ | otherwise -> writeTapeIO (Proxy @s) i1 dx i2 dy
+
+