diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-05-26 00:11:00 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-26 00:11:00 +0200 | 
| commit | 34a9ac8e4497e776c3ca499c41ef749f4edf8383 (patch) | |
| tree | f2b2e34d830d66d23ae19909c71771e810c262d0 /cbits | |
| parent | 85593969debadbf11ad3c159de71e7b480ca367c (diff) | |
Refactor C interface to pass operation as enum
This is hmatrix style, less proliferation of functions as the number of
ops increases
Diffstat (limited to 'cbits')
| -rw-r--r-- | cbits/arith.c | 114 | ||||
| -rw-r--r-- | cbits/arith_lists.h | 10 | 
2 files changed, 118 insertions, 6 deletions
diff --git a/cbits/arith.c b/cbits/arith.c index 002910c..65cdb41 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -1,28 +1,42 @@ +#include <stdio.h>  #include <stdint.h>  #include <stdlib.h>  #include <stdbool.h>  #include <string.h>  #include <math.h> +// These are the wrapper macros used in arith_lists.h. Preset them to empty to +// avoid having to touch macros unrelated to the particular operation set below. +#define LIST_BINOP(name, id, hsop) +#define LIST_UNOP(name, id, _) +#define LIST_REDOP(name, id, _) + + +// Shorter names, due to CPP used both in function names and in C types.  typedef int32_t i32;  typedef int64_t i64; + +/***************************************************************************** + *                             Kernel functions                              * + *****************************************************************************/ +  #define COMM_OP(name, op, typ) \ -  void oxarop_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \ +  static void oxarop_op_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \      for (i64 i = 0; i < n; i++) out[i] = x op y[i]; \    } \ -  void oxarop_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \ +  static void oxarop_op_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \      for (i64 i = 0; i < n; i++) out[i] = x[i] op y[i]; \    }  #define NONCOMM_OP(name, op, typ) \    COMM_OP(name, op, typ) \ -  void oxarop_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \ +  static void oxarop_op_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \      for (i64 i = 0; i < n; i++) out[i] = x[i] op y; \    }  #define UNARY_OP(name, op, typ) \ -  void oxarop_ ## name ## _ ## typ(i64 n, typ *out, const typ *x) { \ +  static void oxarop_op_ ## name ## _ ## typ(i64 n, typ *out, const typ *x) { \      for (i64 i = 0; i < n; i++) out[i] = op(x[i]); \    } @@ -68,7 +82,7 @@ typedef int64_t i64;  // Reduces along the innermost dimension.  // 'out' will be filled densely in linearisation order.  #define REDUCE1_OP(name, op, typ) \ -  void oxarop_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \ +  static void oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \      if (strides[rank - 1] == 1) { \        TARRAY_WALK_NOINNER(again1, rank, shape, strides, { \          typ accum = arr[arrlinidx]; \ @@ -88,6 +102,91 @@ typedef int64_t i64;      } \    } + +/***************************************************************************** + *                           Entry point functions                           * + *****************************************************************************/ + +__attribute__((noreturn, cold)) +static void wrong_op(const char *name, int tag) { +  fprintf(stderr, "ox-arrays: Invalid operation tag passed to %s C code: %d\n", name, tag); +  abort(); +} + +enum binop_tag_t { +#undef LIST_BINOP +#define LIST_BINOP(name, id, hsop) name = id, +#include "arith_lists.h" +#undef LIST_BINOP +#define LIST_BINOP(name, id, hsop) +}; + +#define ENTRY_BINARY_OPS(typ) \ +  void oxarop_binary_ ## typ ## _sv(enum binop_tag_t tag, i64 n, typ *out, typ x, const typ *y) { \ +    switch (tag) { \ +      case BO_ADD: oxarop_op_add_ ## typ ## _sv(n, out, x, y); break; \ +      case BO_SUB: oxarop_op_sub_ ## typ ## _sv(n, out, x, y); break; \ +      case BO_MUL: oxarop_op_mul_ ## typ ## _sv(n, out, x, y); break; \ +      default: wrong_op("binary_sv", tag); \ +    } \ +  } \ +  void oxarop_binary_ ## typ ## _vs(enum binop_tag_t tag, i64 n, typ *out, const typ *x, typ y) { \ +    switch (tag) { \ +      case BO_ADD: oxarop_op_add_ ## typ ## _sv(n, out, y, x); break; \ +      case BO_SUB: oxarop_op_sub_ ## typ ## _vs(n, out, x, y); break; \ +      case BO_MUL: oxarop_op_mul_ ## typ ## _sv(n, out, y, x); break; \ +      default: wrong_op("binary_vs", tag); \ +    } \ +  } \ +  void oxarop_binary_ ## typ ## _vv(enum binop_tag_t tag, i64 n, typ *out, const typ *x, const typ *y) { \ +    switch (tag) { \ +      case BO_ADD: oxarop_op_add_ ## typ ## _vv(n, out, x, y); break; \ +      case BO_SUB: oxarop_op_sub_ ## typ ## _vv(n, out, x, y); break; \ +      case BO_MUL: oxarop_op_mul_ ## typ ## _vv(n, out, x, y); break; \ +      default: wrong_op("binary_vv", tag); \ +    } \ +  } + +enum unop_tag_t { +#undef LIST_UNOP +#define LIST_UNOP(name, id, _) name = id, +#include "arith_lists.h" +#undef LIST_UNOP +#define LIST_UNOP(name, id, _) +}; + +#define ENTRY_UNARY_OPS(typ) \ +  void oxarop_unary_ ## typ(enum unop_tag_t tag, i64 n, typ *out, const typ *x) { \ +    switch (tag) { \ +      case UO_NEG: oxarop_op_neg_ ## typ(n, out, x); break; \ +      case UO_ABS: oxarop_op_abs_ ## typ(n, out, x); break; \ +      case UO_SIGNUM: oxarop_op_signum_ ## typ(n, out, x); break; \ +      default: wrong_op("unary", tag); \ +    } \ +  } + +enum redop_tag_t { +#undef LIST_REDOP +#define LIST_REDOP(name, id, _) name = id, +#include "arith_lists.h" +#undef LIST_REDOP +#define LIST_REDOP(name, id, _) +}; + +#define ENTRY_REDUCE_OPS(typ) \ +  void oxarop_reduce_ ## typ(enum redop_tag_t tag, i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \ +    switch (tag) { \ +      case RO_SUM1: oxarop_op_sum1_ ## typ(rank, shape, strides, out, arr); break; \ +      case RO_PRODUCT1: oxarop_op_product1_ ## typ(rank, shape, strides, out, arr); break; \ +      default: wrong_op("reduce", tag); \ +    } \ +  } + + +/***************************************************************************** + *                        Generate all the functions                         * + *****************************************************************************/ +  #define NUM_TYPES_LOOP_XLIST \    X(i32) X(i64) X(double) X(float) @@ -99,6 +198,9 @@ typedef int64_t i64;    UNARY_OP(abs, GEN_ABS, typ) \    UNARY_OP(signum, GEN_SIGNUM, typ) \    REDUCE1_OP(sum1, +, typ) \ -  REDUCE1_OP(product1, *, typ) +  REDUCE1_OP(product1, *, typ) \ +  ENTRY_BINARY_OPS(typ) \ +  ENTRY_UNARY_OPS(typ) \ +  ENTRY_REDUCE_OPS(typ)  NUM_TYPES_LOOP_XLIST  #undef X diff --git a/cbits/arith_lists.h b/cbits/arith_lists.h new file mode 100644 index 0000000..c7495e8 --- /dev/null +++ b/cbits/arith_lists.h @@ -0,0 +1,10 @@ +LIST_BINOP(BO_ADD, 1, +) +LIST_BINOP(BO_SUB, 2, -) +LIST_BINOP(BO_MUL, 3, *) + +LIST_UNOP(UO_NEG, 1,) +LIST_UNOP(UO_ABS, 2,) +LIST_UNOP(UO_SIGNUM, 3,) + +LIST_REDOP(RO_SUM1, 1,) +LIST_REDOP(RO_PRODUCT1, 2,)  | 
