aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-14 21:57:56 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-14 21:58:51 +0100
commit6276ed3c7bcd20c8b860e1275386ecd068671bcc (patch)
treeb2710f261d12a7a1b73962691c187752663543f6
parent308ca9fac150cd28d62afef852f26ae4c40fa5a0 (diff)
Optimise reductions and dotprod with more vectorisation
Turns out that if you don't supply -ffast-math, the C compiler will faithfully reproduce your linear reduction order, which is rather disastrous for parallelisation with vector units. This changes the summation order, so numerical results might differ slightly. To wit: the test suite needed adjustment.
-rw-r--r--cbits/arith.c113
-rw-r--r--src/Data/Array/Mixed/Internal/Arith/Foreign.hs1
-rw-r--r--test/Tests/C.hs8
3 files changed, 55 insertions, 67 deletions
diff --git a/cbits/arith.c b/cbits/arith.c
index 752fc1c..6380776 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -1,8 +1,3 @@
-// Architecture detection
-#if defined(__x86_64__) || defined(_M_X64)
-#define OX_ARCH_INTEL
-#endif
-
#include <stdio.h>
#include <stdint.h>
#include <inttypes.h>
@@ -11,10 +6,6 @@
#include <string.h>
#include <math.h>
-#ifdef OX_ARCH_INTEL
-#include <emmintrin.h>
-#endif
-
// These are the wrapper macros used in arith_lists.h. Preset them to empty to
// avoid having to touch macros unrelated to the particular operation set below.
#define LIST_BINOP(name, id, hsop)
@@ -229,6 +220,33 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
}); \
}
+// Used for reduction and dot product kernels below
+#define MANUAL_VECT_WID 8
+
+// Used in REDUCE1_OP and REDUCEFULL_OP below; requires the same preconditions
+#define REDUCE_BODY_CODE(op, typ, innerLen, innerStride, arr, arrlinidx, destination) \
+ do { \
+ const i64 n = innerLen; const i64 s = innerStride; \
+ if (n < MANUAL_VECT_WID) { \
+ typ accum = arr[arrlinidx]; \
+ for (i64 i = 1; i < n; i++) accum = accum op arr[arrlinidx + s * i]; \
+ destination = accum; \
+ } else { \
+ typ accum[MANUAL_VECT_WID]; \
+ for (i64 j = 0; j < MANUAL_VECT_WID; j++) accum[j] = arr[arrlinidx + s * j]; \
+ for (i64 i = 1; i < n / MANUAL_VECT_WID; i++) { \
+ for (i64 j = 0; j < MANUAL_VECT_WID; j++) { \
+ accum[j] = accum[j] op arr[arrlinidx + s * (MANUAL_VECT_WID * i + j)]; \
+ } \
+ } \
+ typ res = accum[0]; \
+ for (i64 j = 1; j < MANUAL_VECT_WID; j++) res = res op accum[j]; \
+ for (i64 i = n / MANUAL_VECT_WID * MANUAL_VECT_WID; i < n; i++) \
+ res = res op arr[arrlinidx + s * i]; \
+ destination = res; \
+ } \
+ } while (0)
+
// preconditions:
// - all strides are >0
// - shape is everywhere >0
@@ -239,11 +257,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
#define REDUCE1_OP(name, op, typ) \
static void oxarop_op_ ## name ## _ ## typ(i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *arr) { \
TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
- typ accum = arr[arrlinidx]; \
- for (i64 i = 1; i < shape[rank - 1]; i++) { \
- accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \
- } \
- out[outlinidx] = accum; \
+ REDUCE_BODY_CODE(op, typ, shape[rank - 1], strides[rank - 1], arr, arrlinidx, out[outlinidx]); \
}); \
}
@@ -253,15 +267,11 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
// - rank is >= 1
#define REDUCEFULL_OP(name, op, typ) \
typ oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \
- typ res = 0; \
+ typ result = 0; \
TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
- typ accum = arr[arrlinidx]; \
- for (i64 i = 1; i < shape[rank - 1]; i++) { \
- accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \
- } \
- res = res op accum; \
+ REDUCE_BODY_CODE(op, typ, shape[rank - 1], strides[rank - 1], arr, arrlinidx, result); \
}); \
- return res; \
+ return result; \
}
// preconditions
@@ -286,50 +296,25 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
}); \
}
-#define DOTPROD_OP(typ) \
- typ oxarop_dotprod_ ## typ(i64 length, const typ *arr1, const typ *arr2) { \
- typ res = 0; \
- for (i64 i = 0; i < length; i++) res += arr1[i] * arr2[i]; \
- return res; \
- }
-
#define DOTPROD_STRIDED_OP(typ) \
typ oxarop_dotprod_ ## typ ## _strided(i64 length, i64 stride1, const typ *arr1, i64 stride2, const typ *arr2) { \
- typ res = 0; \
- for (i64 i = 0; i < length; i++) res += arr1[stride1 * i] * arr2[stride2 * i]; \
- return res; \
- }
-
-// The 'double' version here is about 2x as fast as gcc's own vectorisation.
-DOTPROD_OP(i32)
-DOTPROD_OP(i64)
-#ifdef OX_ARCH_INTEL
-float oxarop_dotprod_float(i64 length, const float *arr1, const float *arr2) {
- __m128 accum = _mm_setzero_ps();
- i64 i;
- for (i = 0; i + 3 < length; i += 4) {
- accum = _mm_add_ps(accum, _mm_mul_ps(_mm_loadu_ps(arr1 + i), _mm_loadu_ps(arr2 + i)));
- }
- float dest[4];
- _mm_storeu_ps(dest, accum);
- float tot = dest[0] + dest[1] + dest[2] + dest[3];
- for (; i < length; i++) tot += arr1[i] * arr2[i];
- return tot;
-}
-double oxarop_dotprod_double(i64 length, const double *arr1, const double *arr2) {
- __m128d accum = _mm_setzero_pd();
- i64 i;
- for (i = 0; i + 1 < length; i += 2) {
- accum = _mm_add_pd(accum, _mm_mul_pd(_mm_loadu_pd(arr1 + i), _mm_loadu_pd(arr2 + i)));
+ if (length < MANUAL_VECT_WID) { \
+ typ res = 0; \
+ for (i64 i = 0; i < length; i++) res += arr1[stride1 * i] * arr2[stride2 * i]; \
+ return res; \
+ } else { \
+ typ accum[MANUAL_VECT_WID]; \
+ for (i64 j = 0; j < MANUAL_VECT_WID; j++) accum[j] = arr1[stride1 * j] * arr2[stride2 * j]; \
+ for (i64 i = 1; i < length / MANUAL_VECT_WID; i++) \
+ for (i64 j = 0; j < MANUAL_VECT_WID; j++) \
+ accum[j] += arr1[stride1 * (MANUAL_VECT_WID * i + j)] * arr2[stride2 * (MANUAL_VECT_WID * i + j)]; \
+ typ res = accum[0]; \
+ for (i64 j = 1; j < MANUAL_VECT_WID; j++) res += accum[j]; \
+ for (i64 i = length / MANUAL_VECT_WID * MANUAL_VECT_WID; i < length; i++) \
+ res += arr1[stride1 * i] * arr2[stride2 * i]; \
+ return res; \
+ } \
}
- double tot = _mm_cvtsd_f64(accum) + _mm_cvtsd_f64(_mm_unpackhi_pd(accum, accum));
- if (i < length) tot += arr1[i] * arr2[i];
- return tot;
-}
-#else
-DOTPROD_OP(float)
-DOTPROD_OP(double)
-#endif
// preconditions:
// - all strides are >0
@@ -342,12 +327,12 @@ DOTPROD_OP(double)
void oxarop_dotprodinner_ ## typ(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *arr1, const i64 *strides2, const typ *arr2) { \
if (strides1[rank - 1] == 1 && strides2[rank - 1] == 1) { \
TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \
- out[outlinidx] = oxarop_dotprod_ ## typ(shape[rank - 1], arr1 + arrlinidx1, arr2 + arrlinidx2); \
+ out[outlinidx] = oxarop_dotprod_ ## typ ## _strided(shape[rank - 1], 1, arr1 + arrlinidx1, 1, arr2 + arrlinidx2); \
}); \
} else if (strides1[rank - 1] == -1 && strides2[rank - 1] == -1) { \
TARRAY_WALK2_NOINNER(again2, rank, shape, strides1, strides2, { \
const i64 len = shape[rank - 1]; \
- out[outlinidx] = oxarop_dotprod_ ## typ(len, arr1 + arrlinidx1 - (len - 1), arr2 + arrlinidx2 - (len - 1)); \
+ out[outlinidx] = oxarop_dotprod_ ## typ ## _strided(len, 1, arr1 + arrlinidx1 - (len - 1), 1, arr2 + arrlinidx2 - (len - 1)); \
}); \
} else { \
TARRAY_WALK2_NOINNER(again3, rank, shape, strides1, strides2, { \
diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
index 15fbc79..969a25a 100644
--- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
+++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
@@ -20,7 +20,6 @@ $(do
,("reducefull_" ++ tyn, [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |])
,("extremum_min_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
,("extremum_max_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
- ,("dotprod_" ++ tyn, [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO $ttyp |])
,("dotprod_" ++ tyn ++ "_strided", [t| Int64 -> Int64 -> Int64 -> Ptr $ttyp -> Int64 -> Int64 -> Ptr $ttyp -> IO $ttyp |])
,("dotprodinner_" ++ tyn, [t| Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |])
]
diff --git a/test/Tests/C.hs b/test/Tests/C.hs
index bc8e0de..a0f103d 100644
--- a/test/Tests/C.hs
+++ b/test/Tests/C.hs
@@ -35,6 +35,10 @@ import Gen
import Util
+-- | Appropriate for simple different summation orders
+fineTol :: Double
+fineTol = 1e-8
+
prop_sum_nonempty :: Property
prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do
-- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet.
@@ -46,7 +50,7 @@ prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do
genStorables (Range.singleton (product sh))
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
let rarr = rfromOrthotope inrank arr
- rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr
+ almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr)
prop_sum_empty :: Property
prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do
@@ -74,7 +78,7 @@ prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do
genStorables (Range.singleton (product insh))
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
let rarr = rfromOrthotope inrank arr
- rtoOrthotope (rsumOuter1 rarr) === orSumOuter1 outrank arr
+ almostEq fineTol (rtoOrthotope (rsumOuter1 rarr)) (orSumOuter1 outrank arr)
prop_sum_replicated :: Bool -> Property
prop_sum_replicated doTranspose = property $