diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2025-02-24 13:42:45 +0100 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-02-24 13:42:45 +0100 |
commit | cbb0dd08449cddd141145a2d2f280e3457279b47 (patch) | |
tree | 0997ddd61326e0bd5c74532bbbf1e6aba2bc7902 /src/Numeric | |
parent | f84c5b0cab7e819cdaae8288a06641973cf83437 (diff) |
WIP array stuff untested
Diffstat (limited to 'src/Numeric')
-rw-r--r-- | src/Numeric/ADDual/Array/Internal.hs | 104 |
1 files changed, 71 insertions, 33 deletions
diff --git a/src/Numeric/ADDual/Array/Internal.hs b/src/Numeric/ADDual/Array/Internal.hs index 1cc2796..5a4af4b 100644 --- a/src/Numeric/ADDual/Array/Internal.hs +++ b/src/Numeric/ADDual/Array/Internal.hs @@ -1,4 +1,5 @@ {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -7,13 +8,13 @@ module Numeric.ADDual.Array.Internal where import Control.Monad (when) -import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.State.Strict +import Data.Foldable (toList) import Data.IORef +import Data.List (foldl') +import qualified Data.IntMap.Strict as IM import Data.Proxy -import Data.Typeable import qualified Data.Vector.Storable as VS -import qualified Data.Vector.Storable.Mutable as VSM import Foreign.Storable import GHC.Stack import GHC.Exts (withDict) @@ -34,14 +35,15 @@ debug = toEnum 0 -- TODO: full vjp (just some more Traversable mess) -- TODO: if non-scalar output types are allowed, ensure that all its scalar components are WHNF evaluated before we backpropagate {-# NOINLINE gradient' #-} -gradient' :: forall a f. (Traversable f, Num a, Storable a, Typeable a) +gradient' :: forall a f. (Traversable f, Num a, Storable a) => HasCallStack => Show a -- TODO: remove - => (forall s. Taping s a => f (Dual s a) -> Dual s a) - -> f a -> a -> (a, f a) + => (forall s. Taping s a => f (VDual s a) -> Dual s a) + -> f (VS.Vector a) -> a -> (a, f (VS.Vector a)) gradient' f inp topctg = unsafePerformIO $ do when debug $ hPutStrLn stderr "Preparing input" - let (inp', starti) = runState (traverse (\x -> state (\i -> (Dual x i, i + 1))) inp) 0 + let (inp', starti) = runState (traverse (\x -> state (\i -> (VDual x i, i + 1))) inp) 0 + inpSizes = VS.fromListN starti (map VS.length (toList inp)) -- The tape starts after the input IDs. taperef <- newIORef (Log starti Start) @@ -55,36 +57,63 @@ gradient' f inp topctg = unsafePerformIO $ do -- hPutStrLn stderr $ "tape = " ++ tapestr "" when debug $ hPutStrLn stderr "Backpropagating" - accums <- VSM.new (outi+1) - VSM.write accums outi topctg - let backpropagate i (Cscalar i1 dx i2 dy tape') = do - ctg <- VSM.read accums i - when (i1 /= -1) $ VSM.modify accums (+ ctg*dx) i1 - when (i2 /= -1) $ VSM.modify accums (+ ctg*dy) i2 - backpropagate (i-1) tape' - backpropagate _ Start = return () + let (outaccS, outaccV) = backpropagate (IM.singleton outi topctg) IM.empty outi tape - backpropagate outi tape - - when debug $ do - accums' <- VS.freeze accums - hPutStrLn stderr $ "accums = " ++ show accums' + -- when debug $ do + -- accums' <- VS.freeze accums + -- hPutStrLn stderr $ "accums = " ++ show accums' when debug $ hPutStrLn stderr "Reconstructing gradient" let readDeriv = do i <- get - d <- lift $ VSM.read accums i + let d = IM.findWithDefault (VS.replicate (inpSizes VS.! i) 0) i outaccV put (i+1) return d - grad <- evalStateT (traverse (\_ -> readDeriv) inp) 0 + let grad = evalState (traverse (\_ -> readDeriv) inp) 0 return (result, grad) --- | Contribution to a vector-typed value -data VCon a = VCon {-# UNPACK #-} !Int -- ^ the ID of the vector value - {-# UNPACK #-} !(VS.Vector a) -- ^ the cotangent - | VConNothing - deriving (Show) +backpropagate :: (Num a, Storable a) + => IM.IntMap a -> IM.IntMap (VS.Vector a) -> Int -> Chain a -> (IM.IntMap a, IM.IntMap (VS.Vector a)) +backpropagate accS accV i (Cscalar i1 dx i2 dy tape) = + case IM.lookup i accS of + Nothing -> backpropagate accS accV (i-1) tape + Just ctg -> + let accS1 | i1 /= -1 = IM.insertWith (+) i1 (ctg*dx) accS + | otherwise = accS + accS2 | i2 /= -1 = IM.insertWith (+) i2 (ctg*dy) accS1 + | otherwise = accS1 + in backpropagate accS2 accV (i-1) tape +backpropagate accS accV i (VCfromList is tape) = + case IM.lookup i accV of + Nothing -> backpropagate accS accV (i-1) tape + Just ctg -> + let accS1 | VS.length ctg == VS.length is = + foldl' (\accS' idx -> IM.insertWith (+) (is VS.! idx) (ctg VS.! idx) accS') accS [0 .. VS.length ctg - 1] + | otherwise = error "Numeric.ADDual.Array: wrong cotangent length to vfromList" + in backpropagate accS1 accV (i-1) tape +backpropagate accS accV i (VCtoList j len tape) = + case IM.lookupGE (i - len) accS of + Just (smallid, _) | smallid < i -> + let ctg = VS.fromListN len [IM.findWithDefault 0 (i - len + idx) accS | idx <- [0 .. len-1]] + accV1 = IM.insertWith (VS.zipWith (+)) j ctg accV + in backpropagate accS accV1 (i - 1 - len) tape + _ -> backpropagate accS accV (i - 1 - len) tape +backpropagate accS accV i (VCsum j len tape) = + case IM.lookup i accS of + Nothing -> backpropagate accS accV (i-1) tape + Just ctg -> + let accV1 = IM.alter (\case Nothing -> Just (VS.replicate len ctg) + Just d -> Just (VS.map (+ ctg) d)) + j accV + in backpropagate accS accV1 (i - 1 - len) tape +backpropagate accS accV i (VCreplicate j len tape) = + case IM.lookup i accV of + Nothing -> backpropagate accS accV (i-1) tape + Just ctg -> + let accS1 = IM.insertWith (+) j (fromIntegral len * VS.sum ctg) accS + in backpropagate accS1 accV (i - 1 - len) tape +backpropagate accS accV _ Start = (accS, accV) data Chain a = Cscalar {-# UNPACK #-} !Int !a -- ^ ID == -1 -> no contribution {-# UNPACK #-} !Int !a -- ^ idem @@ -92,13 +121,13 @@ data Chain a = Cscalar {-# UNPACK #-} !Int !a -- ^ ID == -1 -> no contribution | VCfromList {-# UNPACK #-} !(VS.Vector Int) -- ^ IDs of scalars in the input list !(Chain a) | VCtoList {-# UNPACK #-} !Int -- ^ ID of the input vector - {-# UNPACK #-} !Int -- ^ start of the reserved output ID range {-# UNPACK #-} !Int -- ^ number of reserved output IDs (length of the vector) !(Chain a) | VCsum {-# UNPACK #-} !Int -- ^ ID of the input vector + {-# UNPACK #-} !Int -- ^ length of the vector !(Chain a) - | VCreplicate {-# UNPACK #-} !Int -- ^ length of the replicated vector - {-# UNPACK #-} !Int -- ^ ID of the input scalar + | VCreplicate {-# UNPACK #-} !Int -- ^ ID of the input scalar + {-# UNPACK #-} !Int -- ^ length of the replicated vector !(Chain a) | Start deriving (Show) @@ -166,11 +195,13 @@ instance (Storable a, Taping s a) => VectorOps (VDual s a) where vfromList l = let (xs, is) = unzip [(x, i) | Dual x i <- l] in mkVDual (VS.fromList xs) (VCfromList (VS.fromList is)) - vtoList (VDual v i) = _ - vreplicate n (Dual x i) = mkVDual (VS.replicate n x) (VCreplicate n i) + vtoList (VDual v i) = + let starti = allocTapeToListUnsafe (Proxy @a) (Proxy @s) i (VS.length v) + in zipWith Dual (VS.toList v) [starti..] + vreplicate n (Dual x i) = mkVDual (VS.replicate n x) (VCreplicate i n) instance (Storable a, Num a, Taping s a) => VectorOpsNum (VDual s a) where - vsum (VDual v i) = Dual (VS.sum v) (writeTapeUnsafe @a (Proxy @s) (VCsum i)) + vsum (VDual v i) = Dual (VS.sum v) (writeTapeUnsafe @a (Proxy @s) (VCsum i (VS.length v))) vconstant :: VS.Vector a -> VDual s a vconstant v = VDual v (-1) @@ -185,3 +216,10 @@ writeTapeUnsafe _ f = unsafePerformIO $ atomicModifyIORef' (getTape @s) $ \(Log i tape) -> (Log (i + 1) (f tape), i) + +{-# NOINLINE allocTapeToListUnsafe #-} +allocTapeToListUnsafe :: forall a s proxy. Taping s a => proxy a -> proxy s -> Int -> Int -> Int +allocTapeToListUnsafe _ _ vecid len = + unsafePerformIO $ + atomicModifyIORef' (getTape @s @a) $ \(Log i tape) -> + (Log (i + len + 1) (VCtoList vecid len tape), i) |