aboutsummaryrefslogtreecommitdiff
path: root/bench/Main.hs
blob: 17a93e071f55deb2f6df5e11923416b3263f313b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ViewPatterns #-}
module Main where

import Control.Exception (bracket)
import Control.Monad (when)
import Data.Array.Internal qualified as OI
import Data.Array.Internal.RankedG qualified as RG
import Data.Array.Internal.RankedS qualified as RS
import Data.Foldable (toList)
import Data.Vector.Storable qualified as VS
import Numeric.LinearAlgebra qualified as LA
import Test.Tasty.Bench
import Text.Show (showListWith)

import Data.Array.Mixed.XArray (XArray(..))
import Data.Array.Nested
import Data.Array.Nested.Internal.Mixed (mliftPrim, mliftPrim2, Mixed(M_Primitive), toPrimitive)
import Data.Array.Nested.Internal.Ranked (liftRanked1, liftRanked2)
import qualified Data.Array.Strided.Arith.Internal as Arith


enableMisc :: Bool
enableMisc = False

bgroupIf :: Bool -> String -> [Benchmark] -> Benchmark
bgroupIf True = bgroup
bgroupIf False = \name _ -> bgroup name []


main :: IO ()
main = do
  let enable = False
  bracket (Arith.statisticsEnable enable)
          (\() -> do Arith.statisticsEnable False
                     when enable $ Arith.statisticsPrintAll)
          (\() -> main_tests)

main_tests :: IO ()
main_tests = defaultMain
  [bgroup "compare" tests_compare
  ,bgroup "dotprod" $
    let stridesOf (Ranked (toPrimitive -> M_Primitive _ (XArray (RS.A (RG.A _ (OI.T strides _ _)))))) = strides
        dotprodBench name (inp1, inp2) =
          let showSh l = showListWith (\n -> let ln = round (logBase 10 (fromIntegral n :: Double)) :: Int
                                             in if n > 1 && n == 10 ^ ln then showString ("1e" ++ show ln) else shows n)
                                      l ""
          in bench (name ++ " " ++ showSh (toList (rshape inp1)) ++
                      " str " ++ showSh (stridesOf inp1) ++ " " ++ showSh (stridesOf inp2)) $
               nf (\(a,b) -> rsumAllPrim (rdot1Inner a b)) (inp1, inp2)

        iota n = riota @Double n
    in
    [dotprodBench "dot 1D"
        (iota 10_000_000
        ,iota 10_000_000)
    ,dotprodBench "revdot"
        (rrev1 (iota 10_000_000)
        ,rrev1 (iota 10_000_000))
    ,dotprodBench "dot 2D"
        (rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000)
        ,rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000))
    ,dotprodBench "batched dot"
        (rreplicate (1000 :$: ZSR) (iota 10_000)
        ,rreplicate (1000 :$: ZSR) (iota 10_000))
    ,dotprodBench "transposed dot" $
        let (a, b) = (rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000)
                     ,rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000))
        in (rtranspose [1,0] a, rtranspose [1,0] b)
    ,dotprodBench "repdot" $
        let (a, b) = (rreplicate (1000 :$: ZSR) (iota 10_000)
                     ,rreplicate (1000 :$: ZSR) (iota 10_000))
        in (rtranspose [1,0] a, rtranspose [1,0] b)
    ,dotprodBench "matvec" $
        let (m, v) = (rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000)
                     ,iota 10_000)
        in (m, rreplicate (1000 :$: ZSR) v)
    ,dotprodBench "vecmat" $
        let (v, m) = (iota 1_000
                     ,rreshape (1000 :$: 10_000 :$: ZSR) (iota 10_000_000))
        in (rreplicate (10_000 :$: ZSR) v, rtranspose [1,0] m)
    ,dotprodBench "matmat" $
       let (n,m,k) = (100, 100, 1000)
           (m1, m2) = (rreshape (n :$: m :$: ZSR) (iota (n*m))
                      ,rreshape (m :$: k :$: ZSR) (iota (m*k)))
       in (rtranspose [1,0] (rreplicate (k :$: ZSR) m1)
          ,rreplicate (n :$: ZSR) (rtranspose [1,0] m2))
    ,dotprodBench "matmatT" $
       let (n,m,k) = (100, 100, 1000)
           (m1, m2) = (rreshape (n :$: m :$: ZSR) (iota (n*m))
                      ,rreshape (k :$: m :$: ZSR) (iota (m*k)))
       in (rtranspose [1,0] (rreplicate (k :$: ZSR) m1)
          ,rreplicate (n :$: ZSR) m2)
    ]
  ,bgroup "orthotope"
    [bench "normalize [1e6]" $
      let n = 1_000_000
      in nf (\a -> RS.normalize a)
            (RS.rev [0] (RS.iota @Double n))
    ,bench "normalize noop [1e6]" $
      let n = 1_000_000
      in nf (\a -> RS.normalize a)
            (RS.rev [0] (RS.rev [0] (RS.iota @Double n)))
    ]
  ,bgroupIf enableMisc "misc"
    [let n = 1000
         k = 1000
     in bgroup ("fusion [" ++ show k ++ "]*" ++ show n) $
      [bench "sum (concat)" $
        nf (\as -> VS.sum (VS.concat as))
           (replicate n (VS.enumFromTo (1::Int) k))
      ,bench "sum (force (concat))" $
        nf (\as -> VS.sum (VS.force (VS.concat as)))
              (replicate n (VS.enumFromTo (1::Int) k))]
    ,bgroup "concat"
      [bgroup "N"
        [bgroup "hmatrix"
          [bench ("LA.vjoin [500]*1e" ++ show ni) $
            let n = 10 ^ ni
                k = 500
            in nf (\as -> LA.vjoin as)
                  (replicate n (VS.enumFromTo (1::Int) k))
          | ni <- [1::Int ..5]]
        ,bgroup "vectorStorable"
          [bench ("VS.concat [500]*1e" ++ show ni) $
            let n = 10 ^ ni
                k = 500
            in nf (\as -> VS.concat as)
                  (replicate n (VS.enumFromTo (1::Int) k))
          | ni <- [1::Int ..5]]
        ]
      ,bgroup "K"
        [bgroup "hmatrix"
          [bench ("LA.vjoin [1e" ++ show ki ++ "]*500") $
            let n = 500
                k = 10 ^ ki
            in nf (\as -> LA.vjoin as)
                  (replicate n (VS.enumFromTo (1::Int) k))
          | ki <- [1::Int ..5]]
        ,bgroup "vectorStorable"
          [bench ("VS.concat [1e" ++ show ki ++ "]*500") $
            let n = 500
                k = 10 ^ ki
            in nf (\as -> VS.concat as)
                  (replicate n (VS.enumFromTo (1::Int) k))
          | ki <- [1::Int ..5]]
        ]
      ]
    ]
  ]

tests_compare :: [Benchmark]
tests_compare =
  let n = 1_000_000 in
  [bgroup "Num"
    [bench "sum(+) Double [1e6]" $
      nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (+)) a b)))
         (riota @Double n, riota n)
    ,bench "sum(*) Double [1e6]" $
      nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (*)) a b)))
         (riota @Double n, riota n)
    ,bench "sum(/) Double [1e6]" $
      nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (/)) a b)))
         (riota @Double n, riota n)
    ,bench "sum(**) Double [1e6]" $
      nf (\(a, b) -> runScalar (rsumOuter1 (liftRanked2 (mliftPrim2 (**)) a b)))
         (riota @Double n, riota n)
    ,bench "sum(sin) Double [1e6]" $
      nf (\a -> runScalar (rsumOuter1 (liftRanked1 (mliftPrim sin) a)))
         (riota @Double n)
    ,bench "sum Double [1e6]" $
      nf (\a -> runScalar (rsumOuter1 a))
         (riota @Double n)
    ]
  ,bgroup "NumElt"
    [bench "sum(+) Double [1e6]" $
      nf (\(a, b) -> runScalar (rsumOuter1 (a + b)))
         (riota @Double n, riota n)
    ,bench "sum(*) Double [1e6]" $
      nf (\(a, b) -> runScalar (rsumOuter1 (a * b)))
         (riota @Double n, riota n)
    ,bench "sum(/) Double [1e6]" $
      nf (\(a, b) -> runScalar (rsumOuter1 (a / b)))
         (riota @Double n, riota n)
    ,bench "sum(**) Double [1e6]" $
      nf (\(a, b) -> runScalar (rsumOuter1 (a ** b)))
         (riota @Double n, riota n)
    ,bench "sum(sin) Double [1e6]" $
      nf (\a -> runScalar (rsumOuter1 (sin a)))
         (riota @Double n)
    ,bench "sum Double [1e6]" $
      nf (\a -> runScalar (rsumOuter1 a))
         (riota @Double n)
    ,bench "sum(*) Double [1e6] stride 1; -1" $
      nf (\(a, b) -> runScalar (rsumOuter1 (a * b)))
         (riota @Double n, rrev1 (riota n))
    ,bench "dotprod Float [1e6]" $
      nf (\(a, b) -> rdot a b)
         (riota @Float n, riota @Float n)
    ,bench "dotprod Float [1e6] stride 1; -1" $
      nf (\(a, b) -> rdot a b)
         (riota @Float n, rrev1 (riota @Float n))
    ,bench "dotprod Double [1e6]" $
      nf (\(a, b) -> rdot a b)
         (riota @Double n, riota @Double n)
    ,bench "dotprod Double [1e6] stride 1; -1" $
      nf (\(a, b) -> rdot a b)
         (riota @Double n, rrev1 (riota @Double n))
    ]
  ,bgroup "hmatrix"
    [bench "sum(+) Double [1e6]" $
      nf (\(a, b) -> LA.sumElements (a + b))
         (LA.linspace @Double n (0.0, fromIntegral (n - 1))
         ,LA.linspace @Double n (0.0, fromIntegral (n - 1)))
    ,bench "sum(*) Double [1e6]" $
      nf (\(a, b) -> LA.sumElements (a * b))
         (LA.linspace @Double n (0.0, fromIntegral (n - 1))
         ,LA.linspace @Double n (0.0, fromIntegral (n - 1)))
    ,bench "sum(/) Double [1e6]" $
      nf (\(a, b) -> LA.sumElements (a / b))
         (LA.linspace @Double n (0.0, fromIntegral (n - 1))
         ,LA.linspace @Double n (0.0, fromIntegral (n - 1)))
    ,bench "sum(**) Double [1e6]" $
      nf (\(a, b) -> LA.sumElements (a ** b))
         (LA.linspace @Double n (0.0, fromIntegral (n - 1))
         ,LA.linspace @Double n (0.0, fromIntegral (n - 1)))
    ,bench "sum(sin) Double [1e6]" $
      nf (\a -> LA.sumElements (sin a))
         (LA.linspace @Double n (0.0, fromIntegral (n - 1)))
    ,bench "sum Double [1e6]" $
      nf (\a -> LA.sumElements a)
         (LA.linspace @Double n (0.0, fromIntegral (n - 1)))
    ,bench "dotprod Float [1e6]" $
      nf (\(a, b) -> a LA.<.> b)
         (LA.linspace @Double n (0.0, fromIntegral (n - 1))
         ,LA.linspace @Double n (fromIntegral (n - 1), 0.0))
    ,bench "dotprod Double [1e6]" $
      nf (\(a, b) -> a LA.<.> b)
         (LA.linspace @Double n (0.0, fromIntegral (n - 1))
         ,LA.linspace @Double n (fromIntegral (n - 1), 0.0))
    ]
  ]