From 575a218d1b23b454fcdcf2b6ad0018fdc32b64b6 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Tue, 25 Mar 2025 16:27:14 +0100
Subject: bench: Dot product benchmarks

---
 bench/Main.hs | 64 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++--
 1 file changed, 62 insertions(+), 2 deletions(-)

diff --git a/bench/Main.hs b/bench/Main.hs
index cf0e929..17a93e0 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -1,16 +1,23 @@
 {-# LANGUAGE ImportQualifiedPost #-}
 {-# LANGUAGE NumericUnderscores #-}
 {-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE ViewPatterns #-}
 module Main where
 
 import Control.Exception (bracket)
-import Data.Array.RankedS qualified as RS
+import Control.Monad (when)
+import Data.Array.Internal qualified as OI
+import Data.Array.Internal.RankedG qualified as RG
+import Data.Array.Internal.RankedS qualified as RS
+import Data.Foldable (toList)
 import Data.Vector.Storable qualified as VS
 import Numeric.LinearAlgebra qualified as LA
 import Test.Tasty.Bench
+import Text.Show (showListWith)
 
+import Data.Array.Mixed.XArray (XArray(..))
 import Data.Array.Nested
-import Data.Array.Nested.Internal.Mixed (mliftPrim, mliftPrim2)
+import Data.Array.Nested.Internal.Mixed (mliftPrim, mliftPrim2, Mixed(M_Primitive), toPrimitive)
 import Data.Array.Nested.Internal.Ranked (liftRanked1, liftRanked2)
 import qualified Data.Array.Strided.Arith.Internal as Arith
 
@@ -34,6 +41,59 @@ main = do
 main_tests :: IO ()
 main_tests = defaultMain
   [bgroup "compare" tests_compare
+  ,bgroup "dotprod" $
+    let stridesOf (Ranked (toPrimitive -> M_Primitive _ (XArray (RS.A (RG.A _ (OI.T strides _ _)))))) = strides
+        dotprodBench name (inp1, inp2) =
+          let showSh l = showListWith (\n -> let ln = round (logBase 10 (fromIntegral n :: Double)) :: Int
+                                             in if n > 1 && n == 10 ^ ln then showString ("1e" ++ show ln) else shows n)
+                                      l ""
+          in bench (name ++ " " ++ showSh (toList (rshape inp1)) ++
+                      " str " ++ showSh (stridesOf inp1) ++ " " ++ showSh (stridesOf inp2)) $
+               nf (\(a,b) -> rsumAllPrim (rdot1Inner a b)) (inp1, inp2)
+
+        iota n = riota @Double n
+    in
+    [dotprodBench "dot 1D"
+        (iota 10_000_000
+        ,iota 10_000_000)
+    ,dotprodBench "revdot"
+        (rrev1 (iota 10_000_000)
+        ,rrev1 (iota 10_000_000))
+    ,dotprodBench "dot 2D"
+        (rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000)
+        ,rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000))
+    ,dotprodBench "batched dot"
+        (rreplicate (1000 :$: ZSR) (iota 10_000)
+        ,rreplicate (1000 :$: ZSR) (iota 10_000))
+    ,dotprodBench "transposed dot" $
+        let (a, b) = (rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000)
+                     ,rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000))
+        in (rtranspose [1,0] a, rtranspose [1,0] b)
+    ,dotprodBench "repdot" $
+        let (a, b) = (rreplicate (1000 :$: ZSR) (iota 10_000)
+                     ,rreplicate (1000 :$: ZSR) (iota 10_000))
+        in (rtranspose [1,0] a, rtranspose [1,0] b)
+    ,dotprodBench "matvec" $
+        let (m, v) = (rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000)
+                     ,iota 10_000)
+        in (m, rreplicate (1000 :$: ZSR) v)
+    ,dotprodBench "vecmat" $
+        let (v, m) = (iota 1_000
+                     ,rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000))
+        in (rreplicate (10_000 :$: ZSR) v, rtranspose [1,0] m)
+    ,dotprodBench "matmat" $
+       let (n,m,k) = (100, 100, 1000)
+           (m1, m2) = (rreshape (n :$: m :$: ZSR) (iota (n*m))
+                      ,rreshape (m :$: k :$: ZSR) (iota (m*k)))
+       in (rtranspose [1,0] (rreplicate (k :$: ZSR) m1)
+          ,rreplicate (n :$: ZSR) (rtranspose [1,0] m2))
+    ,dotprodBench "matmatT" $
+       let (n,m,k) = (100, 100, 1000)
+           (m1, m2) = (rreshape (n :$: m :$: ZSR) (iota (n*m))
+                      ,rreshape (k :$: m :$: ZSR) (iota (m*k)))
+       in (rtranspose [1,0] (rreplicate (k :$: ZSR) m1)
+          ,rreplicate (n :$: ZSR) m2)
+    ]
   ,bgroup "orthotope"
     [bench "normalize [1e6]" $
       let n = 1_000_000
-- 
cgit v1.2.3-70-g09d2