#include #include #include #include #include #include // These are the wrapper macros used in arith_lists.h. Preset them to empty to // avoid having to touch macros unrelated to the particular operation set below. #define LIST_BINOP(name, id, hsop) #define LIST_FBINOP(name, id, hsop) #define LIST_UNOP(name, id, _) #define LIST_FUNOP(name, id, _) #define LIST_REDOP(name, id, _) // Shorter names, due to CPP used both in function names and in C types. typedef int32_t i32; typedef int64_t i64; /***************************************************************************** * Kernel functions * *****************************************************************************/ #define COMM_OP(name, op, typ) \ static void oxarop_op_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \ for (i64 i = 0; i < n; i++) out[i] = x op y[i]; \ } \ static void oxarop_op_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \ for (i64 i = 0; i < n; i++) out[i] = x[i] op y[i]; \ } #define NONCOMM_OP(name, op, typ) \ COMM_OP(name, op, typ) \ static void oxarop_op_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \ for (i64 i = 0; i < n; i++) out[i] = x[i] op y; \ } #define UNARY_OP(name, op, typ) \ static void oxarop_op_ ## name ## _ ## typ(i64 n, typ *out, const typ *x) { \ for (i64 i = 0; i < n; i++) out[i] = op(x[i]); \ } #define GEN_ABS(x) \ _Generic((x), \ int: abs, \ long: labs, \ long long: llabs, \ float: fabsf, \ double: fabs)(x) // This does not result in multiple loads with GCC 13. #define GEN_SIGNUM(x) ((x) < 0 ? -1 : (x) > 0 ? 1 : 0) // Walk a orthotope-style strided array, except for the inner dimension. The // body is run for every "inner vector". #define TARRAY_WALK_NOINNER(again_label_name, rank, shape, strides, body) \ do { \ i64 idx[(rank) - 1]; \ memset(idx, 0, ((rank) - 1) * sizeof(idx[0])); \ i64 arrlinidx = 0; \ i64 outlinidx = 0; \ again_label_name: \ { \ body \ } \ for (i64 dim = (rank) - 2; dim >= 0; dim--) { \ if (++idx[dim] < (shape)[dim]) { \ arrlinidx += (strides)[dim]; \ outlinidx++; \ goto again_label_name; \ } \ arrlinidx -= (idx[dim] - 1) * (strides)[dim]; \ idx[dim] = 0; \ } \ } while (false) // preconditions: // - all strides are >0 // - shape is everywhere >0 // - rank is >= 1 // - out has capacity for (shape[0] * ... * shape[rank - 2]) elements // Reduces along the innermost dimension. // 'out' will be filled densely in linearisation order. #define REDUCE1_OP(name, op, typ) \ static void oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \ if (strides[rank - 1] == 1) { \ TARRAY_WALK_NOINNER(again1, rank, shape, strides, { \ typ accum = arr[arrlinidx]; \ for (i64 i = 1; i < shape[rank - 1]; i++) { \ accum = accum op arr[arrlinidx + i]; \ } \ out[outlinidx] = accum; \ }); \ } else { \ TARRAY_WALK_NOINNER(again2, rank, shape, strides, { \ typ accum = arr[arrlinidx]; \ for (i64 i = 1; i < shape[rank - 1]; i++) { \ accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \ } \ out[outlinidx] = accum; \ }); \ } \ } /***************************************************************************** * Entry point functions * *****************************************************************************/ __attribute__((noreturn, cold)) static void wrong_op(const char *name, int tag) { fprintf(stderr, "ox-arrays: Invalid operation tag passed to %s C code: %d\n", name, tag); abort(); } enum binop_tag_t { #undef LIST_BINOP #define LIST_BINOP(name, id, hsop) name = id, #include "arith_lists.h" #undef LIST_BINOP #define LIST_BINOP(name, id, hsop) }; #define ENTRY_BINARY_OPS(typ) \ void oxarop_binary_ ## typ ## _sv(enum binop_tag_t tag, i64 n, typ *out, typ x, const typ *y) { \ switch (tag) { \ case BO_ADD: oxarop_op_add_ ## typ ## _sv(n, out, x, y); break; \ case BO_SUB: oxarop_op_sub_ ## typ ## _sv(n, out, x, y); break; \ case BO_MUL: oxarop_op_mul_ ## typ ## _sv(n, out, x, y); break; \ default: wrong_op("binary_sv", tag); \ } \ } \ void oxarop_binary_ ## typ ## _vs(enum binop_tag_t tag, i64 n, typ *out, const typ *x, typ y) { \ switch (tag) { \ case BO_ADD: oxarop_op_add_ ## typ ## _sv(n, out, y, x); break; \ case BO_SUB: oxarop_op_sub_ ## typ ## _vs(n, out, x, y); break; \ case BO_MUL: oxarop_op_mul_ ## typ ## _sv(n, out, y, x); break; \ default: wrong_op("binary_vs", tag); \ } \ } \ void oxarop_binary_ ## typ ## _vv(enum binop_tag_t tag, i64 n, typ *out, const typ *x, const typ *y) { \ switch (tag) { \ case BO_ADD: oxarop_op_add_ ## typ ## _vv(n, out, x, y); break; \ case BO_SUB: oxarop_op_sub_ ## typ ## _vv(n, out, x, y); break; \ case BO_MUL: oxarop_op_mul_ ## typ ## _vv(n, out, x, y); break; \ default: wrong_op("binary_vv", tag); \ } \ } enum fbinop_tag_t { #undef LIST_FBINOP #define LIST_FBINOP(name, id, hsop) name = id, #include "arith_lists.h" #undef LIST_FBINOP #define LIST_FBINOP(name, id, hsop) }; #define ENTRY_FBINARY_OPS(typ) \ void oxarop_fbinary_ ## typ ## _sv(enum binop_tag_t tag, i64 n, typ *out, typ x, const typ *y) { \ switch (tag) { \ case FB_DIV: oxarop_op_fdiv_ ## typ ## _sv(n, out, x, y); break; \ default: wrong_op("binary_sv", tag); \ } \ } \ void oxarop_fbinary_ ## typ ## _vs(enum binop_tag_t tag, i64 n, typ *out, const typ *x, typ y) { \ switch (tag) { \ case FB_DIV: oxarop_op_fdiv_ ## typ ## _vs(n, out, x, y); break; \ default: wrong_op("binary_vs", tag); \ } \ } \ void oxarop_fbinary_ ## typ ## _vv(enum binop_tag_t tag, i64 n, typ *out, const typ *x, const typ *y) { \ switch (tag) { \ case FB_DIV: oxarop_op_fdiv_ ## typ ## _vv(n, out, x, y); break; \ default: wrong_op("binary_vv", tag); \ } \ } enum unop_tag_t { #undef LIST_UNOP #define LIST_UNOP(name, id, _) name = id, #include "arith_lists.h" #undef LIST_UNOP #define LIST_UNOP(name, id, _) }; #define ENTRY_UNARY_OPS(typ) \ void oxarop_unary_ ## typ(enum unop_tag_t tag, i64 n, typ *out, const typ *x) { \ switch (tag) { \ case UO_NEG: oxarop_op_neg_ ## typ(n, out, x); break; \ case UO_ABS: oxarop_op_abs_ ## typ(n, out, x); break; \ case UO_SIGNUM: oxarop_op_signum_ ## typ(n, out, x); break; \ default: wrong_op("unary", tag); \ } \ } enum funop_tag_t { #undef LIST_FUNOP #define LIST_FUNOP(name, id, _) name = id, #include "arith_lists.h" #undef LIST_FUNOP #define LIST_FUNOP(name, id, _) }; #define ENTRY_FUNARY_OPS(typ) \ void oxarop_funary_ ## typ(enum unop_tag_t tag, i64 n, typ *out, const typ *x) { \ switch (tag) { \ case FU_RECIP: oxarop_op_recip_ ## typ(n, out, x); break; \ default: wrong_op("unary", tag); \ } \ } enum redop_tag_t { #undef LIST_REDOP #define LIST_REDOP(name, id, _) name = id, #include "arith_lists.h" #undef LIST_REDOP #define LIST_REDOP(name, id, _) }; #define ENTRY_REDUCE_OPS(typ) \ void oxarop_reduce_ ## typ(enum redop_tag_t tag, i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \ switch (tag) { \ case RO_SUM1: oxarop_op_sum1_ ## typ(rank, shape, strides, out, arr); break; \ case RO_PRODUCT1: oxarop_op_product1_ ## typ(rank, shape, strides, out, arr); break; \ default: wrong_op("reduce", tag); \ } \ } /***************************************************************************** * Generate all the functions * *****************************************************************************/ #define FLOAT_TYPES_XLIST X(double) X(float) #define NUM_TYPES_XLIST X(i32) X(i64) FLOAT_TYPES_XLIST #define X(typ) \ COMM_OP(add, +, typ) \ NONCOMM_OP(sub, -, typ) \ COMM_OP(mul, *, typ) \ UNARY_OP(neg, -, typ) \ UNARY_OP(abs, GEN_ABS, typ) \ UNARY_OP(signum, GEN_SIGNUM, typ) \ REDUCE1_OP(sum1, +, typ) \ REDUCE1_OP(product1, *, typ) \ ENTRY_BINARY_OPS(typ) \ ENTRY_UNARY_OPS(typ) \ ENTRY_REDUCE_OPS(typ) NUM_TYPES_XLIST #undef X #define X(typ) \ NONCOMM_OP(fdiv, /, typ) \ UNARY_OP(recip, 1.0/, typ) \ ENTRY_FBINARY_OPS(typ) \ ENTRY_FUNARY_OPS(typ) FLOAT_TYPES_XLIST #undef X