aboutsummaryrefslogtreecommitdiff
path: root/cbits
diff options
context:
space:
mode:
Diffstat (limited to 'cbits')
-rw-r--r--cbits/arith.c55
1 files changed, 54 insertions, 1 deletions
diff --git a/cbits/arith.c b/cbits/arith.c
index 02c8ce1..ca16bf8 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -1,5 +1,7 @@
#include <stdint.h>
#include <stdlib.h>
+#include <stdbool.h>
+#include <string.h>
#include <math.h>
typedef int32_t i32;
@@ -35,6 +37,55 @@ typedef int64_t i64;
// This does not result in multiple loads with GCC 13.
#define GEN_SIGNUM(x) ((x) < 0 ? -1 : (x) > 0 ? 1 : 0)
+#define TARRAY_WALK(again_label_name, rank, shape, strides, body) \
+ do { \
+ i64 idx[(rank) - 1]; \
+ memset(idx, 0, ((rank) - 1) * sizeof(idx[0])); \
+ i64 arrlinidx = 0; \
+ i64 outlinidx = 0; \
+ again_label_name: \
+ { \
+ body \
+ } \
+ for (i64 dim = (rank) - 2; dim >= 0; dim--) { \
+ if (++idx[dim] < (shape)[dim]) { \
+ arrlinidx += (strides)[dim]; \
+ outlinidx++; \
+ goto again_label_name; \
+ } \
+ arrlinidx -= (idx[dim] - 1) * (strides)[dim]; \
+ idx[dim] = 0; \
+ } \
+ } while (false)
+
+// preconditions:
+// - all strides are >0
+// - shape is everywhere >0
+// - rank is >= 1
+// - out has capacity for (shape[0] * ... * shape[rank - 2]) elements
+// Reduces along the innermost dimension.
+// 'out' will be filled densely in linearisation order.
+#define REDUCE1_OP(name, op, typ) \
+ void oxarop_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \
+ if (strides[rank - 1] == 1) { \
+ TARRAY_WALK(again1, rank, shape, strides, { \
+ typ accum = arr[arrlinidx]; \
+ for (i64 i = 1; i < shape[rank - 1]; i++) { \
+ accum = accum op arr[arrlinidx + i]; \
+ } \
+ out[outlinidx] = accum; \
+ }); \
+ } else { \
+ TARRAY_WALK(again2, rank, shape, strides, { \
+ typ accum = arr[arrlinidx]; \
+ for (i64 i = 1; i < shape[rank - 1]; i++) { \
+ accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \
+ } \
+ out[outlinidx] = accum; \
+ }); \
+ } \
+ }
+
#define NUM_TYPES_LOOP_XLIST \
X(i32) X(i64) X(double) X(float)
@@ -44,6 +95,8 @@ typedef int64_t i64;
COMM_OP(mul, *, typ) \
UNARY_OP(neg, -, typ) \
UNARY_OP(abs, GEN_ABS, typ) \
- UNARY_OP(signum, GEN_SIGNUM, typ)
+ UNARY_OP(signum, GEN_SIGNUM, typ) \
+ REDUCE1_OP(sum1, +, typ) \
+ REDUCE1_OP(product1, *, typ)
NUM_TYPES_LOOP_XLIST
#undef X