// Architecture detection
#if defined(__x86_64__) || defined(_M_X64)
#define OX_ARCH_INTEL
#endif

#include <stdio.h>
#include <stdint.h>
#include <inttypes.h>
#include <stdlib.h>
#include <stdbool.h>
#include <string.h>
#include <math.h>

#ifdef OX_ARCH_INTEL
#include <emmintrin.h>
#endif

// These are the wrapper macros used in arith_lists.h. Preset them to empty to
// avoid having to touch macros unrelated to the particular operation set below.
#define LIST_BINOP(name, id, hsop)
#define LIST_IBINOP(name, id, hsop)
#define LIST_FBINOP(name, id, hsop)
#define LIST_UNOP(name, id, _)
#define LIST_FUNOP(name, id, _)
#define LIST_REDOP(name, id, _)


// Shorter names, due to CPP used both in function names and in C types.
typedef int32_t i32;
typedef int64_t i64;

/*****************************************************************************
 *                         Additional math functions                         *
 *****************************************************************************/

#define GEN_ABS(x) \
  _Generic((x), \
      int: abs, \
      long: labs, \
      long long: llabs, \
      float: fabsf, \
      double: fabs)(x)

// This does not result in multiple loads with GCC 13.
#define GEN_SIGNUM(x) ((x) < 0 ? -1 : (x) > 0 ? 1 : 0)

#define GEN_POW(x, y) _Generic((x), float: powf, double: pow)(x, y)
#define GEN_LOGBASE(x, y) _Generic((x), float: logf(y) / logf(x), double: log(y) / log(x))
#define GEN_ATAN2(y, x) _Generic((x), float: atan2f(y, x), double: atan2(y, x))
#define GEN_EXP(x) _Generic((x), float: expf, double: exp)(x)
#define GEN_LOG(x) _Generic((x), float: logf, double: log)(x)
#define GEN_SQRT(x) _Generic((x), float: sqrtf, double: sqrt)(x)
#define GEN_SIN(x) _Generic((x), float: sinf, double: sin)(x)
#define GEN_COS(x) _Generic((x), float: cosf, double: cos)(x)
#define GEN_TAN(x) _Generic((x), float: tanf, double: tan)(x)
#define GEN_ASIN(x) _Generic((x), float: asinf, double: asin)(x)
#define GEN_ACOS(x) _Generic((x), float: acosf, double: acos)(x)
#define GEN_ATAN(x) _Generic((x), float: atanf, double: atan)(x)
#define GEN_SINH(x) _Generic((x), float: sinhf, double: sinh)(x)
#define GEN_COSH(x) _Generic((x), float: coshf, double: cosh)(x)
#define GEN_TANH(x) _Generic((x), float: tanhf, double: tanh)(x)
#define GEN_ASINH(x) _Generic((x), float: asinhf, double: asinh)(x)
#define GEN_ACOSH(x) _Generic((x), float: acoshf, double: acosh)(x)
#define GEN_ATANH(x) _Generic((x), float: atanhf, double: atanh)(x)
#define GEN_LOG1P(x) _Generic((x), float: log1pf, double: log1p)(x)
#define GEN_EXPM1(x) _Generic((x), float: expm1f, double: expm1)(x)

// Taken from Haskell's implementation:
//   https://hackage.haskell.org/package/ghc-internal-9.1001.0/docs/src//GHC.Internal.Float.html#log1mexpOrd
#define LOG1MEXP_IMPL(x) do { \
    if (x > _Generic((x), float: logf, double: log)(2)) return GEN_LOG(-GEN_EXPM1(x)); \
    else return GEN_LOG1P(-GEN_EXP(x)); \
  } while (0)

static float log1mexp_float(float x) { LOG1MEXP_IMPL(x); }
static double log1mexp_double(double x) { LOG1MEXP_IMPL(x); }

#define GEN_LOG1MEXP(x) _Generic((x), float: log1mexp_float, double: log1mexp_double)(x)

// Taken from Haskell's implementation:
//   https://hackage.haskell.org/package/ghc-internal-9.1001.0/docs/src//GHC.Internal.Float.html#line-595
#define LOG1PEXP_IMPL(x) do { \
      if (x <= 18) return GEN_LOG1P(GEN_EXP(x)); \
      if (x <= 100) return x + GEN_EXP(-x); \
      return x; \
  } while (0)

static float log1pexp_float(float x) { LOG1PEXP_IMPL(x); }
static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }

#define GEN_LOG1PEXP(x) _Generic((x), float: log1pexp_float, double: log1pexp_double)(x)


/*****************************************************************************
 *                             Helper functions                              *
 *****************************************************************************/

__attribute__((used))
static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
  fputc('[', stream);
  for (i64 i = 0; i < rank; i++) {
    if (i != 0) fputc(',', stream);
    fprintf(stream, "%" PRIi64, shape[i]);
  }
  fputc(']', stream);
}


/*****************************************************************************
 *                                Skeletons                                  *
 *****************************************************************************/

// Walk a orthotope-style strided array, except for the inner dimension. The
// body is run for every "inner vector".
// Provides idx, outlinidx, arrlinidx.
#define TARRAY_WALK_NOINNER(again_label_name, rank, shape, strides, body) \
  do { \
    i64 idx[(rank) /* - 1 */]; /* Note: [zero-length VLA] */ \
    memset(idx, 0, ((rank) - 1) * sizeof(idx[0])); \
    i64 arrlinidx = 0; \
    i64 outlinidx = 0; \
  again_label_name: \
    { \
      body \
    } \
    for (i64 dim = (rank) - 2; dim >= 0; dim--) { \
      if (++idx[dim] < (shape)[dim]) { \
        arrlinidx += (strides)[dim]; \
        outlinidx++; \
        goto again_label_name; \
      } \
      arrlinidx -= (idx[dim] - 1) * (strides)[dim]; \
      idx[dim] = 0; \
    } \
  } while (false)

// Walk TWO orthotope-style strided arrays simultaneously, except for their
// inner dimension. The arrays must have the same shape, but may have different
// strides. The body is run for every pair of "inner vectors".
// Provides idx, outlinidx, arrlinidx1, arrlinidx2.
#define TARRAY_WALK2_NOINNER(again_label_name, rank, shape, strides1, strides2, body) \
  do { \
    i64 idx[(rank) /* - 1 */]; /* Note: [zero-length VLA] */ \
    memset(idx, 0, ((rank) - 1) * sizeof(idx[0])); \
    i64 arrlinidx1 = 0, arrlinidx2 = 0; \
    i64 outlinidx = 0; \
  again_label_name: \
    { \
      body \
    } \
    for (i64 dim = (rank) - 2; dim >= 0; dim--) { \
      if (++idx[dim] < (shape)[dim]) { \
        arrlinidx1 += (strides1)[dim]; \
        arrlinidx2 += (strides2)[dim]; \
        outlinidx++; \
        goto again_label_name; \
      } \
      arrlinidx1 -= (idx[dim] - 1) * (strides1)[dim]; \
      arrlinidx2 -= (idx[dim] - 1) * (strides2)[dim]; \
      idx[dim] = 0; \
    } \
  } while (false)


/*****************************************************************************
 *                             Kernel functions                              *
 *****************************************************************************/

#define COMM_OP_STRIDED(name, op, typ) \
  static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \
    TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
      for (i64 i = 0; i < shape[rank - 1]; i++) { \
        out[outlinidx * shape[rank - 1] + i] = x op y[arrlinidx + strides[rank - 1] * i]; \
      } \
    }); \
  } \
  static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
    TARRAY_WALK2_NOINNER(again, rank, shape, strides1, strides2, { \
      for (i64 i = 0; i < shape[rank - 1]; i++) { \
        out[outlinidx * shape[rank - 1] + i] = x[arrlinidx1 + strides1[rank - 1] * i] op y[arrlinidx2 + strides2[rank - 1] * i]; \
      } \
    }); \
  }

#define NONCOMM_OP_STRIDED(name, op, typ) \
  COMM_OP_STRIDED(name, op, typ) \
  static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \
    TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
      for (i64 i = 0; i < shape[rank - 1]; i++) { \
        out[outlinidx * shape[rank - 1] + i] = x[arrlinidx + strides[rank - 1] * i] op y; \
      } \
    }); \
  }

#define PREFIX_BINOP_STRIDED(name, op, typ) \
  static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \
    TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
      for (i64 i = 0; i < shape[rank - 1]; i++) { \
        out[outlinidx * shape[rank - 1] + i] = op(x, y[arrlinidx + strides[rank - 1] * i]); \
      } \
    }); \
  } \
  static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
    TARRAY_WALK2_NOINNER(again, rank, shape, strides1, strides2, { \
      for (i64 i = 0; i < shape[rank - 1]; i++) { \
        out[outlinidx * shape[rank - 1] + i] = op(x[arrlinidx1 + strides1[rank - 1] * i], y[arrlinidx2 + strides2[rank - 1] * i]); \
      } \
    }); \
  } \
  static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \
    TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
      for (i64 i = 0; i < shape[rank - 1]; i++) { \
        out[outlinidx * shape[rank - 1] + i] = op(x[arrlinidx + strides[rank - 1] * i], y); \
      } \
    }); \
  }

#define UNARY_OP_STRIDED(name, op, typ) \
  static void oxarop_op_ ## name ## _ ## typ ## _strided(i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *arr) { \
    /* fprintf(stderr, "oxarop_op_" #name "_" #typ "_strided: rank=%ld shape=", rank); \
    print_shape(stderr, rank, shape); \
    fprintf(stderr, " strides="); \
    print_shape(stderr, rank, strides); \
    fprintf(stderr, "\n"); */ \
    TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
      for (i64 i = 0; i < shape[rank - 1]; i++) { \
        out[outlinidx * shape[rank - 1] + i] = op(arr[arrlinidx + strides[rank - 1] * i]); \
      } \
    }); \
  }

// preconditions:
// - all strides are >0
// - shape is everywhere >0
// - rank is >= 1
// - out has capacity for (shape[0] * ... * shape[rank - 2]) elements
// Reduces along the innermost dimension.
// 'out' will be filled densely in linearisation order.
#define REDUCE1_OP(name, op, typ) \
  static void oxarop_op_ ## name ## _ ## typ(i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *arr) { \
    TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
      typ accum = arr[arrlinidx]; \
      for (i64 i = 1; i < shape[rank - 1]; i++) { \
        accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \
      } \
      out[outlinidx] = accum; \
    }); \
  }

// preconditions
// - all strides are >0
// - shape is everywhere >0
// - rank is >= 1
#define REDUCEFULL_OP(name, op, typ) \
  typ oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \
    typ res = 0; \
    TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
      typ accum = arr[arrlinidx]; \
      for (i64 i = 1; i < shape[rank - 1]; i++) { \
        accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \
      } \
      res = res op accum; \
    }); \
    return res; \
  }

// preconditions
// - all strides are >0
// - shape is everywhere >0
// - rank is >= 1
// Writes extreme index to outidx. If 'cmp' is '<', computes minindex ("argmin"); if '>', maxindex.
#define EXTREMUM_OP(name, cmp, typ) \
  void oxarop_extremum_ ## name ## _ ## typ(i64 *restrict outidx, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \
    typ best = arr[0]; \
    memset(outidx, 0, rank * sizeof(i64)); \
    TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
      bool found = false; \
      for (i64 i = 0; i < shape[rank - 1]; i++) { \
        if (arr[arrlinidx + i] cmp best) { \
          best = arr[arrlinidx + strides[rank - 1] * i]; \
          found = true; \
          outidx[rank - 1] = i; \
        } \
      } \
      if (found) memcpy(outidx, idx, (rank - 1) * sizeof(i64)); \
    }); \
  }

#define DOTPROD_OP(typ) \
  typ oxarop_dotprod_ ## typ(i64 length, const typ *arr1, const typ *arr2) { \
    typ res = 0; \
    for (i64 i = 0; i < length; i++) res += arr1[i] * arr2[i]; \
    return res; \
  }

#define DOTPROD_STRIDED_OP(typ) \
  typ oxarop_dotprod_ ## typ ## _strided(i64 length, i64 stride1, const typ *arr1, i64 stride2, const typ *arr2) { \
    typ res = 0; \
    for (i64 i = 0; i < length; i++) res += arr1[stride1 * i] * arr2[stride2 * i]; \
    return res; \
  }

// The 'double' version here is about 2x as fast as gcc's own vectorisation.
DOTPROD_OP(i32)
DOTPROD_OP(i64)
#ifdef OX_ARCH_INTEL
float oxarop_dotprod_float(i64 length, const float *arr1, const float *arr2) {
  __m128 accum = _mm_setzero_ps();
  i64 i;
  for (i = 0; i + 3 < length; i += 4) {
    accum = _mm_add_ps(accum, _mm_mul_ps(_mm_loadu_ps(arr1 + i), _mm_loadu_ps(arr2 + i)));
  }
  float dest[4];
  _mm_storeu_ps(dest, accum);
  float tot = dest[0] + dest[1] + dest[2] + dest[3];
  for (; i < length; i++) tot += arr1[i] * arr2[i];
  return tot;
}
double oxarop_dotprod_double(i64 length, const double *arr1, const double *arr2) {
  __m128d accum = _mm_setzero_pd();
  i64 i;
  for (i = 0; i + 1 < length; i += 2) {
    accum = _mm_add_pd(accum, _mm_mul_pd(_mm_loadu_pd(arr1 + i), _mm_loadu_pd(arr2 + i)));
  }
  double tot = _mm_cvtsd_f64(accum) + _mm_cvtsd_f64(_mm_unpackhi_pd(accum, accum));
  if (i < length) tot += arr1[i] * arr2[i];
  return tot;
}
#else
DOTPROD_OP(float)
DOTPROD_OP(double)
#endif

// preconditions:
// - all strides are >0
// - shape is everywhere >0
// - rank is >= 1
// - out has capacity for (shape[0] * ... * shape[rank - 2]) elements
// Reduces along the innermost dimension.
// 'out' will be filled densely in linearisation order.
#define DOTPROD_INNER_OP(typ) \
  void oxarop_dotprodinner_ ## typ(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *arr1, const i64 *strides2, const typ *arr2) { \
    if (strides1[rank - 1] == 1 && strides2[rank - 1] == 1) { \
      TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \
        out[outlinidx] = oxarop_dotprod_ ## typ(shape[rank - 1], arr1 + arrlinidx1, arr2 + arrlinidx2); \
      }); \
    } else if (strides1[rank - 1] == -1 && strides2[rank - 1] == -1) { \
      TARRAY_WALK2_NOINNER(again2, rank, shape, strides1, strides2, { \
        const i64 len = shape[rank - 1]; \
        out[outlinidx] = oxarop_dotprod_ ## typ(len, arr1 + arrlinidx1 - (len - 1), arr2 + arrlinidx2 - (len - 1)); \
      }); \
    } else { \
      TARRAY_WALK2_NOINNER(again3, rank, shape, strides1, strides2, { \
        out[outlinidx] = oxarop_dotprod_ ## typ ## _strided(shape[rank - 1], strides1[rank - 1], arr1 + arrlinidx1, strides2[rank - 1], arr2 + arrlinidx2); \
      }); \
    } \
  }


/*****************************************************************************
 *                           Entry point functions                           *
 *****************************************************************************/

__attribute__((noreturn, cold))
static void wrong_op(const char *name, int tag) {
  fprintf(stderr, "ox-arrays: Invalid operation tag passed to %s C code: %d\n", name, tag);
  abort();
}

enum binop_tag_t {
#undef LIST_BINOP
#define LIST_BINOP(name, id, hsop) name = id,
#include "arith_lists.h"
#undef LIST_BINOP
#define LIST_BINOP(name, id, hsop)
};

#define ENTRY_BINARY_STRIDED_OPS(typ) \
  void oxarop_binary_ ## typ ## _sv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \
    switch (tag) { \
      case BO_ADD: oxarop_op_add_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
      case BO_SUB: oxarop_op_sub_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
      case BO_MUL: oxarop_op_mul_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
      default: wrong_op("binary_sv_strided", tag); \
    } \
  } \
  void oxarop_binary_ ## typ ## _vs_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \
    switch (tag) { \
      case BO_ADD: oxarop_op_add_ ## typ ## _sv_strided(rank, shape, out, y, strides, x); break; \
      case BO_SUB: oxarop_op_sub_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
      case BO_MUL: oxarop_op_mul_ ## typ ## _sv_strided(rank, shape, out, y, strides, x); break; \
      default: wrong_op("binary_vs_strided", tag); \
    } \
  } \
  void oxarop_binary_ ## typ ## _vv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
    switch (tag) { \
      case BO_ADD: oxarop_op_add_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
      case BO_SUB: oxarop_op_sub_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
      case BO_MUL: oxarop_op_mul_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
      default: wrong_op("binary_vv_strided", tag); \
    } \
  }

enum ibinop_tag_t {
#undef LIST_IBINOP
#define LIST_IBINOP(name, id, hsop) name = id,
#include "arith_lists.h"
#undef LIST_IBINOP
#define LIST_IBINOP(name, id, hsop)
};

#define ENTRY_IBINARY_STRIDED_OPS(typ) \
  void oxarop_ibinary_ ## typ ## _sv_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \
    switch (tag) { \
      case IB_QUOT: oxarop_op_quot_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
      case IB_REM:  oxarop_op_rem_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
      default: wrong_op("ibinary_sv_strided", tag); \
    } \
  } \
  void oxarop_ibinary_ ## typ ## _vs_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \
    switch (tag) { \
      case IB_QUOT: oxarop_op_quot_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
      case IB_REM:  oxarop_op_rem_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
      default: wrong_op("ibinary_vs_strided", tag); \
    } \
  } \
  void oxarop_ibinary_ ## typ ## _vv_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
    switch (tag) { \
      case IB_QUOT: oxarop_op_quot_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
      case IB_REM:  oxarop_op_rem_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
      default: wrong_op("ibinary_vv_strided", tag); \
    } \
  }

enum fbinop_tag_t {
#undef LIST_FBINOP
#define LIST_FBINOP(name, id, hsop) name = id,
#include "arith_lists.h"
#undef LIST_FBINOP
#define LIST_FBINOP(name, id, hsop)
};

#define ENTRY_FBINARY_STRIDED_OPS(typ) \
  void oxarop_fbinary_ ## typ ## _sv_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \
    switch (tag) { \
      case FB_DIV: oxarop_op_fdiv_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
      case FB_POW: oxarop_op_pow_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
      case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
      case FB_ATAN2: oxarop_op_atan2_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
      default: wrong_op("fbinary_sv_strided", tag); \
    } \
  } \
  void oxarop_fbinary_ ## typ ## _vs_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \
    switch (tag) { \
      case FB_DIV: oxarop_op_fdiv_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
      case FB_POW: oxarop_op_pow_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
      case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
      case FB_ATAN2: oxarop_op_atan2_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
      default: wrong_op("fbinary_vs_strided", tag); \
    } \
  } \
  void oxarop_fbinary_ ## typ ## _vv_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
    switch (tag) { \
      case FB_DIV: oxarop_op_fdiv_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
      case FB_POW: oxarop_op_pow_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
      case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
      case FB_ATAN2: oxarop_op_atan2_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
      default: wrong_op("fbinary_vv_strided", tag); \
    } \
  }

enum unop_tag_t {
#undef LIST_UNOP
#define LIST_UNOP(name, id, _) name = id,
#include "arith_lists.h"
#undef LIST_UNOP
#define LIST_UNOP(name, id, _)
};

#define ENTRY_UNARY_STRIDED_OPS(typ) \
  void oxarop_unary_ ## typ ## _strided(enum unop_tag_t tag, i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *x) { \
    switch (tag) { \
      case UO_NEG: oxarop_op_neg_ ## typ ## _strided(rank, out, shape, strides, x); break; \
      case UO_ABS: oxarop_op_abs_ ## typ ## _strided(rank, out, shape, strides, x); break; \
      case UO_SIGNUM: oxarop_op_signum_ ## typ ## _strided(rank, out, shape, strides, x); break; \
      default: wrong_op("unary_strided", tag); \
    } \
  }

enum funop_tag_t {
#undef LIST_FUNOP
#define LIST_FUNOP(name, id, _) name = id,
#include "arith_lists.h"
#undef LIST_FUNOP
#define LIST_FUNOP(name, id, _)
};

#define ENTRY_FUNARY_STRIDED_OPS(typ) \
  void oxarop_funary_ ## typ ## _strided(enum funop_tag_t tag, i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *x) { \
    switch (tag) { \
      case FU_RECIP:    oxarop_op_recip_    ## typ ## _strided(rank, out, shape, strides, x); break; \
      case FU_EXP:      oxarop_op_exp_      ## typ ## _strided(rank, out, shape, strides, x); break; \
      case FU_LOG:      oxarop_op_log_      ## typ ## _strided(rank, out, shape, strides, x); break; \
      case FU_SQRT:     oxarop_op_sqrt_     ## typ ## _strided(rank, out, shape, strides, x); break; \
      case FU_SIN:      oxarop_op_sin_      ## typ ## _strided(rank, out, shape, strides, x); break; \
      case FU_COS:      oxarop_op_cos_      ## typ ## _strided(rank, out, shape, strides, x); break; \
      case FU_TAN:      oxarop_op_tan_      ## typ ## _strided(rank, out, shape, strides, x); break; \
      case FU_ASIN:     oxarop_op_asin_     ## typ ## _strided(rank, out, shape, strides, x); break; \
      case FU_ACOS:     oxarop_op_acos_     ## typ ## _strided(rank, out, shape, strides, x); break; \
      case FU_ATAN:     oxarop_op_atan_     ## typ ## _strided(rank, out, shape, strides, x); break; \
      case FU_SINH:     oxarop_op_sinh_     ## typ ## _strided(rank, out, shape, strides, x); break; \
      case FU_COSH:     oxarop_op_cosh_     ## typ ## _strided(rank, out, shape, strides, x); break; \
      case FU_TANH:     oxarop_op_tanh_     ## typ ## _strided(rank, out, shape, strides, x); break; \
      case FU_ASINH:    oxarop_op_asinh_    ## typ ## _strided(rank, out, shape, strides, x); break; \
      case FU_ACOSH:    oxarop_op_acosh_    ## typ ## _strided(rank, out, shape, strides, x); break; \
      case FU_ATANH:    oxarop_op_atanh_    ## typ ## _strided(rank, out, shape, strides, x); break; \
      case FU_LOG1P:    oxarop_op_log1p_    ## typ ## _strided(rank, out, shape, strides, x); break; \
      case FU_EXPM1:    oxarop_op_expm1_    ## typ ## _strided(rank, out, shape, strides, x); break; \
      case FU_LOG1PEXP: oxarop_op_log1pexp_ ## typ ## _strided(rank, out, shape, strides, x); break; \
      case FU_LOG1MEXP: oxarop_op_log1mexp_ ## typ ## _strided(rank, out, shape, strides, x); break; \
      default: wrong_op("funary_strided", tag); \
    } \
  }

enum redop_tag_t {
#undef LIST_REDOP
#define LIST_REDOP(name, id, _) name = id,
#include "arith_lists.h"
#undef LIST_REDOP
#define LIST_REDOP(name, id, _)
};

#define ENTRY_REDUCE1_OPS(typ) \
  void oxarop_reduce1_ ## typ(enum redop_tag_t tag, i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *arr) { \
    switch (tag) { \
      case RO_SUM: oxarop_op_sum1_ ## typ(rank, out, shape, strides, arr); break; \
      case RO_PRODUCT: oxarop_op_product1_ ## typ(rank, out, shape, strides, arr); break; \
      default: wrong_op("reduce", tag); \
    } \
  }

#define ENTRY_REDUCEFULL_OPS(typ) \
  typ oxarop_reducefull_ ## typ(enum redop_tag_t tag, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \
    switch (tag) { \
      case RO_SUM: return oxarop_op_sumfull_ ## typ(rank, shape, strides, arr); \
      case RO_PRODUCT: return oxarop_op_productfull_ ## typ(rank, shape, strides, arr); \
      default: wrong_op("reduce", tag); \
    } \
  }


/*****************************************************************************
 *                        Generate all the functions                         *
 *****************************************************************************/

#define INT_TYPES_XLIST X(i32) X(i64)
#define FLOAT_TYPES_XLIST X(double) X(float)
#define NUM_TYPES_XLIST INT_TYPES_XLIST FLOAT_TYPES_XLIST

#define X(typ) \
  COMM_OP_STRIDED(add, +, typ) \
  NONCOMM_OP_STRIDED(sub, -, typ) \
  COMM_OP_STRIDED(mul, *, typ) \
  UNARY_OP_STRIDED(neg, -, typ) \
  UNARY_OP_STRIDED(abs, GEN_ABS, typ) \
  UNARY_OP_STRIDED(signum, GEN_SIGNUM, typ) \
  REDUCE1_OP(sum1, +, typ) \
  REDUCE1_OP(product1, *, typ) \
  REDUCEFULL_OP(sumfull, +, typ) \
  REDUCEFULL_OP(productfull, *, typ) \
  ENTRY_BINARY_STRIDED_OPS(typ) \
  ENTRY_UNARY_STRIDED_OPS(typ) \
  ENTRY_REDUCE1_OPS(typ) \
  ENTRY_REDUCEFULL_OPS(typ) \
  EXTREMUM_OP(min, <, typ) \
  EXTREMUM_OP(max, >, typ) \
  DOTPROD_STRIDED_OP(typ) \
  DOTPROD_INNER_OP(typ)
NUM_TYPES_XLIST
#undef X

#define X(typ) \
  NONCOMM_OP_STRIDED(quot, /, typ) \
  NONCOMM_OP_STRIDED(rem, %, typ) \
  ENTRY_IBINARY_STRIDED_OPS(typ)
INT_TYPES_XLIST
#undef X

#define X(typ) \
  NONCOMM_OP_STRIDED(fdiv, /, typ) \
  PREFIX_BINOP_STRIDED(pow, GEN_POW, typ) \
  PREFIX_BINOP_STRIDED(logbase, GEN_LOGBASE, typ) \
  PREFIX_BINOP_STRIDED(atan2, GEN_ATAN2, typ) \
  UNARY_OP_STRIDED(recip, 1.0/, typ) \
  UNARY_OP_STRIDED(exp, GEN_EXP, typ) \
  UNARY_OP_STRIDED(log, GEN_LOG, typ) \
  UNARY_OP_STRIDED(sqrt, GEN_SQRT, typ) \
  UNARY_OP_STRIDED(sin, GEN_SIN, typ) \
  UNARY_OP_STRIDED(cos, GEN_COS, typ) \
  UNARY_OP_STRIDED(tan, GEN_TAN, typ) \
  UNARY_OP_STRIDED(asin, GEN_ASIN, typ) \
  UNARY_OP_STRIDED(acos, GEN_ACOS, typ) \
  UNARY_OP_STRIDED(atan, GEN_ATAN, typ) \
  UNARY_OP_STRIDED(sinh, GEN_SINH, typ) \
  UNARY_OP_STRIDED(cosh, GEN_COSH, typ) \
  UNARY_OP_STRIDED(tanh, GEN_TANH, typ) \
  UNARY_OP_STRIDED(asinh, GEN_ASINH, typ) \
  UNARY_OP_STRIDED(acosh, GEN_ACOSH, typ) \
  UNARY_OP_STRIDED(atanh, GEN_ATANH, typ) \
  UNARY_OP_STRIDED(log1p, GEN_LOG1P, typ) \
  UNARY_OP_STRIDED(expm1, GEN_EXPM1, typ) \
  UNARY_OP_STRIDED(log1pexp, GEN_LOG1PEXP, typ) \
  UNARY_OP_STRIDED(log1mexp, GEN_LOG1MEXP, typ) \
  ENTRY_FBINARY_STRIDED_OPS(typ) \
  ENTRY_FUNARY_STRIDED_OPS(typ)
FLOAT_TYPES_XLIST
#undef X

// Note: [zero-length VLA]
//
// Zero-length variable-length arrays are not allowed in C(99). Thus whenever we
// have a VLA that could sometimes suffice to be empty (e.g. `idx` in the
// TARRAY_WALK_NOINNER macros), we tweak the length formula (typically by just
// adding 1) so that it never ends up empty.