aboutsummaryrefslogtreecommitdiff
path: root/cbits
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2024-06-10 16:22:11 +0200
committerTom Smeding <t.j.smeding@uu.nl>2024-06-10 16:22:11 +0200
commitc6ac2b69e15ff09622ac2bbc40ede8331866a559 (patch)
tree9b9386b135c55a7bfa37864d03586ac9fe1b2c5d /cbits
parentbb1e1fcd1b0f47747623f1497a4f4ae0f7a2a62d (diff)
Manual vectorisation of dot product for floating points
Diffstat (limited to 'cbits')
-rw-r--r--cbits/arith.c28
1 files changed, 27 insertions, 1 deletions
diff --git a/cbits/arith.c b/cbits/arith.c
index 751fe33..d487cfd 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -4,6 +4,7 @@
#include <stdbool.h>
#include <string.h>
#include <math.h>
+#include <emmintrin.h>
// These are the wrapper macros used in arith_lists.h. Preset them to empty to
// avoid having to touch macros unrelated to the particular operation set below.
@@ -214,6 +215,32 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }
return res; \
}
+// The 'double' version here is about 2x as fast as gcc's own vectorisation.
+DOTPROD_OP(i32)
+DOTPROD_OP(i64)
+float oxarop_dotprod_float(i64 length, const float *arr1, const float *arr2) {
+ __m128 accum = _mm_setzero_ps();
+ i64 i;
+ for (i = 0; i + 3 < length; i += 4) {
+ accum = _mm_add_ps(accum, _mm_mul_ps(_mm_load_ps(arr1 + i), _mm_load_ps(arr2 + i)));
+ }
+ float dest[4];
+ _mm_storeu_ps(dest, accum);
+ float tot = dest[0] + dest[1] + dest[2] + dest[3];
+ for (; i < length; i++) tot += arr1[i] * arr2[i];
+ return tot;
+}
+double oxarop_dotprod_double(i64 length, const double *arr1, const double *arr2) {
+ __m128d accum = _mm_setzero_pd();
+ i64 i;
+ for (i = 0; i + 1 < length; i += 2) {
+ accum = _mm_add_pd(accum, _mm_mul_pd(_mm_load_pd(arr1 + i), _mm_load_pd(arr2 + i)));
+ }
+ double tot = _mm_cvtsd_f64(accum) + _mm_cvtsd_f64(_mm_unpackhi_pd(accum, accum));
+ if (i < length) tot += arr1[i] * arr2[i];
+ return tot;
+}
+
/*****************************************************************************
* Entry point functions *
@@ -385,7 +412,6 @@ enum redop_tag_t {
ENTRY_REDUCE_OPS(typ) \
EXTREMUM_OP(min, <, typ) \
EXTREMUM_OP(max, >, typ) \
- DOTPROD_OP(typ) \
DOTPROD_STRIDED_OP(typ)
NUM_TYPES_XLIST
#undef X