aboutsummaryrefslogtreecommitdiff
path: root/cbits/arith.c
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-26 00:11:00 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-26 00:11:00 +0200
commit34a9ac8e4497e776c3ca499c41ef749f4edf8383 (patch)
treef2b2e34d830d66d23ae19909c71771e810c262d0 /cbits/arith.c
parent85593969debadbf11ad3c159de71e7b480ca367c (diff)
Refactor C interface to pass operation as enum
This is hmatrix style, less proliferation of functions as the number of ops increases
Diffstat (limited to 'cbits/arith.c')
-rw-r--r--cbits/arith.c114
1 files changed, 108 insertions, 6 deletions
diff --git a/cbits/arith.c b/cbits/arith.c
index 002910c..65cdb41 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -1,28 +1,42 @@
+#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <stdbool.h>
#include <string.h>
#include <math.h>
+// These are the wrapper macros used in arith_lists.h. Preset them to empty to
+// avoid having to touch macros unrelated to the particular operation set below.
+#define LIST_BINOP(name, id, hsop)
+#define LIST_UNOP(name, id, _)
+#define LIST_REDOP(name, id, _)
+
+
+// Shorter names, due to CPP used both in function names and in C types.
typedef int32_t i32;
typedef int64_t i64;
+
+/*****************************************************************************
+ * Kernel functions *
+ *****************************************************************************/
+
#define COMM_OP(name, op, typ) \
- void oxarop_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \
+ static void oxarop_op_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \
for (i64 i = 0; i < n; i++) out[i] = x op y[i]; \
} \
- void oxarop_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \
+ static void oxarop_op_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \
for (i64 i = 0; i < n; i++) out[i] = x[i] op y[i]; \
}
#define NONCOMM_OP(name, op, typ) \
COMM_OP(name, op, typ) \
- void oxarop_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \
+ static void oxarop_op_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \
for (i64 i = 0; i < n; i++) out[i] = x[i] op y; \
}
#define UNARY_OP(name, op, typ) \
- void oxarop_ ## name ## _ ## typ(i64 n, typ *out, const typ *x) { \
+ static void oxarop_op_ ## name ## _ ## typ(i64 n, typ *out, const typ *x) { \
for (i64 i = 0; i < n; i++) out[i] = op(x[i]); \
}
@@ -68,7 +82,7 @@ typedef int64_t i64;
// Reduces along the innermost dimension.
// 'out' will be filled densely in linearisation order.
#define REDUCE1_OP(name, op, typ) \
- void oxarop_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \
+ static void oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \
if (strides[rank - 1] == 1) { \
TARRAY_WALK_NOINNER(again1, rank, shape, strides, { \
typ accum = arr[arrlinidx]; \
@@ -88,6 +102,91 @@ typedef int64_t i64;
} \
}
+
+/*****************************************************************************
+ * Entry point functions *
+ *****************************************************************************/
+
+__attribute__((noreturn, cold))
+static void wrong_op(const char *name, int tag) {
+ fprintf(stderr, "ox-arrays: Invalid operation tag passed to %s C code: %d\n", name, tag);
+ abort();
+}
+
+enum binop_tag_t {
+#undef LIST_BINOP
+#define LIST_BINOP(name, id, hsop) name = id,
+#include "arith_lists.h"
+#undef LIST_BINOP
+#define LIST_BINOP(name, id, hsop)
+};
+
+#define ENTRY_BINARY_OPS(typ) \
+ void oxarop_binary_ ## typ ## _sv(enum binop_tag_t tag, i64 n, typ *out, typ x, const typ *y) { \
+ switch (tag) { \
+ case BO_ADD: oxarop_op_add_ ## typ ## _sv(n, out, x, y); break; \
+ case BO_SUB: oxarop_op_sub_ ## typ ## _sv(n, out, x, y); break; \
+ case BO_MUL: oxarop_op_mul_ ## typ ## _sv(n, out, x, y); break; \
+ default: wrong_op("binary_sv", tag); \
+ } \
+ } \
+ void oxarop_binary_ ## typ ## _vs(enum binop_tag_t tag, i64 n, typ *out, const typ *x, typ y) { \
+ switch (tag) { \
+ case BO_ADD: oxarop_op_add_ ## typ ## _sv(n, out, y, x); break; \
+ case BO_SUB: oxarop_op_sub_ ## typ ## _vs(n, out, x, y); break; \
+ case BO_MUL: oxarop_op_mul_ ## typ ## _sv(n, out, y, x); break; \
+ default: wrong_op("binary_vs", tag); \
+ } \
+ } \
+ void oxarop_binary_ ## typ ## _vv(enum binop_tag_t tag, i64 n, typ *out, const typ *x, const typ *y) { \
+ switch (tag) { \
+ case BO_ADD: oxarop_op_add_ ## typ ## _vv(n, out, x, y); break; \
+ case BO_SUB: oxarop_op_sub_ ## typ ## _vv(n, out, x, y); break; \
+ case BO_MUL: oxarop_op_mul_ ## typ ## _vv(n, out, x, y); break; \
+ default: wrong_op("binary_vv", tag); \
+ } \
+ }
+
+enum unop_tag_t {
+#undef LIST_UNOP
+#define LIST_UNOP(name, id, _) name = id,
+#include "arith_lists.h"
+#undef LIST_UNOP
+#define LIST_UNOP(name, id, _)
+};
+
+#define ENTRY_UNARY_OPS(typ) \
+ void oxarop_unary_ ## typ(enum unop_tag_t tag, i64 n, typ *out, const typ *x) { \
+ switch (tag) { \
+ case UO_NEG: oxarop_op_neg_ ## typ(n, out, x); break; \
+ case UO_ABS: oxarop_op_abs_ ## typ(n, out, x); break; \
+ case UO_SIGNUM: oxarop_op_signum_ ## typ(n, out, x); break; \
+ default: wrong_op("unary", tag); \
+ } \
+ }
+
+enum redop_tag_t {
+#undef LIST_REDOP
+#define LIST_REDOP(name, id, _) name = id,
+#include "arith_lists.h"
+#undef LIST_REDOP
+#define LIST_REDOP(name, id, _)
+};
+
+#define ENTRY_REDUCE_OPS(typ) \
+ void oxarop_reduce_ ## typ(enum redop_tag_t tag, i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \
+ switch (tag) { \
+ case RO_SUM1: oxarop_op_sum1_ ## typ(rank, shape, strides, out, arr); break; \
+ case RO_PRODUCT1: oxarop_op_product1_ ## typ(rank, shape, strides, out, arr); break; \
+ default: wrong_op("reduce", tag); \
+ } \
+ }
+
+
+/*****************************************************************************
+ * Generate all the functions *
+ *****************************************************************************/
+
#define NUM_TYPES_LOOP_XLIST \
X(i32) X(i64) X(double) X(float)
@@ -99,6 +198,9 @@ typedef int64_t i64;
UNARY_OP(abs, GEN_ABS, typ) \
UNARY_OP(signum, GEN_SIGNUM, typ) \
REDUCE1_OP(sum1, +, typ) \
- REDUCE1_OP(product1, *, typ)
+ REDUCE1_OP(product1, *, typ) \
+ ENTRY_BINARY_OPS(typ) \
+ ENTRY_UNARY_OPS(typ) \
+ ENTRY_REDUCE_OPS(typ)
NUM_TYPES_LOOP_XLIST
#undef X