diff options
| -rw-r--r-- | bench/Main.hs | 13 | ||||
| -rw-r--r-- | cbits/arith.c | 28 | 
2 files changed, 40 insertions, 1 deletions
| diff --git a/bench/Main.hs b/bench/Main.hs index eb3c6d7..cc4e11f 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -65,6 +65,14 @@ main = defaultMain        let n = 1_000_000        in nf (\a -> runScalar (rsumOuter1 a))              (riota @Double n) +    ,bench "dotprod Float [1e6]" $ +      let n = 1_000_000 +      in nf (\(a, b) -> rdot a b) +            (riota @Float n, riota @Float n) +    ,bench "dotprod Float [1e6] stride 1; -1" $ +      let n = 1_000_000 +      in nf (\(a, b) -> rdot a b) +            (riota @Float n, rrev1 (riota @Float n))      ,bench "dotprod Double [1e6]" $        let n = 1_000_000        in nf (\(a, b) -> rdot a b) @@ -103,6 +111,11 @@ main = defaultMain        let n = 1_000_000        in nf (\a -> LA.sumElements a)              (LA.linspace @Double n (0.0, fromIntegral (n - 1))) +    ,bench "dotprod Float [1e6]" $ +      let n = 1_000_000 +      in nf (\(a, b) -> a LA.<.> b) +            (LA.linspace @Double n (0.0, fromIntegral (n - 1)) +            ,LA.linspace @Double n (fromIntegral (n - 1), 0.0))      ,bench "dotprod Double [1e6]" $        let n = 1_000_000        in nf (\(a, b) -> a LA.<.> b) diff --git a/cbits/arith.c b/cbits/arith.c index 751fe33..d487cfd 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -4,6 +4,7 @@  #include <stdbool.h>  #include <string.h>  #include <math.h> +#include <emmintrin.h>  // These are the wrapper macros used in arith_lists.h. Preset them to empty to  // avoid having to touch macros unrelated to the particular operation set below. @@ -214,6 +215,32 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }      return res; \    } +// The 'double' version here is about 2x as fast as gcc's own vectorisation. +DOTPROD_OP(i32) +DOTPROD_OP(i64) +float oxarop_dotprod_float(i64 length, const float *arr1, const float *arr2) { +  __m128 accum = _mm_setzero_ps(); +  i64 i; +  for (i = 0; i + 3 < length; i += 4) { +    accum = _mm_add_ps(accum, _mm_mul_ps(_mm_load_ps(arr1 + i), _mm_load_ps(arr2 + i))); +  } +  float dest[4]; +  _mm_storeu_ps(dest, accum); +  float tot = dest[0] + dest[1] + dest[2] + dest[3]; +  for (; i < length; i++) tot += arr1[i] * arr2[i]; +  return tot; +} +double oxarop_dotprod_double(i64 length, const double *arr1, const double *arr2) { +  __m128d accum = _mm_setzero_pd(); +  i64 i; +  for (i = 0; i + 1 < length; i += 2) { +    accum = _mm_add_pd(accum, _mm_mul_pd(_mm_load_pd(arr1 + i), _mm_load_pd(arr2 + i))); +  } +  double tot = _mm_cvtsd_f64(accum) + _mm_cvtsd_f64(_mm_unpackhi_pd(accum, accum)); +  if (i < length) tot += arr1[i] * arr2[i]; +  return tot; +} +  /*****************************************************************************   *                           Entry point functions                           * @@ -385,7 +412,6 @@ enum redop_tag_t {    ENTRY_REDUCE_OPS(typ) \    EXTREMUM_OP(min, <, typ) \    EXTREMUM_OP(max, >, typ) \ -  DOTPROD_OP(typ) \    DOTPROD_STRIDED_OP(typ)  NUM_TYPES_XLIST  #undef X | 
