aboutsummaryrefslogtreecommitdiff
path: root/cbits/arith.c
diff options
context:
space:
mode:
Diffstat (limited to 'cbits/arith.c')
-rw-r--r--cbits/arith.c92
1 files changed, 43 insertions, 49 deletions
diff --git a/cbits/arith.c b/cbits/arith.c
index 6ea197d..8d0700d 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -200,11 +200,6 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
for (i64 i = 0; i < n; i++) out[i] = op(x[i], y); \
}
-#define UNARY_OP(name, op, typ) \
- static void oxarop_op_ ## name ## _ ## typ(i64 n, typ *out, const typ *x) { \
- for (i64 i = 0; i < n; i++) out[i] = op(x[i]); \
- }
-
#define UNARY_OP_STRIDED(name, op, typ) \
static void oxarop_op_ ## name ## _ ## typ ## _strided(i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *arr) { \
/* fprintf(stderr, "oxarop_op_" #name "_" #typ "_strided: rank=%ld shape=", rank); \
@@ -451,29 +446,29 @@ enum funop_tag_t {
#define LIST_FUNOP(name, id, _)
};
-#define ENTRY_FUNARY_OPS(typ) \
- void oxarop_funary_ ## typ(enum funop_tag_t tag, i64 n, typ *out, const typ *x) { \
+#define ENTRY_FUNARY_STRIDED_OPS(typ) \
+ void oxarop_funary_ ## typ ## _strided(enum funop_tag_t tag, i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *x) { \
switch (tag) { \
- case FU_RECIP: oxarop_op_recip_ ## typ(n, out, x); break; \
- case FU_EXP: oxarop_op_exp_ ## typ(n, out, x); break; \
- case FU_LOG: oxarop_op_log_ ## typ(n, out, x); break; \
- case FU_SQRT: oxarop_op_sqrt_ ## typ(n, out, x); break; \
- case FU_SIN: oxarop_op_sin_ ## typ(n, out, x); break; \
- case FU_COS: oxarop_op_cos_ ## typ(n, out, x); break; \
- case FU_TAN: oxarop_op_tan_ ## typ(n, out, x); break; \
- case FU_ASIN: oxarop_op_asin_ ## typ(n, out, x); break; \
- case FU_ACOS: oxarop_op_acos_ ## typ(n, out, x); break; \
- case FU_ATAN: oxarop_op_atan_ ## typ(n, out, x); break; \
- case FU_SINH: oxarop_op_sinh_ ## typ(n, out, x); break; \
- case FU_COSH: oxarop_op_cosh_ ## typ(n, out, x); break; \
- case FU_TANH: oxarop_op_tanh_ ## typ(n, out, x); break; \
- case FU_ASINH: oxarop_op_asinh_ ## typ(n, out, x); break; \
- case FU_ACOSH: oxarop_op_acosh_ ## typ(n, out, x); break; \
- case FU_ATANH: oxarop_op_atanh_ ## typ(n, out, x); break; \
- case FU_LOG1P: oxarop_op_log1p_ ## typ(n, out, x); break; \
- case FU_EXPM1: oxarop_op_expm1_ ## typ(n, out, x); break; \
- case FU_LOG1PEXP: oxarop_op_log1pexp_ ## typ(n, out, x); break; \
- case FU_LOG1MEXP: oxarop_op_log1mexp_ ## typ(n, out, x); break; \
+ case FU_RECIP: oxarop_op_recip_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_EXP: oxarop_op_exp_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_LOG: oxarop_op_log_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_SQRT: oxarop_op_sqrt_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_SIN: oxarop_op_sin_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_COS: oxarop_op_cos_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_TAN: oxarop_op_tan_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_ASIN: oxarop_op_asin_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_ACOS: oxarop_op_acos_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_ATAN: oxarop_op_atan_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_SINH: oxarop_op_sinh_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_COSH: oxarop_op_cosh_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_TANH: oxarop_op_tanh_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_ASINH: oxarop_op_asinh_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_ACOSH: oxarop_op_acosh_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_ATANH: oxarop_op_atanh_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_LOG1P: oxarop_op_log1p_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_EXPM1: oxarop_op_expm1_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_LOG1PEXP: oxarop_op_log1pexp_ ## typ ## _strided(rank, out, shape, strides, x); break; \
+ case FU_LOG1MEXP: oxarop_op_log1mexp_ ## typ ## _strided(rank, out, shape, strides, x); break; \
default: wrong_op("unary", tag); \
} \
}
@@ -538,29 +533,28 @@ NUM_TYPES_XLIST
NONCOMM_OP(fdiv, /, typ) \
PREFIX_BINOP(pow, GEN_POW, typ) \
PREFIX_BINOP(logbase, GEN_LOGBASE, typ) \
- /* TODO: when replaced with UNARY_OP_STRIDED, remove UNARY_OP entirely */ \
- UNARY_OP(recip, 1.0/, typ) \
- UNARY_OP(exp, GEN_EXP, typ) \
- UNARY_OP(log, GEN_LOG, typ) \
- UNARY_OP(sqrt, GEN_SQRT, typ) \
- UNARY_OP(sin, GEN_SIN, typ) \
- UNARY_OP(cos, GEN_COS, typ) \
- UNARY_OP(tan, GEN_TAN, typ) \
- UNARY_OP(asin, GEN_ASIN, typ) \
- UNARY_OP(acos, GEN_ACOS, typ) \
- UNARY_OP(atan, GEN_ATAN, typ) \
- UNARY_OP(sinh, GEN_SINH, typ) \
- UNARY_OP(cosh, GEN_COSH, typ) \
- UNARY_OP(tanh, GEN_TANH, typ) \
- UNARY_OP(asinh, GEN_ASINH, typ) \
- UNARY_OP(acosh, GEN_ACOSH, typ) \
- UNARY_OP(atanh, GEN_ATANH, typ) \
- UNARY_OP(log1p, GEN_LOG1P, typ) \
- UNARY_OP(expm1, GEN_EXPM1, typ) \
- UNARY_OP(log1pexp, GEN_LOG1PEXP, typ) \
- UNARY_OP(log1mexp, GEN_LOG1MEXP, typ) \
+ UNARY_OP_STRIDED(recip, 1.0/, typ) \
+ UNARY_OP_STRIDED(exp, GEN_EXP, typ) \
+ UNARY_OP_STRIDED(log, GEN_LOG, typ) \
+ UNARY_OP_STRIDED(sqrt, GEN_SQRT, typ) \
+ UNARY_OP_STRIDED(sin, GEN_SIN, typ) \
+ UNARY_OP_STRIDED(cos, GEN_COS, typ) \
+ UNARY_OP_STRIDED(tan, GEN_TAN, typ) \
+ UNARY_OP_STRIDED(asin, GEN_ASIN, typ) \
+ UNARY_OP_STRIDED(acos, GEN_ACOS, typ) \
+ UNARY_OP_STRIDED(atan, GEN_ATAN, typ) \
+ UNARY_OP_STRIDED(sinh, GEN_SINH, typ) \
+ UNARY_OP_STRIDED(cosh, GEN_COSH, typ) \
+ UNARY_OP_STRIDED(tanh, GEN_TANH, typ) \
+ UNARY_OP_STRIDED(asinh, GEN_ASINH, typ) \
+ UNARY_OP_STRIDED(acosh, GEN_ACOSH, typ) \
+ UNARY_OP_STRIDED(atanh, GEN_ATANH, typ) \
+ UNARY_OP_STRIDED(log1p, GEN_LOG1P, typ) \
+ UNARY_OP_STRIDED(expm1, GEN_EXPM1, typ) \
+ UNARY_OP_STRIDED(log1pexp, GEN_LOG1PEXP, typ) \
+ UNARY_OP_STRIDED(log1mexp, GEN_LOG1MEXP, typ) \
ENTRY_FBINARY_OPS(typ) \
- ENTRY_FUNARY_OPS(typ)
+ ENTRY_FUNARY_STRIDED_OPS(typ)
FLOAT_TYPES_XLIST
#undef X