aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cbits/arith.c93
1 files changed, 35 insertions, 58 deletions
diff --git a/cbits/arith.c b/cbits/arith.c
index 5d74c01..4d60228 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -145,6 +145,17 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }
} \
} while (false)
+// Same as TARRAY_WALK_NOINNER, except the body is specialised twice: once on
+// strides[rank-1] == 1 and a fallback case.
+#define TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, body) \
+ do { \
+ if (strides[rank - 1] == 1) { \
+ TARRAY_WALK_NOINNER(tar_wa_again1, rank, shape, strides, body); \
+ } else { \
+ TARRAY_WALK_NOINNER(tar_wa_again2, rank, shape, strides, body); \
+ } \
+ } while (false)
+
// preconditions:
// - all strides are >0
// - shape is everywhere >0
@@ -154,23 +165,13 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }
// 'out' will be filled densely in linearisation order.
#define REDUCE1_OP(name, op, typ) \
static void oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \
- if (strides[rank - 1] == 1) { \
- TARRAY_WALK_NOINNER(again1, rank, shape, strides, { \
- typ accum = arr[arrlinidx]; \
- for (i64 i = 1; i < shape[rank - 1]; i++) { \
- accum = accum op arr[arrlinidx + i]; \
- } \
- out[outlinidx] = accum; \
- }); \
- } else { \
- TARRAY_WALK_NOINNER(again2, rank, shape, strides, { \
- typ accum = arr[arrlinidx]; \
- for (i64 i = 1; i < shape[rank - 1]; i++) { \
- accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \
- } \
- out[outlinidx] = accum; \
- }); \
- } \
+ TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \
+ typ accum = arr[arrlinidx]; \
+ for (i64 i = 1; i < shape[rank - 1]; i++) { \
+ accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \
+ } \
+ out[outlinidx] = accum; \
+ }); \
}
// preconditions
@@ -180,23 +181,13 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }
#define REDUCEFULL_OP(name, op, typ) \
typ oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \
typ res = 0; \
- if (strides[rank - 1] == 1) { \
- TARRAY_WALK_NOINNER(again1, rank, shape, strides, { \
- typ accum = arr[arrlinidx]; \
- for (i64 i = 1; i < shape[rank - 1]; i++) { \
- accum = accum op arr[arrlinidx + i]; \
- } \
- res = res op accum; \
- }); \
- } else { \
- TARRAY_WALK_NOINNER(again2, rank, shape, strides, { \
- typ accum = arr[arrlinidx]; \
- for (i64 i = 1; i < shape[rank - 1]; i++) { \
- accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \
- } \
- res = res op accum; \
- }); \
- } \
+ TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \
+ typ accum = arr[arrlinidx]; \
+ for (i64 i = 1; i < shape[rank - 1]; i++) { \
+ accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \
+ } \
+ res = res op accum; \
+ }); \
return res; \
}
@@ -209,31 +200,17 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }
void oxarop_extremum_ ## name ## _ ## typ(i64 *outidx, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \
typ best = arr[0]; \
memset(outidx, 0, rank * sizeof(i64)); \
- if (strides[rank - 1] == 1) { \
- TARRAY_WALK_NOINNER(again1, rank, shape, strides, { \
- bool found = false; \
- for (i64 i = 0; i < shape[rank - 1]; i++) { \
- if (arr[arrlinidx + i] cmp best) { \
- best = arr[arrlinidx + i]; \
- found = true; \
- outidx[rank - 1] = i; \
- } \
- } \
- if (found) memcpy(outidx, idx, (rank - 1) * sizeof(i64)); \
- }); \
- } else { \
- TARRAY_WALK_NOINNER(again2, rank, shape, strides, { \
- bool found = false; \
- for (i64 i = 0; i < shape[rank - 1]; i++) { \
- if (arr[arrlinidx + i] cmp best) { \
- best = arr[arrlinidx + strides[rank - 1] * i]; \
- found = true; \
- outidx[rank - 1] = i; \
- } \
+ TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \
+ bool found = false; \
+ for (i64 i = 0; i < shape[rank - 1]; i++) { \
+ if (arr[arrlinidx + i] cmp best) { \
+ best = arr[arrlinidx + strides[rank - 1] * i]; \
+ found = true; \
+ outidx[rank - 1] = i; \
} \
- if (found) memcpy(outidx, idx, (rank - 1) * sizeof(i64)); \
- }); \
- } \
+ } \
+ if (found) memcpy(outidx, idx, (rank - 1) * sizeof(i64)); \
+ }); \
}
#define DOTPROD_OP(typ) \