From c6ac2b69e15ff09622ac2bbc40ede8331866a559 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 10 Jun 2024 16:22:11 +0200 Subject: Manual vectorisation of dot product for floating points --- bench/Main.hs | 13 +++++++++++++ cbits/arith.c | 28 +++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/bench/Main.hs b/bench/Main.hs index eb3c6d7..cc4e11f 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -65,6 +65,14 @@ main = defaultMain let n = 1_000_000 in nf (\a -> runScalar (rsumOuter1 a)) (riota @Double n) + ,bench "dotprod Float [1e6]" $ + let n = 1_000_000 + in nf (\(a, b) -> rdot a b) + (riota @Float n, riota @Float n) + ,bench "dotprod Float [1e6] stride 1; -1" $ + let n = 1_000_000 + in nf (\(a, b) -> rdot a b) + (riota @Float n, rrev1 (riota @Float n)) ,bench "dotprod Double [1e6]" $ let n = 1_000_000 in nf (\(a, b) -> rdot a b) @@ -103,6 +111,11 @@ main = defaultMain let n = 1_000_000 in nf (\a -> LA.sumElements a) (LA.linspace @Double n (0.0, fromIntegral (n - 1))) + ,bench "dotprod Float [1e6]" $ + let n = 1_000_000 + in nf (\(a, b) -> a LA.<.> b) + (LA.linspace @Double n (0.0, fromIntegral (n - 1)) + ,LA.linspace @Double n (fromIntegral (n - 1), 0.0)) ,bench "dotprod Double [1e6]" $ let n = 1_000_000 in nf (\(a, b) -> a LA.<.> b) diff --git a/cbits/arith.c b/cbits/arith.c index 751fe33..d487cfd 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -4,6 +4,7 @@ #include #include #include +#include // These are the wrapper macros used in arith_lists.h. Preset them to empty to // avoid having to touch macros unrelated to the particular operation set below. @@ -214,6 +215,32 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); } return res; \ } +// The 'double' version here is about 2x as fast as gcc's own vectorisation. +DOTPROD_OP(i32) +DOTPROD_OP(i64) +float oxarop_dotprod_float(i64 length, const float *arr1, const float *arr2) { + __m128 accum = _mm_setzero_ps(); + i64 i; + for (i = 0; i + 3 < length; i += 4) { + accum = _mm_add_ps(accum, _mm_mul_ps(_mm_load_ps(arr1 + i), _mm_load_ps(arr2 + i))); + } + float dest[4]; + _mm_storeu_ps(dest, accum); + float tot = dest[0] + dest[1] + dest[2] + dest[3]; + for (; i < length; i++) tot += arr1[i] * arr2[i]; + return tot; +} +double oxarop_dotprod_double(i64 length, const double *arr1, const double *arr2) { + __m128d accum = _mm_setzero_pd(); + i64 i; + for (i = 0; i + 1 < length; i += 2) { + accum = _mm_add_pd(accum, _mm_mul_pd(_mm_load_pd(arr1 + i), _mm_load_pd(arr2 + i))); + } + double tot = _mm_cvtsd_f64(accum) + _mm_cvtsd_f64(_mm_unpackhi_pd(accum, accum)); + if (i < length) tot += arr1[i] * arr2[i]; + return tot; +} + /***************************************************************************** * Entry point functions * @@ -385,7 +412,6 @@ enum redop_tag_t { ENTRY_REDUCE_OPS(typ) \ EXTREMUM_OP(min, <, typ) \ EXTREMUM_OP(max, >, typ) \ - DOTPROD_OP(typ) \ DOTPROD_STRIDED_OP(typ) NUM_TYPES_XLIST #undef X -- cgit v1.2.3-70-g09d2