diff options
Diffstat (limited to 'cbits/arith.c')
-rw-r--r-- | cbits/arith.c | 28 |
1 files changed, 27 insertions, 1 deletions
diff --git a/cbits/arith.c b/cbits/arith.c index 751fe33..d487cfd 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -4,6 +4,7 @@ #include <stdbool.h> #include <string.h> #include <math.h> +#include <emmintrin.h> // These are the wrapper macros used in arith_lists.h. Preset them to empty to // avoid having to touch macros unrelated to the particular operation set below. @@ -214,6 +215,32 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); } return res; \ } +// The 'double' version here is about 2x as fast as gcc's own vectorisation. +DOTPROD_OP(i32) +DOTPROD_OP(i64) +float oxarop_dotprod_float(i64 length, const float *arr1, const float *arr2) { + __m128 accum = _mm_setzero_ps(); + i64 i; + for (i = 0; i + 3 < length; i += 4) { + accum = _mm_add_ps(accum, _mm_mul_ps(_mm_load_ps(arr1 + i), _mm_load_ps(arr2 + i))); + } + float dest[4]; + _mm_storeu_ps(dest, accum); + float tot = dest[0] + dest[1] + dest[2] + dest[3]; + for (; i < length; i++) tot += arr1[i] * arr2[i]; + return tot; +} +double oxarop_dotprod_double(i64 length, const double *arr1, const double *arr2) { + __m128d accum = _mm_setzero_pd(); + i64 i; + for (i = 0; i + 1 < length; i += 2) { + accum = _mm_add_pd(accum, _mm_mul_pd(_mm_load_pd(arr1 + i), _mm_load_pd(arr2 + i))); + } + double tot = _mm_cvtsd_f64(accum) + _mm_cvtsd_f64(_mm_unpackhi_pd(accum, accum)); + if (i < length) tot += arr1[i] * arr2[i]; + return tot; +} + /***************************************************************************** * Entry point functions * @@ -385,7 +412,6 @@ enum redop_tag_t { ENTRY_REDUCE_OPS(typ) \ EXTREMUM_OP(min, <, typ) \ EXTREMUM_OP(max, >, typ) \ - DOTPROD_OP(typ) \ DOTPROD_STRIDED_OP(typ) NUM_TYPES_XLIST #undef X |