diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-06-18 21:03:36 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-06-18 21:03:52 +0200 |
commit | 5dbabda5ef848e8b4b58b8dee55e5a22de7ee7d6 (patch) | |
tree | a2993713b7b32292063fdc30a54a3cec8392698c | |
parent | d3cff40181b2b68a97a26012e1f26f702d57e5f1 (diff) |
C cleanup: abstract strides[rank-1] case into macro
-rw-r--r-- | cbits/arith.c | 93 |
1 files changed, 35 insertions, 58 deletions
diff --git a/cbits/arith.c b/cbits/arith.c index 5d74c01..4d60228 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -145,6 +145,17 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); } } \ } while (false) +// Same as TARRAY_WALK_NOINNER, except the body is specialised twice: once on +// strides[rank-1] == 1 and a fallback case. +#define TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, body) \ + do { \ + if (strides[rank - 1] == 1) { \ + TARRAY_WALK_NOINNER(tar_wa_again1, rank, shape, strides, body); \ + } else { \ + TARRAY_WALK_NOINNER(tar_wa_again2, rank, shape, strides, body); \ + } \ + } while (false) + // preconditions: // - all strides are >0 // - shape is everywhere >0 @@ -154,23 +165,13 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); } // 'out' will be filled densely in linearisation order. #define REDUCE1_OP(name, op, typ) \ static void oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \ - if (strides[rank - 1] == 1) { \ - TARRAY_WALK_NOINNER(again1, rank, shape, strides, { \ - typ accum = arr[arrlinidx]; \ - for (i64 i = 1; i < shape[rank - 1]; i++) { \ - accum = accum op arr[arrlinidx + i]; \ - } \ - out[outlinidx] = accum; \ - }); \ - } else { \ - TARRAY_WALK_NOINNER(again2, rank, shape, strides, { \ - typ accum = arr[arrlinidx]; \ - for (i64 i = 1; i < shape[rank - 1]; i++) { \ - accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \ - } \ - out[outlinidx] = accum; \ - }); \ - } \ + TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ + typ accum = arr[arrlinidx]; \ + for (i64 i = 1; i < shape[rank - 1]; i++) { \ + accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \ + } \ + out[outlinidx] = accum; \ + }); \ } // preconditions @@ -180,23 +181,13 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); } #define REDUCEFULL_OP(name, op, typ) \ typ oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \ typ res = 0; \ - if (strides[rank - 1] == 1) { \ - TARRAY_WALK_NOINNER(again1, rank, shape, strides, { \ - typ accum = arr[arrlinidx]; \ - for (i64 i = 1; i < shape[rank - 1]; i++) { \ - accum = accum op arr[arrlinidx + i]; \ - } \ - res = res op accum; \ - }); \ - } else { \ - TARRAY_WALK_NOINNER(again2, rank, shape, strides, { \ - typ accum = arr[arrlinidx]; \ - for (i64 i = 1; i < shape[rank - 1]; i++) { \ - accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \ - } \ - res = res op accum; \ - }); \ - } \ + TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ + typ accum = arr[arrlinidx]; \ + for (i64 i = 1; i < shape[rank - 1]; i++) { \ + accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \ + } \ + res = res op accum; \ + }); \ return res; \ } @@ -209,31 +200,17 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); } void oxarop_extremum_ ## name ## _ ## typ(i64 *outidx, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \ typ best = arr[0]; \ memset(outidx, 0, rank * sizeof(i64)); \ - if (strides[rank - 1] == 1) { \ - TARRAY_WALK_NOINNER(again1, rank, shape, strides, { \ - bool found = false; \ - for (i64 i = 0; i < shape[rank - 1]; i++) { \ - if (arr[arrlinidx + i] cmp best) { \ - best = arr[arrlinidx + i]; \ - found = true; \ - outidx[rank - 1] = i; \ - } \ - } \ - if (found) memcpy(outidx, idx, (rank - 1) * sizeof(i64)); \ - }); \ - } else { \ - TARRAY_WALK_NOINNER(again2, rank, shape, strides, { \ - bool found = false; \ - for (i64 i = 0; i < shape[rank - 1]; i++) { \ - if (arr[arrlinidx + i] cmp best) { \ - best = arr[arrlinidx + strides[rank - 1] * i]; \ - found = true; \ - outidx[rank - 1] = i; \ - } \ + TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ + bool found = false; \ + for (i64 i = 0; i < shape[rank - 1]; i++) { \ + if (arr[arrlinidx + i] cmp best) { \ + best = arr[arrlinidx + strides[rank - 1] * i]; \ + found = true; \ + outidx[rank - 1] = i; \ } \ - if (found) memcpy(outidx, idx, (rank - 1) * sizeof(i64)); \ - }); \ - } \ + } \ + if (found) memcpy(outidx, idx, (rank - 1) * sizeof(i64)); \ + }); \ } #define DOTPROD_OP(typ) \ |