aboutsummaryrefslogtreecommitdiff
path: root/cbits/arith.c
diff options
context:
space:
mode:
Diffstat (limited to 'cbits/arith.c')
-rw-r--r--cbits/arith.c40
1 files changed, 39 insertions, 1 deletions
diff --git a/cbits/arith.c b/cbits/arith.c
index e20578b..6ac49b8 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -164,6 +164,42 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }
} \
}
+// preconditions
+// - all strides are >0
+// - shape is everywhere >0
+// - rank is >= 1
+// Writes extreme index to outidx. If 'cmp' is '<', computes argmin; if '>', argmax.
+#define EXTREMUM_OP(name, cmp, typ) \
+ void oxarop_extremum_ ## name ## _ ## typ(i64 *outidx, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \
+ typ best = arr[0]; \
+ memset(outidx, 0, rank * sizeof(i64)); \
+ if (strides[rank - 1] == 1) { \
+ TARRAY_WALK_NOINNER(again1, rank, shape, strides, { \
+ bool found = false; \
+ for (i64 i = 0; i < shape[rank - 1]; i++) { \
+ if (arr[arrlinidx + i] cmp best) { \
+ best = arr[arrlinidx + i]; \
+ found = true; \
+ outidx[rank - 1] = i; \
+ } \
+ } \
+ if (found) memcpy(outidx, idx, (rank - 1) * sizeof(i64)); \
+ }); \
+ } else { \
+ TARRAY_WALK_NOINNER(again2, rank, shape, strides, { \
+ bool found = false; \
+ for (i64 i = 0; i < shape[rank - 1]; i++) { \
+ if (arr[arrlinidx + i] cmp best) { \
+ best = arr[arrlinidx + strides[rank - 1] * i]; \
+ found = true; \
+ outidx[rank - 1] = i; \
+ } \
+ } \
+ if (found) memcpy(outidx, idx, (rank - 1) * sizeof(i64)); \
+ }); \
+ } \
+ }
+
/*****************************************************************************
* Entry point functions *
@@ -332,7 +368,9 @@ enum redop_tag_t {
REDUCE1_OP(product1, *, typ) \
ENTRY_BINARY_OPS(typ) \
ENTRY_UNARY_OPS(typ) \
- ENTRY_REDUCE_OPS(typ)
+ ENTRY_REDUCE_OPS(typ) \
+ EXTREMUM_OP(min, <, typ) \
+ EXTREMUM_OP(max, >, typ)
NUM_TYPES_XLIST
#undef X