aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2024-06-10 16:22:11 +0200
committerTom Smeding <t.j.smeding@uu.nl>2024-06-10 16:22:11 +0200
commitc6ac2b69e15ff09622ac2bbc40ede8331866a559 (patch)
tree9b9386b135c55a7bfa37864d03586ac9fe1b2c5d
parentbb1e1fcd1b0f47747623f1497a4f4ae0f7a2a62d (diff)
Manual vectorisation of dot product for floating points
-rw-r--r--bench/Main.hs13
-rw-r--r--cbits/arith.c28
2 files changed, 40 insertions, 1 deletions
diff --git a/bench/Main.hs b/bench/Main.hs
index eb3c6d7..cc4e11f 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -65,6 +65,14 @@ main = defaultMain
let n = 1_000_000
in nf (\a -> runScalar (rsumOuter1 a))
(riota @Double n)
+ ,bench "dotprod Float [1e6]" $
+ let n = 1_000_000
+ in nf (\(a, b) -> rdot a b)
+ (riota @Float n, riota @Float n)
+ ,bench "dotprod Float [1e6] stride 1; -1" $
+ let n = 1_000_000
+ in nf (\(a, b) -> rdot a b)
+ (riota @Float n, rrev1 (riota @Float n))
,bench "dotprod Double [1e6]" $
let n = 1_000_000
in nf (\(a, b) -> rdot a b)
@@ -103,6 +111,11 @@ main = defaultMain
let n = 1_000_000
in nf (\a -> LA.sumElements a)
(LA.linspace @Double n (0.0, fromIntegral (n - 1)))
+ ,bench "dotprod Float [1e6]" $
+ let n = 1_000_000
+ in nf (\(a, b) -> a LA.<.> b)
+ (LA.linspace @Double n (0.0, fromIntegral (n - 1))
+ ,LA.linspace @Double n (fromIntegral (n - 1), 0.0))
,bench "dotprod Double [1e6]" $
let n = 1_000_000
in nf (\(a, b) -> a LA.<.> b)
diff --git a/cbits/arith.c b/cbits/arith.c
index 751fe33..d487cfd 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -4,6 +4,7 @@
#include <stdbool.h>
#include <string.h>
#include <math.h>
+#include <emmintrin.h>
// These are the wrapper macros used in arith_lists.h. Preset them to empty to
// avoid having to touch macros unrelated to the particular operation set below.
@@ -214,6 +215,32 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }
return res; \
}
+// The 'double' version here is about 2x as fast as gcc's own vectorisation.
+DOTPROD_OP(i32)
+DOTPROD_OP(i64)
+float oxarop_dotprod_float(i64 length, const float *arr1, const float *arr2) {
+ __m128 accum = _mm_setzero_ps();
+ i64 i;
+ for (i = 0; i + 3 < length; i += 4) {
+ accum = _mm_add_ps(accum, _mm_mul_ps(_mm_load_ps(arr1 + i), _mm_load_ps(arr2 + i)));
+ }
+ float dest[4];
+ _mm_storeu_ps(dest, accum);
+ float tot = dest[0] + dest[1] + dest[2] + dest[3];
+ for (; i < length; i++) tot += arr1[i] * arr2[i];
+ return tot;
+}
+double oxarop_dotprod_double(i64 length, const double *arr1, const double *arr2) {
+ __m128d accum = _mm_setzero_pd();
+ i64 i;
+ for (i = 0; i + 1 < length; i += 2) {
+ accum = _mm_add_pd(accum, _mm_mul_pd(_mm_load_pd(arr1 + i), _mm_load_pd(arr2 + i)));
+ }
+ double tot = _mm_cvtsd_f64(accum) + _mm_cvtsd_f64(_mm_unpackhi_pd(accum, accum));
+ if (i < length) tot += arr1[i] * arr2[i];
+ return tot;
+}
+
/*****************************************************************************
* Entry point functions *
@@ -385,7 +412,6 @@ enum redop_tag_t {
ENTRY_REDUCE_OPS(typ) \
EXTREMUM_OP(min, <, typ) \
EXTREMUM_OP(max, >, typ) \
- DOTPROD_OP(typ) \
DOTPROD_STRIDED_OP(typ)
NUM_TYPES_XLIST
#undef X