aboutsummaryrefslogtreecommitdiff
path: root/bench/Main.hs
diff options
context:
space:
mode:
Diffstat (limited to 'bench/Main.hs')
-rw-r--r--bench/Main.hs55
1 files changed, 29 insertions, 26 deletions
diff --git a/bench/Main.hs b/bench/Main.hs
index 5901d8b..2058e77 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -15,19 +15,12 @@ import Numeric.LinearAlgebra qualified as LA
import Test.Tasty.Bench
import Text.Show (showListWith)
-import Data.Array.Mixed.XArray (XArray(..))
import Data.Array.Nested
-import Data.Array.Nested.Internal.Mixed (Mixed (M_Primitive), mliftPrim, mliftPrim2, toPrimitive)
-import Data.Array.Nested.Internal.Ranked (liftRanked1, liftRanked2)
+import Data.Array.Nested.Mixed (Mixed(M_Primitive), mliftPrim, mliftPrim2, toPrimitive)
+import Data.Array.Nested.Mixed.Shape
+import Data.Array.Nested.Ranked (liftRanked1, liftRanked2)
import Data.Array.Strided.Arith.Internal qualified as Arith
-
-
-enableMisc :: Bool
-enableMisc = False
-
-bgroupIf :: Bool -> String -> [Benchmark] -> Benchmark
-bgroupIf True = bgroup
-bgroupIf False = \name _ -> bgroup name []
+import Data.Array.XArray (XArray(..))
main :: IO ()
@@ -51,7 +44,7 @@ main_tests = defaultMain
" str " ++ showSh (stridesOf inp1) ++ " " ++ showSh (stridesOf inp2)) $
nf (\(a,b) -> rsumAllPrim (rdot1Inner a b)) (inp1, inp2)
- iota n = riota @Double n
+ iota = riota @Double
in
[dotprodBench "dot 1D"
(iota 10_000_000
@@ -104,7 +97,7 @@ main_tests = defaultMain
in nf (\a -> RS.normalize a)
(RS.rev [0] (RS.rev [0] (RS.iota @Double n)))
]
- ,bgroupIf enableMisc "misc"
+ ,bgroup "misc"
[let n = 1000
k = 1000
in bgroup ("fusion [" ++ show k ++ "]*" ++ show n)
@@ -148,6 +141,16 @@ main_tests = defaultMain
| ki <- [1::Int ..5]]
]
]
+ ,bench "ixxFromLinear 10000x" $
+ let n = 10000
+ sh0 = SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% ZSX
+ in nf (\sh -> [ixxFromLinear @Int sh i | i <- [1..n]]) sh0
+ ,bench "ixxFromLinear 1x" $
+ let sh0 = SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% ZSX
+ in nf (\sh -> ixxFromLinear @Int sh 1234) sh0
+ ,bench "shxEnum" $
+ let sh0 = SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% ZSX
+ in nf (\sh -> shxEnum sh) sh0
]
]
@@ -156,45 +159,45 @@ tests_compare =
let n = 1_000_000 in
[bgroup "Num"
[bench "sum(+) Double [1e6]" $
- nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (+)) a b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (liftRanked2 (mliftPrim2 (+)) a b)))
(riota @Double n, riota n)
,bench "sum(*) Double [1e6]" $
- nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (*)) a b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (liftRanked2 (mliftPrim2 (*)) a b)))
(riota @Double n, riota n)
,bench "sum(/) Double [1e6]" $
- nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (/)) a b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (liftRanked2 (mliftPrim2 (/)) a b)))
(riota @Double n, riota n)
,bench "sum(**) Double [1e6]" $
- nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (**)) a b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (liftRanked2 (mliftPrim2 (**)) a b)))
(riota @Double n, riota n)
,bench "sum(sin) Double [1e6]" $
- nf (\a -> runScalar (rsumOuter1 (liftRanked1 (mliftPrim sin) a)))
+ nf (\a -> runScalar (rsumOuter1Prim (liftRanked1 (mliftPrim sin) a)))
(riota @Double n)
,bench "sum Double [1e6]" $
- nf (\a -> runScalar (rsumOuter1 a))
+ nf (\a -> runScalar (rsumOuter1Prim a))
(riota @Double n)
]
,bgroup "NumElt"
[bench "sum(+) Double [1e6]" $
- nf (\(a, b) -> runScalar (rsumOuter1 (a + b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (a + b)))
(riota @Double n, riota n)
,bench "sum(*) Double [1e6]" $
- nf (\(a, b) -> runScalar (rsumOuter1 (a * b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (a * b)))
(riota @Double n, riota n)
,bench "sum(/) Double [1e6]" $
- nf (\(a, b) -> runScalar (rsumOuter1 (a / b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (a / b)))
(riota @Double n, riota n)
,bench "sum(**) Double [1e6]" $
- nf (\(a, b) -> runScalar (rsumOuter1 (a ** b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (a ** b)))
(riota @Double n, riota n)
,bench "sum(sin) Double [1e6]" $
- nf (\a -> runScalar (rsumOuter1 (sin a)))
+ nf (\a -> runScalar (rsumOuter1Prim (sin a)))
(riota @Double n)
,bench "sum Double [1e6]" $
- nf (\a -> runScalar (rsumOuter1 a))
+ nf (\a -> runScalar (rsumOuter1Prim a))
(riota @Double n)
,bench "sum(*) Double [1e6] stride 1; -1" $
- nf (\(a, b) -> runScalar (rsumOuter1 (a * b)))
+ nf (\(a, b) -> runScalar (rsumOuter1Prim (a * b)))
(riota @Double n, rrev1 (riota n))
,bench "dotprod Float [1e6]" $
nf (\(a, b) -> rdot a b)