diff options
Diffstat (limited to 'bench/Main.hs')
| -rw-r--r-- | bench/Main.hs | 55 |
1 files changed, 29 insertions, 26 deletions
diff --git a/bench/Main.hs b/bench/Main.hs index 5901d8b..2058e77 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -15,19 +15,12 @@ import Numeric.LinearAlgebra qualified as LA import Test.Tasty.Bench import Text.Show (showListWith) -import Data.Array.Mixed.XArray (XArray(..)) import Data.Array.Nested -import Data.Array.Nested.Internal.Mixed (Mixed (M_Primitive), mliftPrim, mliftPrim2, toPrimitive) -import Data.Array.Nested.Internal.Ranked (liftRanked1, liftRanked2) +import Data.Array.Nested.Mixed (Mixed(M_Primitive), mliftPrim, mliftPrim2, toPrimitive) +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Ranked (liftRanked1, liftRanked2) import Data.Array.Strided.Arith.Internal qualified as Arith - - -enableMisc :: Bool -enableMisc = False - -bgroupIf :: Bool -> String -> [Benchmark] -> Benchmark -bgroupIf True = bgroup -bgroupIf False = \name _ -> bgroup name [] +import Data.Array.XArray (XArray(..)) main :: IO () @@ -51,7 +44,7 @@ main_tests = defaultMain " str " ++ showSh (stridesOf inp1) ++ " " ++ showSh (stridesOf inp2)) $ nf (\(a,b) -> rsumAllPrim (rdot1Inner a b)) (inp1, inp2) - iota n = riota @Double n + iota = riota @Double in [dotprodBench "dot 1D" (iota 10_000_000 @@ -104,7 +97,7 @@ main_tests = defaultMain in nf (\a -> RS.normalize a) (RS.rev [0] (RS.rev [0] (RS.iota @Double n))) ] - ,bgroupIf enableMisc "misc" + ,bgroup "misc" [let n = 1000 k = 1000 in bgroup ("fusion [" ++ show k ++ "]*" ++ show n) @@ -148,6 +141,16 @@ main_tests = defaultMain | ki <- [1::Int ..5]] ] ] + ,bench "ixxFromLinear 10000x" $ + let n = 10000 + sh0 = SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% ZSX + in nf (\sh -> [ixxFromLinear @Int sh i | i <- [1..n]]) sh0 + ,bench "ixxFromLinear 1x" $ + let sh0 = SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% ZSX + in nf (\sh -> ixxFromLinear @Int sh 1234) sh0 + ,bench "shxEnum" $ + let sh0 = SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% ZSX + in nf (\sh -> shxEnum sh) sh0 ] ] @@ -156,45 +159,45 @@ tests_compare = let n = 1_000_000 in [bgroup "Num" [bench "sum(+) Double [1e6]" $ - nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (+)) a b))) + nf (\(a, b) -> runScalar (rsumOuter1Prim (liftRanked2 (mliftPrim2 (+)) a b))) (riota @Double n, riota n) ,bench "sum(*) Double [1e6]" $ - nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (*)) a b))) + nf (\(a, b) -> runScalar (rsumOuter1Prim (liftRanked2 (mliftPrim2 (*)) a b))) (riota @Double n, riota n) ,bench "sum(/) Double [1e6]" $ - nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (/)) a b))) + nf (\(a, b) -> runScalar (rsumOuter1Prim (liftRanked2 (mliftPrim2 (/)) a b))) (riota @Double n, riota n) ,bench "sum(**) Double [1e6]" $ - nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (**)) a b))) + nf (\(a, b) -> runScalar (rsumOuter1Prim (liftRanked2 (mliftPrim2 (**)) a b))) (riota @Double n, riota n) ,bench "sum(sin) Double [1e6]" $ - nf (\a -> runScalar (rsumOuter1 (liftRanked1 (mliftPrim sin) a))) + nf (\a -> runScalar (rsumOuter1Prim (liftRanked1 (mliftPrim sin) a))) (riota @Double n) ,bench "sum Double [1e6]" $ - nf (\a -> runScalar (rsumOuter1 a)) + nf (\a -> runScalar (rsumOuter1Prim a)) (riota @Double n) ] ,bgroup "NumElt" [bench "sum(+) Double [1e6]" $ - nf (\(a, b) -> runScalar (rsumOuter1 (a + b))) + nf (\(a, b) -> runScalar (rsumOuter1Prim (a + b))) (riota @Double n, riota n) ,bench "sum(*) Double [1e6]" $ - nf (\(a, b) -> runScalar (rsumOuter1 (a * b))) + nf (\(a, b) -> runScalar (rsumOuter1Prim (a * b))) (riota @Double n, riota n) ,bench "sum(/) Double [1e6]" $ - nf (\(a, b) -> runScalar (rsumOuter1 (a / b))) + nf (\(a, b) -> runScalar (rsumOuter1Prim (a / b))) (riota @Double n, riota n) ,bench "sum(**) Double [1e6]" $ - nf (\(a, b) -> runScalar (rsumOuter1 (a ** b))) + nf (\(a, b) -> runScalar (rsumOuter1Prim (a ** b))) (riota @Double n, riota n) ,bench "sum(sin) Double [1e6]" $ - nf (\a -> runScalar (rsumOuter1 (sin a))) + nf (\a -> runScalar (rsumOuter1Prim (sin a))) (riota @Double n) ,bench "sum Double [1e6]" $ - nf (\a -> runScalar (rsumOuter1 a)) + nf (\a -> runScalar (rsumOuter1Prim a)) (riota @Double n) ,bench "sum(*) Double [1e6] stride 1; -1" $ - nf (\(a, b) -> runScalar (rsumOuter1 (a * b))) + nf (\(a, b) -> runScalar (rsumOuter1Prim (a * b))) (riota @Double n, rrev1 (riota n)) ,bench "dotprod Float [1e6]" $ nf (\(a, b) -> rdot a b) |
