diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-02-24 22:10:47 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-02-24 22:11:24 +0100 | 
| commit | 3631b758acfb2585809fdb0755e1a8e7afe3b9b7 (patch) | |
| tree | cba5d6b110e3b7679d7558ad30e7454fc6716616 /src | |
| parent | cbb0dd08449cddd141145a2d2f280e3457279b47 (diff) | |
Dual arrays is >100x faster than 'ad' on large fneural
ad:Numeric.AD.Double / ad-dual:Numeric.ADDual.Array.Internal
Prelude> 1.129e-3 / 41.89e-6  -- neural-100
26.951539746956314
Prelude> 34.67e-3 / 156.9e-6  -- neural-180
220.9687699171447
Prelude> 79.03e-3 / 178.6e-6  -- neural-500
442.4972004479283
Prelude> 365.3e-3 / 665.5e-6  -- neural-2000
548.9105935386928
Diffstat (limited to 'src')
| -rw-r--r-- | src/Numeric/ADDual/Array/Internal.hs | 53 | ||||
| -rw-r--r-- | src/Numeric/ADDual/VectorOps.hs | 72 | 
2 files changed, 121 insertions, 4 deletions
diff --git a/src/Numeric/ADDual/Array/Internal.hs b/src/Numeric/ADDual/Array/Internal.hs index 5a4af4b..227be3e 100644 --- a/src/Numeric/ADDual/Array/Internal.hs +++ b/src/Numeric/ADDual/Array/Internal.hs @@ -15,6 +15,7 @@ import Data.List (foldl')  import qualified Data.IntMap.Strict as IM  import Data.Proxy  import qualified Data.Vector.Storable as VS +import Foreign.Ptr (castPtr)  import Foreign.Storable  import GHC.Stack  import GHC.Exts (withDict) @@ -58,7 +59,7 @@ gradient' f inp topctg = unsafePerformIO $ do    when debug $ hPutStrLn stderr "Backpropagating" -  let (outaccS, outaccV) = backpropagate (IM.singleton outi topctg) IM.empty outi tape +  let (_outaccS, outaccV) = backpropagate (IM.singleton outi topctg) IM.empty outi tape    -- when debug $ do    --   accums' <- VS.freeze accums @@ -84,6 +85,21 @@ backpropagate accS accV i (Cscalar i1 dx i2 dy tape) =            accS2 | i2 /= -1 = IM.insertWith (+) i2 (ctg*dy) accS1                  | otherwise = accS1        in backpropagate accS2 accV (i-1) tape +backpropagate accS accV i (VCarith i1 dx i2 dy tape) = +  case IM.lookup i accV of +    Nothing -> backpropagate accS accV (i-1) tape +    Just ctg -> +      let accV1 | i1 /= -1 = +                    if VS.length ctg == VS.length dx +                      then IM.insertWith vadd i1 (vmul ctg dx) accV +                      else error "Numeric.ADDual.Array: wrong cotangent length to vectorised arithmetic operation" +                | otherwise = accV +          accV2 | i2 /= -1 = +                    if VS.length ctg == VS.length dy +                      then IM.insertWith vadd i2 (vmul ctg dy) accV1 +                      else error "Numeric.ADDual.Array: wrong cotangent length to vectorised arithmetic operation" +                | otherwise = accV +      in backpropagate accS accV2 (i-1) tape  backpropagate accS accV i (VCfromList is tape) =    case IM.lookup i accV of      Nothing -> backpropagate accS accV (i-1) tape @@ -96,7 +112,7 @@ backpropagate accS accV i (VCtoList j len tape) =    case IM.lookupGE (i - len) accS of      Just (smallid, _) | smallid < i ->        let ctg = VS.fromListN len [IM.findWithDefault 0 (i - len + idx) accS | idx <- [0 .. len-1]] -          accV1 = IM.insertWith (VS.zipWith (+)) j ctg accV +          accV1 = IM.insertWith vadd j ctg accV        in backpropagate accS accV1 (i - 1 - len) tape      _ -> backpropagate accS accV (i - 1 - len) tape  backpropagate accS accV i (VCsum j len tape) = @@ -118,6 +134,9 @@ backpropagate accS accV _ Start = (accS, accV)  data Chain a = Cscalar {-# UNPACK #-} !Int !a  -- ^ ID == -1 -> no contribution                         {-# UNPACK #-} !Int !a  -- ^ idem                         !(Chain a) +             | VCarith {-# UNPACK #-} !Int {-# UNPACK #-} !(VS.Vector a)  -- ^ first argument with scale factors +                       {-# UNPACK #-} !Int {-# UNPACK #-} !(VS.Vector a)  -- ^ second argument with scale factors +                       !(Chain a)               | VCfromList {-# UNPACK #-} !(VS.Vector Int)  -- ^ IDs of scalars in the input list                            !(Chain a)               | VCtoList {-# UNPACK #-} !Int  -- ^ ID of the input vector @@ -178,6 +197,13 @@ instance (Floating a, Taping s a) => Floating (Dual s a) where    cosh = undefined ; tanh = undefined ; asinh = undefined ; acosh = undefined    atanh = undefined +-- | This instance allows breaking the abstraction of 'Dual'. Don't inspect or modify the serialised representation, and DO NOT use serialised 'Dual' values from one 'gradient'' computation in another! +instance Storable a => Storable (Dual s a) where +  sizeOf _ = sizeOf (undefined :: a) + sizeOf (undefined :: Int) +  alignment _ = alignment (undefined :: a) +  peek ptr = Dual <$> peek (castPtr ptr) <*> peekByteOff ptr (sizeOf (undefined :: a)) +  poke ptr (Dual x i) = poke (castPtr ptr) x >> pokeByteOff ptr (sizeOf (undefined :: a)) i +  constant :: a -> Dual s a  constant x = Dual x (-1) @@ -187,7 +213,7 @@ mkDual res i1 dx i2 dy = Dual res (writeTapeUnsafe (Proxy @s) (Cscalar i1 dx i2  data VDual s a = VDual !(VS.Vector a)                         {-# UNPACK #-} !Int  -- ^ -1 if this is a constant vector -instance (Storable a, Taping s a) => VectorOps (VDual s a) where +instance (Storable a, Num a, Taping s a) => VectorOps (VDual s a) where    type VectorOpsScalar (VDual s a) = Dual s a    vfromListN n l =      let (xs, is) = unzip [(x, i) | Dual x i <- l] @@ -198,11 +224,32 @@ instance (Storable a, Taping s a) => VectorOps (VDual s a) where    vtoList (VDual v i) =      let starti = allocTapeToListUnsafe (Proxy @a) (Proxy @s) i (VS.length v)      in zipWith Dual (VS.toList v) [starti..] +  vlength (VDual v _) = VS.length v    vreplicate n (Dual x i) = mkVDual (VS.replicate n x) (VCreplicate i n) +  vselect bs (VDual a i) (VDual b j) = +    mkVDual (vselect bs a b) (VCarith i (VS.map (fromIntegral . fromEnum) bs) +                                      j (VS.map (fromIntegral . fromEnum . not) bs))  instance (Storable a, Num a, Taping s a) => VectorOpsNum (VDual s a) where +  vadd (VDual v i) (VDual w j) = +    let len = VS.length v +    in mkVDual (vadd v w) (VCarith i (VS.replicate len 1) j (VS.replicate len 1)) +  vsub (VDual v i) (VDual w j) = +    let len = VS.length v +    in mkVDual (vsub v w) (VCarith i (VS.replicate len 1) j (VS.replicate len (-1))) +  vmul (VDual v i) (VDual w j) = +    mkVDual (vmul v w) (VCarith i w j v)    vsum (VDual v i) = Dual (VS.sum v) (writeTapeUnsafe @a (Proxy @s) (VCsum i (VS.length v))) +instance (Storable a, Floating a, Taping s a) => VectorOpsFloating (VDual s a) where +  vexp (VDual v i) = mkVDual (vexp v) (VCarith i v (-1) VS.empty) + +instance (Storable a, Num a, Ord a, Taping s a) => VectorOpsOrd (VDual s a) where +  vcmpLE (VDual v _) (VDual w _) = vcmpLE v w +  vmaximum (VDual v i) = +    let w = vmaximum v +    in Dual w (writeTapeUnsafe @a (Proxy @s) (VCarith i (VS.map (\x -> if x == w then 1 else 0) v) (-1) VS.empty)) +  vconstant :: VS.Vector a -> VDual s a  vconstant v = VDual v (-1) diff --git a/src/Numeric/ADDual/VectorOps.hs b/src/Numeric/ADDual/VectorOps.hs index 38063e0..9bedebe 100644 --- a/src/Numeric/ADDual/VectorOps.hs +++ b/src/Numeric/ADDual/VectorOps.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE FlexibleContexts #-}  {-# LANGUAGE TypeFamilies #-}  module Numeric.ADDual.VectorOps where @@ -14,47 +15,116 @@ class VectorOps v where    vfromListN :: Int -> [VectorOpsScalar v] -> v    vfromList :: [VectorOpsScalar v] -> v    vtoList :: v -> [VectorOpsScalar v] +  vlength :: v -> Int    vreplicate :: Int -> VectorOpsScalar v -> v +  vselect :: VS.Vector Bool -> v -> v -> v  -- ^ True selects the first argument, False the second -class VectorOpsNum v where +class (VectorOps v, Num (VectorOpsScalar v)) => VectorOpsNum v where +  vadd :: v -> v -> v +  vsub :: v -> v -> v +  vmul :: v -> v -> v    vsum :: v -> VectorOpsScalar v +class (VectorOpsNum v, Floating (VectorOpsScalar v)) => VectorOpsFloating v where +  vexp :: v -> v + +class (VectorOps v, Ord (VectorOpsScalar v)) => VectorOpsOrd v where +  vcmpLE :: v -> v -> VS.Vector Bool +  vmaximum :: v -> VectorOpsScalar v + +  vcmpLT, vcmpGT, vcmpGE :: v -> v -> VS.Vector Bool +  vcmpLT a b = VS.map not (vcmpLE b a) +  vcmpGT a b = VS.map not (vcmpLE a b) +  vcmpGE a b = vcmpLE b a +  instance VectorOps (V.Vector a) where    type VectorOpsScalar (V.Vector a) = a    vfromListN = V.fromListN    vfromList = V.fromList    vtoList = V.toList +  vlength = V.length    vreplicate = V.replicate +  vselect bs a b = V.fromListN (VS.length bs) [if bs VS.! i then a V.! i else b V.! i +                                              | i <- [0 .. VS.length bs - 1]]  instance Num a => VectorOpsNum (V.Vector a) where +  vadd = V.zipWith (+) +  vsub = V.zipWith (-) +  vmul = V.zipWith (*)    vsum = V.sum +instance Floating a => VectorOpsFloating (V.Vector a) where +  vexp = V.map exp + +instance Ord a => VectorOpsOrd (V.Vector a) where +  vcmpLE a b = VS.generate (V.length a) (\i -> a V.! i <= b V.! i) +  vmaximum = V.maximum +  instance VectorOps (VSr.Vector a) where    type VectorOpsScalar (VSr.Vector a) = a    vfromListN = VSr.fromListN    vfromList = VSr.fromList    vtoList = VSr.toList +  vlength = VSr.length    vreplicate = VSr.replicate +  vselect bs a b = VSr.fromListN (VS.length bs) [if bs VS.! i then a VSr.! i else b VSr.! i +                                                | i <- [0 .. VS.length bs - 1]]  instance Num a => VectorOpsNum (VSr.Vector a) where +  vadd = VSr.zipWith (+) +  vsub = VSr.zipWith (-) +  vmul = VSr.zipWith (*)    vsum = VSr.sum +instance Floating a => VectorOpsFloating (VSr.Vector a) where +  vexp = VSr.map exp + +instance Ord a => VectorOpsOrd (VSr.Vector a) where +  vcmpLE a b = VS.generate (VSr.length a) (\i -> a VSr.! i <= b VSr.! i) +  vmaximum = VSr.maximum +  instance Storable a => VectorOps (VS.Vector a) where    type VectorOpsScalar (VS.Vector a) = a    vfromListN = VS.fromListN    vfromList = VS.fromList    vtoList = VS.toList +  vlength = VS.length    vreplicate = VS.replicate +  vselect bs a b = VS.fromListN (VS.length bs) [if bs VS.! i then a VS.! i else b VS.! i +                                               | i <- [0 .. VS.length bs - 1]]  instance (Storable a, Num a) => VectorOpsNum (VS.Vector a) where +  vadd = VS.zipWith (+) +  vsub = VS.zipWith (-) +  vmul = VS.zipWith (*)    vsum = VS.sum +instance (Storable a, Floating a) => VectorOpsFloating (VS.Vector a) where +  vexp = VS.map exp + +instance (Storable a, Ord a) => VectorOpsOrd (VS.Vector a) where +  vcmpLE a b = VS.generate (VS.length a) (\i -> a VS.! i <= b VS.! i) +  vmaximum = VS.maximum +  instance VU.Unbox a => VectorOps (VU.Vector a) where    type VectorOpsScalar (VU.Vector a) = a    vfromListN = VU.fromListN    vfromList = VU.fromList    vtoList = VU.toList +  vlength = VU.length    vreplicate = VU.replicate +  vselect bs a b = VU.fromListN (VS.length bs) [if bs VS.! i then a VU.! i else b VU.! i +                                               | i <- [0 .. VS.length bs - 1]]  instance (VU.Unbox a, Num a) => VectorOpsNum (VU.Vector a) where +  vadd = VU.zipWith (+) +  vsub = VU.zipWith (-) +  vmul = VU.zipWith (*)    vsum = VU.sum + +instance (VU.Unbox a, Floating a) => VectorOpsFloating (VU.Vector a) where +  vexp = VU.map exp + +instance (VU.Unbox a, Ord a) => VectorOpsOrd (VU.Vector a) where +  vcmpLE a b = VS.generate (VU.length a) (\i -> a VU.! i <= b VU.! i) +  vmaximum = VU.maximum  | 
