diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-06-09 23:09:19 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-06-09 23:09:19 +0200 |
commit | 1f3d57e13441f86b97ee7ff213bb4a677e31f2db (patch) | |
tree | e72bfd568b032a9af611118038c2eeb6f347ea22 /cbits | |
parent | c8f99847359a92289cf0ded280069794f6abae6a (diff) |
argmin and argmax
Diffstat (limited to 'cbits')
-rw-r--r-- | cbits/arith.c | 40 |
1 files changed, 39 insertions, 1 deletions
diff --git a/cbits/arith.c b/cbits/arith.c index e20578b..6ac49b8 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -164,6 +164,42 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); } } \ } +// preconditions +// - all strides are >0 +// - shape is everywhere >0 +// - rank is >= 1 +// Writes extreme index to outidx. If 'cmp' is '<', computes argmin; if '>', argmax. +#define EXTREMUM_OP(name, cmp, typ) \ + void oxarop_extremum_ ## name ## _ ## typ(i64 *outidx, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \ + typ best = arr[0]; \ + memset(outidx, 0, rank * sizeof(i64)); \ + if (strides[rank - 1] == 1) { \ + TARRAY_WALK_NOINNER(again1, rank, shape, strides, { \ + bool found = false; \ + for (i64 i = 0; i < shape[rank - 1]; i++) { \ + if (arr[arrlinidx + i] cmp best) { \ + best = arr[arrlinidx + i]; \ + found = true; \ + outidx[rank - 1] = i; \ + } \ + } \ + if (found) memcpy(outidx, idx, (rank - 1) * sizeof(i64)); \ + }); \ + } else { \ + TARRAY_WALK_NOINNER(again2, rank, shape, strides, { \ + bool found = false; \ + for (i64 i = 0; i < shape[rank - 1]; i++) { \ + if (arr[arrlinidx + i] cmp best) { \ + best = arr[arrlinidx + strides[rank - 1] * i]; \ + found = true; \ + outidx[rank - 1] = i; \ + } \ + } \ + if (found) memcpy(outidx, idx, (rank - 1) * sizeof(i64)); \ + }); \ + } \ + } + /***************************************************************************** * Entry point functions * @@ -332,7 +368,9 @@ enum redop_tag_t { REDUCE1_OP(product1, *, typ) \ ENTRY_BINARY_OPS(typ) \ ENTRY_UNARY_OPS(typ) \ - ENTRY_REDUCE_OPS(typ) + ENTRY_REDUCE_OPS(typ) \ + EXTREMUM_OP(min, <, typ) \ + EXTREMUM_OP(max, >, typ) NUM_TYPES_XLIST #undef X |