diff options
Diffstat (limited to 'src/Numeric/ADDual')
| -rw-r--r-- | src/Numeric/ADDual/Internal.hs | 57 | 
1 files changed, 36 insertions, 21 deletions
diff --git a/src/Numeric/ADDual/Internal.hs b/src/Numeric/ADDual/Internal.hs index 1ea3132..5dd84aa 100644 --- a/src/Numeric/ADDual/Internal.hs +++ b/src/Numeric/ADDual/Internal.hs @@ -1,5 +1,6 @@  {-# LANGUAGE BangPatterns #-}  {-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE GADTs #-}  {-# LANGUAGE MultiParamTypeClasses #-}  {-# LANGUAGE MultiWayIf #-}  {-# LANGUAGE RankNTypes #-} @@ -11,9 +12,13 @@ import Control.Monad (when)  import Control.Monad.Trans.Class (lift)  import Control.Monad.Trans.State.Strict  import Data.IORef +import Data.Int  import Data.Proxy +import Data.Typeable  import qualified Data.Vector.Storable as VS  import qualified Data.Vector.Storable.Mutable as VSM +import Foreign.C.Types +import Foreign.ForeignPtr  import Foreign.Ptr  import Foreign.Storable  import GHC.Stack @@ -27,9 +32,14 @@ debug :: Bool  debug = toEnum 0 +foreign import ccall unsafe "ad_dual_hs_backpropagate_double" +  c_backpropagate_double :: Ptr CDouble -> Int64 -> Int64 -> Ptr () -> IO () + +  -- TODO: full vjp (just some more Traversable mess) +-- TODO: if non-scalar output types are allowed, ensure that all its scalar components are WHNF evaluated before we backpropagate  {-# NOINLINE gradient' #-} -gradient' :: forall a f. (Traversable f, Num a, Storable a) +gradient' :: forall a f. (Traversable f, Num a, Storable a, Typeable a)            => HasCallStack            => Show a  -- TODO: remove            => (forall s. Taping s a => f (Dual s a) -> Dual s a) @@ -38,8 +48,9 @@ gradient' f inp topctg = unsafePerformIO $ do    when debug $ hPutStrLn stderr "Preparing input"    let (inp', starti) = runState (traverse (\x -> state (\i -> (Dual x i, i + 1))) inp) 0    idref <- newIORef starti -  vec1 <- VSM.new (max 128 (2 * starti)) -  taperef <- newIORef (MLog idref (Chunk 0 vec1) SLNil) +  -- The first chunk starts after the input IDs. +  vec1 <- VSM.unsafeNew 128 +  taperef <- newIORef (MLog idref (Chunk starti vec1) SLNil)    when debug $ hPutStrLn stderr "Running function"    let !(Dual result outi) = withDict @(Taping () a) taperef $ f @() inp' @@ -63,19 +74,23 @@ gradient' f inp topctg = unsafePerformIO $ do              backpropagate (i-1) chunk tape          | otherwise = case tape of                          SLNil -> return ()  -- reached end of tape we should loop over -                        tape'@Snoc{} `Snoc` chunk' -> backpropagate i chunk' tape' -                        -- When we reach the last chunk, modify it so that its -                        -- starting index is after the inputs. -                        SLNil `Snoc` Chunk _ vec' -> -                          backpropagate i (Chunk starti (VSM.slice starti (VSM.length vec' - starti) vec')) SLNil +                        tape' `Snoc` chunk' -> backpropagate i chunk' tape' + +      backpropagate_via_c :: Ptr CDouble -> Int -> Chunk Double -> Snoclist (Chunk Double) -> IO () +      backpropagate_via_c accums_ptr i (Chunk ci0 vec) tape = do +        let (vec_fptr, _) = VSM.unsafeToForeignPtr0 vec +        withForeignPtr vec_fptr $ \vec_ptr -> +          c_backpropagate_double accums_ptr (fromIntegral ci0) (fromIntegral i) (castPtr @(Contrib Double) @() vec_ptr) +        case tape of +          SLNil -> return () +          tape' `Snoc` chunk' -> backpropagate_via_c accums_ptr (ci0 - 1) chunk' tape' -  -- Ensure that if there are no more chunks in the tape tail, the starting -  -- index of the first chunk is adjusted so that backpropagate stops in time. -  case tapeTail of -    SLNil -> backpropagate outi (let Chunk _ vec = lastChunk -                                 in Chunk starti (VSM.slice starti (VSM.length vec - starti) vec)) -                                SLNil -    Snoc{} -> backpropagate outi lastChunk tapeTail +  case (eqT @a @Double, sizeOf (undefined :: Int)) of +    (Just Refl, 8) -> do +      let (accums_fptr, _) = VSM.unsafeToForeignPtr0 accums +      withForeignPtr accums_fptr $ \accums_ptr -> +        backpropagate_via_c (castPtr @Double @CDouble accums_ptr) outi lastChunk tapeTail +    _ -> backpropagate outi lastChunk tapeTail    when debug $ do      accums' <- VS.freeze accums @@ -93,8 +108,8 @@ gradient' f inp topctg = unsafePerformIO $ do  data Snoclist a = SLNil | Snoc !(Snoclist a) !a    deriving (Show, Eq, Ord, Functor, Foldable, Traversable) -data Contrib a = Contrib {-# UNPACK #-} !Int a  -- ^ ID == -1 -> no contribution -                         {-# UNPACK #-} !Int a  -- ^ idem +data Contrib a = Contrib {-# UNPACK #-} !Int !a  -- ^ ID == -1 -> no contribution +                         {-# UNPACK #-} !Int !a  -- ^ idem    deriving (Show)  instance Storable a => Storable (Contrib a) where @@ -147,7 +162,7 @@ instance (Num a, Storable a, Taping s a) => Num (Dual s a) where    Dual x i1 + Dual y i2 = mkDual (x + y) i1 1 i2 1    Dual x i1 - Dual y i2 = mkDual (x - y) i1 1 i2 (-1)    Dual x i1 * Dual y i2 = mkDual (x * y) i1 y i2 x -  negate (Dual x i1)    = mkDual (negate x)  i1 (-1) (-1) 0 +  negate (Dual x i1)    = mkDual (negate x) i1 (-1) (-1) 0    abs (Dual x i1)       = mkDual (abs x) i1 (x * signum x) (-1) 0    signum (Dual x _) = Dual (signum x) (-1)    fromInteger n = Dual (fromInteger n) (-1) @@ -181,6 +196,8 @@ mkDual res i1 dx i2 dy = Dual res (writeTapeUnsafe @a (Proxy @s) i1 dx i2 dy)  data WriteTapeAction a = WTANewvec (VSM.IOVector (Contrib a))                         | WTAOldTape (Snoclist (Chunk a)) +-- This NOINLINE really doesn't seem to matter for performance, so let's be safe +{-# NOINLINE writeTapeUnsafe #-}  writeTapeUnsafe :: forall a s proxy. (Num a, Storable a, Taping s a) => proxy s -> Int -> a -> Int -> a -> Int  writeTapeUnsafe _ i1 dx i2 dy = unsafePerformIO $ writeTapeIO (Proxy @s) i1 dx i2 dy @@ -200,7 +217,7 @@ writeTapeIO _ i1 dx i2 dy = do       -- check if we'd fit in the next chunk (overwhelmingly likely)       | let newlen = 3 * n `div` 2       , idx < n + newlen -> do -         newvec <- VSM.new newlen +         newvec <- VSM.unsafeNew newlen           action <- atomicModifyIORef' (getTape @s) $ \(MLog idref' chunk@(Chunk ci0' vec') tape) ->             if | ci0 == ci0' ->                   -- Likely (certain when single-threaded): no race condition, @@ -236,5 +253,3 @@ writeTapeIO _ i1 dx i2 dy = do       -- there's a tremendous amount of competition, let's just try again       | otherwise -> writeTapeIO (Proxy @s) i1 dx i2 dy - -  | 
