diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-14 21:57:56 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-14 21:58:51 +0100 |
commit | 6276ed3c7bcd20c8b860e1275386ecd068671bcc (patch) | |
tree | b2710f261d12a7a1b73962691c187752663543f6 | |
parent | 308ca9fac150cd28d62afef852f26ae4c40fa5a0 (diff) |
Optimise reductions and dotprod with more vectorisation
Turns out that if you don't supply -ffast-math, the C compiler will
faithfully reproduce your linear reduction order, which is rather
disastrous for parallelisation with vector units.
This changes the summation order, so numerical results might differ
slightly. To wit: the test suite needed adjustment.
-rw-r--r-- | cbits/arith.c | 113 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 1 | ||||
-rw-r--r-- | test/Tests/C.hs | 8 |
3 files changed, 55 insertions, 67 deletions
diff --git a/cbits/arith.c b/cbits/arith.c index 752fc1c..6380776 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -1,8 +1,3 @@ -// Architecture detection -#if defined(__x86_64__) || defined(_M_X64) -#define OX_ARCH_INTEL -#endif - #include <stdio.h> #include <stdint.h> #include <inttypes.h> @@ -11,10 +6,6 @@ #include <string.h> #include <math.h> -#ifdef OX_ARCH_INTEL -#include <emmintrin.h> -#endif - // These are the wrapper macros used in arith_lists.h. Preset them to empty to // avoid having to touch macros unrelated to the particular operation set below. #define LIST_BINOP(name, id, hsop) @@ -229,6 +220,33 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { }); \ } +// Used for reduction and dot product kernels below +#define MANUAL_VECT_WID 8 + +// Used in REDUCE1_OP and REDUCEFULL_OP below; requires the same preconditions +#define REDUCE_BODY_CODE(op, typ, innerLen, innerStride, arr, arrlinidx, destination) \ + do { \ + const i64 n = innerLen; const i64 s = innerStride; \ + if (n < MANUAL_VECT_WID) { \ + typ accum = arr[arrlinidx]; \ + for (i64 i = 1; i < n; i++) accum = accum op arr[arrlinidx + s * i]; \ + destination = accum; \ + } else { \ + typ accum[MANUAL_VECT_WID]; \ + for (i64 j = 0; j < MANUAL_VECT_WID; j++) accum[j] = arr[arrlinidx + s * j]; \ + for (i64 i = 1; i < n / MANUAL_VECT_WID; i++) { \ + for (i64 j = 0; j < MANUAL_VECT_WID; j++) { \ + accum[j] = accum[j] op arr[arrlinidx + s * (MANUAL_VECT_WID * i + j)]; \ + } \ + } \ + typ res = accum[0]; \ + for (i64 j = 1; j < MANUAL_VECT_WID; j++) res = res op accum[j]; \ + for (i64 i = n / MANUAL_VECT_WID * MANUAL_VECT_WID; i < n; i++) \ + res = res op arr[arrlinidx + s * i]; \ + destination = res; \ + } \ + } while (0) + // preconditions: // - all strides are >0 // - shape is everywhere >0 @@ -239,11 +257,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { #define REDUCE1_OP(name, op, typ) \ static void oxarop_op_ ## name ## _ ## typ(i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *arr) { \ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ - typ accum = arr[arrlinidx]; \ - for (i64 i = 1; i < shape[rank - 1]; i++) { \ - accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \ - } \ - out[outlinidx] = accum; \ + REDUCE_BODY_CODE(op, typ, shape[rank - 1], strides[rank - 1], arr, arrlinidx, out[outlinidx]); \ }); \ } @@ -253,15 +267,11 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { // - rank is >= 1 #define REDUCEFULL_OP(name, op, typ) \ typ oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \ - typ res = 0; \ + typ result = 0; \ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ - typ accum = arr[arrlinidx]; \ - for (i64 i = 1; i < shape[rank - 1]; i++) { \ - accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \ - } \ - res = res op accum; \ + REDUCE_BODY_CODE(op, typ, shape[rank - 1], strides[rank - 1], arr, arrlinidx, result); \ }); \ - return res; \ + return result; \ } // preconditions @@ -286,50 +296,25 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { }); \ } -#define DOTPROD_OP(typ) \ - typ oxarop_dotprod_ ## typ(i64 length, const typ *arr1, const typ *arr2) { \ - typ res = 0; \ - for (i64 i = 0; i < length; i++) res += arr1[i] * arr2[i]; \ - return res; \ - } - #define DOTPROD_STRIDED_OP(typ) \ typ oxarop_dotprod_ ## typ ## _strided(i64 length, i64 stride1, const typ *arr1, i64 stride2, const typ *arr2) { \ - typ res = 0; \ - for (i64 i = 0; i < length; i++) res += arr1[stride1 * i] * arr2[stride2 * i]; \ - return res; \ - } - -// The 'double' version here is about 2x as fast as gcc's own vectorisation. -DOTPROD_OP(i32) -DOTPROD_OP(i64) -#ifdef OX_ARCH_INTEL -float oxarop_dotprod_float(i64 length, const float *arr1, const float *arr2) { - __m128 accum = _mm_setzero_ps(); - i64 i; - for (i = 0; i + 3 < length; i += 4) { - accum = _mm_add_ps(accum, _mm_mul_ps(_mm_loadu_ps(arr1 + i), _mm_loadu_ps(arr2 + i))); - } - float dest[4]; - _mm_storeu_ps(dest, accum); - float tot = dest[0] + dest[1] + dest[2] + dest[3]; - for (; i < length; i++) tot += arr1[i] * arr2[i]; - return tot; -} -double oxarop_dotprod_double(i64 length, const double *arr1, const double *arr2) { - __m128d accum = _mm_setzero_pd(); - i64 i; - for (i = 0; i + 1 < length; i += 2) { - accum = _mm_add_pd(accum, _mm_mul_pd(_mm_loadu_pd(arr1 + i), _mm_loadu_pd(arr2 + i))); + if (length < MANUAL_VECT_WID) { \ + typ res = 0; \ + for (i64 i = 0; i < length; i++) res += arr1[stride1 * i] * arr2[stride2 * i]; \ + return res; \ + } else { \ + typ accum[MANUAL_VECT_WID]; \ + for (i64 j = 0; j < MANUAL_VECT_WID; j++) accum[j] = arr1[stride1 * j] * arr2[stride2 * j]; \ + for (i64 i = 1; i < length / MANUAL_VECT_WID; i++) \ + for (i64 j = 0; j < MANUAL_VECT_WID; j++) \ + accum[j] += arr1[stride1 * (MANUAL_VECT_WID * i + j)] * arr2[stride2 * (MANUAL_VECT_WID * i + j)]; \ + typ res = accum[0]; \ + for (i64 j = 1; j < MANUAL_VECT_WID; j++) res += accum[j]; \ + for (i64 i = length / MANUAL_VECT_WID * MANUAL_VECT_WID; i < length; i++) \ + res += arr1[stride1 * i] * arr2[stride2 * i]; \ + return res; \ + } \ } - double tot = _mm_cvtsd_f64(accum) + _mm_cvtsd_f64(_mm_unpackhi_pd(accum, accum)); - if (i < length) tot += arr1[i] * arr2[i]; - return tot; -} -#else -DOTPROD_OP(float) -DOTPROD_OP(double) -#endif // preconditions: // - all strides are >0 @@ -342,12 +327,12 @@ DOTPROD_OP(double) void oxarop_dotprodinner_ ## typ(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *arr1, const i64 *strides2, const typ *arr2) { \ if (strides1[rank - 1] == 1 && strides2[rank - 1] == 1) { \ TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \ - out[outlinidx] = oxarop_dotprod_ ## typ(shape[rank - 1], arr1 + arrlinidx1, arr2 + arrlinidx2); \ + out[outlinidx] = oxarop_dotprod_ ## typ ## _strided(shape[rank - 1], 1, arr1 + arrlinidx1, 1, arr2 + arrlinidx2); \ }); \ } else if (strides1[rank - 1] == -1 && strides2[rank - 1] == -1) { \ TARRAY_WALK2_NOINNER(again2, rank, shape, strides1, strides2, { \ const i64 len = shape[rank - 1]; \ - out[outlinidx] = oxarop_dotprod_ ## typ(len, arr1 + arrlinidx1 - (len - 1), arr2 + arrlinidx2 - (len - 1)); \ + out[outlinidx] = oxarop_dotprod_ ## typ ## _strided(len, 1, arr1 + arrlinidx1 - (len - 1), 1, arr2 + arrlinidx2 - (len - 1)); \ }); \ } else { \ TARRAY_WALK2_NOINNER(again3, rank, shape, strides1, strides2, { \ diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs index 15fbc79..969a25a 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -20,7 +20,6 @@ $(do ,("reducefull_" ++ tyn, [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |]) ,("extremum_min_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) ,("extremum_max_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) - ,("dotprod_" ++ tyn, [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO $ttyp |]) ,("dotprod_" ++ tyn ++ "_strided", [t| Int64 -> Int64 -> Int64 -> Ptr $ttyp -> Int64 -> Int64 -> Ptr $ttyp -> IO $ttyp |]) ,("dotprodinner_" ++ tyn, [t| Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |]) ] diff --git a/test/Tests/C.hs b/test/Tests/C.hs index bc8e0de..a0f103d 100644 --- a/test/Tests/C.hs +++ b/test/Tests/C.hs @@ -35,6 +35,10 @@ import Gen import Util +-- | Appropriate for simple different summation orders +fineTol :: Double +fineTol = 1e-8 + prop_sum_nonempty :: Property prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet. @@ -46,7 +50,7 @@ prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do genStorables (Range.singleton (product sh)) (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) let rarr = rfromOrthotope inrank arr - rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr + almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr) prop_sum_empty :: Property prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do @@ -74,7 +78,7 @@ prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do genStorables (Range.singleton (product insh)) (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) let rarr = rfromOrthotope inrank arr - rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr + almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr) prop_sum_replicated :: Bool -> Property prop_sum_replicated doTranspose = property $ |