From 1f3d57e13441f86b97ee7ff213bb4a677e31f2db Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 9 Jun 2024 23:09:19 +0200 Subject: argmin and argmax --- cbits/arith.c | 40 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) (limited to 'cbits') diff --git a/cbits/arith.c b/cbits/arith.c index e20578b..6ac49b8 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -164,6 +164,42 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); } } \ } +// preconditions +// - all strides are >0 +// - shape is everywhere >0 +// - rank is >= 1 +// Writes extreme index to outidx. If 'cmp' is '<', computes argmin; if '>', argmax. +#define EXTREMUM_OP(name, cmp, typ) \ + void oxarop_extremum_ ## name ## _ ## typ(i64 *outidx, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \ + typ best = arr[0]; \ + memset(outidx, 0, rank * sizeof(i64)); \ + if (strides[rank - 1] == 1) { \ + TARRAY_WALK_NOINNER(again1, rank, shape, strides, { \ + bool found = false; \ + for (i64 i = 0; i < shape[rank - 1]; i++) { \ + if (arr[arrlinidx + i] cmp best) { \ + best = arr[arrlinidx + i]; \ + found = true; \ + outidx[rank - 1] = i; \ + } \ + } \ + if (found) memcpy(outidx, idx, (rank - 1) * sizeof(i64)); \ + }); \ + } else { \ + TARRAY_WALK_NOINNER(again2, rank, shape, strides, { \ + bool found = false; \ + for (i64 i = 0; i < shape[rank - 1]; i++) { \ + if (arr[arrlinidx + i] cmp best) { \ + best = arr[arrlinidx + strides[rank - 1] * i]; \ + found = true; \ + outidx[rank - 1] = i; \ + } \ + } \ + if (found) memcpy(outidx, idx, (rank - 1) * sizeof(i64)); \ + }); \ + } \ + } + /***************************************************************************** * Entry point functions * @@ -332,7 +368,9 @@ enum redop_tag_t { REDUCE1_OP(product1, *, typ) \ ENTRY_BINARY_OPS(typ) \ ENTRY_UNARY_OPS(typ) \ - ENTRY_REDUCE_OPS(typ) + ENTRY_REDUCE_OPS(typ) \ + EXTREMUM_OP(min, <, typ) \ + EXTREMUM_OP(max, >, typ) NUM_TYPES_XLIST #undef X -- cgit v1.2.3-70-g09d2