diff options
Diffstat (limited to 'cbits')
| -rw-r--r-- | cbits/arith.c | 28 | 
1 files changed, 27 insertions, 1 deletions
| diff --git a/cbits/arith.c b/cbits/arith.c index 751fe33..d487cfd 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -4,6 +4,7 @@  #include <stdbool.h>  #include <string.h>  #include <math.h> +#include <emmintrin.h>  // These are the wrapper macros used in arith_lists.h. Preset them to empty to  // avoid having to touch macros unrelated to the particular operation set below. @@ -214,6 +215,32 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }      return res; \    } +// The 'double' version here is about 2x as fast as gcc's own vectorisation. +DOTPROD_OP(i32) +DOTPROD_OP(i64) +float oxarop_dotprod_float(i64 length, const float *arr1, const float *arr2) { +  __m128 accum = _mm_setzero_ps(); +  i64 i; +  for (i = 0; i + 3 < length; i += 4) { +    accum = _mm_add_ps(accum, _mm_mul_ps(_mm_load_ps(arr1 + i), _mm_load_ps(arr2 + i))); +  } +  float dest[4]; +  _mm_storeu_ps(dest, accum); +  float tot = dest[0] + dest[1] + dest[2] + dest[3]; +  for (; i < length; i++) tot += arr1[i] * arr2[i]; +  return tot; +} +double oxarop_dotprod_double(i64 length, const double *arr1, const double *arr2) { +  __m128d accum = _mm_setzero_pd(); +  i64 i; +  for (i = 0; i + 1 < length; i += 2) { +    accum = _mm_add_pd(accum, _mm_mul_pd(_mm_load_pd(arr1 + i), _mm_load_pd(arr2 + i))); +  } +  double tot = _mm_cvtsd_f64(accum) + _mm_cvtsd_f64(_mm_unpackhi_pd(accum, accum)); +  if (i < length) tot += arr1[i] * arr2[i]; +  return tot; +} +  /*****************************************************************************   *                           Entry point functions                           * @@ -385,7 +412,6 @@ enum redop_tag_t {    ENTRY_REDUCE_OPS(typ) \    EXTREMUM_OP(min, <, typ) \    EXTREMUM_OP(max, >, typ) \ -  DOTPROD_OP(typ) \    DOTPROD_STRIDED_OP(typ)  NUM_TYPES_XLIST  #undef X | 
