diff options
Diffstat (limited to 'cbits')
| -rw-r--r-- | cbits/arith.c | 58 | 
1 files changed, 56 insertions, 2 deletions
| diff --git a/cbits/arith.c b/cbits/arith.c index ea06ac1..b9c86ab 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -124,6 +124,7 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }  // Walk a orthotope-style strided array, except for the inner dimension. The  // body is run for every "inner vector". +// Provides idx, outlinidx, arrlinidx.  #define TARRAY_WALK_NOINNER(again_label_name, rank, shape, strides, body) \    do { \      i64 idx[(rank) - 1]; \ @@ -145,11 +146,38 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }      } \    } while (false) +// Walk TWO orthotope-style strided arrays simultaneously, except for their +// inner dimension. The arrays must have the same shape, but may have different +// strides. The body is run for every pair of "inner vectors". +// Provides idx, outlinidx, arrlinidx1, arrlinidx2. +#define TARRAY_WALK2_NOINNER(again_label_name, rank, shape, strides1, strides2, body) \ +  do { \ +    i64 idx[(rank) - 1]; \ +    memset(idx, 0, ((rank) - 1) * sizeof(idx[0])); \ +    i64 arrlinidx1 = 0, arrlinidx2 = 0; \ +    i64 outlinidx = 0; \ +  again_label_name: \ +    { \ +      body \ +    } \ +    for (i64 dim = (rank) - 2; dim >= 0; dim--) { \ +      if (++idx[dim] < (shape)[dim]) { \ +        arrlinidx1 += (strides1)[dim]; \ +        arrlinidx2 += (strides2)[dim]; \ +        outlinidx++; \ +        goto again_label_name; \ +      } \ +      arrlinidx1 -= (idx[dim] - 1) * (strides1)[dim]; \ +      arrlinidx2 -= (idx[dim] - 1) * (strides2)[dim]; \ +      idx[dim] = 0; \ +    } \ +  } while (false) +  // Same as TARRAY_WALK_NOINNER, except the body is specialised twice: once on  // strides[rank-1] == 1 and a fallback case.  #define TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, body) \    do { \ -    if (strides[rank - 1] == 1) { \ +    if ((strides)[(rank) - 1] == 1) { \        TARRAY_WALK_NOINNER(tar_wa_again1, rank, shape, strides, body); \      } else { \        TARRAY_WALK_NOINNER(tar_wa_again2, rank, shape, strides, body); \ @@ -258,6 +286,31 @@ DOTPROD_OP(float)  DOTPROD_OP(double)  #endif +// preconditions: +// - all strides are >0 +// - shape is everywhere >0 +// - rank is >= 1 +// - out has capacity for (shape[0] * ... * shape[rank - 2]) elements +// Reduces along the innermost dimension. +// 'out' will be filled densely in linearisation order. +#define DOTPROD_INNER_OP(typ) \ +  void oxarop_dotprodinner_ ## typ(i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *arr1, const i64 *strides2, const typ *arr2) { \ +    if (strides1[rank - 1] == 1 && strides2[rank - 1] == 1) { \ +      TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \ +        out[outlinidx] = oxarop_dotprod_ ## typ(shape[rank - 1], arr1 + arrlinidx1, arr2 + arrlinidx2); \ +      }); \ +    } else if (strides1[rank - 1] == -1 && strides2[rank - 1] == -1) { \ +      TARRAY_WALK2_NOINNER(again2, rank, shape, strides1, strides2, { \ +        const i64 len = shape[rank - 1]; \ +        out[outlinidx] = oxarop_dotprod_ ## typ(len, arr1 + arrlinidx1 - (len - 1), arr2 + arrlinidx2 - (len - 1)); \ +      }); \ +    } else { \ +      TARRAY_WALK2_NOINNER(again3, rank, shape, strides1, strides2, { \ +        out[outlinidx] = oxarop_dotprod_ ## typ ## _strided(shape[rank - 1], arrlinidx1, strides1[rank - 1], arr1, arrlinidx2, strides2[rank - 1], arr2); \ +      }); \ +    } \ +  } +  /*****************************************************************************   *                           Entry point functions                           * @@ -441,7 +494,8 @@ enum redop_tag_t {    ENTRY_REDUCEFULL_OPS(typ) \    EXTREMUM_OP(min, <, typ) \    EXTREMUM_OP(max, >, typ) \ -  DOTPROD_STRIDED_OP(typ) +  DOTPROD_STRIDED_OP(typ) \ +  DOTPROD_INNER_OP(typ)  NUM_TYPES_XLIST  #undef X | 
