aboutsummaryrefslogtreecommitdiff
path: root/cbits/arith.c
diff options
context:
space:
mode:
Diffstat (limited to 'cbits/arith.c')
-rw-r--r--cbits/arith.c130
1 files changed, 118 insertions, 12 deletions
diff --git a/cbits/arith.c b/cbits/arith.c
index a71c1b9..e20578b 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -18,6 +18,66 @@
typedef int32_t i32;
typedef int64_t i64;
+/*****************************************************************************
+ * Additional math functions *
+ *****************************************************************************/
+
+#define GEN_ABS(x) \
+ _Generic((x), \
+ int: abs, \
+ long: labs, \
+ long long: llabs, \
+ float: fabsf, \
+ double: fabs)(x)
+
+// This does not result in multiple loads with GCC 13.
+#define GEN_SIGNUM(x) ((x) < 0 ? -1 : (x) > 0 ? 1 : 0)
+
+#define GEN_POW(x, y) _Generic((x), float: powf, double: pow)(x, y)
+#define GEN_LOGBASE(x, y) _Generic((x), float: logf(y) / logf(x), double: log(y) / log(x))
+#define GEN_EXP(x) _Generic((x), float: expf, double: exp)(x)
+#define GEN_LOG(x) _Generic((x), float: logf, double: log)(x)
+#define GEN_SQRT(x) _Generic((x), float: sqrtf, double: sqrt)(x)
+#define GEN_SIN(x) _Generic((x), float: sinf, double: sin)(x)
+#define GEN_COS(x) _Generic((x), float: cosf, double: cos)(x)
+#define GEN_TAN(x) _Generic((x), float: tanf, double: tan)(x)
+#define GEN_ASIN(x) _Generic((x), float: asinf, double: asin)(x)
+#define GEN_ACOS(x) _Generic((x), float: acosf, double: acos)(x)
+#define GEN_ATAN(x) _Generic((x), float: atanf, double: atan)(x)
+#define GEN_SINH(x) _Generic((x), float: sinhf, double: sinh)(x)
+#define GEN_COSH(x) _Generic((x), float: coshf, double: cosh)(x)
+#define GEN_TANH(x) _Generic((x), float: tanhf, double: tanh)(x)
+#define GEN_ASINH(x) _Generic((x), float: asinhf, double: asinh)(x)
+#define GEN_ACOSH(x) _Generic((x), float: acoshf, double: acosh)(x)
+#define GEN_ATANH(x) _Generic((x), float: atanhf, double: atanh)(x)
+#define GEN_LOG1P(x) _Generic((x), float: log1pf, double: log1p)(x)
+#define GEN_EXPM1(x) _Generic((x), float: expm1f, double: expm1)(x)
+
+// Taken from Haskell's implementation:
+// https://hackage.haskell.org/package/ghc-internal-9.1001.0/docs/src//GHC.Internal.Float.html#log1mexpOrd
+#define LOG1MEXP_IMPL(x) do { \
+ if (x > _Generic((x), float: logf, double: log)(2)) return GEN_LOG(-GEN_EXPM1(x)); \
+ else return GEN_LOG1P(-GEN_EXP(x)); \
+ } while (0)
+
+static float log1mexp_float(float x) { LOG1MEXP_IMPL(x); }
+static double log1mexp_double(double x) { LOG1MEXP_IMPL(x); }
+
+#define GEN_LOG1MEXP(x) _Generic((x), float: log1mexp_float, double: log1mexp_double)(x)
+
+// Taken from Haskell's implementation:
+// https://hackage.haskell.org/package/ghc-internal-9.1001.0/docs/src//GHC.Internal.Float.html#line-595
+#define LOG1PEXP_IMPL(x) do { \
+ if (x <= 18) return GEN_LOG1P(GEN_EXP(x)); \
+ if (x <= 100) return x + GEN_EXP(-x); \
+ return x; \
+ } while (0)
+
+static float log1pexp_float(float x) { LOG1PEXP_IMPL(x); }
+static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }
+
+#define GEN_LOG1PEXP(x) _Generic((x), float: log1pexp_float, double: log1pexp_double)(x)
+
/*****************************************************************************
* Kernel functions *
@@ -37,22 +97,22 @@ typedef int64_t i64;
for (i64 i = 0; i < n; i++) out[i] = x[i] op y; \
}
+#define PREFIX_BINOP(name, op, typ) \
+ static void oxarop_op_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \
+ for (i64 i = 0; i < n; i++) out[i] = op(x, y[i]); \
+ } \
+ static void oxarop_op_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \
+ for (i64 i = 0; i < n; i++) out[i] = op(x[i], y[i]); \
+ } \
+ static void oxarop_op_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \
+ for (i64 i = 0; i < n; i++) out[i] = op(x[i], y); \
+ }
+
#define UNARY_OP(name, op, typ) \
static void oxarop_op_ ## name ## _ ## typ(i64 n, typ *out, const typ *x) { \
for (i64 i = 0; i < n; i++) out[i] = op(x[i]); \
}
-#define GEN_ABS(x) \
- _Generic((x), \
- int: abs, \
- long: labs, \
- long long: llabs, \
- float: fabsf, \
- double: fabs)(x)
-
-// This does not result in multiple loads with GCC 13.
-#define GEN_SIGNUM(x) ((x) < 0 ? -1 : (x) > 0 ? 1 : 0)
-
// Walk a orthotope-style strided array, except for the inner dimension. The
// body is run for every "inner vector".
#define TARRAY_WALK_NOINNER(again_label_name, rank, shape, strides, body) \
@@ -161,18 +221,24 @@ enum fbinop_tag_t {
void oxarop_fbinary_ ## typ ## _sv(enum binop_tag_t tag, i64 n, typ *out, typ x, const typ *y) { \
switch (tag) { \
case FB_DIV: oxarop_op_fdiv_ ## typ ## _sv(n, out, x, y); break; \
+ case FB_POW: oxarop_op_pow_ ## typ ## _sv(n, out, x, y); break; \
+ case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _sv(n, out, x, y); break; \
default: wrong_op("binary_sv", tag); \
} \
} \
void oxarop_fbinary_ ## typ ## _vs(enum binop_tag_t tag, i64 n, typ *out, const typ *x, typ y) { \
switch (tag) { \
case FB_DIV: oxarop_op_fdiv_ ## typ ## _vs(n, out, x, y); break; \
+ case FB_POW: oxarop_op_pow_ ## typ ## _vs(n, out, x, y); break; \
+ case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vs(n, out, x, y); break; \
default: wrong_op("binary_vs", tag); \
} \
} \
void oxarop_fbinary_ ## typ ## _vv(enum binop_tag_t tag, i64 n, typ *out, const typ *x, const typ *y) { \
switch (tag) { \
case FB_DIV: oxarop_op_fdiv_ ## typ ## _vv(n, out, x, y); break; \
+ case FB_POW: oxarop_op_pow_ ## typ ## _vv(n, out, x, y); break; \
+ case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vv(n, out, x, y); break; \
default: wrong_op("binary_vv", tag); \
} \
}
@@ -204,9 +270,28 @@ enum funop_tag_t {
};
#define ENTRY_FUNARY_OPS(typ) \
- void oxarop_funary_ ## typ(enum unop_tag_t tag, i64 n, typ *out, const typ *x) { \
+ void oxarop_funary_ ## typ(enum funop_tag_t tag, i64 n, typ *out, const typ *x) { \
switch (tag) { \
case FU_RECIP: oxarop_op_recip_ ## typ(n, out, x); break; \
+ case FU_EXP: oxarop_op_exp_ ## typ(n, out, x); break; \
+ case FU_LOG: oxarop_op_log_ ## typ(n, out, x); break; \
+ case FU_SQRT: oxarop_op_sqrt_ ## typ(n, out, x); break; \
+ case FU_SIN: oxarop_op_sin_ ## typ(n, out, x); break; \
+ case FU_COS: oxarop_op_cos_ ## typ(n, out, x); break; \
+ case FU_TAN: oxarop_op_tan_ ## typ(n, out, x); break; \
+ case FU_ASIN: oxarop_op_asin_ ## typ(n, out, x); break; \
+ case FU_ACOS: oxarop_op_acos_ ## typ(n, out, x); break; \
+ case FU_ATAN: oxarop_op_atan_ ## typ(n, out, x); break; \
+ case FU_SINH: oxarop_op_sinh_ ## typ(n, out, x); break; \
+ case FU_COSH: oxarop_op_cosh_ ## typ(n, out, x); break; \
+ case FU_TANH: oxarop_op_tanh_ ## typ(n, out, x); break; \
+ case FU_ASINH: oxarop_op_asinh_ ## typ(n, out, x); break; \
+ case FU_ACOSH: oxarop_op_acosh_ ## typ(n, out, x); break; \
+ case FU_ATANH: oxarop_op_atanh_ ## typ(n, out, x); break; \
+ case FU_LOG1P: oxarop_op_log1p_ ## typ(n, out, x); break; \
+ case FU_EXPM1: oxarop_op_expm1_ ## typ(n, out, x); break; \
+ case FU_LOG1PEXP: oxarop_op_log1pexp_ ## typ(n, out, x); break; \
+ case FU_LOG1MEXP: oxarop_op_log1mexp_ ## typ(n, out, x); break; \
default: wrong_op("unary", tag); \
} \
}
@@ -253,7 +338,28 @@ NUM_TYPES_XLIST
#define X(typ) \
NONCOMM_OP(fdiv, /, typ) \
+ PREFIX_BINOP(pow, GEN_POW, typ) \
+ PREFIX_BINOP(logbase, GEN_LOGBASE, typ) \
UNARY_OP(recip, 1.0/, typ) \
+ UNARY_OP(exp, GEN_EXP, typ) \
+ UNARY_OP(log, GEN_LOG, typ) \
+ UNARY_OP(sqrt, GEN_SQRT, typ) \
+ UNARY_OP(sin, GEN_SIN, typ) \
+ UNARY_OP(cos, GEN_COS, typ) \
+ UNARY_OP(tan, GEN_TAN, typ) \
+ UNARY_OP(asin, GEN_ASIN, typ) \
+ UNARY_OP(acos, GEN_ACOS, typ) \
+ UNARY_OP(atan, GEN_ATAN, typ) \
+ UNARY_OP(sinh, GEN_SINH, typ) \
+ UNARY_OP(cosh, GEN_COSH, typ) \
+ UNARY_OP(tanh, GEN_TANH, typ) \
+ UNARY_OP(asinh, GEN_ASINH, typ) \
+ UNARY_OP(acosh, GEN_ACOSH, typ) \
+ UNARY_OP(atanh, GEN_ATANH, typ) \
+ UNARY_OP(log1p, GEN_LOG1P, typ) \
+ UNARY_OP(expm1, GEN_EXPM1, typ) \
+ UNARY_OP(log1pexp, GEN_LOG1PEXP, typ) \
+ UNARY_OP(log1mexp, GEN_LOG1MEXP, typ) \
ENTRY_FBINARY_OPS(typ) \
ENTRY_FUNARY_OPS(typ)
FLOAT_TYPES_XLIST