diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-02-19 22:53:02 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-02-19 22:53:02 +0100 |
commit | 011bda94ea9ab0bdb43751d8d19963beb5a887a0 (patch) | |
tree | fd9dc640c8902ca53c741392b46beb7138f659fa /src/Numeric/ADDual.hs |
Initial
Diffstat (limited to 'src/Numeric/ADDual.hs')
-rw-r--r-- | src/Numeric/ADDual.hs | 201 |
1 files changed, 201 insertions, 0 deletions
diff --git a/src/Numeric/ADDual.hs b/src/Numeric/ADDual.hs new file mode 100644 index 0000000..d9c5b74 --- /dev/null +++ b/src/Numeric/ADDual.hs @@ -0,0 +1,201 @@ +{-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +module Numeric.ADDual where + +import Control.Monad (when) +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.State.Strict +import Data.IORef +import Data.Proxy +import qualified Data.Vector.Storable as VS +import qualified Data.Vector.Storable.Mutable as VSM +import Foreign.Ptr +import Foreign.Storable +import GHC.Exts (withDict) +import System.IO.Unsafe + +-- import System.IO (hPutStrLn, stderr) + +-- import Numeric.ADDual.IDGen + + +-- TODO: full vjp (just some more Traversable mess) +{-# NOINLINE gradient' #-} +gradient' :: forall a f. (Traversable f, Num a, Storable a) + -- => Show a -- TODO: remove + => (forall s. Taping s a => f (Dual s a) -> Dual s a) + -> f a -> a -> (a, f a) +gradient' f inp topctg = unsafePerformIO $ do + -- hPutStrLn stderr "Preparing input" + let (inp', starti) = runState (traverse (\x -> state (\i -> (Dual x i, i + 1))) inp) 0 + idref <- newIORef starti + vec1 <- VSM.new (max 128 (2 * starti)) + taperef <- newIORef (MLog idref (Chunk 0 vec1) SLNil) + + -- hPutStrLn stderr "Running function" + let Dual result outi = withDict @(Taping () a) taperef $ f @() inp' + -- hPutStrLn stderr $ "result = " ++ show result ++ "; outi = " ++ show outi + MLog _ lastChunk tapeTail <- readIORef taperef + + -- do tapestr <- showTape (tapeTail `Snoc` lastChunk) + -- hPutStrLn stderr $ "tape = " ++ tapestr "" + + -- hPutStrLn stderr "Backpropagating" + accums <- VSM.new (outi+1) + VSM.write accums outi topctg + + let backpropagate i chunk@(Chunk ci0 vec) tape + | i >= ci0 = do + ctg <- VSM.read accums i + Contrib i1 dx i2 dy <- VSM.read vec (i - ci0) + when (i1 /= -1) $ VSM.modify accums (+ ctg*dx) i1 + when (i2 /= -1) $ VSM.modify accums (+ ctg*dy) i2 + backpropagate (i-1) chunk tape + | otherwise = case tape of + SLNil -> return () -- reached end of tape we should loop over + tape'@Snoc{} `Snoc` chunk' -> backpropagate i chunk' tape' + -- When we reach the last chunk, modify it so that its + -- starting index is after the inputs. + SLNil `Snoc` Chunk _ vec' -> + backpropagate i (Chunk starti (VSM.slice starti (VSM.length vec' - starti) vec')) SLNil + + -- Ensure that if there are no more chunks in the tape tail, the starting + -- index of the first chunk is adjusted so that backpropagate stops in time. + case tapeTail of + SLNil -> backpropagate outi (let Chunk _ vec = lastChunk + in Chunk starti (VSM.slice starti (VSM.length vec - starti) vec)) + SLNil + Snoc{} -> backpropagate outi lastChunk tapeTail + + -- do accums' <- VS.freeze accums + -- hPutStrLn stderr $ "accums = " ++ show accums' + + -- hPutStrLn stderr "Reconstructing gradient" + let readDeriv = do i <- get + d <- lift $ VSM.read accums i + put (i+1) + return d + grad <- evalStateT (traverse (\_ -> readDeriv) inp) 0 + + return (result, grad) + +data Snoclist a = SLNil | Snoc !(Snoclist a) !a + deriving (Show, Eq, Ord, Functor, Foldable, Traversable) + +data Contrib a = Contrib {-# UNPACK #-} !Int a -- ^ ID == -1 -> no contribution + {-# UNPACK #-} !Int a -- ^ idem + deriving (Show) + +instance Storable a => Storable (Contrib a) where + sizeOf _ = 2 * (sizeOf (undefined :: Int) + sizeOf (undefined :: a)) + alignment _ = alignment (undefined :: Int) + peek ptr = Contrib <$> peek (castPtr ptr) + <*> peekByteOff (castPtr ptr) (sizeOf (undefined :: Int)) + <*> peekByteOff (castPtr ptr) (sizeOf (undefined :: Int) + sizeOf (undefined :: a)) + <*> peekByteOff (castPtr ptr) (2 * sizeOf (undefined :: Int) + sizeOf (undefined :: a)) + poke ptr (Contrib i1 dx i2 dy) = do + poke (castPtr ptr) i1 + pokeByteOff (castPtr ptr) (sizeOf (undefined :: Int)) dx + pokeByteOff (castPtr ptr) (sizeOf (undefined :: Int) + sizeOf (undefined :: a)) i2 + pokeByteOff (castPtr ptr) (2 * sizeOf (undefined :: Int) + sizeOf (undefined :: a)) dy + +data Chunk a = Chunk {-# UNPACK #-} !Int -- ^ First ID in this chunk + {-# UNPACK #-} !(VSM.IOVector (Contrib a)) + +data MLog s a = MLog !(IORef Int) -- ^ next ID to generate + {-# UNPACK #-} !(Chunk a) -- ^ current running chunk + !(Snoclist (Chunk a)) -- ^ tape + +showChunk :: (Storable a, Show a) => Chunk a -> IO ShowS +showChunk (Chunk ci0 vec) = do + vec' <- VS.freeze vec + return (showString ("Chunk " ++ show ci0 ++ " ") . shows vec') + +showTape :: (Storable a, Show a) => Snoclist (Chunk a) -> IO ShowS +showTape SLNil = return (showString "SLNil") +showTape (tape `Snoc` chunk) = do + s1 <- showTape tape + s2 <- showChunk chunk + return (s1 . showString " `Snoc` " . s2) + +-- | This class does not have any instances defined, on purpose. You'll get one +-- magically when you differentiate. +class Taping s a where + getTape :: IORef (MLog s a) + +data Dual s a = Dual !a + {-# UNPACK #-} !Int -- ^ -1 if this is a constant + +instance Eq a => Eq (Dual s a) where + Dual x _ == Dual y _ = x == y + +instance Ord a => Ord (Dual s a) where + compare (Dual x _) (Dual y _) = compare x y + +instance (Num a, Storable a, Taping s a) => Num (Dual s a) where + Dual x i1 + Dual y i2 = Dual (x + y) (writeTape @a (Proxy @s) i1 1 i2 1) + Dual x i1 - Dual y i2 = Dual (x - y) (writeTape @a (Proxy @s) i1 1 i2 (-1)) + Dual x i1 * Dual y i2 = Dual (x * y) (writeTape (Proxy @s) i1 y i2 x) + negate (Dual x i1) = Dual (negate x) (writeTape @a (Proxy @s) i1 (-1) (-1) 0) + abs (Dual x i1) = Dual (abs x) (writeTape (Proxy @s) i1 (x * signum x) (-1) 0) + signum (Dual x _) = Dual (signum x) (-1) + fromInteger n = Dual (fromInteger n) (-1) + +data WriteTapeAction a = WTANewvec (VSM.IOVector (Contrib a)) + | WTAOldTape (Snoclist (Chunk a)) + +writeTape :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> Int +writeTape _ i1 dx i2 dy = unsafePerformIO $ writeTapeIO (Proxy @s) i1 dx i2 dy + +writeTapeIO :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> IO Int +writeTapeIO _ i1 dx i2 dy = do + MLog idref (Chunk ci0 vec) _ <- readIORef (getTape @s) + let n = VSM.length vec + i <- atomicModifyIORef' idref (\i -> (i + 1, i)) + let idx = i - ci0 + + if | idx < n -> do + VSM.write vec idx (Contrib i1 dx i2 dy) + return i + + -- check if we'd fit in the next chunk (overwhelmingly likely) + | let newlen = 3 * n `div` 2 + , idx < n + newlen -> do + newvec <- VSM.new newlen + action <- atomicModifyIORef' (getTape @s) $ \(MLog idref' chunk@(Chunk ci0' vec') tape) -> + if | ci0 == ci0' -> + -- Likely (certain when single-threaded): no race condition, + -- we get the chance to put the new chunk in place. + (MLog idref' (Chunk (ci0 + n) newvec) (tape `Snoc` chunk), WTANewvec newvec) + | i < ci0' + VSM.length vec' -> + -- Race condition; need to write to appropriate position in this vector. + (MLog idref' chunk tape, WTANewvec vec') + | i < ci0' -> + -- Very unlikely; need to write to old chunk in tape. + (MLog idref' chunk tape, WTAOldTape tape) + | otherwise -> + -- We got an ID so far in the future that it doesn't even fit + -- in the next chunk. But that can't happen, because we're + -- only in this branch if the ID would have fit in the next + -- chunk in the first place! + error "writeTape: impossible" + case action of + WTANewvec vec' -> VSM.write vec' (idx - n) (Contrib i1 dx i2 dy) + WTAOldTape tape -> + let go SLNil = error "writeTape: no appropriate tape chunk?" + go (tape' `Snoc` Chunk ci0' vec') + -- The first comparison here is technically unnecessary, but + -- I'm not courageous enough to remove it. + | ci0' <= i, i < ci0' + VSM.length vec' = + VSM.write vec' (i - ci0') (Contrib i1 dx i2 dy) + | otherwise = + go tape' + in go tape + return i + + -- there's a tremendous amount of competition, let's just try again + | otherwise -> writeTapeIO (Proxy @s) i1 dx i2 dy |