aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-25 16:27:14 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-25 17:08:44 +0100
commit575a218d1b23b454fcdcf2b6ad0018fdc32b64b6 (patch)
treefad7c6e5ba9c7e6e6bc0944a42c5d6f207030c77
parent388df7878914666b43059f94ea9665f44937cf3c (diff)
bench: Dot product benchmarks
-rw-r--r--bench/Main.hs64
1 files changed, 62 insertions, 2 deletions
diff --git a/bench/Main.hs b/bench/Main.hs
index cf0e929..17a93e0 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -1,16 +1,23 @@
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE ViewPatterns #-}
module Main where
import Control.Exception (bracket)
-import Data.Array.RankedS qualified as RS
+import Control.Monad (when)
+import Data.Array.Internal qualified as OI
+import Data.Array.Internal.RankedG qualified as RG
+import Data.Array.Internal.RankedS qualified as RS
+import Data.Foldable (toList)
import Data.Vector.Storable qualified as VS
import Numeric.LinearAlgebra qualified as LA
import Test.Tasty.Bench
+import Text.Show (showListWith)
+import Data.Array.Mixed.XArray (XArray(..))
import Data.Array.Nested
-import Data.Array.Nested.Internal.Mixed (mliftPrim, mliftPrim2)
+import Data.Array.Nested.Internal.Mixed (mliftPrim, mliftPrim2, Mixed(M_Primitive), toPrimitive)
import Data.Array.Nested.Internal.Ranked (liftRanked1, liftRanked2)
import qualified Data.Array.Strided.Arith.Internal as Arith
@@ -34,6 +41,59 @@ main = do
main_tests :: IO ()
main_tests = defaultMain
[bgroup "compare" tests_compare
+ ,bgroup "dotprod" $
+ let stridesOf (Ranked (toPrimitive -> M_Primitive _ (XArray (RS.A (RG.A _ (OI.T strides _ _)))))) = strides
+ dotprodBench name (inp1, inp2) =
+ let showSh l = showListWith (\n -> let ln = round (logBase 10 (fromIntegral n :: Double)) :: Int
+ in if n > 1 && n == 10 ^ ln then showString ("1e" ++ show ln) else shows n)
+ l ""
+ in bench (name ++ " " ++ showSh (toList (rshape inp1)) ++
+ " str " ++ showSh (stridesOf inp1) ++ " " ++ showSh (stridesOf inp2)) $
+ nf (\(a,b) -> rsumAllPrim (rdot1Inner a b)) (inp1, inp2)
+
+ iota n = riota @Double n
+ in
+ [dotprodBench "dot 1D"
+ (iota 10_000_000
+ ,iota 10_000_000)
+ ,dotprodBench "revdot"
+ (rrev1 (iota 10_000_000)
+ ,rrev1 (iota 10_000_000))
+ ,dotprodBench "dot 2D"
+ (rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000)
+ ,rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000))
+ ,dotprodBench "batched dot"
+ (rreplicate (1000 :$: ZSR) (iota 10_000)
+ ,rreplicate (1000 :$: ZSR) (iota 10_000))
+ ,dotprodBench "transposed dot" $
+ let (a, b) = (rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000)
+ ,rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000))
+ in (rtranspose [1,0] a, rtranspose [1,0] b)
+ ,dotprodBench "repdot" $
+ let (a, b) = (rreplicate (1000 :$: ZSR) (iota 10_000)
+ ,rreplicate (1000 :$: ZSR) (iota 10_000))
+ in (rtranspose [1,0] a, rtranspose [1,0] b)
+ ,dotprodBench "matvec" $
+ let (m, v) = (rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000)
+ ,iota 10_000)
+ in (m, rreplicate (1000 :$: ZSR) v)
+ ,dotprodBench "vecmat" $
+ let (v, m) = (iota 1_000
+ ,rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000))
+ in (rreplicate (10_000 :$: ZSR) v, rtranspose [1,0] m)
+ ,dotprodBench "matmat" $
+ let (n,m,k) = (100, 100, 1000)
+ (m1, m2) = (rreshape (n :$: m :$: ZSR) (iota (n*m))
+ ,rreshape (m :$: k :$: ZSR) (iota (m*k)))
+ in (rtranspose [1,0] (rreplicate (k :$: ZSR) m1)
+ ,rreplicate (n :$: ZSR) (rtranspose [1,0] m2))
+ ,dotprodBench "matmatT" $
+ let (n,m,k) = (100, 100, 1000)
+ (m1, m2) = (rreshape (n :$: m :$: ZSR) (iota (n*m))
+ ,rreshape (k :$: m :$: ZSR) (iota (m*k)))
+ in (rtranspose [1,0] (rreplicate (k :$: ZSR) m1)
+ ,rreplicate (n :$: ZSR) m2)
+ ]
,bgroup "orthotope"
[bench "normalize [1e6]" $
let n = 1_000_000