diff options
Diffstat (limited to 'cbits/arith.c')
-rw-r--r-- | cbits/arith.c | 103 |
1 files changed, 76 insertions, 27 deletions
diff --git a/cbits/arith.c b/cbits/arith.c index b9c86ab..f08e456 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -5,6 +5,7 @@ #include <stdio.h> #include <stdint.h> +#include <inttypes.h> #include <stdlib.h> #include <stdbool.h> #include <string.h> @@ -89,38 +90,23 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); } /***************************************************************************** - * Kernel functions * + * Helper functions * *****************************************************************************/ -#define COMM_OP(name, op, typ) \ - static void oxarop_op_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \ - for (i64 i = 0; i < n; i++) out[i] = x op y[i]; \ - } \ - static void oxarop_op_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \ - for (i64 i = 0; i < n; i++) out[i] = x[i] op y[i]; \ - } - -#define NONCOMM_OP(name, op, typ) \ - COMM_OP(name, op, typ) \ - static void oxarop_op_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \ - for (i64 i = 0; i < n; i++) out[i] = x[i] op y; \ +__attribute__((used)) +static void print_shape(FILE *stream, i64 rank, const i64 *shape) { + fputc('[', stream); + for (i64 i = 0; i < rank; i++) { + if (i != 0) fputc(',', stream); + fprintf(stream, "%" PRIi64, shape[i]); } + fputc(']', stream); +} -#define PREFIX_BINOP(name, op, typ) \ - static void oxarop_op_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \ - for (i64 i = 0; i < n; i++) out[i] = op(x, y[i]); \ - } \ - static void oxarop_op_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \ - for (i64 i = 0; i < n; i++) out[i] = op(x[i], y[i]); \ - } \ - static void oxarop_op_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \ - for (i64 i = 0; i < n; i++) out[i] = op(x[i], y); \ - } -#define UNARY_OP(name, op, typ) \ - static void oxarop_op_ ## name ## _ ## typ(i64 n, typ *out, const typ *x) { \ - for (i64 i = 0; i < n; i++) out[i] = op(x[i]); \ - } +/***************************************************************************** + * Skeletons * + *****************************************************************************/ // Walk a orthotope-style strided array, except for the inner dimension. The // body is run for every "inner vector". @@ -184,6 +170,55 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); } } \ } while (false) + +/***************************************************************************** + * Kernel functions * + *****************************************************************************/ + +#define COMM_OP(name, op, typ) \ + static void oxarop_op_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \ + for (i64 i = 0; i < n; i++) out[i] = x op y[i]; \ + } \ + static void oxarop_op_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \ + for (i64 i = 0; i < n; i++) out[i] = x[i] op y[i]; \ + } + +#define NONCOMM_OP(name, op, typ) \ + COMM_OP(name, op, typ) \ + static void oxarop_op_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \ + for (i64 i = 0; i < n; i++) out[i] = x[i] op y; \ + } + +#define PREFIX_BINOP(name, op, typ) \ + static void oxarop_op_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \ + for (i64 i = 0; i < n; i++) out[i] = op(x, y[i]); \ + } \ + static void oxarop_op_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \ + for (i64 i = 0; i < n; i++) out[i] = op(x[i], y[i]); \ + } \ + static void oxarop_op_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \ + for (i64 i = 0; i < n; i++) out[i] = op(x[i], y); \ + } + +#define UNARY_OP(name, op, typ) \ + static void oxarop_op_ ## name ## _ ## typ(i64 n, typ *out, const typ *x) { \ + for (i64 i = 0; i < n; i++) out[i] = op(x[i]); \ + } + +#define UNARY_OP_STRIDED(name, op, typ) \ + static void oxarop_op_ ## name ## _ ## typ ## _strided(i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *arr) { \ + /* fprintf(stderr, "oxarop_op_" #name "_" #typ "_strided: rank=%ld shape=", rank); \ + print_shape(stderr, rank, shape); \ + fprintf(stderr, " strides="); \ + print_shape(stderr, rank, strides); \ + fprintf(stderr, "\n"); */ \ + TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ + for (i64 i = 0; i < shape[rank - 1]; i++) { \ + out[outlinidx * shape[rank - 1] + i] = op(arr[arrlinidx + strides[rank - 1] * i]); \ + } \ + }); \ + } + // preconditions: // - all strides are >0 // - shape is everywhere >0 @@ -408,6 +443,16 @@ enum unop_tag_t { } \ } +#define ENTRY_UNARY_STRIDED_OPS(typ) \ + void oxarop_unary_ ## typ ## _strided(enum unop_tag_t tag, i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *x) { \ + switch (tag) { \ + case UO_NEG: oxarop_op_neg_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + case UO_ABS: oxarop_op_abs_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + case UO_SIGNUM: oxarop_op_signum_ ## typ ## _strided(rank, out, shape, strides, x); break; \ + default: wrong_op("unary_strided", tag); \ + } \ + } + enum funop_tag_t { #undef LIST_FUNOP #define LIST_FUNOP(name, id, _) name = id, @@ -484,12 +529,16 @@ enum redop_tag_t { UNARY_OP(neg, -, typ) \ UNARY_OP(abs, GEN_ABS, typ) \ UNARY_OP(signum, GEN_SIGNUM, typ) \ + UNARY_OP_STRIDED(neg, -, typ) \ + UNARY_OP_STRIDED(abs, GEN_ABS, typ) \ + UNARY_OP_STRIDED(signum, GEN_SIGNUM, typ) \ REDUCE1_OP(sum1, +, typ) \ REDUCE1_OP(product1, *, typ) \ REDUCEFULL_OP(sumfull, +, typ) \ REDUCEFULL_OP(productfull, *, typ) \ ENTRY_BINARY_OPS(typ) \ ENTRY_UNARY_OPS(typ) \ + ENTRY_UNARY_STRIDED_OPS(typ) \ ENTRY_REDUCE1_OPS(typ) \ ENTRY_REDUCEFULL_OPS(typ) \ EXTREMUM_OP(min, <, typ) \ |