aboutsummaryrefslogtreecommitdiff
path: root/src/Numeric/ADDual.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-02-20 10:11:57 +0100
committerTom Smeding <tom@tomsmeding.com>2025-02-20 10:11:57 +0100
commitfe3132304b6c25e5bebc9fb327e3ea5d6018be7a (patch)
treeedbf62755eab8103c3f39ee7dfe2b0006692e857 /src/Numeric/ADDual.hs
parent011bda94ea9ab0bdb43751d8d19963beb5a887a0 (diff)
Attempt at a benchmark (crashes)
Diffstat (limited to 'src/Numeric/ADDual.hs')
-rw-r--r--src/Numeric/ADDual.hs209
1 files changed, 8 insertions, 201 deletions
diff --git a/src/Numeric/ADDual.hs b/src/Numeric/ADDual.hs
index d9c5b74..2b69025 100644
--- a/src/Numeric/ADDual.hs
+++ b/src/Numeric/ADDual.hs
@@ -1,201 +1,8 @@
-{-# LANGUAGE DeriveTraversable #-}
-{-# LANGUAGE MultiParamTypeClasses #-}
-{-# LANGUAGE MultiWayIf #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TypeApplications #-}
-module Numeric.ADDual where
-
-import Control.Monad (when)
-import Control.Monad.Trans.Class (lift)
-import Control.Monad.Trans.State.Strict
-import Data.IORef
-import Data.Proxy
-import qualified Data.Vector.Storable as VS
-import qualified Data.Vector.Storable.Mutable as VSM
-import Foreign.Ptr
-import Foreign.Storable
-import GHC.Exts (withDict)
-import System.IO.Unsafe
-
--- import System.IO (hPutStrLn, stderr)
-
--- import Numeric.ADDual.IDGen
-
-
--- TODO: full vjp (just some more Traversable mess)
-{-# NOINLINE gradient' #-}
-gradient' :: forall a f. (Traversable f, Num a, Storable a)
- -- => Show a -- TODO: remove
- => (forall s. Taping s a => f (Dual s a) -> Dual s a)
- -> f a -> a -> (a, f a)
-gradient' f inp topctg = unsafePerformIO $ do
- -- hPutStrLn stderr "Preparing input"
- let (inp', starti) = runState (traverse (\x -> state (\i -> (Dual x i, i + 1))) inp) 0
- idref <- newIORef starti
- vec1 <- VSM.new (max 128 (2 * starti))
- taperef <- newIORef (MLog idref (Chunk 0 vec1) SLNil)
-
- -- hPutStrLn stderr "Running function"
- let Dual result outi = withDict @(Taping () a) taperef $ f @() inp'
- -- hPutStrLn stderr $ "result = " ++ show result ++ "; outi = " ++ show outi
- MLog _ lastChunk tapeTail <- readIORef taperef
-
- -- do tapestr <- showTape (tapeTail `Snoc` lastChunk)
- -- hPutStrLn stderr $ "tape = " ++ tapestr ""
-
- -- hPutStrLn stderr "Backpropagating"
- accums <- VSM.new (outi+1)
- VSM.write accums outi topctg
-
- let backpropagate i chunk@(Chunk ci0 vec) tape
- | i >= ci0 = do
- ctg <- VSM.read accums i
- Contrib i1 dx i2 dy <- VSM.read vec (i - ci0)
- when (i1 /= -1) $ VSM.modify accums (+ ctg*dx) i1
- when (i2 /= -1) $ VSM.modify accums (+ ctg*dy) i2
- backpropagate (i-1) chunk tape
- | otherwise = case tape of
- SLNil -> return () -- reached end of tape we should loop over
- tape'@Snoc{} `Snoc` chunk' -> backpropagate i chunk' tape'
- -- When we reach the last chunk, modify it so that its
- -- starting index is after the inputs.
- SLNil `Snoc` Chunk _ vec' ->
- backpropagate i (Chunk starti (VSM.slice starti (VSM.length vec' - starti) vec')) SLNil
-
- -- Ensure that if there are no more chunks in the tape tail, the starting
- -- index of the first chunk is adjusted so that backpropagate stops in time.
- case tapeTail of
- SLNil -> backpropagate outi (let Chunk _ vec = lastChunk
- in Chunk starti (VSM.slice starti (VSM.length vec - starti) vec))
- SLNil
- Snoc{} -> backpropagate outi lastChunk tapeTail
-
- -- do accums' <- VS.freeze accums
- -- hPutStrLn stderr $ "accums = " ++ show accums'
-
- -- hPutStrLn stderr "Reconstructing gradient"
- let readDeriv = do i <- get
- d <- lift $ VSM.read accums i
- put (i+1)
- return d
- grad <- evalStateT (traverse (\_ -> readDeriv) inp) 0
-
- return (result, grad)
-
-data Snoclist a = SLNil | Snoc !(Snoclist a) !a
- deriving (Show, Eq, Ord, Functor, Foldable, Traversable)
-
-data Contrib a = Contrib {-# UNPACK #-} !Int a -- ^ ID == -1 -> no contribution
- {-# UNPACK #-} !Int a -- ^ idem
- deriving (Show)
-
-instance Storable a => Storable (Contrib a) where
- sizeOf _ = 2 * (sizeOf (undefined :: Int) + sizeOf (undefined :: a))
- alignment _ = alignment (undefined :: Int)
- peek ptr = Contrib <$> peek (castPtr ptr)
- <*> peekByteOff (castPtr ptr) (sizeOf (undefined :: Int))
- <*> peekByteOff (castPtr ptr) (sizeOf (undefined :: Int) + sizeOf (undefined :: a))
- <*> peekByteOff (castPtr ptr) (2 * sizeOf (undefined :: Int) + sizeOf (undefined :: a))
- poke ptr (Contrib i1 dx i2 dy) = do
- poke (castPtr ptr) i1
- pokeByteOff (castPtr ptr) (sizeOf (undefined :: Int)) dx
- pokeByteOff (castPtr ptr) (sizeOf (undefined :: Int) + sizeOf (undefined :: a)) i2
- pokeByteOff (castPtr ptr) (2 * sizeOf (undefined :: Int) + sizeOf (undefined :: a)) dy
-
-data Chunk a = Chunk {-# UNPACK #-} !Int -- ^ First ID in this chunk
- {-# UNPACK #-} !(VSM.IOVector (Contrib a))
-
-data MLog s a = MLog !(IORef Int) -- ^ next ID to generate
- {-# UNPACK #-} !(Chunk a) -- ^ current running chunk
- !(Snoclist (Chunk a)) -- ^ tape
-
-showChunk :: (Storable a, Show a) => Chunk a -> IO ShowS
-showChunk (Chunk ci0 vec) = do
- vec' <- VS.freeze vec
- return (showString ("Chunk " ++ show ci0 ++ " ") . shows vec')
-
-showTape :: (Storable a, Show a) => Snoclist (Chunk a) -> IO ShowS
-showTape SLNil = return (showString "SLNil")
-showTape (tape `Snoc` chunk) = do
- s1 <- showTape tape
- s2 <- showChunk chunk
- return (s1 . showString " `Snoc` " . s2)
-
--- | This class does not have any instances defined, on purpose. You'll get one
--- magically when you differentiate.
-class Taping s a where
- getTape :: IORef (MLog s a)
-
-data Dual s a = Dual !a
- {-# UNPACK #-} !Int -- ^ -1 if this is a constant
-
-instance Eq a => Eq (Dual s a) where
- Dual x _ == Dual y _ = x == y
-
-instance Ord a => Ord (Dual s a) where
- compare (Dual x _) (Dual y _) = compare x y
-
-instance (Num a, Storable a, Taping s a) => Num (Dual s a) where
- Dual x i1 + Dual y i2 = Dual (x + y) (writeTape @a (Proxy @s) i1 1 i2 1)
- Dual x i1 - Dual y i2 = Dual (x - y) (writeTape @a (Proxy @s) i1 1 i2 (-1))
- Dual x i1 * Dual y i2 = Dual (x * y) (writeTape (Proxy @s) i1 y i2 x)
- negate (Dual x i1) = Dual (negate x) (writeTape @a (Proxy @s) i1 (-1) (-1) 0)
- abs (Dual x i1) = Dual (abs x) (writeTape (Proxy @s) i1 (x * signum x) (-1) 0)
- signum (Dual x _) = Dual (signum x) (-1)
- fromInteger n = Dual (fromInteger n) (-1)
-
-data WriteTapeAction a = WTANewvec (VSM.IOVector (Contrib a))
- | WTAOldTape (Snoclist (Chunk a))
-
-writeTape :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> Int
-writeTape _ i1 dx i2 dy = unsafePerformIO $ writeTapeIO (Proxy @s) i1 dx i2 dy
-
-writeTapeIO :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> IO Int
-writeTapeIO _ i1 dx i2 dy = do
- MLog idref (Chunk ci0 vec) _ <- readIORef (getTape @s)
- let n = VSM.length vec
- i <- atomicModifyIORef' idref (\i -> (i + 1, i))
- let idx = i - ci0
-
- if | idx < n -> do
- VSM.write vec idx (Contrib i1 dx i2 dy)
- return i
-
- -- check if we'd fit in the next chunk (overwhelmingly likely)
- | let newlen = 3 * n `div` 2
- , idx < n + newlen -> do
- newvec <- VSM.new newlen
- action <- atomicModifyIORef' (getTape @s) $ \(MLog idref' chunk@(Chunk ci0' vec') tape) ->
- if | ci0 == ci0' ->
- -- Likely (certain when single-threaded): no race condition,
- -- we get the chance to put the new chunk in place.
- (MLog idref' (Chunk (ci0 + n) newvec) (tape `Snoc` chunk), WTANewvec newvec)
- | i < ci0' + VSM.length vec' ->
- -- Race condition; need to write to appropriate position in this vector.
- (MLog idref' chunk tape, WTANewvec vec')
- | i < ci0' ->
- -- Very unlikely; need to write to old chunk in tape.
- (MLog idref' chunk tape, WTAOldTape tape)
- | otherwise ->
- -- We got an ID so far in the future that it doesn't even fit
- -- in the next chunk. But that can't happen, because we're
- -- only in this branch if the ID would have fit in the next
- -- chunk in the first place!
- error "writeTape: impossible"
- case action of
- WTANewvec vec' -> VSM.write vec' (idx - n) (Contrib i1 dx i2 dy)
- WTAOldTape tape ->
- let go SLNil = error "writeTape: no appropriate tape chunk?"
- go (tape' `Snoc` Chunk ci0' vec')
- -- The first comparison here is technically unnecessary, but
- -- I'm not courageous enough to remove it.
- | ci0' <= i, i < ci0' + VSM.length vec' =
- VSM.write vec' (i - ci0') (Contrib i1 dx i2 dy)
- | otherwise =
- go tape'
- in go tape
- return i
-
- -- there's a tremendous amount of competition, let's just try again
- | otherwise -> writeTapeIO (Proxy @s) i1 dx i2 dy
+module Numeric.ADDual (
+ gradient',
+ Dual,
+ Taping,
+ constant,
+) where
+
+import Numeric.ADDual.Internal