diff options
Diffstat (limited to 'cbits/arith.c')
-rw-r--r-- | cbits/arith.c | 58 |
1 files changed, 56 insertions, 2 deletions
diff --git a/cbits/arith.c b/cbits/arith.c index ea06ac1..b9c86ab 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -124,6 +124,7 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); } // Walk a orthotope-style strided array, except for the inner dimension. The // body is run for every "inner vector". +// Provides idx, outlinidx, arrlinidx. #define TARRAY_WALK_NOINNER(again_label_name, rank, shape, strides, body) \ do { \ i64 idx[(rank) - 1]; \ @@ -145,11 +146,38 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); } } \ } while (false) +// Walk TWO orthotope-style strided arrays simultaneously, except for their +// inner dimension. The arrays must have the same shape, but may have different +// strides. The body is run for every pair of "inner vectors". +// Provides idx, outlinidx, arrlinidx1, arrlinidx2. +#define TARRAY_WALK2_NOINNER(again_label_name, rank, shape, strides1, strides2, body) \ + do { \ + i64 idx[(rank) - 1]; \ + memset(idx, 0, ((rank) - 1) * sizeof(idx[0])); \ + i64 arrlinidx1 = 0, arrlinidx2 = 0; \ + i64 outlinidx = 0; \ + again_label_name: \ + { \ + body \ + } \ + for (i64 dim = (rank) - 2; dim >= 0; dim--) { \ + if (++idx[dim] < (shape)[dim]) { \ + arrlinidx1 += (strides1)[dim]; \ + arrlinidx2 += (strides2)[dim]; \ + outlinidx++; \ + goto again_label_name; \ + } \ + arrlinidx1 -= (idx[dim] - 1) * (strides1)[dim]; \ + arrlinidx2 -= (idx[dim] - 1) * (strides2)[dim]; \ + idx[dim] = 0; \ + } \ + } while (false) + // Same as TARRAY_WALK_NOINNER, except the body is specialised twice: once on // strides[rank-1] == 1 and a fallback case. #define TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, body) \ do { \ - if (strides[rank - 1] == 1) { \ + if ((strides)[(rank) - 1] == 1) { \ TARRAY_WALK_NOINNER(tar_wa_again1, rank, shape, strides, body); \ } else { \ TARRAY_WALK_NOINNER(tar_wa_again2, rank, shape, strides, body); \ @@ -258,6 +286,31 @@ DOTPROD_OP(float) DOTPROD_OP(double) #endif +// preconditions: +// - all strides are >0 +// - shape is everywhere >0 +// - rank is >= 1 +// - out has capacity for (shape[0] * ... * shape[rank - 2]) elements +// Reduces along the innermost dimension. +// 'out' will be filled densely in linearisation order. +#define DOTPROD_INNER_OP(typ) \ + void oxarop_dotprodinner_ ## typ(i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *arr1, const i64 *strides2, const typ *arr2) { \ + if (strides1[rank - 1] == 1 && strides2[rank - 1] == 1) { \ + TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \ + out[outlinidx] = oxarop_dotprod_ ## typ(shape[rank - 1], arr1 + arrlinidx1, arr2 + arrlinidx2); \ + }); \ + } else if (strides1[rank - 1] == -1 && strides2[rank - 1] == -1) { \ + TARRAY_WALK2_NOINNER(again2, rank, shape, strides1, strides2, { \ + const i64 len = shape[rank - 1]; \ + out[outlinidx] = oxarop_dotprod_ ## typ(len, arr1 + arrlinidx1 - (len - 1), arr2 + arrlinidx2 - (len - 1)); \ + }); \ + } else { \ + TARRAY_WALK2_NOINNER(again3, rank, shape, strides1, strides2, { \ + out[outlinidx] = oxarop_dotprod_ ## typ ## _strided(shape[rank - 1], arrlinidx1, strides1[rank - 1], arr1, arrlinidx2, strides2[rank - 1], arr2); \ + }); \ + } \ + } + /***************************************************************************** * Entry point functions * @@ -441,7 +494,8 @@ enum redop_tag_t { ENTRY_REDUCEFULL_OPS(typ) \ EXTREMUM_OP(min, <, typ) \ EXTREMUM_OP(max, >, typ) \ - DOTPROD_STRIDED_OP(typ) + DOTPROD_STRIDED_OP(typ) \ + DOTPROD_INNER_OP(typ) NUM_TYPES_XLIST #undef X |