// Architecture detection #if defined(__x86_64__) || defined(_M_X64) #define OX_ARCH_INTEL #endif #include #include #include #include #include #include #ifdef OX_ARCH_INTEL #include #endif // These are the wrapper macros used in arith_lists.h. Preset them to empty to // avoid having to touch macros unrelated to the particular operation set below. #define LIST_BINOP(name, id, hsop) #define LIST_FBINOP(name, id, hsop) #define LIST_UNOP(name, id, _) #define LIST_FUNOP(name, id, _) #define LIST_REDOP(name, id, _) // Shorter names, due to CPP used both in function names and in C types. typedef int32_t i32; typedef int64_t i64; /***************************************************************************** * Additional math functions * *****************************************************************************/ #define GEN_ABS(x) \ _Generic((x), \ int: abs, \ long: labs, \ long long: llabs, \ float: fabsf, \ double: fabs)(x) // This does not result in multiple loads with GCC 13. #define GEN_SIGNUM(x) ((x) < 0 ? -1 : (x) > 0 ? 1 : 0) #define GEN_POW(x, y) _Generic((x), float: powf, double: pow)(x, y) #define GEN_LOGBASE(x, y) _Generic((x), float: logf(y) / logf(x), double: log(y) / log(x)) #define GEN_EXP(x) _Generic((x), float: expf, double: exp)(x) #define GEN_LOG(x) _Generic((x), float: logf, double: log)(x) #define GEN_SQRT(x) _Generic((x), float: sqrtf, double: sqrt)(x) #define GEN_SIN(x) _Generic((x), float: sinf, double: sin)(x) #define GEN_COS(x) _Generic((x), float: cosf, double: cos)(x) #define GEN_TAN(x) _Generic((x), float: tanf, double: tan)(x) #define GEN_ASIN(x) _Generic((x), float: asinf, double: asin)(x) #define GEN_ACOS(x) _Generic((x), float: acosf, double: acos)(x) #define GEN_ATAN(x) _Generic((x), float: atanf, double: atan)(x) #define GEN_SINH(x) _Generic((x), float: sinhf, double: sinh)(x) #define GEN_COSH(x) _Generic((x), float: coshf, double: cosh)(x) #define GEN_TANH(x) _Generic((x), float: tanhf, double: tanh)(x) #define GEN_ASINH(x) _Generic((x), float: asinhf, double: asinh)(x) #define GEN_ACOSH(x) _Generic((x), float: acoshf, double: acosh)(x) #define GEN_ATANH(x) _Generic((x), float: atanhf, double: atanh)(x) #define GEN_LOG1P(x) _Generic((x), float: log1pf, double: log1p)(x) #define GEN_EXPM1(x) _Generic((x), float: expm1f, double: expm1)(x) // Taken from Haskell's implementation: // https://hackage.haskell.org/package/ghc-internal-9.1001.0/docs/src//GHC.Internal.Float.html#log1mexpOrd #define LOG1MEXP_IMPL(x) do { \ if (x > _Generic((x), float: logf, double: log)(2)) return GEN_LOG(-GEN_EXPM1(x)); \ else return GEN_LOG1P(-GEN_EXP(x)); \ } while (0) static float log1mexp_float(float x) { LOG1MEXP_IMPL(x); } static double log1mexp_double(double x) { LOG1MEXP_IMPL(x); } #define GEN_LOG1MEXP(x) _Generic((x), float: log1mexp_float, double: log1mexp_double)(x) // Taken from Haskell's implementation: // https://hackage.haskell.org/package/ghc-internal-9.1001.0/docs/src//GHC.Internal.Float.html#line-595 #define LOG1PEXP_IMPL(x) do { \ if (x <= 18) return GEN_LOG1P(GEN_EXP(x)); \ if (x <= 100) return x + GEN_EXP(-x); \ return x; \ } while (0) static float log1pexp_float(float x) { LOG1PEXP_IMPL(x); } static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); } #define GEN_LOG1PEXP(x) _Generic((x), float: log1pexp_float, double: log1pexp_double)(x) /***************************************************************************** * Kernel functions * *****************************************************************************/ #define COMM_OP(name, op, typ) \ static void oxarop_op_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \ for (i64 i = 0; i < n; i++) out[i] = x op y[i]; \ } \ static void oxarop_op_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \ for (i64 i = 0; i < n; i++) out[i] = x[i] op y[i]; \ } #define NONCOMM_OP(name, op, typ) \ COMM_OP(name, op, typ) \ static void oxarop_op_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \ for (i64 i = 0; i < n; i++) out[i] = x[i] op y; \ } #define PREFIX_BINOP(name, op, typ) \ static void oxarop_op_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \ for (i64 i = 0; i < n; i++) out[i] = op(x, y[i]); \ } \ static void oxarop_op_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \ for (i64 i = 0; i < n; i++) out[i] = op(x[i], y[i]); \ } \ static void oxarop_op_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \ for (i64 i = 0; i < n; i++) out[i] = op(x[i], y); \ } #define UNARY_OP(name, op, typ) \ static void oxarop_op_ ## name ## _ ## typ(i64 n, typ *out, const typ *x) { \ for (i64 i = 0; i < n; i++) out[i] = op(x[i]); \ } // Walk a orthotope-style strided array, except for the inner dimension. The // body is run for every "inner vector". // Provides idx, outlinidx, arrlinidx. #define TARRAY_WALK_NOINNER(again_label_name, rank, shape, strides, body) \ do { \ i64 idx[(rank) - 1]; \ memset(idx, 0, ((rank) - 1) * sizeof(idx[0])); \ i64 arrlinidx = 0; \ i64 outlinidx = 0; \ again_label_name: \ { \ body \ } \ for (i64 dim = (rank) - 2; dim >= 0; dim--) { \ if (++idx[dim] < (shape)[dim]) { \ arrlinidx += (strides)[dim]; \ outlinidx++; \ goto again_label_name; \ } \ arrlinidx -= (idx[dim] - 1) * (strides)[dim]; \ idx[dim] = 0; \ } \ } while (false) // Walk TWO orthotope-style strided arrays simultaneously, except for their // inner dimension. The arrays must have the same shape, but may have different // strides. The body is run for every pair of "inner vectors". // Provides idx, outlinidx, arrlinidx1, arrlinidx2. #define TARRAY_WALK2_NOINNER(again_label_name, rank, shape, strides1, strides2, body) \ do { \ i64 idx[(rank) - 1]; \ memset(idx, 0, ((rank) - 1) * sizeof(idx[0])); \ i64 arrlinidx1 = 0, arrlinidx2 = 0; \ i64 outlinidx = 0; \ again_label_name: \ { \ body \ } \ for (i64 dim = (rank) - 2; dim >= 0; dim--) { \ if (++idx[dim] < (shape)[dim]) { \ arrlinidx1 += (strides1)[dim]; \ arrlinidx2 += (strides2)[dim]; \ outlinidx++; \ goto again_label_name; \ } \ arrlinidx1 -= (idx[dim] - 1) * (strides1)[dim]; \ arrlinidx2 -= (idx[dim] - 1) * (strides2)[dim]; \ idx[dim] = 0; \ } \ } while (false) // Same as TARRAY_WALK_NOINNER, except the body is specialised twice: once on // strides[rank-1] == 1 and a fallback case. #define TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, body) \ do { \ if ((strides)[(rank) - 1] == 1) { \ TARRAY_WALK_NOINNER(tar_wa_again1, rank, shape, strides, body); \ } else { \ TARRAY_WALK_NOINNER(tar_wa_again2, rank, shape, strides, body); \ } \ } while (false) // preconditions: // - all strides are >0 // - shape is everywhere >0 // - rank is >= 1 // - out has capacity for (shape[0] * ... * shape[rank - 2]) elements // Reduces along the innermost dimension. // 'out' will be filled densely in linearisation order. #define REDUCE1_OP(name, op, typ) \ static void oxarop_op_ ## name ## _ ## typ(i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *arr) { \ TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ typ accum = arr[arrlinidx]; \ for (i64 i = 1; i < shape[rank - 1]; i++) { \ accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \ } \ out[outlinidx] = accum; \ }); \ } // preconditions // - all strides are >0 // - shape is everywhere >0 // - rank is >= 1 #define REDUCEFULL_OP(name, op, typ) \ typ oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \ typ res = 0; \ TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ typ accum = arr[arrlinidx]; \ for (i64 i = 1; i < shape[rank - 1]; i++) { \ accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \ } \ res = res op accum; \ }); \ return res; \ } // preconditions // - all strides are >0 // - shape is everywhere >0 // - rank is >= 1 // Writes extreme index to outidx. If 'cmp' is '<', computes minindex ("argmin"); if '>', maxindex. #define EXTREMUM_OP(name, cmp, typ) \ void oxarop_extremum_ ## name ## _ ## typ(i64 *outidx, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \ typ best = arr[0]; \ memset(outidx, 0, rank * sizeof(i64)); \ TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ bool found = false; \ for (i64 i = 0; i < shape[rank - 1]; i++) { \ if (arr[arrlinidx + i] cmp best) { \ best = arr[arrlinidx + strides[rank - 1] * i]; \ found = true; \ outidx[rank - 1] = i; \ } \ } \ if (found) memcpy(outidx, idx, (rank - 1) * sizeof(i64)); \ }); \ } #define DOTPROD_OP(typ) \ typ oxarop_dotprod_ ## typ(i64 length, const typ *arr1, const typ *arr2) { \ typ res = 0; \ for (i64 i = 0; i < length; i++) res += arr1[i] * arr2[i]; \ return res; \ } #define DOTPROD_STRIDED_OP(typ) \ typ oxarop_dotprod_ ## typ ## _strided(i64 length, i64 offset1, i64 stride1, const typ *arr1, i64 offset2, i64 stride2, const typ *arr2) { \ typ res = 0; \ for (i64 i = 0; i < length; i++) res += arr1[offset1 + stride1 * i] * arr2[offset2 + stride2 * i]; \ return res; \ } // The 'double' version here is about 2x as fast as gcc's own vectorisation. DOTPROD_OP(i32) DOTPROD_OP(i64) #ifdef OX_ARCH_INTEL float oxarop_dotprod_float(i64 length, const float *arr1, const float *arr2) { __m128 accum = _mm_setzero_ps(); i64 i; for (i = 0; i + 3 < length; i += 4) { accum = _mm_add_ps(accum, _mm_mul_ps(_mm_loadu_ps(arr1 + i), _mm_loadu_ps(arr2 + i))); } float dest[4]; _mm_storeu_ps(dest, accum); float tot = dest[0] + dest[1] + dest[2] + dest[3]; for (; i < length; i++) tot += arr1[i] * arr2[i]; return tot; } double oxarop_dotprod_double(i64 length, const double *arr1, const double *arr2) { __m128d accum = _mm_setzero_pd(); i64 i; for (i = 0; i + 1 < length; i += 2) { accum = _mm_add_pd(accum, _mm_mul_pd(_mm_loadu_pd(arr1 + i), _mm_loadu_pd(arr2 + i))); } double tot = _mm_cvtsd_f64(accum) + _mm_cvtsd_f64(_mm_unpackhi_pd(accum, accum)); if (i < length) tot += arr1[i] * arr2[i]; return tot; } #else DOTPROD_OP(float) DOTPROD_OP(double) #endif // preconditions: // - all strides are >0 // - shape is everywhere >0 // - rank is >= 1 // - out has capacity for (shape[0] * ... * shape[rank - 2]) elements // Reduces along the innermost dimension. // 'out' will be filled densely in linearisation order. #define DOTPROD_INNER_OP(typ) \ void oxarop_dotprodinner_ ## typ(i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *arr1, const i64 *strides2, const typ *arr2) { \ if (strides1[rank - 1] == 1 && strides2[rank - 1] == 1) { \ TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \ out[outlinidx] = oxarop_dotprod_ ## typ(shape[rank - 1], arr1 + arrlinidx1, arr2 + arrlinidx2); \ }); \ } else if (strides1[rank - 1] == -1 && strides2[rank - 1] == -1) { \ TARRAY_WALK2_NOINNER(again2, rank, shape, strides1, strides2, { \ const i64 len = shape[rank - 1]; \ out[outlinidx] = oxarop_dotprod_ ## typ(len, arr1 + arrlinidx1 - (len - 1), arr2 + arrlinidx2 - (len - 1)); \ }); \ } else { \ TARRAY_WALK2_NOINNER(again3, rank, shape, strides1, strides2, { \ out[outlinidx] = oxarop_dotprod_ ## typ ## _strided(shape[rank - 1], arrlinidx1, strides1[rank - 1], arr1, arrlinidx2, strides2[rank - 1], arr2); \ }); \ } \ } /***************************************************************************** * Entry point functions * *****************************************************************************/ __attribute__((noreturn, cold)) static void wrong_op(const char *name, int tag) { fprintf(stderr, "ox-arrays: Invalid operation tag passed to %s C code: %d\n", name, tag); abort(); } enum binop_tag_t { #undef LIST_BINOP #define LIST_BINOP(name, id, hsop) name = id, #include "arith_lists.h" #undef LIST_BINOP #define LIST_BINOP(name, id, hsop) }; #define ENTRY_BINARY_OPS(typ) \ void oxarop_binary_ ## typ ## _sv(enum binop_tag_t tag, i64 n, typ *out, typ x, const typ *y) { \ switch (tag) { \ case BO_ADD: oxarop_op_add_ ## typ ## _sv(n, out, x, y); break; \ case BO_SUB: oxarop_op_sub_ ## typ ## _sv(n, out, x, y); break; \ case BO_MUL: oxarop_op_mul_ ## typ ## _sv(n, out, x, y); break; \ default: wrong_op("binary_sv", tag); \ } \ } \ void oxarop_binary_ ## typ ## _vs(enum binop_tag_t tag, i64 n, typ *out, const typ *x, typ y) { \ switch (tag) { \ case BO_ADD: oxarop_op_add_ ## typ ## _sv(n, out, y, x); break; \ case BO_SUB: oxarop_op_sub_ ## typ ## _vs(n, out, x, y); break; \ case BO_MUL: oxarop_op_mul_ ## typ ## _sv(n, out, y, x); break; \ default: wrong_op("binary_vs", tag); \ } \ } \ void oxarop_binary_ ## typ ## _vv(enum binop_tag_t tag, i64 n, typ *out, const typ *x, const typ *y) { \ switch (tag) { \ case BO_ADD: oxarop_op_add_ ## typ ## _vv(n, out, x, y); break; \ case BO_SUB: oxarop_op_sub_ ## typ ## _vv(n, out, x, y); break; \ case BO_MUL: oxarop_op_mul_ ## typ ## _vv(n, out, x, y); break; \ default: wrong_op("binary_vv", tag); \ } \ } enum fbinop_tag_t { #undef LIST_FBINOP #define LIST_FBINOP(name, id, hsop) name = id, #include "arith_lists.h" #undef LIST_FBINOP #define LIST_FBINOP(name, id, hsop) }; #define ENTRY_FBINARY_OPS(typ) \ void oxarop_fbinary_ ## typ ## _sv(enum binop_tag_t tag, i64 n, typ *out, typ x, const typ *y) { \ switch (tag) { \ case FB_DIV: oxarop_op_fdiv_ ## typ ## _sv(n, out, x, y); break; \ case FB_POW: oxarop_op_pow_ ## typ ## _sv(n, out, x, y); break; \ case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _sv(n, out, x, y); break; \ default: wrong_op("binary_sv", tag); \ } \ } \ void oxarop_fbinary_ ## typ ## _vs(enum binop_tag_t tag, i64 n, typ *out, const typ *x, typ y) { \ switch (tag) { \ case FB_DIV: oxarop_op_fdiv_ ## typ ## _vs(n, out, x, y); break; \ case FB_POW: oxarop_op_pow_ ## typ ## _vs(n, out, x, y); break; \ case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vs(n, out, x, y); break; \ default: wrong_op("binary_vs", tag); \ } \ } \ void oxarop_fbinary_ ## typ ## _vv(enum binop_tag_t tag, i64 n, typ *out, const typ *x, const typ *y) { \ switch (tag) { \ case FB_DIV: oxarop_op_fdiv_ ## typ ## _vv(n, out, x, y); break; \ case FB_POW: oxarop_op_pow_ ## typ ## _vv(n, out, x, y); break; \ case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vv(n, out, x, y); break; \ default: wrong_op("binary_vv", tag); \ } \ } enum unop_tag_t { #undef LIST_UNOP #define LIST_UNOP(name, id, _) name = id, #include "arith_lists.h" #undef LIST_UNOP #define LIST_UNOP(name, id, _) }; #define ENTRY_UNARY_OPS(typ) \ void oxarop_unary_ ## typ(enum unop_tag_t tag, i64 n, typ *out, const typ *x) { \ switch (tag) { \ case UO_NEG: oxarop_op_neg_ ## typ(n, out, x); break; \ case UO_ABS: oxarop_op_abs_ ## typ(n, out, x); break; \ case UO_SIGNUM: oxarop_op_signum_ ## typ(n, out, x); break; \ default: wrong_op("unary", tag); \ } \ } enum funop_tag_t { #undef LIST_FUNOP #define LIST_FUNOP(name, id, _) name = id, #include "arith_lists.h" #undef LIST_FUNOP #define LIST_FUNOP(name, id, _) }; #define ENTRY_FUNARY_OPS(typ) \ void oxarop_funary_ ## typ(enum funop_tag_t tag, i64 n, typ *out, const typ *x) { \ switch (tag) { \ case FU_RECIP: oxarop_op_recip_ ## typ(n, out, x); break; \ case FU_EXP: oxarop_op_exp_ ## typ(n, out, x); break; \ case FU_LOG: oxarop_op_log_ ## typ(n, out, x); break; \ case FU_SQRT: oxarop_op_sqrt_ ## typ(n, out, x); break; \ case FU_SIN: oxarop_op_sin_ ## typ(n, out, x); break; \ case FU_COS: oxarop_op_cos_ ## typ(n, out, x); break; \ case FU_TAN: oxarop_op_tan_ ## typ(n, out, x); break; \ case FU_ASIN: oxarop_op_asin_ ## typ(n, out, x); break; \ case FU_ACOS: oxarop_op_acos_ ## typ(n, out, x); break; \ case FU_ATAN: oxarop_op_atan_ ## typ(n, out, x); break; \ case FU_SINH: oxarop_op_sinh_ ## typ(n, out, x); break; \ case FU_COSH: oxarop_op_cosh_ ## typ(n, out, x); break; \ case FU_TANH: oxarop_op_tanh_ ## typ(n, out, x); break; \ case FU_ASINH: oxarop_op_asinh_ ## typ(n, out, x); break; \ case FU_ACOSH: oxarop_op_acosh_ ## typ(n, out, x); break; \ case FU_ATANH: oxarop_op_atanh_ ## typ(n, out, x); break; \ case FU_LOG1P: oxarop_op_log1p_ ## typ(n, out, x); break; \ case FU_EXPM1: oxarop_op_expm1_ ## typ(n, out, x); break; \ case FU_LOG1PEXP: oxarop_op_log1pexp_ ## typ(n, out, x); break; \ case FU_LOG1MEXP: oxarop_op_log1mexp_ ## typ(n, out, x); break; \ default: wrong_op("unary", tag); \ } \ } enum redop_tag_t { #undef LIST_REDOP #define LIST_REDOP(name, id, _) name = id, #include "arith_lists.h" #undef LIST_REDOP #define LIST_REDOP(name, id, _) }; #define ENTRY_REDUCE1_OPS(typ) \ void oxarop_reduce1_ ## typ(enum redop_tag_t tag, i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *arr) { \ switch (tag) { \ case RO_SUM: oxarop_op_sum1_ ## typ(rank, out, shape, strides, arr); break; \ case RO_PRODUCT: oxarop_op_product1_ ## typ(rank, out, shape, strides, arr); break; \ default: wrong_op("reduce", tag); \ } \ } #define ENTRY_REDUCEFULL_OPS(typ) \ typ oxarop_reducefull_ ## typ(enum redop_tag_t tag, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \ switch (tag) { \ case RO_SUM: return oxarop_op_sumfull_ ## typ(rank, shape, strides, arr); \ case RO_PRODUCT: return oxarop_op_productfull_ ## typ(rank, shape, strides, arr); \ default: wrong_op("reduce", tag); \ } \ } /***************************************************************************** * Generate all the functions * *****************************************************************************/ #define FLOAT_TYPES_XLIST X(double) X(float) #define NUM_TYPES_XLIST X(i32) X(i64) FLOAT_TYPES_XLIST #define X(typ) \ COMM_OP(add, +, typ) \ NONCOMM_OP(sub, -, typ) \ COMM_OP(mul, *, typ) \ UNARY_OP(neg, -, typ) \ UNARY_OP(abs, GEN_ABS, typ) \ UNARY_OP(signum, GEN_SIGNUM, typ) \ REDUCE1_OP(sum1, +, typ) \ REDUCE1_OP(product1, *, typ) \ REDUCEFULL_OP(sumfull, +, typ) \ REDUCEFULL_OP(productfull, *, typ) \ ENTRY_BINARY_OPS(typ) \ ENTRY_UNARY_OPS(typ) \ ENTRY_REDUCE1_OPS(typ) \ ENTRY_REDUCEFULL_OPS(typ) \ EXTREMUM_OP(min, <, typ) \ EXTREMUM_OP(max, >, typ) \ DOTPROD_STRIDED_OP(typ) \ DOTPROD_INNER_OP(typ) NUM_TYPES_XLIST #undef X #define X(typ) \ NONCOMM_OP(fdiv, /, typ) \ PREFIX_BINOP(pow, GEN_POW, typ) \ PREFIX_BINOP(logbase, GEN_LOGBASE, typ) \ UNARY_OP(recip, 1.0/, typ) \ UNARY_OP(exp, GEN_EXP, typ) \ UNARY_OP(log, GEN_LOG, typ) \ UNARY_OP(sqrt, GEN_SQRT, typ) \ UNARY_OP(sin, GEN_SIN, typ) \ UNARY_OP(cos, GEN_COS, typ) \ UNARY_OP(tan, GEN_TAN, typ) \ UNARY_OP(asin, GEN_ASIN, typ) \ UNARY_OP(acos, GEN_ACOS, typ) \ UNARY_OP(atan, GEN_ATAN, typ) \ UNARY_OP(sinh, GEN_SINH, typ) \ UNARY_OP(cosh, GEN_COSH, typ) \ UNARY_OP(tanh, GEN_TANH, typ) \ UNARY_OP(asinh, GEN_ASINH, typ) \ UNARY_OP(acosh, GEN_ACOSH, typ) \ UNARY_OP(atanh, GEN_ATANH, typ) \ UNARY_OP(log1p, GEN_LOG1P, typ) \ UNARY_OP(expm1, GEN_EXPM1, typ) \ UNARY_OP(log1pexp, GEN_LOG1PEXP, typ) \ UNARY_OP(log1mexp, GEN_LOG1MEXP, typ) \ ENTRY_FBINARY_OPS(typ) \ ENTRY_FUNARY_OPS(typ) FLOAT_TYPES_XLIST #undef X