From 9b0651bf19e889dfb28ba81b6ada25b27b0e6071 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 17 Jun 2024 13:08:13 +0200 Subject: sumAllPrim --- cbits/arith.c | 49 ++++++++++++++++++++++++++++++++++++++++++++----- cbits/arith_lists.h | 4 ++-- 2 files changed, 46 insertions(+), 7 deletions(-) (limited to 'cbits') diff --git a/cbits/arith.c b/cbits/arith.c index fb993c8..5d74c01 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -173,6 +173,33 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); } } \ } +// preconditions +// - all strides are >0 +// - shape is everywhere >0 +// - rank is >= 1 +#define REDUCEFULL_OP(name, op, typ) \ + typ oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \ + typ res = 0; \ + if (strides[rank - 1] == 1) { \ + TARRAY_WALK_NOINNER(again1, rank, shape, strides, { \ + typ accum = arr[arrlinidx]; \ + for (i64 i = 1; i < shape[rank - 1]; i++) { \ + accum = accum op arr[arrlinidx + i]; \ + } \ + res = res op accum; \ + }); \ + } else { \ + TARRAY_WALK_NOINNER(again2, rank, shape, strides, { \ + typ accum = arr[arrlinidx]; \ + for (i64 i = 1; i < shape[rank - 1]; i++) { \ + accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \ + } \ + res = res op accum; \ + }); \ + } \ + return res; \ + } + // preconditions // - all strides are >0 // - shape is everywhere >0 @@ -394,11 +421,20 @@ enum redop_tag_t { #define LIST_REDOP(name, id, _) }; -#define ENTRY_REDUCE_OPS(typ) \ - void oxarop_reduce_ ## typ(enum redop_tag_t tag, i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \ +#define ENTRY_REDUCE1_OPS(typ) \ + void oxarop_reduce1_ ## typ(enum redop_tag_t tag, i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \ + switch (tag) { \ + case RO_SUM: oxarop_op_sum1_ ## typ(rank, shape, strides, out, arr); break; \ + case RO_PRODUCT: oxarop_op_product1_ ## typ(rank, shape, strides, out, arr); break; \ + default: wrong_op("reduce", tag); \ + } \ + } + +#define ENTRY_REDUCEFULL_OPS(typ) \ + typ oxarop_reducefull_ ## typ(enum redop_tag_t tag, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \ switch (tag) { \ - case RO_SUM1: oxarop_op_sum1_ ## typ(rank, shape, strides, out, arr); break; \ - case RO_PRODUCT1: oxarop_op_product1_ ## typ(rank, shape, strides, out, arr); break; \ + case RO_SUM: return oxarop_op_sumfull_ ## typ(rank, shape, strides, arr); \ + case RO_PRODUCT: return oxarop_op_productfull_ ## typ(rank, shape, strides, arr); \ default: wrong_op("reduce", tag); \ } \ } @@ -420,9 +456,12 @@ enum redop_tag_t { UNARY_OP(signum, GEN_SIGNUM, typ) \ REDUCE1_OP(sum1, +, typ) \ REDUCE1_OP(product1, *, typ) \ + REDUCEFULL_OP(sumfull, +, typ) \ + REDUCEFULL_OP(productfull, *, typ) \ ENTRY_BINARY_OPS(typ) \ ENTRY_UNARY_OPS(typ) \ - ENTRY_REDUCE_OPS(typ) \ + ENTRY_REDUCE1_OPS(typ) \ + ENTRY_REDUCEFULL_OPS(typ) \ EXTREMUM_OP(min, <, typ) \ EXTREMUM_OP(max, >, typ) \ DOTPROD_STRIDED_OP(typ) diff --git a/cbits/arith_lists.h b/cbits/arith_lists.h index 2e37575..58de65a 100644 --- a/cbits/arith_lists.h +++ b/cbits/arith_lists.h @@ -31,5 +31,5 @@ LIST_FUNOP(FU_EXPM1, 18,) LIST_FUNOP(FU_LOG1PEXP, 19,) LIST_FUNOP(FU_LOG1MEXP, 20,) -LIST_REDOP(RO_SUM1, 1,) -LIST_REDOP(RO_PRODUCT1, 2,) +LIST_REDOP(RO_SUM, 1,) +LIST_REDOP(RO_PRODUCT, 2,) -- cgit v1.2.3-70-g09d2