diff options
Diffstat (limited to 'cbits/arith.c')
-rw-r--r-- | cbits/arith.c | 113 |
1 files changed, 49 insertions, 64 deletions
diff --git a/cbits/arith.c b/cbits/arith.c index 752fc1c..6380776 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -1,8 +1,3 @@ -// Architecture detection -#if defined(__x86_64__) || defined(_M_X64) -#define OX_ARCH_INTEL -#endif - #include <stdio.h> #include <stdint.h> #include <inttypes.h> @@ -11,10 +6,6 @@ #include <string.h> #include <math.h> -#ifdef OX_ARCH_INTEL -#include <emmintrin.h> -#endif - // These are the wrapper macros used in arith_lists.h. Preset them to empty to // avoid having to touch macros unrelated to the particular operation set below. #define LIST_BINOP(name, id, hsop) @@ -229,6 +220,33 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { }); \ } +// Used for reduction and dot product kernels below +#define MANUAL_VECT_WID 8 + +// Used in REDUCE1_OP and REDUCEFULL_OP below; requires the same preconditions +#define REDUCE_BODY_CODE(op, typ, innerLen, innerStride, arr, arrlinidx, destination) \ + do { \ + const i64 n = innerLen; const i64 s = innerStride; \ + if (n < MANUAL_VECT_WID) { \ + typ accum = arr[arrlinidx]; \ + for (i64 i = 1; i < n; i++) accum = accum op arr[arrlinidx + s * i]; \ + destination = accum; \ + } else { \ + typ accum[MANUAL_VECT_WID]; \ + for (i64 j = 0; j < MANUAL_VECT_WID; j++) accum[j] = arr[arrlinidx + s * j]; \ + for (i64 i = 1; i < n / MANUAL_VECT_WID; i++) { \ + for (i64 j = 0; j < MANUAL_VECT_WID; j++) { \ + accum[j] = accum[j] op arr[arrlinidx + s * (MANUAL_VECT_WID * i + j)]; \ + } \ + } \ + typ res = accum[0]; \ + for (i64 j = 1; j < MANUAL_VECT_WID; j++) res = res op accum[j]; \ + for (i64 i = n / MANUAL_VECT_WID * MANUAL_VECT_WID; i < n; i++) \ + res = res op arr[arrlinidx + s * i]; \ + destination = res; \ + } \ + } while (0) + // preconditions: // - all strides are >0 // - shape is everywhere >0 @@ -239,11 +257,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { #define REDUCE1_OP(name, op, typ) \ static void oxarop_op_ ## name ## _ ## typ(i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *arr) { \ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ - typ accum = arr[arrlinidx]; \ - for (i64 i = 1; i < shape[rank - 1]; i++) { \ - accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \ - } \ - out[outlinidx] = accum; \ + REDUCE_BODY_CODE(op, typ, shape[rank - 1], strides[rank - 1], arr, arrlinidx, out[outlinidx]); \ }); \ } @@ -253,15 +267,11 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { // - rank is >= 1 #define REDUCEFULL_OP(name, op, typ) \ typ oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \ - typ res = 0; \ + typ result = 0; \ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ - typ accum = arr[arrlinidx]; \ - for (i64 i = 1; i < shape[rank - 1]; i++) { \ - accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \ - } \ - res = res op accum; \ + REDUCE_BODY_CODE(op, typ, shape[rank - 1], strides[rank - 1], arr, arrlinidx, result); \ }); \ - return res; \ + return result; \ } // preconditions @@ -286,50 +296,25 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { }); \ } -#define DOTPROD_OP(typ) \ - typ oxarop_dotprod_ ## typ(i64 length, const typ *arr1, const typ *arr2) { \ - typ res = 0; \ - for (i64 i = 0; i < length; i++) res += arr1[i] * arr2[i]; \ - return res; \ - } - #define DOTPROD_STRIDED_OP(typ) \ typ oxarop_dotprod_ ## typ ## _strided(i64 length, i64 stride1, const typ *arr1, i64 stride2, const typ *arr2) { \ - typ res = 0; \ - for (i64 i = 0; i < length; i++) res += arr1[stride1 * i] * arr2[stride2 * i]; \ - return res; \ - } - -// The 'double' version here is about 2x as fast as gcc's own vectorisation. -DOTPROD_OP(i32) -DOTPROD_OP(i64) -#ifdef OX_ARCH_INTEL -float oxarop_dotprod_float(i64 length, const float *arr1, const float *arr2) { - __m128 accum = _mm_setzero_ps(); - i64 i; - for (i = 0; i + 3 < length; i += 4) { - accum = _mm_add_ps(accum, _mm_mul_ps(_mm_loadu_ps(arr1 + i), _mm_loadu_ps(arr2 + i))); - } - float dest[4]; - _mm_storeu_ps(dest, accum); - float tot = dest[0] + dest[1] + dest[2] + dest[3]; - for (; i < length; i++) tot += arr1[i] * arr2[i]; - return tot; -} -double oxarop_dotprod_double(i64 length, const double *arr1, const double *arr2) { - __m128d accum = _mm_setzero_pd(); - i64 i; - for (i = 0; i + 1 < length; i += 2) { - accum = _mm_add_pd(accum, _mm_mul_pd(_mm_loadu_pd(arr1 + i), _mm_loadu_pd(arr2 + i))); + if (length < MANUAL_VECT_WID) { \ + typ res = 0; \ + for (i64 i = 0; i < length; i++) res += arr1[stride1 * i] * arr2[stride2 * i]; \ + return res; \ + } else { \ + typ accum[MANUAL_VECT_WID]; \ + for (i64 j = 0; j < MANUAL_VECT_WID; j++) accum[j] = arr1[stride1 * j] * arr2[stride2 * j]; \ + for (i64 i = 1; i < length / MANUAL_VECT_WID; i++) \ + for (i64 j = 0; j < MANUAL_VECT_WID; j++) \ + accum[j] += arr1[stride1 * (MANUAL_VECT_WID * i + j)] * arr2[stride2 * (MANUAL_VECT_WID * i + j)]; \ + typ res = accum[0]; \ + for (i64 j = 1; j < MANUAL_VECT_WID; j++) res += accum[j]; \ + for (i64 i = length / MANUAL_VECT_WID * MANUAL_VECT_WID; i < length; i++) \ + res += arr1[stride1 * i] * arr2[stride2 * i]; \ + return res; \ + } \ } - double tot = _mm_cvtsd_f64(accum) + _mm_cvtsd_f64(_mm_unpackhi_pd(accum, accum)); - if (i < length) tot += arr1[i] * arr2[i]; - return tot; -} -#else -DOTPROD_OP(float) -DOTPROD_OP(double) -#endif // preconditions: // - all strides are >0 @@ -342,12 +327,12 @@ DOTPROD_OP(double) void oxarop_dotprodinner_ ## typ(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *arr1, const i64 *strides2, const typ *arr2) { \ if (strides1[rank - 1] == 1 && strides2[rank - 1] == 1) { \ TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \ - out[outlinidx] = oxarop_dotprod_ ## typ(shape[rank - 1], arr1 + arrlinidx1, arr2 + arrlinidx2); \ + out[outlinidx] = oxarop_dotprod_ ## typ ## _strided(shape[rank - 1], 1, arr1 + arrlinidx1, 1, arr2 + arrlinidx2); \ }); \ } else if (strides1[rank - 1] == -1 && strides2[rank - 1] == -1) { \ TARRAY_WALK2_NOINNER(again2, rank, shape, strides1, strides2, { \ const i64 len = shape[rank - 1]; \ - out[outlinidx] = oxarop_dotprod_ ## typ(len, arr1 + arrlinidx1 - (len - 1), arr2 + arrlinidx2 - (len - 1)); \ + out[outlinidx] = oxarop_dotprod_ ## typ ## _strided(len, 1, arr1 + arrlinidx1 - (len - 1), 1, arr2 + arrlinidx2 - (len - 1)); \ }); \ } else { \ TARRAY_WALK2_NOINNER(again3, rank, shape, strides1, strides2, { \ |