diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/Numeric/ADDual.hs | 201 | 
1 files changed, 201 insertions, 0 deletions
diff --git a/src/Numeric/ADDual.hs b/src/Numeric/ADDual.hs new file mode 100644 index 0000000..d9c5b74 --- /dev/null +++ b/src/Numeric/ADDual.hs @@ -0,0 +1,201 @@ +{-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +module Numeric.ADDual where + +import Control.Monad (when) +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.State.Strict +import Data.IORef +import Data.Proxy +import qualified Data.Vector.Storable as VS +import qualified Data.Vector.Storable.Mutable as VSM +import Foreign.Ptr +import Foreign.Storable +import GHC.Exts (withDict) +import System.IO.Unsafe + +-- import System.IO (hPutStrLn, stderr) + +-- import Numeric.ADDual.IDGen + + +-- TODO: full vjp (just some more Traversable mess) +{-# NOINLINE gradient' #-} +gradient' :: forall a f. (Traversable f, Num a, Storable a) +          -- => Show a  -- TODO: remove +          => (forall s. Taping s a => f (Dual s a) -> Dual s a) +          -> f a -> a -> (a, f a) +gradient' f inp topctg = unsafePerformIO $ do +  -- hPutStrLn stderr "Preparing input" +  let (inp', starti) = runState (traverse (\x -> state (\i -> (Dual x i, i + 1))) inp) 0 +  idref <- newIORef starti +  vec1 <- VSM.new (max 128 (2 * starti)) +  taperef <- newIORef (MLog idref (Chunk 0 vec1) SLNil) + +  -- hPutStrLn stderr "Running function" +  let Dual result outi = withDict @(Taping () a) taperef $ f @() inp' +  -- hPutStrLn stderr $ "result = " ++ show result ++ "; outi = " ++ show outi +  MLog _ lastChunk tapeTail <- readIORef taperef + +  -- do tapestr <- showTape (tapeTail `Snoc` lastChunk) +  --    hPutStrLn stderr $ "tape = " ++ tapestr "" + +  -- hPutStrLn stderr "Backpropagating" +  accums <- VSM.new (outi+1) +  VSM.write accums outi topctg + +  let backpropagate i chunk@(Chunk ci0 vec) tape +        | i >= ci0 = do +            ctg <- VSM.read accums i +            Contrib i1 dx i2 dy <- VSM.read vec (i - ci0) +            when (i1 /= -1) $ VSM.modify accums (+ ctg*dx) i1 +            when (i2 /= -1) $ VSM.modify accums (+ ctg*dy) i2 +            backpropagate (i-1) chunk tape +        | otherwise = case tape of +                        SLNil -> return ()  -- reached end of tape we should loop over +                        tape'@Snoc{} `Snoc` chunk' -> backpropagate i chunk' tape' +                        -- When we reach the last chunk, modify it so that its +                        -- starting index is after the inputs. +                        SLNil `Snoc` Chunk _ vec' -> +                          backpropagate i (Chunk starti (VSM.slice starti (VSM.length vec' - starti) vec')) SLNil + +  -- Ensure that if there are no more chunks in the tape tail, the starting +  -- index of the first chunk is adjusted so that backpropagate stops in time. +  case tapeTail of +    SLNil -> backpropagate outi (let Chunk _ vec = lastChunk +                                 in Chunk starti (VSM.slice starti (VSM.length vec - starti) vec)) +                                SLNil +    Snoc{} -> backpropagate outi lastChunk tapeTail + +  -- do accums' <- VS.freeze accums +  --    hPutStrLn stderr $ "accums = " ++ show accums' + +  -- hPutStrLn stderr "Reconstructing gradient" +  let readDeriv = do i <- get +                     d <- lift $ VSM.read accums i +                     put (i+1) +                     return d +  grad <- evalStateT (traverse (\_ -> readDeriv) inp) 0 + +  return (result, grad) + +data Snoclist a = SLNil | Snoc !(Snoclist a) !a +  deriving (Show, Eq, Ord, Functor, Foldable, Traversable) + +data Contrib a = Contrib {-# UNPACK #-} !Int a  -- ^ ID == -1 -> no contribution +                         {-# UNPACK #-} !Int a  -- ^ idem +  deriving (Show) + +instance Storable a => Storable (Contrib a) where +  sizeOf _ = 2 * (sizeOf (undefined :: Int) + sizeOf (undefined :: a)) +  alignment _ = alignment (undefined :: Int) +  peek ptr = Contrib <$> peek (castPtr ptr) +                     <*> peekByteOff (castPtr ptr) (sizeOf (undefined :: Int)) +                     <*> peekByteOff (castPtr ptr) (sizeOf (undefined :: Int) + sizeOf (undefined :: a)) +                     <*> peekByteOff (castPtr ptr) (2 * sizeOf (undefined :: Int) + sizeOf (undefined :: a)) +  poke ptr (Contrib i1 dx i2 dy) = do +    poke (castPtr ptr) i1 +    pokeByteOff (castPtr ptr) (sizeOf (undefined :: Int)) dx +    pokeByteOff (castPtr ptr) (sizeOf (undefined :: Int) + sizeOf (undefined :: a)) i2 +    pokeByteOff (castPtr ptr) (2 * sizeOf (undefined :: Int) + sizeOf (undefined :: a)) dy + +data Chunk a = Chunk {-# UNPACK #-} !Int  -- ^ First ID in this chunk +                     {-# UNPACK #-} !(VSM.IOVector (Contrib a)) + +data MLog s a = MLog !(IORef Int)  -- ^ next ID to generate +                     {-# UNPACK #-} !(Chunk a)  -- ^ current running chunk +                     !(Snoclist (Chunk a))  -- ^ tape + +showChunk :: (Storable a, Show a) => Chunk a -> IO ShowS +showChunk (Chunk ci0 vec) = do +  vec' <- VS.freeze vec +  return (showString ("Chunk " ++ show ci0 ++ " ") . shows vec') + +showTape :: (Storable a, Show a) => Snoclist (Chunk a) -> IO ShowS +showTape SLNil = return (showString "SLNil") +showTape (tape `Snoc` chunk) = do +  s1 <- showTape tape +  s2 <- showChunk chunk +  return (s1 . showString " `Snoc` " . s2) + +-- | This class does not have any instances defined, on purpose. You'll get one +-- magically when you differentiate. +class Taping s a where +  getTape :: IORef (MLog s a) + +data Dual s a = Dual !a +                     {-# UNPACK #-} !Int  -- ^ -1 if this is a constant + +instance Eq a => Eq (Dual s a) where +  Dual x _ == Dual y _ = x == y + +instance Ord a => Ord (Dual s a) where +  compare (Dual x _) (Dual y _) = compare x y + +instance (Num a, Storable a, Taping s a) => Num (Dual s a) where +  Dual x i1 + Dual y i2 = Dual (x + y)    (writeTape @a (Proxy @s) i1 1 i2 1) +  Dual x i1 - Dual y i2 = Dual (x - y)    (writeTape @a (Proxy @s) i1 1 i2 (-1)) +  Dual x i1 * Dual y i2 = Dual (x * y)    (writeTape    (Proxy @s) i1 y i2 x) +  negate (Dual x i1)    = Dual (negate x) (writeTape @a (Proxy @s) i1 (-1) (-1) 0) +  abs (Dual x i1)       = Dual (abs x)    (writeTape    (Proxy @s) i1 (x * signum x) (-1) 0) +  signum (Dual x _) = Dual (signum x) (-1) +  fromInteger n = Dual (fromInteger n) (-1) + +data WriteTapeAction a = WTANewvec (VSM.IOVector (Contrib a)) +                       | WTAOldTape (Snoclist (Chunk a)) + +writeTape :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> Int +writeTape _ i1 dx i2 dy = unsafePerformIO $ writeTapeIO (Proxy @s) i1 dx i2 dy + +writeTapeIO :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> IO Int +writeTapeIO _ i1 dx i2 dy = do +  MLog idref (Chunk ci0 vec) _ <- readIORef (getTape @s) +  let n = VSM.length vec +  i <- atomicModifyIORef' idref (\i -> (i + 1, i)) +  let idx = i - ci0 + +  if | idx < n -> do +         VSM.write vec idx (Contrib i1 dx i2 dy) +         return i + +     -- check if we'd fit in the next chunk (overwhelmingly likely) +     | let newlen = 3 * n `div` 2 +     , idx < n + newlen -> do +         newvec <- VSM.new newlen +         action <- atomicModifyIORef' (getTape @s) $ \(MLog idref' chunk@(Chunk ci0' vec') tape) -> +           if | ci0 == ci0' -> +                 -- Likely (certain when single-threaded): no race condition, +                 -- we get the chance to put the new chunk in place. +                 (MLog idref' (Chunk (ci0 + n) newvec) (tape `Snoc` chunk), WTANewvec newvec) +              | i < ci0' + VSM.length vec' -> +                 -- Race condition; need to write to appropriate position in this vector. +                 (MLog idref' chunk tape, WTANewvec vec') +              | i < ci0' -> +                 -- Very unlikely; need to write to old chunk in tape. +                 (MLog idref' chunk tape, WTAOldTape tape) +              | otherwise -> +                 -- We got an ID so far in the future that it doesn't even fit +                 -- in the next chunk. But that can't happen, because we're +                 -- only in this branch if the ID would have fit in the next +                 -- chunk in the first place! +                 error "writeTape: impossible" +         case action of +           WTANewvec vec' -> VSM.write vec' (idx - n) (Contrib i1 dx i2 dy) +           WTAOldTape tape -> +             let go SLNil = error "writeTape: no appropriate tape chunk?" +                 go (tape' `Snoc` Chunk ci0' vec') +                   -- The first comparison here is technically unnecessary, but +                   -- I'm not courageous enough to remove it. +                   | ci0' <= i, i < ci0' + VSM.length vec' = +                       VSM.write vec' (i - ci0') (Contrib i1 dx i2 dy) +                   | otherwise = +                       go tape' +             in go tape +         return i + +     -- there's a tremendous amount of competition, let's just try again +     | otherwise -> writeTapeIO (Proxy @s) i1 dx i2 dy  | 
