diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-02-20 10:11:57 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-02-20 10:11:57 +0100 |
commit | fe3132304b6c25e5bebc9fb327e3ea5d6018be7a (patch) | |
tree | edbf62755eab8103c3f39ee7dfe2b0006692e857 /src/Numeric/ADDual.hs | |
parent | 011bda94ea9ab0bdb43751d8d19963beb5a887a0 (diff) |
Attempt at a benchmark (crashes)
Diffstat (limited to 'src/Numeric/ADDual.hs')
-rw-r--r-- | src/Numeric/ADDual.hs | 209 |
1 files changed, 8 insertions, 201 deletions
diff --git a/src/Numeric/ADDual.hs b/src/Numeric/ADDual.hs index d9c5b74..2b69025 100644 --- a/src/Numeric/ADDual.hs +++ b/src/Numeric/ADDual.hs @@ -1,201 +1,8 @@ -{-# LANGUAGE DeriveTraversable #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE MultiWayIf #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -module Numeric.ADDual where - -import Control.Monad (when) -import Control.Monad.Trans.Class (lift) -import Control.Monad.Trans.State.Strict -import Data.IORef -import Data.Proxy -import qualified Data.Vector.Storable as VS -import qualified Data.Vector.Storable.Mutable as VSM -import Foreign.Ptr -import Foreign.Storable -import GHC.Exts (withDict) -import System.IO.Unsafe - --- import System.IO (hPutStrLn, stderr) - --- import Numeric.ADDual.IDGen - - --- TODO: full vjp (just some more Traversable mess) -{-# NOINLINE gradient' #-} -gradient' :: forall a f. (Traversable f, Num a, Storable a) - -- => Show a -- TODO: remove - => (forall s. Taping s a => f (Dual s a) -> Dual s a) - -> f a -> a -> (a, f a) -gradient' f inp topctg = unsafePerformIO $ do - -- hPutStrLn stderr "Preparing input" - let (inp', starti) = runState (traverse (\x -> state (\i -> (Dual x i, i + 1))) inp) 0 - idref <- newIORef starti - vec1 <- VSM.new (max 128 (2 * starti)) - taperef <- newIORef (MLog idref (Chunk 0 vec1) SLNil) - - -- hPutStrLn stderr "Running function" - let Dual result outi = withDict @(Taping () a) taperef $ f @() inp' - -- hPutStrLn stderr $ "result = " ++ show result ++ "; outi = " ++ show outi - MLog _ lastChunk tapeTail <- readIORef taperef - - -- do tapestr <- showTape (tapeTail `Snoc` lastChunk) - -- hPutStrLn stderr $ "tape = " ++ tapestr "" - - -- hPutStrLn stderr "Backpropagating" - accums <- VSM.new (outi+1) - VSM.write accums outi topctg - - let backpropagate i chunk@(Chunk ci0 vec) tape - | i >= ci0 = do - ctg <- VSM.read accums i - Contrib i1 dx i2 dy <- VSM.read vec (i - ci0) - when (i1 /= -1) $ VSM.modify accums (+ ctg*dx) i1 - when (i2 /= -1) $ VSM.modify accums (+ ctg*dy) i2 - backpropagate (i-1) chunk tape - | otherwise = case tape of - SLNil -> return () -- reached end of tape we should loop over - tape'@Snoc{} `Snoc` chunk' -> backpropagate i chunk' tape' - -- When we reach the last chunk, modify it so that its - -- starting index is after the inputs. - SLNil `Snoc` Chunk _ vec' -> - backpropagate i (Chunk starti (VSM.slice starti (VSM.length vec' - starti) vec')) SLNil - - -- Ensure that if there are no more chunks in the tape tail, the starting - -- index of the first chunk is adjusted so that backpropagate stops in time. - case tapeTail of - SLNil -> backpropagate outi (let Chunk _ vec = lastChunk - in Chunk starti (VSM.slice starti (VSM.length vec - starti) vec)) - SLNil - Snoc{} -> backpropagate outi lastChunk tapeTail - - -- do accums' <- VS.freeze accums - -- hPutStrLn stderr $ "accums = " ++ show accums' - - -- hPutStrLn stderr "Reconstructing gradient" - let readDeriv = do i <- get - d <- lift $ VSM.read accums i - put (i+1) - return d - grad <- evalStateT (traverse (\_ -> readDeriv) inp) 0 - - return (result, grad) - -data Snoclist a = SLNil | Snoc !(Snoclist a) !a - deriving (Show, Eq, Ord, Functor, Foldable, Traversable) - -data Contrib a = Contrib {-# UNPACK #-} !Int a -- ^ ID == -1 -> no contribution - {-# UNPACK #-} !Int a -- ^ idem - deriving (Show) - -instance Storable a => Storable (Contrib a) where - sizeOf _ = 2 * (sizeOf (undefined :: Int) + sizeOf (undefined :: a)) - alignment _ = alignment (undefined :: Int) - peek ptr = Contrib <$> peek (castPtr ptr) - <*> peekByteOff (castPtr ptr) (sizeOf (undefined :: Int)) - <*> peekByteOff (castPtr ptr) (sizeOf (undefined :: Int) + sizeOf (undefined :: a)) - <*> peekByteOff (castPtr ptr) (2 * sizeOf (undefined :: Int) + sizeOf (undefined :: a)) - poke ptr (Contrib i1 dx i2 dy) = do - poke (castPtr ptr) i1 - pokeByteOff (castPtr ptr) (sizeOf (undefined :: Int)) dx - pokeByteOff (castPtr ptr) (sizeOf (undefined :: Int) + sizeOf (undefined :: a)) i2 - pokeByteOff (castPtr ptr) (2 * sizeOf (undefined :: Int) + sizeOf (undefined :: a)) dy - -data Chunk a = Chunk {-# UNPACK #-} !Int -- ^ First ID in this chunk - {-# UNPACK #-} !(VSM.IOVector (Contrib a)) - -data MLog s a = MLog !(IORef Int) -- ^ next ID to generate - {-# UNPACK #-} !(Chunk a) -- ^ current running chunk - !(Snoclist (Chunk a)) -- ^ tape - -showChunk :: (Storable a, Show a) => Chunk a -> IO ShowS -showChunk (Chunk ci0 vec) = do - vec' <- VS.freeze vec - return (showString ("Chunk " ++ show ci0 ++ " ") . shows vec') - -showTape :: (Storable a, Show a) => Snoclist (Chunk a) -> IO ShowS -showTape SLNil = return (showString "SLNil") -showTape (tape `Snoc` chunk) = do - s1 <- showTape tape - s2 <- showChunk chunk - return (s1 . showString " `Snoc` " . s2) - --- | This class does not have any instances defined, on purpose. You'll get one --- magically when you differentiate. -class Taping s a where - getTape :: IORef (MLog s a) - -data Dual s a = Dual !a - {-# UNPACK #-} !Int -- ^ -1 if this is a constant - -instance Eq a => Eq (Dual s a) where - Dual x _ == Dual y _ = x == y - -instance Ord a => Ord (Dual s a) where - compare (Dual x _) (Dual y _) = compare x y - -instance (Num a, Storable a, Taping s a) => Num (Dual s a) where - Dual x i1 + Dual y i2 = Dual (x + y) (writeTape @a (Proxy @s) i1 1 i2 1) - Dual x i1 - Dual y i2 = Dual (x - y) (writeTape @a (Proxy @s) i1 1 i2 (-1)) - Dual x i1 * Dual y i2 = Dual (x * y) (writeTape (Proxy @s) i1 y i2 x) - negate (Dual x i1) = Dual (negate x) (writeTape @a (Proxy @s) i1 (-1) (-1) 0) - abs (Dual x i1) = Dual (abs x) (writeTape (Proxy @s) i1 (x * signum x) (-1) 0) - signum (Dual x _) = Dual (signum x) (-1) - fromInteger n = Dual (fromInteger n) (-1) - -data WriteTapeAction a = WTANewvec (VSM.IOVector (Contrib a)) - | WTAOldTape (Snoclist (Chunk a)) - -writeTape :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> Int -writeTape _ i1 dx i2 dy = unsafePerformIO $ writeTapeIO (Proxy @s) i1 dx i2 dy - -writeTapeIO :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> IO Int -writeTapeIO _ i1 dx i2 dy = do - MLog idref (Chunk ci0 vec) _ <- readIORef (getTape @s) - let n = VSM.length vec - i <- atomicModifyIORef' idref (\i -> (i + 1, i)) - let idx = i - ci0 - - if | idx < n -> do - VSM.write vec idx (Contrib i1 dx i2 dy) - return i - - -- check if we'd fit in the next chunk (overwhelmingly likely) - | let newlen = 3 * n `div` 2 - , idx < n + newlen -> do - newvec <- VSM.new newlen - action <- atomicModifyIORef' (getTape @s) $ \(MLog idref' chunk@(Chunk ci0' vec') tape) -> - if | ci0 == ci0' -> - -- Likely (certain when single-threaded): no race condition, - -- we get the chance to put the new chunk in place. - (MLog idref' (Chunk (ci0 + n) newvec) (tape `Snoc` chunk), WTANewvec newvec) - | i < ci0' + VSM.length vec' -> - -- Race condition; need to write to appropriate position in this vector. - (MLog idref' chunk tape, WTANewvec vec') - | i < ci0' -> - -- Very unlikely; need to write to old chunk in tape. - (MLog idref' chunk tape, WTAOldTape tape) - | otherwise -> - -- We got an ID so far in the future that it doesn't even fit - -- in the next chunk. But that can't happen, because we're - -- only in this branch if the ID would have fit in the next - -- chunk in the first place! - error "writeTape: impossible" - case action of - WTANewvec vec' -> VSM.write vec' (idx - n) (Contrib i1 dx i2 dy) - WTAOldTape tape -> - let go SLNil = error "writeTape: no appropriate tape chunk?" - go (tape' `Snoc` Chunk ci0' vec') - -- The first comparison here is technically unnecessary, but - -- I'm not courageous enough to remove it. - | ci0' <= i, i < ci0' + VSM.length vec' = - VSM.write vec' (i - ci0') (Contrib i1 dx i2 dy) - | otherwise = - go tape' - in go tape - return i - - -- there's a tremendous amount of competition, let's just try again - | otherwise -> writeTapeIO (Proxy @s) i1 dx i2 dy +module Numeric.ADDual ( + gradient', + Dual, + Taping, + constant, +) where + +import Numeric.ADDual.Internal |