aboutsummaryrefslogtreecommitdiff
path: root/cbits/arith.c
diff options
context:
space:
mode:
Diffstat (limited to 'cbits/arith.c')
-rw-r--r--cbits/arith.c49
1 files changed, 44 insertions, 5 deletions
diff --git a/cbits/arith.c b/cbits/arith.c
index fb993c8..5d74c01 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -177,6 +177,33 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }
// - all strides are >0
// - shape is everywhere >0
// - rank is >= 1
+#define REDUCEFULL_OP(name, op, typ) \
+ typ oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \
+ typ res = 0; \
+ if (strides[rank - 1] == 1) { \
+ TARRAY_WALK_NOINNER(again1, rank, shape, strides, { \
+ typ accum = arr[arrlinidx]; \
+ for (i64 i = 1; i < shape[rank - 1]; i++) { \
+ accum = accum op arr[arrlinidx + i]; \
+ } \
+ res = res op accum; \
+ }); \
+ } else { \
+ TARRAY_WALK_NOINNER(again2, rank, shape, strides, { \
+ typ accum = arr[arrlinidx]; \
+ for (i64 i = 1; i < shape[rank - 1]; i++) { \
+ accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \
+ } \
+ res = res op accum; \
+ }); \
+ } \
+ return res; \
+ }
+
+// preconditions
+// - all strides are >0
+// - shape is everywhere >0
+// - rank is >= 1
// Writes extreme index to outidx. If 'cmp' is '<', computes minindex ("argmin"); if '>', maxindex.
#define EXTREMUM_OP(name, cmp, typ) \
void oxarop_extremum_ ## name ## _ ## typ(i64 *outidx, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \
@@ -394,11 +421,20 @@ enum redop_tag_t {
#define LIST_REDOP(name, id, _)
};
-#define ENTRY_REDUCE_OPS(typ) \
- void oxarop_reduce_ ## typ(enum redop_tag_t tag, i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \
+#define ENTRY_REDUCE1_OPS(typ) \
+ void oxarop_reduce1_ ## typ(enum redop_tag_t tag, i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \
+ switch (tag) { \
+ case RO_SUM: oxarop_op_sum1_ ## typ(rank, shape, strides, out, arr); break; \
+ case RO_PRODUCT: oxarop_op_product1_ ## typ(rank, shape, strides, out, arr); break; \
+ default: wrong_op("reduce", tag); \
+ } \
+ }
+
+#define ENTRY_REDUCEFULL_OPS(typ) \
+ typ oxarop_reducefull_ ## typ(enum redop_tag_t tag, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \
switch (tag) { \
- case RO_SUM1: oxarop_op_sum1_ ## typ(rank, shape, strides, out, arr); break; \
- case RO_PRODUCT1: oxarop_op_product1_ ## typ(rank, shape, strides, out, arr); break; \
+ case RO_SUM: return oxarop_op_sumfull_ ## typ(rank, shape, strides, arr); \
+ case RO_PRODUCT: return oxarop_op_productfull_ ## typ(rank, shape, strides, arr); \
default: wrong_op("reduce", tag); \
} \
}
@@ -420,9 +456,12 @@ enum redop_tag_t {
UNARY_OP(signum, GEN_SIGNUM, typ) \
REDUCE1_OP(sum1, +, typ) \
REDUCE1_OP(product1, *, typ) \
+ REDUCEFULL_OP(sumfull, +, typ) \
+ REDUCEFULL_OP(productfull, *, typ) \
ENTRY_BINARY_OPS(typ) \
ENTRY_UNARY_OPS(typ) \
- ENTRY_REDUCE_OPS(typ) \
+ ENTRY_REDUCE1_OPS(typ) \
+ ENTRY_REDUCEFULL_OPS(typ) \
EXTREMUM_OP(min, <, typ) \
EXTREMUM_OP(max, >, typ) \
DOTPROD_STRIDED_OP(typ)