diff options
| -rw-r--r-- | cbits/arith.c | 75 | 
1 files changed, 32 insertions, 43 deletions
| diff --git a/cbits/arith.c b/cbits/arith.c index c984255..752fc1c 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -161,32 +161,21 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {      } \    } while (false) -// Same as TARRAY_WALK_NOINNER, except the body is specialised twice: once on -// strides[rank-1] == 1 and a fallback case. -#define TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, body) \ -  do { \ -    if ((strides)[(rank) - 1] == 1) { \ -      TARRAY_WALK_NOINNER(tar_wa_again1, rank, shape, strides, body); \ -    } else { \ -      TARRAY_WALK_NOINNER(tar_wa_again2, rank, shape, strides, body); \ -    } \ -  } while (false) -  /*****************************************************************************   *                             Kernel functions                              *   *****************************************************************************/  #define COMM_OP_STRIDED(name, op, typ) \ -  static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \ -    TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ +  static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \ +    TARRAY_WALK_NOINNER(again, rank, shape, strides, { \        for (i64 i = 0; i < shape[rank - 1]; i++) { \          out[outlinidx * shape[rank - 1] + i] = x op y[arrlinidx + strides[rank - 1] * i]; \        } \      }); \    } \ -  static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ -    TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \ +  static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ +    TARRAY_WALK2_NOINNER(again, rank, shape, strides1, strides2, { \        for (i64 i = 0; i < shape[rank - 1]; i++) { \          out[outlinidx * shape[rank - 1] + i] = x[arrlinidx1 + strides1[rank - 1] * i] op y[arrlinidx2 + strides2[rank - 1] * i]; \        } \ @@ -195,8 +184,8 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {  #define NONCOMM_OP_STRIDED(name, op, typ) \    COMM_OP_STRIDED(name, op, typ) \ -  static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \ -    TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ +  static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \ +    TARRAY_WALK_NOINNER(again, rank, shape, strides, { \        for (i64 i = 0; i < shape[rank - 1]; i++) { \          out[outlinidx * shape[rank - 1] + i] = x[arrlinidx + strides[rank - 1] * i] op y; \        } \ @@ -204,22 +193,22 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {    }  #define PREFIX_BINOP_STRIDED(name, op, typ) \ -  static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \ -    TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ +  static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \ +    TARRAY_WALK_NOINNER(again, rank, shape, strides, { \        for (i64 i = 0; i < shape[rank - 1]; i++) { \          out[outlinidx * shape[rank - 1] + i] = op(x, y[arrlinidx + strides[rank - 1] * i]); \        } \      }); \    } \ -  static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ -    TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \ +  static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ +    TARRAY_WALK2_NOINNER(again, rank, shape, strides1, strides2, { \        for (i64 i = 0; i < shape[rank - 1]; i++) { \          out[outlinidx * shape[rank - 1] + i] = op(x[arrlinidx1 + strides1[rank - 1] * i], y[arrlinidx2 + strides2[rank - 1] * i]); \        } \      }); \    } \ -  static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \ -    TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ +  static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \ +    TARRAY_WALK_NOINNER(again, rank, shape, strides, { \        for (i64 i = 0; i < shape[rank - 1]; i++) { \          out[outlinidx * shape[rank - 1] + i] = op(x[arrlinidx + strides[rank - 1] * i], y); \        } \ @@ -227,13 +216,13 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {    }  #define UNARY_OP_STRIDED(name, op, typ) \ -  static void oxarop_op_ ## name ## _ ## typ ## _strided(i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *arr) { \ +  static void oxarop_op_ ## name ## _ ## typ ## _strided(i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *arr) { \      /* fprintf(stderr, "oxarop_op_" #name "_" #typ "_strided: rank=%ld shape=", rank); \      print_shape(stderr, rank, shape); \      fprintf(stderr, " strides="); \      print_shape(stderr, rank, strides); \      fprintf(stderr, "\n"); */ \ -    TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ +    TARRAY_WALK_NOINNER(again, rank, shape, strides, { \        for (i64 i = 0; i < shape[rank - 1]; i++) { \          out[outlinidx * shape[rank - 1] + i] = op(arr[arrlinidx + strides[rank - 1] * i]); \        } \ @@ -248,8 +237,8 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {  // Reduces along the innermost dimension.  // 'out' will be filled densely in linearisation order.  #define REDUCE1_OP(name, op, typ) \ -  static void oxarop_op_ ## name ## _ ## typ(i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *arr) { \ -    TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ +  static void oxarop_op_ ## name ## _ ## typ(i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *arr) { \ +    TARRAY_WALK_NOINNER(again, rank, shape, strides, { \        typ accum = arr[arrlinidx]; \        for (i64 i = 1; i < shape[rank - 1]; i++) { \          accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \ @@ -265,7 +254,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {  #define REDUCEFULL_OP(name, op, typ) \    typ oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \      typ res = 0; \ -    TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ +    TARRAY_WALK_NOINNER(again, rank, shape, strides, { \        typ accum = arr[arrlinidx]; \        for (i64 i = 1; i < shape[rank - 1]; i++) { \          accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \ @@ -281,10 +270,10 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {  // - rank is >= 1  // Writes extreme index to outidx. If 'cmp' is '<', computes minindex ("argmin"); if '>', maxindex.  #define EXTREMUM_OP(name, cmp, typ) \ -  void oxarop_extremum_ ## name ## _ ## typ(i64 *outidx, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \ +  void oxarop_extremum_ ## name ## _ ## typ(i64 *restrict outidx, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \      typ best = arr[0]; \      memset(outidx, 0, rank * sizeof(i64)); \ -    TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ +    TARRAY_WALK_NOINNER(again, rank, shape, strides, { \        bool found = false; \        for (i64 i = 0; i < shape[rank - 1]; i++) { \          if (arr[arrlinidx + i] cmp best) { \ @@ -350,7 +339,7 @@ DOTPROD_OP(double)  // Reduces along the innermost dimension.  // 'out' will be filled densely in linearisation order.  #define DOTPROD_INNER_OP(typ) \ -  void oxarop_dotprodinner_ ## typ(i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *arr1, const i64 *strides2, const typ *arr2) { \ +  void oxarop_dotprodinner_ ## typ(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *arr1, const i64 *strides2, const typ *arr2) { \      if (strides1[rank - 1] == 1 && strides2[rank - 1] == 1) { \        TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \          out[outlinidx] = oxarop_dotprod_ ## typ(shape[rank - 1], arr1 + arrlinidx1, arr2 + arrlinidx2); \ @@ -387,7 +376,7 @@ enum binop_tag_t {  };  #define ENTRY_BINARY_STRIDED_OPS(typ) \ -  void oxarop_binary_ ## typ ## _sv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \ +  void oxarop_binary_ ## typ ## _sv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \      switch (tag) { \        case BO_ADD: oxarop_op_add_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \        case BO_SUB: oxarop_op_sub_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ @@ -395,7 +384,7 @@ enum binop_tag_t {        default: wrong_op("binary_sv_strided", tag); \      } \    } \ -  void oxarop_binary_ ## typ ## _vs_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \ +  void oxarop_binary_ ## typ ## _vs_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \      switch (tag) { \        case BO_ADD: oxarop_op_add_ ## typ ## _sv_strided(rank, shape, out, y, strides, x); break; \        case BO_SUB: oxarop_op_sub_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ @@ -403,7 +392,7 @@ enum binop_tag_t {        default: wrong_op("binary_vs_strided", tag); \      } \    } \ -  void oxarop_binary_ ## typ ## _vv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ +  void oxarop_binary_ ## typ ## _vv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \      switch (tag) { \        case BO_ADD: oxarop_op_add_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \        case BO_SUB: oxarop_op_sub_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ @@ -421,21 +410,21 @@ enum ibinop_tag_t {  };  #define ENTRY_IBINARY_STRIDED_OPS(typ) \ -  void oxarop_ibinary_ ## typ ## _sv_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \ +  void oxarop_ibinary_ ## typ ## _sv_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \      switch (tag) { \        case IB_QUOT: oxarop_op_quot_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \        case IB_REM:  oxarop_op_rem_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \        default: wrong_op("ibinary_sv_strided", tag); \      } \    } \ -  void oxarop_ibinary_ ## typ ## _vs_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \ +  void oxarop_ibinary_ ## typ ## _vs_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \      switch (tag) { \        case IB_QUOT: oxarop_op_quot_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \        case IB_REM:  oxarop_op_rem_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \        default: wrong_op("ibinary_vs_strided", tag); \      } \    } \ -  void oxarop_ibinary_ ## typ ## _vv_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ +  void oxarop_ibinary_ ## typ ## _vv_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \      switch (tag) { \        case IB_QUOT: oxarop_op_quot_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \        case IB_REM:  oxarop_op_rem_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ @@ -452,7 +441,7 @@ enum fbinop_tag_t {  };  #define ENTRY_FBINARY_STRIDED_OPS(typ) \ -  void oxarop_fbinary_ ## typ ## _sv_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \ +  void oxarop_fbinary_ ## typ ## _sv_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \      switch (tag) { \        case FB_DIV: oxarop_op_fdiv_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \        case FB_POW: oxarop_op_pow_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ @@ -461,7 +450,7 @@ enum fbinop_tag_t {        default: wrong_op("fbinary_sv_strided", tag); \      } \    } \ -  void oxarop_fbinary_ ## typ ## _vs_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \ +  void oxarop_fbinary_ ## typ ## _vs_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \      switch (tag) { \        case FB_DIV: oxarop_op_fdiv_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \        case FB_POW: oxarop_op_pow_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ @@ -470,7 +459,7 @@ enum fbinop_tag_t {        default: wrong_op("fbinary_vs_strided", tag); \      } \    } \ -  void oxarop_fbinary_ ## typ ## _vv_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ +  void oxarop_fbinary_ ## typ ## _vv_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \      switch (tag) { \        case FB_DIV: oxarop_op_fdiv_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \        case FB_POW: oxarop_op_pow_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ @@ -489,7 +478,7 @@ enum unop_tag_t {  };  #define ENTRY_UNARY_STRIDED_OPS(typ) \ -  void oxarop_unary_ ## typ ## _strided(enum unop_tag_t tag, i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *x) { \ +  void oxarop_unary_ ## typ ## _strided(enum unop_tag_t tag, i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *x) { \      switch (tag) { \        case UO_NEG: oxarop_op_neg_ ## typ ## _strided(rank, out, shape, strides, x); break; \        case UO_ABS: oxarop_op_abs_ ## typ ## _strided(rank, out, shape, strides, x); break; \ @@ -507,7 +496,7 @@ enum funop_tag_t {  };  #define ENTRY_FUNARY_STRIDED_OPS(typ) \ -  void oxarop_funary_ ## typ ## _strided(enum funop_tag_t tag, i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *x) { \ +  void oxarop_funary_ ## typ ## _strided(enum funop_tag_t tag, i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *x) { \      switch (tag) { \        case FU_RECIP:    oxarop_op_recip_    ## typ ## _strided(rank, out, shape, strides, x); break; \        case FU_EXP:      oxarop_op_exp_      ## typ ## _strided(rank, out, shape, strides, x); break; \ @@ -542,7 +531,7 @@ enum redop_tag_t {  };  #define ENTRY_REDUCE1_OPS(typ) \ -  void oxarop_reduce1_ ## typ(enum redop_tag_t tag, i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *arr) { \ +  void oxarop_reduce1_ ## typ(enum redop_tag_t tag, i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *arr) { \      switch (tag) { \        case RO_SUM: oxarop_op_sum1_ ## typ(rank, out, shape, strides, arr); break; \        case RO_PRODUCT: oxarop_op_product1_ ## typ(rank, out, shape, strides, arr); break; \ | 
