diff options
Diffstat (limited to 'cbits')
-rw-r--r-- | cbits/arith.c | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/cbits/arith.c b/cbits/arith.c index 5594c80..751fe33 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -200,6 +200,20 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); } } \ } +#define DOTPROD_OP(typ) \ + typ oxarop_dotprod_ ## typ(i64 length, const typ *arr1, const typ *arr2) { \ + typ res = 0; \ + for (i64 i = 0; i < length; i++) res += arr1[i] * arr2[i]; \ + return res; \ + } + +#define DOTPROD_STRIDED_OP(typ) \ + typ oxarop_dotprod_ ## typ ## _strided(i64 length, i64 offset1, i64 stride1, const typ *arr1, i64 offset2, i64 stride2, const typ *arr2) { \ + typ res = 0; \ + for (i64 i = 0; i < length; i++) res += arr1[offset1 + stride1 * i] * arr2[offset2 + stride2 * i]; \ + return res; \ + } + /***************************************************************************** * Entry point functions * @@ -370,7 +384,9 @@ enum redop_tag_t { ENTRY_UNARY_OPS(typ) \ ENTRY_REDUCE_OPS(typ) \ EXTREMUM_OP(min, <, typ) \ - EXTREMUM_OP(max, >, typ) + EXTREMUM_OP(max, >, typ) \ + DOTPROD_OP(typ) \ + DOTPROD_STRIDED_OP(typ) NUM_TYPES_XLIST #undef X |