diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-02-24 22:10:47 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-02-24 22:11:24 +0100 |
commit | 3631b758acfb2585809fdb0755e1a8e7afe3b9b7 (patch) | |
tree | cba5d6b110e3b7679d7558ad30e7454fc6716616 /src/Numeric/ADDual/Array | |
parent | cbb0dd08449cddd141145a2d2f280e3457279b47 (diff) |
ad:Numeric.AD.Double / ad-dual:Numeric.ADDual.Array.Internal
Prelude> 1.129e-3 / 41.89e-6 -- neural-100
26.951539746956314
Prelude> 34.67e-3 / 156.9e-6 -- neural-180
220.9687699171447
Prelude> 79.03e-3 / 178.6e-6 -- neural-500
442.4972004479283
Prelude> 365.3e-3 / 665.5e-6 -- neural-2000
548.9105935386928
Diffstat (limited to 'src/Numeric/ADDual/Array')
-rw-r--r-- | src/Numeric/ADDual/Array/Internal.hs | 53 |
1 files changed, 50 insertions, 3 deletions
diff --git a/src/Numeric/ADDual/Array/Internal.hs b/src/Numeric/ADDual/Array/Internal.hs index 5a4af4b..227be3e 100644 --- a/src/Numeric/ADDual/Array/Internal.hs +++ b/src/Numeric/ADDual/Array/Internal.hs @@ -15,6 +15,7 @@ import Data.List (foldl') import qualified Data.IntMap.Strict as IM import Data.Proxy import qualified Data.Vector.Storable as VS +import Foreign.Ptr (castPtr) import Foreign.Storable import GHC.Stack import GHC.Exts (withDict) @@ -58,7 +59,7 @@ gradient' f inp topctg = unsafePerformIO $ do when debug $ hPutStrLn stderr "Backpropagating" - let (outaccS, outaccV) = backpropagate (IM.singleton outi topctg) IM.empty outi tape + let (_outaccS, outaccV) = backpropagate (IM.singleton outi topctg) IM.empty outi tape -- when debug $ do -- accums' <- VS.freeze accums @@ -84,6 +85,21 @@ backpropagate accS accV i (Cscalar i1 dx i2 dy tape) = accS2 | i2 /= -1 = IM.insertWith (+) i2 (ctg*dy) accS1 | otherwise = accS1 in backpropagate accS2 accV (i-1) tape +backpropagate accS accV i (VCarith i1 dx i2 dy tape) = + case IM.lookup i accV of + Nothing -> backpropagate accS accV (i-1) tape + Just ctg -> + let accV1 | i1 /= -1 = + if VS.length ctg == VS.length dx + then IM.insertWith vadd i1 (vmul ctg dx) accV + else error "Numeric.ADDual.Array: wrong cotangent length to vectorised arithmetic operation" + | otherwise = accV + accV2 | i2 /= -1 = + if VS.length ctg == VS.length dy + then IM.insertWith vadd i2 (vmul ctg dy) accV1 + else error "Numeric.ADDual.Array: wrong cotangent length to vectorised arithmetic operation" + | otherwise = accV + in backpropagate accS accV2 (i-1) tape backpropagate accS accV i (VCfromList is tape) = case IM.lookup i accV of Nothing -> backpropagate accS accV (i-1) tape @@ -96,7 +112,7 @@ backpropagate accS accV i (VCtoList j len tape) = case IM.lookupGE (i - len) accS of Just (smallid, _) | smallid < i -> let ctg = VS.fromListN len [IM.findWithDefault 0 (i - len + idx) accS | idx <- [0 .. len-1]] - accV1 = IM.insertWith (VS.zipWith (+)) j ctg accV + accV1 = IM.insertWith vadd j ctg accV in backpropagate accS accV1 (i - 1 - len) tape _ -> backpropagate accS accV (i - 1 - len) tape backpropagate accS accV i (VCsum j len tape) = @@ -118,6 +134,9 @@ backpropagate accS accV _ Start = (accS, accV) data Chain a = Cscalar {-# UNPACK #-} !Int !a -- ^ ID == -1 -> no contribution {-# UNPACK #-} !Int !a -- ^ idem !(Chain a) + | VCarith {-# UNPACK #-} !Int {-# UNPACK #-} !(VS.Vector a) -- ^ first argument with scale factors + {-# UNPACK #-} !Int {-# UNPACK #-} !(VS.Vector a) -- ^ second argument with scale factors + !(Chain a) | VCfromList {-# UNPACK #-} !(VS.Vector Int) -- ^ IDs of scalars in the input list !(Chain a) | VCtoList {-# UNPACK #-} !Int -- ^ ID of the input vector @@ -178,6 +197,13 @@ instance (Floating a, Taping s a) => Floating (Dual s a) where cosh = undefined ; tanh = undefined ; asinh = undefined ; acosh = undefined atanh = undefined +-- | This instance allows breaking the abstraction of 'Dual'. Don't inspect or modify the serialised representation, and DO NOT use serialised 'Dual' values from one 'gradient'' computation in another! +instance Storable a => Storable (Dual s a) where + sizeOf _ = sizeOf (undefined :: a) + sizeOf (undefined :: Int) + alignment _ = alignment (undefined :: a) + peek ptr = Dual <$> peek (castPtr ptr) <*> peekByteOff ptr (sizeOf (undefined :: a)) + poke ptr (Dual x i) = poke (castPtr ptr) x >> pokeByteOff ptr (sizeOf (undefined :: a)) i + constant :: a -> Dual s a constant x = Dual x (-1) @@ -187,7 +213,7 @@ mkDual res i1 dx i2 dy = Dual res (writeTapeUnsafe (Proxy @s) (Cscalar i1 dx i2 data VDual s a = VDual !(VS.Vector a) {-# UNPACK #-} !Int -- ^ -1 if this is a constant vector -instance (Storable a, Taping s a) => VectorOps (VDual s a) where +instance (Storable a, Num a, Taping s a) => VectorOps (VDual s a) where type VectorOpsScalar (VDual s a) = Dual s a vfromListN n l = let (xs, is) = unzip [(x, i) | Dual x i <- l] @@ -198,11 +224,32 @@ instance (Storable a, Taping s a) => VectorOps (VDual s a) where vtoList (VDual v i) = let starti = allocTapeToListUnsafe (Proxy @a) (Proxy @s) i (VS.length v) in zipWith Dual (VS.toList v) [starti..] + vlength (VDual v _) = VS.length v vreplicate n (Dual x i) = mkVDual (VS.replicate n x) (VCreplicate i n) + vselect bs (VDual a i) (VDual b j) = + mkVDual (vselect bs a b) (VCarith i (VS.map (fromIntegral . fromEnum) bs) + j (VS.map (fromIntegral . fromEnum . not) bs)) instance (Storable a, Num a, Taping s a) => VectorOpsNum (VDual s a) where + vadd (VDual v i) (VDual w j) = + let len = VS.length v + in mkVDual (vadd v w) (VCarith i (VS.replicate len 1) j (VS.replicate len 1)) + vsub (VDual v i) (VDual w j) = + let len = VS.length v + in mkVDual (vsub v w) (VCarith i (VS.replicate len 1) j (VS.replicate len (-1))) + vmul (VDual v i) (VDual w j) = + mkVDual (vmul v w) (VCarith i w j v) vsum (VDual v i) = Dual (VS.sum v) (writeTapeUnsafe @a (Proxy @s) (VCsum i (VS.length v))) +instance (Storable a, Floating a, Taping s a) => VectorOpsFloating (VDual s a) where + vexp (VDual v i) = mkVDual (vexp v) (VCarith i v (-1) VS.empty) + +instance (Storable a, Num a, Ord a, Taping s a) => VectorOpsOrd (VDual s a) where + vcmpLE (VDual v _) (VDual w _) = vcmpLE v w + vmaximum (VDual v i) = + let w = vmaximum v + in Dual w (writeTapeUnsafe @a (Proxy @s) (VCarith i (VS.map (\x -> if x == w then 1 else 0) v) (-1) VS.empty)) + vconstant :: VS.Vector a -> VDual s a vconstant v = VDual v (-1) |