diff options
Diffstat (limited to 'cbits')
| -rw-r--r-- | cbits/arith.c | 40 | 
1 files changed, 39 insertions, 1 deletions
| diff --git a/cbits/arith.c b/cbits/arith.c index e20578b..6ac49b8 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -164,6 +164,42 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }      } \    } +// preconditions +// - all strides are >0 +// - shape is everywhere >0 +// - rank is >= 1 +// Writes extreme index to outidx. If 'cmp' is '<', computes argmin; if '>', argmax. +#define EXTREMUM_OP(name, cmp, typ) \ +  void oxarop_extremum_ ## name ## _ ## typ(i64 *outidx, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \ +    typ best = arr[0]; \ +    memset(outidx, 0, rank * sizeof(i64)); \ +    if (strides[rank - 1] == 1) { \ +      TARRAY_WALK_NOINNER(again1, rank, shape, strides, { \ +        bool found = false; \ +        for (i64 i = 0; i < shape[rank - 1]; i++) { \ +          if (arr[arrlinidx + i] cmp best) { \ +            best = arr[arrlinidx + i]; \ +            found = true; \ +            outidx[rank - 1] = i; \ +          } \ +        } \ +        if (found) memcpy(outidx, idx, (rank - 1) * sizeof(i64)); \ +      }); \ +    } else { \ +      TARRAY_WALK_NOINNER(again2, rank, shape, strides, { \ +        bool found = false; \ +        for (i64 i = 0; i < shape[rank - 1]; i++) { \ +          if (arr[arrlinidx + i] cmp best) { \ +            best = arr[arrlinidx + strides[rank - 1] * i]; \ +            found = true; \ +            outidx[rank - 1] = i; \ +          } \ +        } \ +        if (found) memcpy(outidx, idx, (rank - 1) * sizeof(i64)); \ +      }); \ +    } \ +  } +  /*****************************************************************************   *                           Entry point functions                           * @@ -332,7 +368,9 @@ enum redop_tag_t {    REDUCE1_OP(product1, *, typ) \    ENTRY_BINARY_OPS(typ) \    ENTRY_UNARY_OPS(typ) \ -  ENTRY_REDUCE_OPS(typ) +  ENTRY_REDUCE_OPS(typ) \ +  EXTREMUM_OP(min, <, typ) \ +  EXTREMUM_OP(max, >, typ)  NUM_TYPES_XLIST  #undef X | 
