diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-02-20 10:11:57 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-02-20 10:11:57 +0100 |
commit | fe3132304b6c25e5bebc9fb327e3ea5d6018be7a (patch) | |
tree | edbf62755eab8103c3f39ee7dfe2b0006692e857 /src | |
parent | 011bda94ea9ab0bdb43751d8d19963beb5a887a0 (diff) |
Attempt at a benchmark (crashes)
Diffstat (limited to 'src')
-rw-r--r-- | src/Numeric/ADDual.hs | 209 | ||||
-rw-r--r-- | src/Numeric/ADDual/Internal.hs | 226 |
2 files changed, 234 insertions, 201 deletions
diff --git a/src/Numeric/ADDual.hs b/src/Numeric/ADDual.hs index d9c5b74..2b69025 100644 --- a/src/Numeric/ADDual.hs +++ b/src/Numeric/ADDual.hs @@ -1,201 +1,8 @@ -{-# LANGUAGE DeriveTraversable #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE MultiWayIf #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -module Numeric.ADDual where - -import Control.Monad (when) -import Control.Monad.Trans.Class (lift) -import Control.Monad.Trans.State.Strict -import Data.IORef -import Data.Proxy -import qualified Data.Vector.Storable as VS -import qualified Data.Vector.Storable.Mutable as VSM -import Foreign.Ptr -import Foreign.Storable -import GHC.Exts (withDict) -import System.IO.Unsafe - --- import System.IO (hPutStrLn, stderr) - --- import Numeric.ADDual.IDGen - - --- TODO: full vjp (just some more Traversable mess) -{-# NOINLINE gradient' #-} -gradient' :: forall a f. (Traversable f, Num a, Storable a) - -- => Show a -- TODO: remove - => (forall s. Taping s a => f (Dual s a) -> Dual s a) - -> f a -> a -> (a, f a) -gradient' f inp topctg = unsafePerformIO $ do - -- hPutStrLn stderr "Preparing input" - let (inp', starti) = runState (traverse (\x -> state (\i -> (Dual x i, i + 1))) inp) 0 - idref <- newIORef starti - vec1 <- VSM.new (max 128 (2 * starti)) - taperef <- newIORef (MLog idref (Chunk 0 vec1) SLNil) - - -- hPutStrLn stderr "Running function" - let Dual result outi = withDict @(Taping () a) taperef $ f @() inp' - -- hPutStrLn stderr $ "result = " ++ show result ++ "; outi = " ++ show outi - MLog _ lastChunk tapeTail <- readIORef taperef - - -- do tapestr <- showTape (tapeTail `Snoc` lastChunk) - -- hPutStrLn stderr $ "tape = " ++ tapestr "" - - -- hPutStrLn stderr "Backpropagating" - accums <- VSM.new (outi+1) - VSM.write accums outi topctg - - let backpropagate i chunk@(Chunk ci0 vec) tape - | i >= ci0 = do - ctg <- VSM.read accums i - Contrib i1 dx i2 dy <- VSM.read vec (i - ci0) - when (i1 /= -1) $ VSM.modify accums (+ ctg*dx) i1 - when (i2 /= -1) $ VSM.modify accums (+ ctg*dy) i2 - backpropagate (i-1) chunk tape - | otherwise = case tape of - SLNil -> return () -- reached end of tape we should loop over - tape'@Snoc{} `Snoc` chunk' -> backpropagate i chunk' tape' - -- When we reach the last chunk, modify it so that its - -- starting index is after the inputs. - SLNil `Snoc` Chunk _ vec' -> - backpropagate i (Chunk starti (VSM.slice starti (VSM.length vec' - starti) vec')) SLNil - - -- Ensure that if there are no more chunks in the tape tail, the starting - -- index of the first chunk is adjusted so that backpropagate stops in time. - case tapeTail of - SLNil -> backpropagate outi (let Chunk _ vec = lastChunk - in Chunk starti (VSM.slice starti (VSM.length vec - starti) vec)) - SLNil - Snoc{} -> backpropagate outi lastChunk tapeTail - - -- do accums' <- VS.freeze accums - -- hPutStrLn stderr $ "accums = " ++ show accums' - - -- hPutStrLn stderr "Reconstructing gradient" - let readDeriv = do i <- get - d <- lift $ VSM.read accums i - put (i+1) - return d - grad <- evalStateT (traverse (\_ -> readDeriv) inp) 0 - - return (result, grad) - -data Snoclist a = SLNil | Snoc !(Snoclist a) !a - deriving (Show, Eq, Ord, Functor, Foldable, Traversable) - -data Contrib a = Contrib {-# UNPACK #-} !Int a -- ^ ID == -1 -> no contribution - {-# UNPACK #-} !Int a -- ^ idem - deriving (Show) - -instance Storable a => Storable (Contrib a) where - sizeOf _ = 2 * (sizeOf (undefined :: Int) + sizeOf (undefined :: a)) - alignment _ = alignment (undefined :: Int) - peek ptr = Contrib <$> peek (castPtr ptr) - <*> peekByteOff (castPtr ptr) (sizeOf (undefined :: Int)) - <*> peekByteOff (castPtr ptr) (sizeOf (undefined :: Int) + sizeOf (undefined :: a)) - <*> peekByteOff (castPtr ptr) (2 * sizeOf (undefined :: Int) + sizeOf (undefined :: a)) - poke ptr (Contrib i1 dx i2 dy) = do - poke (castPtr ptr) i1 - pokeByteOff (castPtr ptr) (sizeOf (undefined :: Int)) dx - pokeByteOff (castPtr ptr) (sizeOf (undefined :: Int) + sizeOf (undefined :: a)) i2 - pokeByteOff (castPtr ptr) (2 * sizeOf (undefined :: Int) + sizeOf (undefined :: a)) dy - -data Chunk a = Chunk {-# UNPACK #-} !Int -- ^ First ID in this chunk - {-# UNPACK #-} !(VSM.IOVector (Contrib a)) - -data MLog s a = MLog !(IORef Int) -- ^ next ID to generate - {-# UNPACK #-} !(Chunk a) -- ^ current running chunk - !(Snoclist (Chunk a)) -- ^ tape - -showChunk :: (Storable a, Show a) => Chunk a -> IO ShowS -showChunk (Chunk ci0 vec) = do - vec' <- VS.freeze vec - return (showString ("Chunk " ++ show ci0 ++ " ") . shows vec') - -showTape :: (Storable a, Show a) => Snoclist (Chunk a) -> IO ShowS -showTape SLNil = return (showString "SLNil") -showTape (tape `Snoc` chunk) = do - s1 <- showTape tape - s2 <- showChunk chunk - return (s1 . showString " `Snoc` " . s2) - --- | This class does not have any instances defined, on purpose. You'll get one --- magically when you differentiate. -class Taping s a where - getTape :: IORef (MLog s a) - -data Dual s a = Dual !a - {-# UNPACK #-} !Int -- ^ -1 if this is a constant - -instance Eq a => Eq (Dual s a) where - Dual x _ == Dual y _ = x == y - -instance Ord a => Ord (Dual s a) where - compare (Dual x _) (Dual y _) = compare x y - -instance (Num a, Storable a, Taping s a) => Num (Dual s a) where - Dual x i1 + Dual y i2 = Dual (x + y) (writeTape @a (Proxy @s) i1 1 i2 1) - Dual x i1 - Dual y i2 = Dual (x - y) (writeTape @a (Proxy @s) i1 1 i2 (-1)) - Dual x i1 * Dual y i2 = Dual (x * y) (writeTape (Proxy @s) i1 y i2 x) - negate (Dual x i1) = Dual (negate x) (writeTape @a (Proxy @s) i1 (-1) (-1) 0) - abs (Dual x i1) = Dual (abs x) (writeTape (Proxy @s) i1 (x * signum x) (-1) 0) - signum (Dual x _) = Dual (signum x) (-1) - fromInteger n = Dual (fromInteger n) (-1) - -data WriteTapeAction a = WTANewvec (VSM.IOVector (Contrib a)) - | WTAOldTape (Snoclist (Chunk a)) - -writeTape :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> Int -writeTape _ i1 dx i2 dy = unsafePerformIO $ writeTapeIO (Proxy @s) i1 dx i2 dy - -writeTapeIO :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> IO Int -writeTapeIO _ i1 dx i2 dy = do - MLog idref (Chunk ci0 vec) _ <- readIORef (getTape @s) - let n = VSM.length vec - i <- atomicModifyIORef' idref (\i -> (i + 1, i)) - let idx = i - ci0 - - if | idx < n -> do - VSM.write vec idx (Contrib i1 dx i2 dy) - return i - - -- check if we'd fit in the next chunk (overwhelmingly likely) - | let newlen = 3 * n `div` 2 - , idx < n + newlen -> do - newvec <- VSM.new newlen - action <- atomicModifyIORef' (getTape @s) $ \(MLog idref' chunk@(Chunk ci0' vec') tape) -> - if | ci0 == ci0' -> - -- Likely (certain when single-threaded): no race condition, - -- we get the chance to put the new chunk in place. - (MLog idref' (Chunk (ci0 + n) newvec) (tape `Snoc` chunk), WTANewvec newvec) - | i < ci0' + VSM.length vec' -> - -- Race condition; need to write to appropriate position in this vector. - (MLog idref' chunk tape, WTANewvec vec') - | i < ci0' -> - -- Very unlikely; need to write to old chunk in tape. - (MLog idref' chunk tape, WTAOldTape tape) - | otherwise -> - -- We got an ID so far in the future that it doesn't even fit - -- in the next chunk. But that can't happen, because we're - -- only in this branch if the ID would have fit in the next - -- chunk in the first place! - error "writeTape: impossible" - case action of - WTANewvec vec' -> VSM.write vec' (idx - n) (Contrib i1 dx i2 dy) - WTAOldTape tape -> - let go SLNil = error "writeTape: no appropriate tape chunk?" - go (tape' `Snoc` Chunk ci0' vec') - -- The first comparison here is technically unnecessary, but - -- I'm not courageous enough to remove it. - | ci0' <= i, i < ci0' + VSM.length vec' = - VSM.write vec' (i - ci0') (Contrib i1 dx i2 dy) - | otherwise = - go tape' - in go tape - return i - - -- there's a tremendous amount of competition, let's just try again - | otherwise -> writeTapeIO (Proxy @s) i1 dx i2 dy +module Numeric.ADDual ( + gradient', + Dual, + Taping, + constant, +) where + +import Numeric.ADDual.Internal diff --git a/src/Numeric/ADDual/Internal.hs b/src/Numeric/ADDual/Internal.hs new file mode 100644 index 0000000..8228694 --- /dev/null +++ b/src/Numeric/ADDual/Internal.hs @@ -0,0 +1,226 @@ +{-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +module Numeric.ADDual.Internal where + +import Control.Monad (when) +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.State.Strict +import Data.IORef +import Data.Proxy +import qualified Data.Vector.Storable as VS +import qualified Data.Vector.Storable.Mutable as VSM +import Foreign.Ptr +import Foreign.Storable +import GHC.Exts (withDict) +import System.IO.Unsafe + +-- import System.IO (hPutStrLn, stderr) + +-- import Numeric.ADDual.IDGen + + +-- TODO: full vjp (just some more Traversable mess) +{-# NOINLINE gradient' #-} +gradient' :: forall a f. (Traversable f, Num a, Storable a) + -- => Show a -- TODO: remove + => (forall s. Taping s a => f (Dual s a) -> Dual s a) + -> f a -> a -> (a, f a) +gradient' f inp topctg = unsafePerformIO $ do + -- hPutStrLn stderr "Preparing input" + let (inp', starti) = runState (traverse (\x -> state (\i -> (Dual x i, i + 1))) inp) 0 + idref <- newIORef starti + vec1 <- VSM.new (max 128 (2 * starti)) + taperef <- newIORef (MLog idref (Chunk 0 vec1) SLNil) + + -- hPutStrLn stderr "Running function" + let Dual result outi = withDict @(Taping () a) taperef $ f @() inp' + -- hPutStrLn stderr $ "result = " ++ show result ++ "; outi = " ++ show outi + MLog _ lastChunk tapeTail <- readIORef taperef + + -- do tapestr <- showTape (tapeTail `Snoc` lastChunk) + -- hPutStrLn stderr $ "tape = " ++ tapestr "" + + -- hPutStrLn stderr "Backpropagating" + accums <- VSM.new (outi+1) + VSM.write accums outi topctg + + let backpropagate i chunk@(Chunk ci0 vec) tape + | i >= ci0 = do + ctg <- VSM.read accums i + Contrib i1 dx i2 dy <- VSM.read vec (i - ci0) + when (i1 /= -1) $ VSM.modify accums (+ ctg*dx) i1 + when (i2 /= -1) $ VSM.modify accums (+ ctg*dy) i2 + backpropagate (i-1) chunk tape + | otherwise = case tape of + SLNil -> return () -- reached end of tape we should loop over + tape'@Snoc{} `Snoc` chunk' -> backpropagate i chunk' tape' + -- When we reach the last chunk, modify it so that its + -- starting index is after the inputs. + SLNil `Snoc` Chunk _ vec' -> + backpropagate i (Chunk starti (VSM.slice starti (VSM.length vec' - starti) vec')) SLNil + + -- Ensure that if there are no more chunks in the tape tail, the starting + -- index of the first chunk is adjusted so that backpropagate stops in time. + case tapeTail of + SLNil -> backpropagate outi (let Chunk _ vec = lastChunk + in Chunk starti (VSM.slice starti (VSM.length vec - starti) vec)) + SLNil + Snoc{} -> backpropagate outi lastChunk tapeTail + + -- do accums' <- VS.freeze accums + -- hPutStrLn stderr $ "accums = " ++ show accums' + + -- hPutStrLn stderr "Reconstructing gradient" + let readDeriv = do i <- get + d <- lift $ VSM.read accums i + put (i+1) + return d + grad <- evalStateT (traverse (\_ -> readDeriv) inp) 0 + + return (result, grad) + +data Snoclist a = SLNil | Snoc !(Snoclist a) !a + deriving (Show, Eq, Ord, Functor, Foldable, Traversable) + +data Contrib a = Contrib {-# UNPACK #-} !Int a -- ^ ID == -1 -> no contribution + {-# UNPACK #-} !Int a -- ^ idem + deriving (Show) + +instance Storable a => Storable (Contrib a) where + sizeOf _ = 2 * (sizeOf (undefined :: Int) + sizeOf (undefined :: a)) + alignment _ = alignment (undefined :: Int) + peek ptr = Contrib <$> peek (castPtr ptr) + <*> peekByteOff (castPtr ptr) (sizeOf (undefined :: Int)) + <*> peekByteOff (castPtr ptr) (sizeOf (undefined :: Int) + sizeOf (undefined :: a)) + <*> peekByteOff (castPtr ptr) (2 * sizeOf (undefined :: Int) + sizeOf (undefined :: a)) + poke ptr (Contrib i1 dx i2 dy) = do + poke (castPtr ptr) i1 + pokeByteOff (castPtr ptr) (sizeOf (undefined :: Int)) dx + pokeByteOff (castPtr ptr) (sizeOf (undefined :: Int) + sizeOf (undefined :: a)) i2 + pokeByteOff (castPtr ptr) (2 * sizeOf (undefined :: Int) + sizeOf (undefined :: a)) dy + +data Chunk a = Chunk {-# UNPACK #-} !Int -- ^ First ID in this chunk + {-# UNPACK #-} !(VSM.IOVector (Contrib a)) + +data MLog s a = MLog !(IORef Int) -- ^ next ID to generate + {-# UNPACK #-} !(Chunk a) -- ^ current running chunk + !(Snoclist (Chunk a)) -- ^ tape + +showChunk :: (Storable a, Show a) => Chunk a -> IO ShowS +showChunk (Chunk ci0 vec) = do + vec' <- VS.freeze vec + return (showString ("Chunk " ++ show ci0 ++ " ") . shows vec') + +showTape :: (Storable a, Show a) => Snoclist (Chunk a) -> IO ShowS +showTape SLNil = return (showString "SLNil") +showTape (tape `Snoc` chunk) = do + s1 <- showTape tape + s2 <- showChunk chunk + return (s1 . showString " `Snoc` " . s2) + +-- | This class does not have any instances defined, on purpose. You'll get one +-- magically when you differentiate. +class Taping s a where + getTape :: IORef (MLog s a) + +data Dual s a = Dual !a + {-# UNPACK #-} !Int -- ^ -1 if this is a constant + +instance Eq a => Eq (Dual s a) where + Dual x _ == Dual y _ = x == y + +instance Ord a => Ord (Dual s a) where + compare (Dual x _) (Dual y _) = compare x y + +instance (Num a, Storable a, Taping s a) => Num (Dual s a) where + Dual x i1 + Dual y i2 = Dual (x + y) (writeTape @a (Proxy @s) i1 1 i2 1) + Dual x i1 - Dual y i2 = Dual (x - y) (writeTape @a (Proxy @s) i1 1 i2 (-1)) + Dual x i1 * Dual y i2 = Dual (x * y) (writeTape (Proxy @s) i1 y i2 x) + negate (Dual x i1) = Dual (negate x) (writeTape @a (Proxy @s) i1 (-1) (-1) 0) + abs (Dual x i1) = Dual (abs x) (writeTape (Proxy @s) i1 (x * signum x) (-1) 0) + signum (Dual x _) = Dual (signum x) (-1) + fromInteger n = Dual (fromInteger n) (-1) + +instance (Fractional a, Storable a, Taping s a) => Fractional (Dual s a) where + Dual x i1 / Dual y i2 = Dual (x / y) (writeTape (Proxy @s) i1 (recip y) i2 (-x/(y*y))) + recip (Dual x i1) = Dual (recip x) (writeTape (Proxy @s) i1 (-1/(x*x)) (-1) 0) + fromRational r = Dual (fromRational r) (-1) + +instance (Floating a, Storable a, Taping s a) => Floating (Dual s a) where + pi = Dual pi (-1) + exp (Dual x i1) = Dual (exp x) (writeTape (Proxy @s) i1 (exp x) (-1) 0) + log (Dual x i1) = Dual (log x) (writeTape (Proxy @s) i1 (recip x) (-1) 0) + sqrt (Dual x i1) = Dual (sqrt x) (writeTape (Proxy @s) i1 (recip (2*sqrt x)) (-1) 0) + -- d/dx (x ^ y) = d/dx (e ^ (y ln x)) = e ^ (y ln x) * d/dx (y ln x) = e ^ (y ln x) * y/x + -- d/dy (x ^ y) = d/dy (e ^ (y ln x)) = e ^ (y ln x) * d/dy (y ln x) = e ^ (y ln x) * ln x + Dual x i1 ** Dual y i2 = + let z = x ** y + in Dual z (writeTape (Proxy @s) i1 (z * y/x) i2 (z * log x)) + logBase = undefined ; sin = undefined ; cos = undefined ; tan = undefined + asin = undefined ; acos = undefined ; atan = undefined ; sinh = undefined + cosh = undefined ; tanh = undefined ; asinh = undefined ; acosh = undefined + atanh = undefined + +constant :: a -> Dual s a +constant x = Dual x (-1) + +data WriteTapeAction a = WTANewvec (VSM.IOVector (Contrib a)) + | WTAOldTape (Snoclist (Chunk a)) + +writeTape :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> Int +writeTape _ i1 dx i2 dy = unsafePerformIO $ writeTapeIO (Proxy @s) i1 dx i2 dy + +writeTapeIO :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> IO Int +writeTapeIO _ i1 dx i2 dy = do + MLog idref (Chunk ci0 vec) _ <- readIORef (getTape @s) + let n = VSM.length vec + i <- atomicModifyIORef' idref (\i -> (i + 1, i)) + let idx = i - ci0 + + if | idx < n -> do + VSM.write vec idx (Contrib i1 dx i2 dy) + return i + + -- check if we'd fit in the next chunk (overwhelmingly likely) + | let newlen = 3 * n `div` 2 + , idx < n + newlen -> do + newvec <- VSM.new newlen + action <- atomicModifyIORef' (getTape @s) $ \(MLog idref' chunk@(Chunk ci0' vec') tape) -> + if | ci0 == ci0' -> + -- Likely (certain when single-threaded): no race condition, + -- we get the chance to put the new chunk in place. + (MLog idref' (Chunk (ci0 + n) newvec) (tape `Snoc` chunk), WTANewvec newvec) + | i < ci0' + VSM.length vec' -> + -- Race condition; need to write to appropriate position in this vector. + (MLog idref' chunk tape, WTANewvec vec') + | i < ci0' -> + -- Very unlikely; need to write to old chunk in tape. + (MLog idref' chunk tape, WTAOldTape tape) + | otherwise -> + -- We got an ID so far in the future that it doesn't even fit + -- in the next chunk. But that can't happen, because we're + -- only in this branch if the ID would have fit in the next + -- chunk in the first place! + error "writeTape: impossible" + case action of + WTANewvec vec' -> VSM.write vec' (idx - n) (Contrib i1 dx i2 dy) + WTAOldTape tape -> + let go SLNil = error "writeTape: no appropriate tape chunk?" + go (tape' `Snoc` Chunk ci0' vec') + -- The first comparison here is technically unnecessary, but + -- I'm not courageous enough to remove it. + | ci0' <= i, i < ci0' + VSM.length vec' = + VSM.write vec' (i - ci0') (Contrib i1 dx i2 dy) + | otherwise = + go tape' + in go tape + return i + + -- there's a tremendous amount of competition, let's just try again + | otherwise -> writeTapeIO (Proxy @s) i1 dx i2 dy + + |