diff options
Diffstat (limited to 'src/Numeric/ADDual/Array')
| -rw-r--r-- | src/Numeric/ADDual/Array/Internal.hs | 104 | 
1 files changed, 71 insertions, 33 deletions
diff --git a/src/Numeric/ADDual/Array/Internal.hs b/src/Numeric/ADDual/Array/Internal.hs index 1cc2796..5a4af4b 100644 --- a/src/Numeric/ADDual/Array/Internal.hs +++ b/src/Numeric/ADDual/Array/Internal.hs @@ -1,4 +1,5 @@  {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE LambdaCase #-}  {-# LANGUAGE MultiParamTypeClasses #-}  {-# LANGUAGE RankNTypes #-}  {-# LANGUAGE ScopedTypeVariables #-} @@ -7,13 +8,13 @@  module Numeric.ADDual.Array.Internal where  import Control.Monad (when) -import Control.Monad.Trans.Class (lift)  import Control.Monad.Trans.State.Strict +import Data.Foldable (toList)  import Data.IORef +import Data.List (foldl') +import qualified Data.IntMap.Strict as IM  import Data.Proxy -import Data.Typeable  import qualified Data.Vector.Storable as VS -import qualified Data.Vector.Storable.Mutable as VSM  import Foreign.Storable  import GHC.Stack  import GHC.Exts (withDict) @@ -34,14 +35,15 @@ debug = toEnum 0  -- TODO: full vjp (just some more Traversable mess)  -- TODO: if non-scalar output types are allowed, ensure that all its scalar components are WHNF evaluated before we backpropagate  {-# NOINLINE gradient' #-} -gradient' :: forall a f. (Traversable f, Num a, Storable a, Typeable a) +gradient' :: forall a f. (Traversable f, Num a, Storable a)            => HasCallStack            => Show a  -- TODO: remove -          => (forall s. Taping s a => f (Dual s a) -> Dual s a) -          -> f a -> a -> (a, f a) +          => (forall s. Taping s a => f (VDual s a) -> Dual s a) +          -> f (VS.Vector a) -> a -> (a, f (VS.Vector a))  gradient' f inp topctg = unsafePerformIO $ do    when debug $ hPutStrLn stderr "Preparing input" -  let (inp', starti) = runState (traverse (\x -> state (\i -> (Dual x i, i + 1))) inp) 0 +  let (inp', starti) = runState (traverse (\x -> state (\i -> (VDual x i, i + 1))) inp) 0 +      inpSizes = VS.fromListN starti (map VS.length (toList inp))    -- The tape starts after the input IDs.    taperef <- newIORef (Log starti Start) @@ -55,36 +57,63 @@ gradient' f inp topctg = unsafePerformIO $ do    --   hPutStrLn stderr $ "tape = " ++ tapestr ""    when debug $ hPutStrLn stderr "Backpropagating" -  accums <- VSM.new (outi+1) -  VSM.write accums outi topctg -  let backpropagate i (Cscalar i1 dx i2 dy tape') = do -        ctg <- VSM.read accums i -        when (i1 /= -1) $ VSM.modify accums (+ ctg*dx) i1 -        when (i2 /= -1) $ VSM.modify accums (+ ctg*dy) i2 -        backpropagate (i-1) tape' -      backpropagate _ Start = return () +  let (outaccS, outaccV) = backpropagate (IM.singleton outi topctg) IM.empty outi tape -  backpropagate outi tape - -  when debug $ do -    accums' <- VS.freeze accums -    hPutStrLn stderr $ "accums = " ++ show accums' +  -- when debug $ do +  --   accums' <- VS.freeze accums +  --   hPutStrLn stderr $ "accums = " ++ show accums'    when debug $ hPutStrLn stderr "Reconstructing gradient"    let readDeriv = do i <- get -                     d <- lift $ VSM.read accums i +                     let d = IM.findWithDefault (VS.replicate (inpSizes VS.! i) 0) i outaccV                       put (i+1)                       return d -  grad <- evalStateT (traverse (\_ -> readDeriv) inp) 0 +  let grad = evalState (traverse (\_ -> readDeriv) inp) 0    return (result, grad) --- | Contribution to a vector-typed value -data VCon a = VCon {-# UNPACK #-} !Int  -- ^ the ID of the vector value -                   {-# UNPACK #-} !(VS.Vector a)  -- ^ the cotangent -            | VConNothing -  deriving (Show) +backpropagate :: (Num a, Storable a) +              => IM.IntMap a -> IM.IntMap (VS.Vector a) -> Int -> Chain a -> (IM.IntMap a, IM.IntMap (VS.Vector a)) +backpropagate accS accV i (Cscalar i1 dx i2 dy tape) = +  case IM.lookup i accS of +    Nothing -> backpropagate accS accV (i-1) tape +    Just ctg -> +      let accS1 | i1 /= -1 = IM.insertWith (+) i1 (ctg*dx) accS +                | otherwise = accS +          accS2 | i2 /= -1 = IM.insertWith (+) i2 (ctg*dy) accS1 +                | otherwise = accS1 +      in backpropagate accS2 accV (i-1) tape +backpropagate accS accV i (VCfromList is tape) = +  case IM.lookup i accV of +    Nothing -> backpropagate accS accV (i-1) tape +    Just ctg -> +      let accS1 | VS.length ctg == VS.length is = +                    foldl' (\accS' idx -> IM.insertWith (+) (is VS.! idx) (ctg VS.! idx) accS') accS [0 .. VS.length ctg - 1] +                | otherwise = error "Numeric.ADDual.Array: wrong cotangent length to vfromList" +      in backpropagate accS1 accV (i-1) tape +backpropagate accS accV i (VCtoList j len tape) = +  case IM.lookupGE (i - len) accS of +    Just (smallid, _) | smallid < i -> +      let ctg = VS.fromListN len [IM.findWithDefault 0 (i - len + idx) accS | idx <- [0 .. len-1]] +          accV1 = IM.insertWith (VS.zipWith (+)) j ctg accV +      in backpropagate accS accV1 (i - 1 - len) tape +    _ -> backpropagate accS accV (i - 1 - len) tape +backpropagate accS accV i (VCsum j len tape) = +  case IM.lookup i accS of +    Nothing -> backpropagate accS accV (i-1) tape +    Just ctg -> +      let accV1 = IM.alter (\case Nothing -> Just (VS.replicate len ctg) +                                  Just d -> Just (VS.map (+ ctg) d)) +                           j accV +      in backpropagate accS accV1 (i - 1 - len) tape +backpropagate accS accV i (VCreplicate j len tape) = +  case IM.lookup i accV of +    Nothing -> backpropagate accS accV (i-1) tape +    Just ctg -> +      let accS1 = IM.insertWith (+) j (fromIntegral len * VS.sum ctg) accS +      in backpropagate accS1 accV (i - 1 - len) tape +backpropagate accS accV _ Start = (accS, accV)  data Chain a = Cscalar {-# UNPACK #-} !Int !a  -- ^ ID == -1 -> no contribution                         {-# UNPACK #-} !Int !a  -- ^ idem @@ -92,13 +121,13 @@ data Chain a = Cscalar {-# UNPACK #-} !Int !a  -- ^ ID == -1 -> no contribution               | VCfromList {-# UNPACK #-} !(VS.Vector Int)  -- ^ IDs of scalars in the input list                            !(Chain a)               | VCtoList {-# UNPACK #-} !Int  -- ^ ID of the input vector -                        {-# UNPACK #-} !Int  -- ^ start of the reserved output ID range                          {-# UNPACK #-} !Int  -- ^ number of reserved output IDs (length of the vector)                          !(Chain a)               | VCsum {-# UNPACK #-} !Int  -- ^ ID of the input vector +                     {-# UNPACK #-} !Int  -- ^ length of the vector                       !(Chain a) -             | VCreplicate {-# UNPACK #-} !Int  -- ^ length of the replicated vector -                           {-# UNPACK #-} !Int  -- ^ ID of the input scalar +             | VCreplicate {-# UNPACK #-} !Int  -- ^ ID of the input scalar +                           {-# UNPACK #-} !Int  -- ^ length of the replicated vector                             !(Chain a)               | Start    deriving (Show) @@ -166,11 +195,13 @@ instance (Storable a, Taping s a) => VectorOps (VDual s a) where    vfromList l =      let (xs, is) = unzip [(x, i) | Dual x i <- l]      in mkVDual (VS.fromList xs) (VCfromList (VS.fromList is)) -  vtoList (VDual v i) = _ -  vreplicate n (Dual x i) = mkVDual (VS.replicate n x) (VCreplicate n i) +  vtoList (VDual v i) = +    let starti = allocTapeToListUnsafe (Proxy @a) (Proxy @s) i (VS.length v) +    in zipWith Dual (VS.toList v) [starti..] +  vreplicate n (Dual x i) = mkVDual (VS.replicate n x) (VCreplicate i n)  instance (Storable a, Num a, Taping s a) => VectorOpsNum (VDual s a) where -  vsum (VDual v i) = Dual (VS.sum v) (writeTapeUnsafe @a (Proxy @s) (VCsum i)) +  vsum (VDual v i) = Dual (VS.sum v) (writeTapeUnsafe @a (Proxy @s) (VCsum i (VS.length v)))  vconstant :: VS.Vector a -> VDual s a  vconstant v = VDual v (-1) @@ -185,3 +216,10 @@ writeTapeUnsafe _ f =    unsafePerformIO $      atomicModifyIORef' (getTape @s) $ \(Log i tape) ->        (Log (i + 1) (f tape), i) + +{-# NOINLINE allocTapeToListUnsafe #-} +allocTapeToListUnsafe :: forall a s proxy. Taping s a => proxy a -> proxy s -> Int -> Int -> Int +allocTapeToListUnsafe _ _ vecid len = +  unsafePerformIO $ +    atomicModifyIORef' (getTape @s @a) $ \(Log i tape) -> +      (Log (i + len + 1) (VCtoList vecid len tape), i)  | 
