From 9e5945120bbcfeff15ee7356398e06ab5ba25561 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 27 May 2024 14:10:57 +0200 Subject: Fast (C) Floating ops --- cbits/arith.c | 130 +++++++++++++++++++++++++++++++++++++++++++++++----- cbits/arith_lists.h | 21 +++++++++ 2 files changed, 139 insertions(+), 12 deletions(-) (limited to 'cbits') diff --git a/cbits/arith.c b/cbits/arith.c index a71c1b9..e20578b 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -18,6 +18,66 @@ typedef int32_t i32; typedef int64_t i64; +/***************************************************************************** + * Additional math functions * + *****************************************************************************/ + +#define GEN_ABS(x) \ + _Generic((x), \ + int: abs, \ + long: labs, \ + long long: llabs, \ + float: fabsf, \ + double: fabs)(x) + +// This does not result in multiple loads with GCC 13. +#define GEN_SIGNUM(x) ((x) < 0 ? -1 : (x) > 0 ? 1 : 0) + +#define GEN_POW(x, y) _Generic((x), float: powf, double: pow)(x, y) +#define GEN_LOGBASE(x, y) _Generic((x), float: logf(y) / logf(x), double: log(y) / log(x)) +#define GEN_EXP(x) _Generic((x), float: expf, double: exp)(x) +#define GEN_LOG(x) _Generic((x), float: logf, double: log)(x) +#define GEN_SQRT(x) _Generic((x), float: sqrtf, double: sqrt)(x) +#define GEN_SIN(x) _Generic((x), float: sinf, double: sin)(x) +#define GEN_COS(x) _Generic((x), float: cosf, double: cos)(x) +#define GEN_TAN(x) _Generic((x), float: tanf, double: tan)(x) +#define GEN_ASIN(x) _Generic((x), float: asinf, double: asin)(x) +#define GEN_ACOS(x) _Generic((x), float: acosf, double: acos)(x) +#define GEN_ATAN(x) _Generic((x), float: atanf, double: atan)(x) +#define GEN_SINH(x) _Generic((x), float: sinhf, double: sinh)(x) +#define GEN_COSH(x) _Generic((x), float: coshf, double: cosh)(x) +#define GEN_TANH(x) _Generic((x), float: tanhf, double: tanh)(x) +#define GEN_ASINH(x) _Generic((x), float: asinhf, double: asinh)(x) +#define GEN_ACOSH(x) _Generic((x), float: acoshf, double: acosh)(x) +#define GEN_ATANH(x) _Generic((x), float: atanhf, double: atanh)(x) +#define GEN_LOG1P(x) _Generic((x), float: log1pf, double: log1p)(x) +#define GEN_EXPM1(x) _Generic((x), float: expm1f, double: expm1)(x) + +// Taken from Haskell's implementation: +// https://hackage.haskell.org/package/ghc-internal-9.1001.0/docs/src//GHC.Internal.Float.html#log1mexpOrd +#define LOG1MEXP_IMPL(x) do { \ + if (x > _Generic((x), float: logf, double: log)(2)) return GEN_LOG(-GEN_EXPM1(x)); \ + else return GEN_LOG1P(-GEN_EXP(x)); \ + } while (0) + +static float log1mexp_float(float x) { LOG1MEXP_IMPL(x); } +static double log1mexp_double(double x) { LOG1MEXP_IMPL(x); } + +#define GEN_LOG1MEXP(x) _Generic((x), float: log1mexp_float, double: log1mexp_double)(x) + +// Taken from Haskell's implementation: +// https://hackage.haskell.org/package/ghc-internal-9.1001.0/docs/src//GHC.Internal.Float.html#line-595 +#define LOG1PEXP_IMPL(x) do { \ + if (x <= 18) return GEN_LOG1P(GEN_EXP(x)); \ + if (x <= 100) return x + GEN_EXP(-x); \ + return x; \ + } while (0) + +static float log1pexp_float(float x) { LOG1PEXP_IMPL(x); } +static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); } + +#define GEN_LOG1PEXP(x) _Generic((x), float: log1pexp_float, double: log1pexp_double)(x) + /***************************************************************************** * Kernel functions * @@ -37,22 +97,22 @@ typedef int64_t i64; for (i64 i = 0; i < n; i++) out[i] = x[i] op y; \ } +#define PREFIX_BINOP(name, op, typ) \ + static void oxarop_op_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \ + for (i64 i = 0; i < n; i++) out[i] = op(x, y[i]); \ + } \ + static void oxarop_op_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \ + for (i64 i = 0; i < n; i++) out[i] = op(x[i], y[i]); \ + } \ + static void oxarop_op_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \ + for (i64 i = 0; i < n; i++) out[i] = op(x[i], y); \ + } + #define UNARY_OP(name, op, typ) \ static void oxarop_op_ ## name ## _ ## typ(i64 n, typ *out, const typ *x) { \ for (i64 i = 0; i < n; i++) out[i] = op(x[i]); \ } -#define GEN_ABS(x) \ - _Generic((x), \ - int: abs, \ - long: labs, \ - long long: llabs, \ - float: fabsf, \ - double: fabs)(x) - -// This does not result in multiple loads with GCC 13. -#define GEN_SIGNUM(x) ((x) < 0 ? -1 : (x) > 0 ? 1 : 0) - // Walk a orthotope-style strided array, except for the inner dimension. The // body is run for every "inner vector". #define TARRAY_WALK_NOINNER(again_label_name, rank, shape, strides, body) \ @@ -161,18 +221,24 @@ enum fbinop_tag_t { void oxarop_fbinary_ ## typ ## _sv(enum binop_tag_t tag, i64 n, typ *out, typ x, const typ *y) { \ switch (tag) { \ case FB_DIV: oxarop_op_fdiv_ ## typ ## _sv(n, out, x, y); break; \ + case FB_POW: oxarop_op_pow_ ## typ ## _sv(n, out, x, y); break; \ + case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _sv(n, out, x, y); break; \ default: wrong_op("binary_sv", tag); \ } \ } \ void oxarop_fbinary_ ## typ ## _vs(enum binop_tag_t tag, i64 n, typ *out, const typ *x, typ y) { \ switch (tag) { \ case FB_DIV: oxarop_op_fdiv_ ## typ ## _vs(n, out, x, y); break; \ + case FB_POW: oxarop_op_pow_ ## typ ## _vs(n, out, x, y); break; \ + case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vs(n, out, x, y); break; \ default: wrong_op("binary_vs", tag); \ } \ } \ void oxarop_fbinary_ ## typ ## _vv(enum binop_tag_t tag, i64 n, typ *out, const typ *x, const typ *y) { \ switch (tag) { \ case FB_DIV: oxarop_op_fdiv_ ## typ ## _vv(n, out, x, y); break; \ + case FB_POW: oxarop_op_pow_ ## typ ## _vv(n, out, x, y); break; \ + case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vv(n, out, x, y); break; \ default: wrong_op("binary_vv", tag); \ } \ } @@ -204,9 +270,28 @@ enum funop_tag_t { }; #define ENTRY_FUNARY_OPS(typ) \ - void oxarop_funary_ ## typ(enum unop_tag_t tag, i64 n, typ *out, const typ *x) { \ + void oxarop_funary_ ## typ(enum funop_tag_t tag, i64 n, typ *out, const typ *x) { \ switch (tag) { \ case FU_RECIP: oxarop_op_recip_ ## typ(n, out, x); break; \ + case FU_EXP: oxarop_op_exp_ ## typ(n, out, x); break; \ + case FU_LOG: oxarop_op_log_ ## typ(n, out, x); break; \ + case FU_SQRT: oxarop_op_sqrt_ ## typ(n, out, x); break; \ + case FU_SIN: oxarop_op_sin_ ## typ(n, out, x); break; \ + case FU_COS: oxarop_op_cos_ ## typ(n, out, x); break; \ + case FU_TAN: oxarop_op_tan_ ## typ(n, out, x); break; \ + case FU_ASIN: oxarop_op_asin_ ## typ(n, out, x); break; \ + case FU_ACOS: oxarop_op_acos_ ## typ(n, out, x); break; \ + case FU_ATAN: oxarop_op_atan_ ## typ(n, out, x); break; \ + case FU_SINH: oxarop_op_sinh_ ## typ(n, out, x); break; \ + case FU_COSH: oxarop_op_cosh_ ## typ(n, out, x); break; \ + case FU_TANH: oxarop_op_tanh_ ## typ(n, out, x); break; \ + case FU_ASINH: oxarop_op_asinh_ ## typ(n, out, x); break; \ + case FU_ACOSH: oxarop_op_acosh_ ## typ(n, out, x); break; \ + case FU_ATANH: oxarop_op_atanh_ ## typ(n, out, x); break; \ + case FU_LOG1P: oxarop_op_log1p_ ## typ(n, out, x); break; \ + case FU_EXPM1: oxarop_op_expm1_ ## typ(n, out, x); break; \ + case FU_LOG1PEXP: oxarop_op_log1pexp_ ## typ(n, out, x); break; \ + case FU_LOG1MEXP: oxarop_op_log1mexp_ ## typ(n, out, x); break; \ default: wrong_op("unary", tag); \ } \ } @@ -253,7 +338,28 @@ NUM_TYPES_XLIST #define X(typ) \ NONCOMM_OP(fdiv, /, typ) \ + PREFIX_BINOP(pow, GEN_POW, typ) \ + PREFIX_BINOP(logbase, GEN_LOGBASE, typ) \ UNARY_OP(recip, 1.0/, typ) \ + UNARY_OP(exp, GEN_EXP, typ) \ + UNARY_OP(log, GEN_LOG, typ) \ + UNARY_OP(sqrt, GEN_SQRT, typ) \ + UNARY_OP(sin, GEN_SIN, typ) \ + UNARY_OP(cos, GEN_COS, typ) \ + UNARY_OP(tan, GEN_TAN, typ) \ + UNARY_OP(asin, GEN_ASIN, typ) \ + UNARY_OP(acos, GEN_ACOS, typ) \ + UNARY_OP(atan, GEN_ATAN, typ) \ + UNARY_OP(sinh, GEN_SINH, typ) \ + UNARY_OP(cosh, GEN_COSH, typ) \ + UNARY_OP(tanh, GEN_TANH, typ) \ + UNARY_OP(asinh, GEN_ASINH, typ) \ + UNARY_OP(acosh, GEN_ACOSH, typ) \ + UNARY_OP(atanh, GEN_ATANH, typ) \ + UNARY_OP(log1p, GEN_LOG1P, typ) \ + UNARY_OP(expm1, GEN_EXPM1, typ) \ + UNARY_OP(log1pexp, GEN_LOG1PEXP, typ) \ + UNARY_OP(log1mexp, GEN_LOG1MEXP, typ) \ ENTRY_FBINARY_OPS(typ) \ ENTRY_FUNARY_OPS(typ) FLOAT_TYPES_XLIST diff --git a/cbits/arith_lists.h b/cbits/arith_lists.h index 1137c18..2e37575 100644 --- a/cbits/arith_lists.h +++ b/cbits/arith_lists.h @@ -3,12 +3,33 @@ LIST_BINOP(BO_SUB, 2, -) LIST_BINOP(BO_MUL, 3, *) LIST_FBINOP(FB_DIV, 1, /) +LIST_FBINOP(FB_POW, 2, **) +LIST_FBINOP(FB_LOGBASE, 3, logBase) LIST_UNOP(UO_NEG, 1,) LIST_UNOP(UO_ABS, 2,) LIST_UNOP(UO_SIGNUM, 3,) LIST_FUNOP(FU_RECIP, 1,) +LIST_FUNOP(FU_EXP, 2,) +LIST_FUNOP(FU_LOG, 3,) +LIST_FUNOP(FU_SQRT, 4,) +LIST_FUNOP(FU_SIN, 5,) +LIST_FUNOP(FU_COS, 6,) +LIST_FUNOP(FU_TAN, 7,) +LIST_FUNOP(FU_ASIN, 8,) +LIST_FUNOP(FU_ACOS, 9,) +LIST_FUNOP(FU_ATAN, 10,) +LIST_FUNOP(FU_SINH, 11,) +LIST_FUNOP(FU_COSH, 12,) +LIST_FUNOP(FU_TANH, 13,) +LIST_FUNOP(FU_ASINH, 14,) +LIST_FUNOP(FU_ACOSH, 15,) +LIST_FUNOP(FU_ATANH, 16,) +LIST_FUNOP(FU_LOG1P, 17,) +LIST_FUNOP(FU_EXPM1, 18,) +LIST_FUNOP(FU_LOG1PEXP, 19,) +LIST_FUNOP(FU_LOG1MEXP, 20,) LIST_REDOP(RO_SUM1, 1,) LIST_REDOP(RO_PRODUCT1, 2,) -- cgit v1.2.3-70-g09d2