diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-02-24 22:10:47 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-02-24 22:11:24 +0100 |
commit | 3631b758acfb2585809fdb0755e1a8e7afe3b9b7 (patch) | |
tree | cba5d6b110e3b7679d7558ad30e7454fc6716616 | |
parent | cbb0dd08449cddd141145a2d2f280e3457279b47 (diff) |
ad:Numeric.AD.Double / ad-dual:Numeric.ADDual.Array.Internal
Prelude> 1.129e-3 / 41.89e-6 -- neural-100
26.951539746956314
Prelude> 34.67e-3 / 156.9e-6 -- neural-180
220.9687699171447
Prelude> 79.03e-3 / 178.6e-6 -- neural-500
442.4972004479283
Prelude> 365.3e-3 / 665.5e-6 -- neural-2000
548.9105935386928
-rw-r--r-- | ad-dual.cabal | 1 | ||||
-rw-r--r-- | bench/Main.hs | 15 | ||||
-rw-r--r-- | examples/Numeric/ADDual/Examples.hs | 61 | ||||
-rw-r--r-- | src/Numeric/ADDual/Array/Internal.hs | 53 | ||||
-rw-r--r-- | src/Numeric/ADDual/VectorOps.hs | 72 | ||||
-rw-r--r-- | test/Main.hs | 44 |
6 files changed, 223 insertions, 23 deletions
diff --git a/ad-dual.cabal b/ad-dual.cabal index 47340e5..82c1342 100644 --- a/ad-dual.cabal +++ b/ad-dual.cabal @@ -31,6 +31,7 @@ library ad-dual-examples exposed-modules: Numeric.ADDual.Examples build-depends: + ad-dual, deepseq, hedgehog, vector diff --git a/bench/Main.hs b/bench/Main.hs index 1174a3a..cccc686 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -11,8 +11,10 @@ import System.Environment (getArgs) import System.Mem (performGC) import qualified Numeric.AD as AD +import qualified Numeric.AD.Double as AD.Double import qualified Numeric.ADDual as ADD +import qualified Numeric.ADDual.Array.Internal as ADDA import Numeric.ADDual.Examples @@ -29,6 +31,10 @@ mainCriterion = defaultMain ,benchNeural 180 -- rather stably 2 GCs ,benchNeural 500 ,benchNeural 2000 + ,benchNeuralA 100 + ,benchNeuralA 180 -- rather stably 2 GCs + ,benchNeuralA 500 + ,benchNeuralA 2000 ] where benchNeural :: Int -> Benchmark @@ -36,7 +42,14 @@ mainCriterion = defaultMain env (pure (makeNeuralInput n)) $ \input -> bgroup ("neural-" ++ show n) [bench "dual" $ nf (\inp -> ADD.gradient' fneural inp 1.0) input - ,bench "ad" $ nf (\inp -> AD.grad fneural inp) input] + ,bench "ad" $ nf (\inp -> AD.grad fneural inp) input + ,bench "ad.Double" $ nf (\inp -> AD.Double.grad fneural inp) input] + + benchNeuralA :: Int -> Benchmark + benchNeuralA n = + env (pure (makeNeuralInput_A n)) $ \input -> + bgroup ("neuralA-" ++ show n) + [bench "dual" $ nf (\inp -> ADDA.gradient' fneural_A inp 1.0) input] mainNeuralGraph :: IO () mainNeuralGraph = do diff --git a/examples/Numeric/ADDual/Examples.hs b/examples/Numeric/ADDual/Examples.hs index 819aec4..3835daa 100644 --- a/examples/Numeric/ADDual/Examples.hs +++ b/examples/Numeric/ADDual/Examples.hs @@ -1,12 +1,15 @@ {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} module Numeric.ADDual.Examples where import Control.DeepSeq import Control.Monad (replicateM) +import Data.Bifunctor (bimap) import Data.Maybe (catMaybes) import qualified Data.Vector as V +import qualified Data.Vector.Storable as VS import GHC.Generics (Generic) import Hedgehog (Gen, Size) import qualified Hedgehog.Gen as Gen @@ -15,14 +18,21 @@ import qualified Hedgehog.Internal.Gen as HI.Gen import qualified Hedgehog.Internal.Seed as HI.Seed import qualified Hedgehog.Internal.Tree as HI.Tree +import Numeric.ADDual.VectorOps + type Matrix s = V.Vector s data FNeural a = FNeural [(Matrix a, V.Vector a)] (V.Vector a) deriving (Show, Eq, Functor, Foldable, Traversable, Generic) - instance NFData a => NFData (FNeural a) +type SMatrix s = VS.Vector s + +data FNeuralA v = FNeuralA [(V.Vector v, v)] v + deriving (Show, Eq, Functor, Foldable, Traversable, Generic) +instance NFData v => NFData (FNeuralA v) + fneural :: (Floating a, Ord a) => FNeural a -> a fneural (FNeural layers input) = let dotp v1 v2 = V.sum (V.zipWith (*) v1 v2) @@ -44,20 +54,59 @@ fneural (FNeural layers input) = in V.sum $ forward layers input makeNeuralInput :: Int -> FNeural Double -makeNeuralInput scale = sampleGenPure 100 (genNeuralInput scale) +makeNeuralInput scale = cvtFNeuralAtoFNeural $ makeNeuralInput_A scale genNeuralInput :: Int -> Gen (FNeural Double) -genNeuralInput scale = do +genNeuralInput scale = cvtFNeuralAtoFNeural <$> genNeuralInput_A scale + +cvtFNeuralAtoFNeural :: FNeuralA (VS.Vector Double) -> FNeural Double +cvtFNeuralAtoFNeural (FNeuralA layers input) = + FNeural (map (bimap (\m -> let nin = V.length m + nout = VS.length (m V.! 0) + in V.fromListN (nin*nout) $ concatMap VS.toList $ V.toList m) + (\v -> let n = VS.length v in V.fromListN n (VS.toList v))) + layers) + (let n = VS.length input in V.fromListN n (VS.toList input)) + +fneural_A :: forall v. (VectorOpsFloating v, VectorOpsOrd v) + => FNeuralA v -> VectorOpsScalar v +fneural_A (FNeuralA layers input) = + let dotp v1 v2 = vsum (vmul v1 v2) + + (@.) :: V.Vector v -> v -> v + mat @. vec = + let n = vlength vec + m = V.length mat `div` n + in vfromListN m $ map (\row -> dotp row vec) (V.toList mat) + (+.) = vadd + + batchrelu :: v -> v + batchrelu x = vselect (vcmpGE x (vreplicate (vlength x) 0.0)) x (vreplicate (vlength x) 0.0) + safeSoftmax vec = let m = vmaximum vec + exps = vexp (vsub vec (vreplicate (vlength vec) m)) + factor = vsum exps + in vmul exps (vreplicate (vlength vec) (recip factor)) + forward [] x = safeSoftmax x + forward ((weights, bias) : lys) x = + let x' = batchrelu ((weights @. x) +. bias) + in forward lys x' + in vsum $ forward layers input + +makeNeuralInput_A :: Int -> FNeuralA (VS.Vector Double) +makeNeuralInput_A scale = sampleGenPure 100 (genNeuralInput_A scale) + +genNeuralInput_A :: Int -> Gen (FNeuralA (VS.Vector Double)) +genNeuralInput_A scale = do let genScalar = Gen.double (Range.linearFracFrom 0 (-1) 1) - genMatrix nin nout = V.fromListN (nin*nout) <$> replicateM (nin*nout) genScalar - genVector nout = V.fromListN nout <$> replicateM nout genScalar + genVector nout = VS.fromListN nout <$> replicateM nout genScalar + genMatrix nin nout = V.fromListN nin <$> replicateM nin (genVector nout) nIn <- Gen.integral (Range.linear 1 scale) n1 <- Gen.integral (Range.linear 1 scale) n2 <- Gen.integral (Range.linear 1 scale) m1 <- genMatrix nIn n1; v1 <- genVector n1 m2 <- genMatrix n1 n2; v2 <- genVector n2 inp <- genVector nIn - pure $ FNeural [(m1, v1), (m2, v2)] inp + pure $ FNeuralA [(m1, v1), (m2, v2)] inp sampleGenPure :: Size -> Gen a -> a diff --git a/src/Numeric/ADDual/Array/Internal.hs b/src/Numeric/ADDual/Array/Internal.hs index 5a4af4b..227be3e 100644 --- a/src/Numeric/ADDual/Array/Internal.hs +++ b/src/Numeric/ADDual/Array/Internal.hs @@ -15,6 +15,7 @@ import Data.List (foldl') import qualified Data.IntMap.Strict as IM import Data.Proxy import qualified Data.Vector.Storable as VS +import Foreign.Ptr (castPtr) import Foreign.Storable import GHC.Stack import GHC.Exts (withDict) @@ -58,7 +59,7 @@ gradient' f inp topctg = unsafePerformIO $ do when debug $ hPutStrLn stderr "Backpropagating" - let (outaccS, outaccV) = backpropagate (IM.singleton outi topctg) IM.empty outi tape + let (_outaccS, outaccV) = backpropagate (IM.singleton outi topctg) IM.empty outi tape -- when debug $ do -- accums' <- VS.freeze accums @@ -84,6 +85,21 @@ backpropagate accS accV i (Cscalar i1 dx i2 dy tape) = accS2 | i2 /= -1 = IM.insertWith (+) i2 (ctg*dy) accS1 | otherwise = accS1 in backpropagate accS2 accV (i-1) tape +backpropagate accS accV i (VCarith i1 dx i2 dy tape) = + case IM.lookup i accV of + Nothing -> backpropagate accS accV (i-1) tape + Just ctg -> + let accV1 | i1 /= -1 = + if VS.length ctg == VS.length dx + then IM.insertWith vadd i1 (vmul ctg dx) accV + else error "Numeric.ADDual.Array: wrong cotangent length to vectorised arithmetic operation" + | otherwise = accV + accV2 | i2 /= -1 = + if VS.length ctg == VS.length dy + then IM.insertWith vadd i2 (vmul ctg dy) accV1 + else error "Numeric.ADDual.Array: wrong cotangent length to vectorised arithmetic operation" + | otherwise = accV + in backpropagate accS accV2 (i-1) tape backpropagate accS accV i (VCfromList is tape) = case IM.lookup i accV of Nothing -> backpropagate accS accV (i-1) tape @@ -96,7 +112,7 @@ backpropagate accS accV i (VCtoList j len tape) = case IM.lookupGE (i - len) accS of Just (smallid, _) | smallid < i -> let ctg = VS.fromListN len [IM.findWithDefault 0 (i - len + idx) accS | idx <- [0 .. len-1]] - accV1 = IM.insertWith (VS.zipWith (+)) j ctg accV + accV1 = IM.insertWith vadd j ctg accV in backpropagate accS accV1 (i - 1 - len) tape _ -> backpropagate accS accV (i - 1 - len) tape backpropagate accS accV i (VCsum j len tape) = @@ -118,6 +134,9 @@ backpropagate accS accV _ Start = (accS, accV) data Chain a = Cscalar {-# UNPACK #-} !Int !a -- ^ ID == -1 -> no contribution {-# UNPACK #-} !Int !a -- ^ idem !(Chain a) + | VCarith {-# UNPACK #-} !Int {-# UNPACK #-} !(VS.Vector a) -- ^ first argument with scale factors + {-# UNPACK #-} !Int {-# UNPACK #-} !(VS.Vector a) -- ^ second argument with scale factors + !(Chain a) | VCfromList {-# UNPACK #-} !(VS.Vector Int) -- ^ IDs of scalars in the input list !(Chain a) | VCtoList {-# UNPACK #-} !Int -- ^ ID of the input vector @@ -178,6 +197,13 @@ instance (Floating a, Taping s a) => Floating (Dual s a) where cosh = undefined ; tanh = undefined ; asinh = undefined ; acosh = undefined atanh = undefined +-- | This instance allows breaking the abstraction of 'Dual'. Don't inspect or modify the serialised representation, and DO NOT use serialised 'Dual' values from one 'gradient'' computation in another! +instance Storable a => Storable (Dual s a) where + sizeOf _ = sizeOf (undefined :: a) + sizeOf (undefined :: Int) + alignment _ = alignment (undefined :: a) + peek ptr = Dual <$> peek (castPtr ptr) <*> peekByteOff ptr (sizeOf (undefined :: a)) + poke ptr (Dual x i) = poke (castPtr ptr) x >> pokeByteOff ptr (sizeOf (undefined :: a)) i + constant :: a -> Dual s a constant x = Dual x (-1) @@ -187,7 +213,7 @@ mkDual res i1 dx i2 dy = Dual res (writeTapeUnsafe (Proxy @s) (Cscalar i1 dx i2 data VDual s a = VDual !(VS.Vector a) {-# UNPACK #-} !Int -- ^ -1 if this is a constant vector -instance (Storable a, Taping s a) => VectorOps (VDual s a) where +instance (Storable a, Num a, Taping s a) => VectorOps (VDual s a) where type VectorOpsScalar (VDual s a) = Dual s a vfromListN n l = let (xs, is) = unzip [(x, i) | Dual x i <- l] @@ -198,11 +224,32 @@ instance (Storable a, Taping s a) => VectorOps (VDual s a) where vtoList (VDual v i) = let starti = allocTapeToListUnsafe (Proxy @a) (Proxy @s) i (VS.length v) in zipWith Dual (VS.toList v) [starti..] + vlength (VDual v _) = VS.length v vreplicate n (Dual x i) = mkVDual (VS.replicate n x) (VCreplicate i n) + vselect bs (VDual a i) (VDual b j) = + mkVDual (vselect bs a b) (VCarith i (VS.map (fromIntegral . fromEnum) bs) + j (VS.map (fromIntegral . fromEnum . not) bs)) instance (Storable a, Num a, Taping s a) => VectorOpsNum (VDual s a) where + vadd (VDual v i) (VDual w j) = + let len = VS.length v + in mkVDual (vadd v w) (VCarith i (VS.replicate len 1) j (VS.replicate len 1)) + vsub (VDual v i) (VDual w j) = + let len = VS.length v + in mkVDual (vsub v w) (VCarith i (VS.replicate len 1) j (VS.replicate len (-1))) + vmul (VDual v i) (VDual w j) = + mkVDual (vmul v w) (VCarith i w j v) vsum (VDual v i) = Dual (VS.sum v) (writeTapeUnsafe @a (Proxy @s) (VCsum i (VS.length v))) +instance (Storable a, Floating a, Taping s a) => VectorOpsFloating (VDual s a) where + vexp (VDual v i) = mkVDual (vexp v) (VCarith i v (-1) VS.empty) + +instance (Storable a, Num a, Ord a, Taping s a) => VectorOpsOrd (VDual s a) where + vcmpLE (VDual v _) (VDual w _) = vcmpLE v w + vmaximum (VDual v i) = + let w = vmaximum v + in Dual w (writeTapeUnsafe @a (Proxy @s) (VCarith i (VS.map (\x -> if x == w then 1 else 0) v) (-1) VS.empty)) + vconstant :: VS.Vector a -> VDual s a vconstant v = VDual v (-1) diff --git a/src/Numeric/ADDual/VectorOps.hs b/src/Numeric/ADDual/VectorOps.hs index 38063e0..9bedebe 100644 --- a/src/Numeric/ADDual/VectorOps.hs +++ b/src/Numeric/ADDual/VectorOps.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeFamilies #-} module Numeric.ADDual.VectorOps where @@ -14,47 +15,116 @@ class VectorOps v where vfromListN :: Int -> [VectorOpsScalar v] -> v vfromList :: [VectorOpsScalar v] -> v vtoList :: v -> [VectorOpsScalar v] + vlength :: v -> Int vreplicate :: Int -> VectorOpsScalar v -> v + vselect :: VS.Vector Bool -> v -> v -> v -- ^ True selects the first argument, False the second -class VectorOpsNum v where +class (VectorOps v, Num (VectorOpsScalar v)) => VectorOpsNum v where + vadd :: v -> v -> v + vsub :: v -> v -> v + vmul :: v -> v -> v vsum :: v -> VectorOpsScalar v +class (VectorOpsNum v, Floating (VectorOpsScalar v)) => VectorOpsFloating v where + vexp :: v -> v + +class (VectorOps v, Ord (VectorOpsScalar v)) => VectorOpsOrd v where + vcmpLE :: v -> v -> VS.Vector Bool + vmaximum :: v -> VectorOpsScalar v + + vcmpLT, vcmpGT, vcmpGE :: v -> v -> VS.Vector Bool + vcmpLT a b = VS.map not (vcmpLE b a) + vcmpGT a b = VS.map not (vcmpLE a b) + vcmpGE a b = vcmpLE b a + instance VectorOps (V.Vector a) where type VectorOpsScalar (V.Vector a) = a vfromListN = V.fromListN vfromList = V.fromList vtoList = V.toList + vlength = V.length vreplicate = V.replicate + vselect bs a b = V.fromListN (VS.length bs) [if bs VS.! i then a V.! i else b V.! i + | i <- [0 .. VS.length bs - 1]] instance Num a => VectorOpsNum (V.Vector a) where + vadd = V.zipWith (+) + vsub = V.zipWith (-) + vmul = V.zipWith (*) vsum = V.sum +instance Floating a => VectorOpsFloating (V.Vector a) where + vexp = V.map exp + +instance Ord a => VectorOpsOrd (V.Vector a) where + vcmpLE a b = VS.generate (V.length a) (\i -> a V.! i <= b V.! i) + vmaximum = V.maximum + instance VectorOps (VSr.Vector a) where type VectorOpsScalar (VSr.Vector a) = a vfromListN = VSr.fromListN vfromList = VSr.fromList vtoList = VSr.toList + vlength = VSr.length vreplicate = VSr.replicate + vselect bs a b = VSr.fromListN (VS.length bs) [if bs VS.! i then a VSr.! i else b VSr.! i + | i <- [0 .. VS.length bs - 1]] instance Num a => VectorOpsNum (VSr.Vector a) where + vadd = VSr.zipWith (+) + vsub = VSr.zipWith (-) + vmul = VSr.zipWith (*) vsum = VSr.sum +instance Floating a => VectorOpsFloating (VSr.Vector a) where + vexp = VSr.map exp + +instance Ord a => VectorOpsOrd (VSr.Vector a) where + vcmpLE a b = VS.generate (VSr.length a) (\i -> a VSr.! i <= b VSr.! i) + vmaximum = VSr.maximum + instance Storable a => VectorOps (VS.Vector a) where type VectorOpsScalar (VS.Vector a) = a vfromListN = VS.fromListN vfromList = VS.fromList vtoList = VS.toList + vlength = VS.length vreplicate = VS.replicate + vselect bs a b = VS.fromListN (VS.length bs) [if bs VS.! i then a VS.! i else b VS.! i + | i <- [0 .. VS.length bs - 1]] instance (Storable a, Num a) => VectorOpsNum (VS.Vector a) where + vadd = VS.zipWith (+) + vsub = VS.zipWith (-) + vmul = VS.zipWith (*) vsum = VS.sum +instance (Storable a, Floating a) => VectorOpsFloating (VS.Vector a) where + vexp = VS.map exp + +instance (Storable a, Ord a) => VectorOpsOrd (VS.Vector a) where + vcmpLE a b = VS.generate (VS.length a) (\i -> a VS.! i <= b VS.! i) + vmaximum = VS.maximum + instance VU.Unbox a => VectorOps (VU.Vector a) where type VectorOpsScalar (VU.Vector a) = a vfromListN = VU.fromListN vfromList = VU.fromList vtoList = VU.toList + vlength = VU.length vreplicate = VU.replicate + vselect bs a b = VU.fromListN (VS.length bs) [if bs VS.! i then a VU.! i else b VU.! i + | i <- [0 .. VS.length bs - 1]] instance (VU.Unbox a, Num a) => VectorOpsNum (VU.Vector a) where + vadd = VU.zipWith (+) + vsub = VU.zipWith (-) + vmul = VU.zipWith (*) vsum = VU.sum + +instance (VU.Unbox a, Floating a) => VectorOpsFloating (VU.Vector a) where + vexp = VU.map exp + +instance (VU.Unbox a, Ord a) => VectorOpsOrd (VU.Vector a) where + vcmpLE a b = VS.generate (VU.length a) (\i -> a VU.! i <= b VU.! i) + vmaximum = VU.maximum diff --git a/test/Main.hs b/test/Main.hs index a04533f..f149ab7 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -8,35 +8,55 @@ import Test.Tasty.Hedgehog import Test.Tasty.HUnit import qualified Numeric.AD as AD +import qualified Numeric.AD.Double as AD.Double import Numeric.ADDual +import qualified Numeric.ADDual.Array.Internal as ADDA import Numeric.ADDual.Examples -(~==) :: (Foldable t, Fractional a, Ord a, Show (t a)) => t a -> t a -> PropertyT IO () -a ~== b - | length (toList a) == length (toList b) - , and (zipWith close (toList a) (toList b)) - = return () - | otherwise - = diff a (\_ _ -> False) b - where - close x y = abs (x - y) < 1e-5 || - (let m = max (abs x) (abs y) in m > 1e-5 && abs (x - y) / m < 1e-5) +(~=) :: (Fractional a, Ord a) => a -> a -> Bool +x ~= y = abs (x - y) < 1e-5 || (let m = max (abs x) (abs y) in m > 1e-5 && abs (x - y) / m < 1e-5) + +(~==) :: (Fractional a, Ord a, Show a) => a -> a -> PropertyT IO () +x ~== y = diff x (~=) y + +(~=!) :: (Foldable t, Fractional a, Ord a) => t a -> t a -> Bool +a ~=! b = length (toList a) == length (toList b) && and (zipWith (~=) (toList a) (toList b)) + +(~==!) :: (Foldable t, Fractional a, Ord a, Show (t a)) => t a -> t a -> PropertyT IO () +a ~==! b = diff a (~=!) b main :: IO () main = defaultMain $ testGroup "Tests" [testCase "product [1..5]" $ gradient' @Double product [1..5] 1 @?= (120, [120, 60, 40, 30, 24]) + ,testProperty "neural 80" $ property $ do input <- forAll (genNeuralInput 80) let (res, grad) = gradient' fneural input 1 res === fneural input - grad ~== AD.grad fneural input + grad ~==! AD.grad fneural input + AD.grad fneural input === AD.Double.grad fneural input + ,testProperty "neural 150" $ property $ do input <- forAll (genNeuralInput 150) let (res, grad) = gradient' fneural input 1 res === fneural input - grad ~== AD.grad fneural input + grad ~==! AD.grad fneural input + AD.grad fneural input === AD.Double.grad fneural input + + ,testProperty "primal neural == neural_A" $ property $ do + input <- forAll (genNeuralInput_A 100) + let resA = fneural_A input + let res = fneural (cvtFNeuralAtoFNeural input) + resA ~== res + + ,testProperty "neural_A 100" $ property $ do + input <- forAll (genNeuralInput_A 100) + let (resA, gradA) = ADDA.gradient' fneural_A input 1 + let (res, grad) = gradient' fneural (cvtFNeuralAtoFNeural input) 1 + resA ~== res + cvtFNeuralAtoFNeural gradA ~==! grad ] |