diff options
Diffstat (limited to 'cbits')
| -rw-r--r-- | cbits/arith.c | 55 | 
1 files changed, 54 insertions, 1 deletions
| diff --git a/cbits/arith.c b/cbits/arith.c index 02c8ce1..ca16bf8 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -1,5 +1,7 @@  #include <stdint.h>  #include <stdlib.h> +#include <stdbool.h> +#include <string.h>  #include <math.h>  typedef int32_t i32; @@ -35,6 +37,55 @@ typedef int64_t i64;  // This does not result in multiple loads with GCC 13.  #define GEN_SIGNUM(x) ((x) < 0 ? -1 : (x) > 0 ? 1 : 0) +#define TARRAY_WALK(again_label_name, rank, shape, strides, body) \ +  do { \ +    i64 idx[(rank) - 1]; \ +    memset(idx, 0, ((rank) - 1) * sizeof(idx[0])); \ +    i64 arrlinidx = 0; \ +    i64 outlinidx = 0; \ +  again_label_name: \ +    { \ +      body \ +    } \ +    for (i64 dim = (rank) - 2; dim >= 0; dim--) { \ +      if (++idx[dim] < (shape)[dim]) { \ +        arrlinidx += (strides)[dim]; \ +        outlinidx++; \ +        goto again_label_name; \ +      } \ +      arrlinidx -= (idx[dim] - 1) * (strides)[dim]; \ +      idx[dim] = 0; \ +    } \ +  } while (false) + +// preconditions: +// - all strides are >0 +// - shape is everywhere >0 +// - rank is >= 1 +// - out has capacity for (shape[0] * ... * shape[rank - 2]) elements +// Reduces along the innermost dimension. +// 'out' will be filled densely in linearisation order. +#define REDUCE1_OP(name, op, typ) \ +  void oxarop_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \ +    if (strides[rank - 1] == 1) { \ +      TARRAY_WALK(again1, rank, shape, strides, { \ +        typ accum = arr[arrlinidx]; \ +        for (i64 i = 1; i < shape[rank - 1]; i++) { \ +          accum = accum op arr[arrlinidx + i]; \ +        } \ +        out[outlinidx] = accum; \ +      }); \ +    } else { \ +      TARRAY_WALK(again2, rank, shape, strides, { \ +        typ accum = arr[arrlinidx]; \ +        for (i64 i = 1; i < shape[rank - 1]; i++) { \ +          accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \ +        } \ +        out[outlinidx] = accum; \ +      }); \ +    } \ +  } +  #define NUM_TYPES_LOOP_XLIST \    X(i32) X(i64) X(double) X(float) @@ -44,6 +95,8 @@ typedef int64_t i64;    COMM_OP(mul, *, typ) \    UNARY_OP(neg, -, typ) \    UNARY_OP(abs, GEN_ABS, typ) \ -  UNARY_OP(signum, GEN_SIGNUM, typ) +  UNARY_OP(signum, GEN_SIGNUM, typ) \ +  REDUCE1_OP(sum1, +, typ) \ +  REDUCE1_OP(product1, *, typ)  NUM_TYPES_LOOP_XLIST  #undef X | 
