diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-25 16:27:14 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-25 17:08:44 +0100 |
commit | 575a218d1b23b454fcdcf2b6ad0018fdc32b64b6 (patch) | |
tree | fad7c6e5ba9c7e6e6bc0944a42c5d6f207030c77 | |
parent | 388df7878914666b43059f94ea9665f44937cf3c (diff) |
bench: Dot product benchmarks
-rw-r--r-- | bench/Main.hs | 64 |
1 files changed, 62 insertions, 2 deletions
diff --git a/bench/Main.hs b/bench/Main.hs index cf0e929..17a93e0 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -1,16 +1,23 @@ {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ViewPatterns #-} module Main where import Control.Exception (bracket) -import Data.Array.RankedS qualified as RS +import Control.Monad (when) +import Data.Array.Internal qualified as OI +import Data.Array.Internal.RankedG qualified as RG +import Data.Array.Internal.RankedS qualified as RS +import Data.Foldable (toList) import Data.Vector.Storable qualified as VS import Numeric.LinearAlgebra qualified as LA import Test.Tasty.Bench +import Text.Show (showListWith) +import Data.Array.Mixed.XArray (XArray(..)) import Data.Array.Nested -import Data.Array.Nested.Internal.Mixed (mliftPrim, mliftPrim2) +import Data.Array.Nested.Internal.Mixed (mliftPrim, mliftPrim2, Mixed(M_Primitive), toPrimitive) import Data.Array.Nested.Internal.Ranked (liftRanked1, liftRanked2) import qualified Data.Array.Strided.Arith.Internal as Arith @@ -34,6 +41,59 @@ main = do main_tests :: IO () main_tests = defaultMain [bgroup "compare" tests_compare + ,bgroup "dotprod" $ + let stridesOf (Ranked (toPrimitive -> M_Primitive _ (XArray (RS.A (RG.A _ (OI.T strides _ _)))))) = strides + dotprodBench name (inp1, inp2) = + let showSh l = showListWith (\n -> let ln = round (logBase 10 (fromIntegral n :: Double)) :: Int + in if n > 1 && n == 10 ^ ln then showString ("1e" ++ show ln) else shows n) + l "" + in bench (name ++ " " ++ showSh (toList (rshape inp1)) ++ + " str " ++ showSh (stridesOf inp1) ++ " " ++ showSh (stridesOf inp2)) $ + nf (\(a,b) -> rsumAllPrim (rdot1Inner a b)) (inp1, inp2) + + iota n = riota @Double n + in + [dotprodBench "dot 1D" + (iota 10_000_000 + ,iota 10_000_000) + ,dotprodBench "revdot" + (rrev1 (iota 10_000_000) + ,rrev1 (iota 10_000_000)) + ,dotprodBench "dot 2D" + (rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000) + ,rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000)) + ,dotprodBench "batched dot" + (rreplicate (1000 :$: ZSR) (iota 10_000) + ,rreplicate (1000 :$: ZSR) (iota 10_000)) + ,dotprodBench "transposed dot" $ + let (a, b) = (rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000) + ,rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000)) + in (rtranspose [1,0] a, rtranspose [1,0] b) + ,dotprodBench "repdot" $ + let (a, b) = (rreplicate (1000 :$: ZSR) (iota 10_000) + ,rreplicate (1000 :$: ZSR) (iota 10_000)) + in (rtranspose [1,0] a, rtranspose [1,0] b) + ,dotprodBench "matvec" $ + let (m, v) = (rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000) + ,iota 10_000) + in (m, rreplicate (1000 :$: ZSR) v) + ,dotprodBench "vecmat" $ + let (v, m) = (iota 1_000 + ,rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000)) + in (rreplicate (10_000 :$: ZSR) v, rtranspose [1,0] m) + ,dotprodBench "matmat" $ + let (n,m,k) = (100, 100, 1000) + (m1, m2) = (rreshape (n :$: m :$: ZSR) (iota (n*m)) + ,rreshape (m :$: k :$: ZSR) (iota (m*k))) + in (rtranspose [1,0] (rreplicate (k :$: ZSR) m1) + ,rreplicate (n :$: ZSR) (rtranspose [1,0] m2)) + ,dotprodBench "matmatT" $ + let (n,m,k) = (100, 100, 1000) + (m1, m2) = (rreshape (n :$: m :$: ZSR) (iota (n*m)) + ,rreshape (k :$: m :$: ZSR) (iota (m*k))) + in (rtranspose [1,0] (rreplicate (k :$: ZSR) m1) + ,rreplicate (n :$: ZSR) m2) + ] ,bgroup "orthotope" [bench "normalize [1e6]" $ let n = 1_000_000 |