From 4c86a3a4231cecc5b7c31491398f43b4ba667eea Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 23 May 2024 13:47:18 +0200 Subject: Fast sum Also fast product, but that's currently unused --- cbits/arith.c | 55 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) (limited to 'cbits') diff --git a/cbits/arith.c b/cbits/arith.c index 02c8ce1..ca16bf8 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -1,5 +1,7 @@ #include #include +#include +#include #include typedef int32_t i32; @@ -35,6 +37,55 @@ typedef int64_t i64; // This does not result in multiple loads with GCC 13. #define GEN_SIGNUM(x) ((x) < 0 ? -1 : (x) > 0 ? 1 : 0) +#define TARRAY_WALK(again_label_name, rank, shape, strides, body) \ + do { \ + i64 idx[(rank) - 1]; \ + memset(idx, 0, ((rank) - 1) * sizeof(idx[0])); \ + i64 arrlinidx = 0; \ + i64 outlinidx = 0; \ + again_label_name: \ + { \ + body \ + } \ + for (i64 dim = (rank) - 2; dim >= 0; dim--) { \ + if (++idx[dim] < (shape)[dim]) { \ + arrlinidx += (strides)[dim]; \ + outlinidx++; \ + goto again_label_name; \ + } \ + arrlinidx -= (idx[dim] - 1) * (strides)[dim]; \ + idx[dim] = 0; \ + } \ + } while (false) + +// preconditions: +// - all strides are >0 +// - shape is everywhere >0 +// - rank is >= 1 +// - out has capacity for (shape[0] * ... * shape[rank - 2]) elements +// Reduces along the innermost dimension. +// 'out' will be filled densely in linearisation order. +#define REDUCE1_OP(name, op, typ) \ + void oxarop_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \ + if (strides[rank - 1] == 1) { \ + TARRAY_WALK(again1, rank, shape, strides, { \ + typ accum = arr[arrlinidx]; \ + for (i64 i = 1; i < shape[rank - 1]; i++) { \ + accum = accum op arr[arrlinidx + i]; \ + } \ + out[outlinidx] = accum; \ + }); \ + } else { \ + TARRAY_WALK(again2, rank, shape, strides, { \ + typ accum = arr[arrlinidx]; \ + for (i64 i = 1; i < shape[rank - 1]; i++) { \ + accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \ + } \ + out[outlinidx] = accum; \ + }); \ + } \ + } + #define NUM_TYPES_LOOP_XLIST \ X(i32) X(i64) X(double) X(float) @@ -44,6 +95,8 @@ typedef int64_t i64; COMM_OP(mul, *, typ) \ UNARY_OP(neg, -, typ) \ UNARY_OP(abs, GEN_ABS, typ) \ - UNARY_OP(signum, GEN_SIGNUM, typ) + UNARY_OP(signum, GEN_SIGNUM, typ) \ + REDUCE1_OP(sum1, +, typ) \ + REDUCE1_OP(product1, *, typ) NUM_TYPES_LOOP_XLIST #undef X -- cgit v1.2.3-70-g09d2