aboutsummaryrefslogtreecommitdiff
path: root/cbits
diff options
context:
space:
mode:
Diffstat (limited to 'cbits')
-rw-r--r--cbits/arith.c58
1 files changed, 56 insertions, 2 deletions
diff --git a/cbits/arith.c b/cbits/arith.c
index ea06ac1..b9c86ab 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -124,6 +124,7 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }
// Walk a orthotope-style strided array, except for the inner dimension. The
// body is run for every "inner vector".
+// Provides idx, outlinidx, arrlinidx.
#define TARRAY_WALK_NOINNER(again_label_name, rank, shape, strides, body) \
do { \
i64 idx[(rank) - 1]; \
@@ -145,11 +146,38 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }
} \
} while (false)
+// Walk TWO orthotope-style strided arrays simultaneously, except for their
+// inner dimension. The arrays must have the same shape, but may have different
+// strides. The body is run for every pair of "inner vectors".
+// Provides idx, outlinidx, arrlinidx1, arrlinidx2.
+#define TARRAY_WALK2_NOINNER(again_label_name, rank, shape, strides1, strides2, body) \
+ do { \
+ i64 idx[(rank) - 1]; \
+ memset(idx, 0, ((rank) - 1) * sizeof(idx[0])); \
+ i64 arrlinidx1 = 0, arrlinidx2 = 0; \
+ i64 outlinidx = 0; \
+ again_label_name: \
+ { \
+ body \
+ } \
+ for (i64 dim = (rank) - 2; dim >= 0; dim--) { \
+ if (++idx[dim] < (shape)[dim]) { \
+ arrlinidx1 += (strides1)[dim]; \
+ arrlinidx2 += (strides2)[dim]; \
+ outlinidx++; \
+ goto again_label_name; \
+ } \
+ arrlinidx1 -= (idx[dim] - 1) * (strides1)[dim]; \
+ arrlinidx2 -= (idx[dim] - 1) * (strides2)[dim]; \
+ idx[dim] = 0; \
+ } \
+ } while (false)
+
// Same as TARRAY_WALK_NOINNER, except the body is specialised twice: once on
// strides[rank-1] == 1 and a fallback case.
#define TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, body) \
do { \
- if (strides[rank - 1] == 1) { \
+ if ((strides)[(rank) - 1] == 1) { \
TARRAY_WALK_NOINNER(tar_wa_again1, rank, shape, strides, body); \
} else { \
TARRAY_WALK_NOINNER(tar_wa_again2, rank, shape, strides, body); \
@@ -258,6 +286,31 @@ DOTPROD_OP(float)
DOTPROD_OP(double)
#endif
+// preconditions:
+// - all strides are >0
+// - shape is everywhere >0
+// - rank is >= 1
+// - out has capacity for (shape[0] * ... * shape[rank - 2]) elements
+// Reduces along the innermost dimension.
+// 'out' will be filled densely in linearisation order.
+#define DOTPROD_INNER_OP(typ) \
+ void oxarop_dotprodinner_ ## typ(i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *arr1, const i64 *strides2, const typ *arr2) { \
+ if (strides1[rank - 1] == 1 && strides2[rank - 1] == 1) { \
+ TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \
+ out[outlinidx] = oxarop_dotprod_ ## typ(shape[rank - 1], arr1 + arrlinidx1, arr2 + arrlinidx2); \
+ }); \
+ } else if (strides1[rank - 1] == -1 && strides2[rank - 1] == -1) { \
+ TARRAY_WALK2_NOINNER(again2, rank, shape, strides1, strides2, { \
+ const i64 len = shape[rank - 1]; \
+ out[outlinidx] = oxarop_dotprod_ ## typ(len, arr1 + arrlinidx1 - (len - 1), arr2 + arrlinidx2 - (len - 1)); \
+ }); \
+ } else { \
+ TARRAY_WALK2_NOINNER(again3, rank, shape, strides1, strides2, { \
+ out[outlinidx] = oxarop_dotprod_ ## typ ## _strided(shape[rank - 1], arrlinidx1, strides1[rank - 1], arr1, arrlinidx2, strides2[rank - 1], arr2); \
+ }); \
+ } \
+ }
+
/*****************************************************************************
* Entry point functions *
@@ -441,7 +494,8 @@ enum redop_tag_t {
ENTRY_REDUCEFULL_OPS(typ) \
EXTREMUM_OP(min, <, typ) \
EXTREMUM_OP(max, >, typ) \
- DOTPROD_STRIDED_OP(typ)
+ DOTPROD_STRIDED_OP(typ) \
+ DOTPROD_INNER_OP(typ)
NUM_TYPES_XLIST
#undef X