diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-14 14:40:02 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-14 14:40:02 +0100 |
commit | 639acb0abed995400d984203684e178a11d91fa1 (patch) | |
tree | 48828391e23b2fb53578a1554864f0ab6af913d1 | |
parent | 08e139de6bfeba885cacec1ad5600b85cd0f0947 (diff) |
arith: Remove CASE1, add restrict
Turns out that GCC already splits generates separate code for an inner
stride of 1 automatically, so no need to do fancy stuff in C.
Also, GCC generated a whole bunch of superfluous code to correctly
handle the case where output and input arrays overlap; since this never
happens in our case, let's add `restrict` and save some binary size.
-rw-r--r-- | cbits/arith.c | 75 |
1 files changed, 32 insertions, 43 deletions
diff --git a/cbits/arith.c b/cbits/arith.c index c984255..752fc1c 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -161,32 +161,21 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { } \ } while (false) -// Same as TARRAY_WALK_NOINNER, except the body is specialised twice: once on -// strides[rank-1] == 1 and a fallback case. -#define TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, body) \ - do { \ - if ((strides)[(rank) - 1] == 1) { \ - TARRAY_WALK_NOINNER(tar_wa_again1, rank, shape, strides, body); \ - } else { \ - TARRAY_WALK_NOINNER(tar_wa_again2, rank, shape, strides, body); \ - } \ - } while (false) - /***************************************************************************** * Kernel functions * *****************************************************************************/ #define COMM_OP_STRIDED(name, op, typ) \ - static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \ - TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ + static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \ + TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ for (i64 i = 0; i < shape[rank - 1]; i++) { \ out[outlinidx * shape[rank - 1] + i] = x op y[arrlinidx + strides[rank - 1] * i]; \ } \ }); \ } \ - static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ - TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \ + static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ + TARRAY_WALK2_NOINNER(again, rank, shape, strides1, strides2, { \ for (i64 i = 0; i < shape[rank - 1]; i++) { \ out[outlinidx * shape[rank - 1] + i] = x[arrlinidx1 + strides1[rank - 1] * i] op y[arrlinidx2 + strides2[rank - 1] * i]; \ } \ @@ -195,8 +184,8 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { #define NONCOMM_OP_STRIDED(name, op, typ) \ COMM_OP_STRIDED(name, op, typ) \ - static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \ - TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ + static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \ + TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ for (i64 i = 0; i < shape[rank - 1]; i++) { \ out[outlinidx * shape[rank - 1] + i] = x[arrlinidx + strides[rank - 1] * i] op y; \ } \ @@ -204,22 +193,22 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { } #define PREFIX_BINOP_STRIDED(name, op, typ) \ - static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \ - TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ + static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \ + TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ for (i64 i = 0; i < shape[rank - 1]; i++) { \ out[outlinidx * shape[rank - 1] + i] = op(x, y[arrlinidx + strides[rank - 1] * i]); \ } \ }); \ } \ - static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ - TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \ + static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ + TARRAY_WALK2_NOINNER(again, rank, shape, strides1, strides2, { \ for (i64 i = 0; i < shape[rank - 1]; i++) { \ out[outlinidx * shape[rank - 1] + i] = op(x[arrlinidx1 + strides1[rank - 1] * i], y[arrlinidx2 + strides2[rank - 1] * i]); \ } \ }); \ } \ - static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \ - TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ + static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \ + TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ for (i64 i = 0; i < shape[rank - 1]; i++) { \ out[outlinidx * shape[rank - 1] + i] = op(x[arrlinidx + strides[rank - 1] * i], y); \ } \ @@ -227,13 +216,13 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { } #define UNARY_OP_STRIDED(name, op, typ) \ - static void oxarop_op_ ## name ## _ ## typ ## _strided(i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *arr) { \ + static void oxarop_op_ ## name ## _ ## typ ## _strided(i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *arr) { \ /* fprintf(stderr, "oxarop_op_" #name "_" #typ "_strided: rank=%ld shape=", rank); \ print_shape(stderr, rank, shape); \ fprintf(stderr, " strides="); \ print_shape(stderr, rank, strides); \ fprintf(stderr, "\n"); */ \ - TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ + TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ for (i64 i = 0; i < shape[rank - 1]; i++) { \ out[outlinidx * shape[rank - 1] + i] = op(arr[arrlinidx + strides[rank - 1] * i]); \ } \ @@ -248,8 +237,8 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { // Reduces along the innermost dimension. // 'out' will be filled densely in linearisation order. #define REDUCE1_OP(name, op, typ) \ - static void oxarop_op_ ## name ## _ ## typ(i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *arr) { \ - TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ + static void oxarop_op_ ## name ## _ ## typ(i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *arr) { \ + TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ typ accum = arr[arrlinidx]; \ for (i64 i = 1; i < shape[rank - 1]; i++) { \ accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \ @@ -265,7 +254,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { #define REDUCEFULL_OP(name, op, typ) \ typ oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \ typ res = 0; \ - TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ + TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ typ accum = arr[arrlinidx]; \ for (i64 i = 1; i < shape[rank - 1]; i++) { \ accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \ @@ -281,10 +270,10 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { // - rank is >= 1 // Writes extreme index to outidx. If 'cmp' is '<', computes minindex ("argmin"); if '>', maxindex. #define EXTREMUM_OP(name, cmp, typ) \ - void oxarop_extremum_ ## name ## _ ## typ(i64 *outidx, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \ + void oxarop_extremum_ ## name ## _ ## typ(i64 *restrict outidx, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \ typ best = arr[0]; \ memset(outidx, 0, rank * sizeof(i64)); \ - TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \ + TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ bool found = false; \ for (i64 i = 0; i < shape[rank - 1]; i++) { \ if (arr[arrlinidx + i] cmp best) { \ @@ -350,7 +339,7 @@ DOTPROD_OP(double) // Reduces along the innermost dimension. // 'out' will be filled densely in linearisation order. #define DOTPROD_INNER_OP(typ) \ - void oxarop_dotprodinner_ ## typ(i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *arr1, const i64 *strides2, const typ *arr2) { \ + void oxarop_dotprodinner_ ## typ(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *arr1, const i64 *strides2, const typ *arr2) { \ if (strides1[rank - 1] == 1 && strides2[rank - 1] == 1) { \ TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \ out[outlinidx] = oxarop_dotprod_ ## typ(shape[rank - 1], arr1 + arrlinidx1, arr2 + arrlinidx2); \ @@ -387,7 +376,7 @@ enum binop_tag_t { }; #define ENTRY_BINARY_STRIDED_OPS(typ) \ - void oxarop_binary_ ## typ ## _sv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \ + void oxarop_binary_ ## typ ## _sv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \ switch (tag) { \ case BO_ADD: oxarop_op_add_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ case BO_SUB: oxarop_op_sub_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ @@ -395,7 +384,7 @@ enum binop_tag_t { default: wrong_op("binary_sv_strided", tag); \ } \ } \ - void oxarop_binary_ ## typ ## _vs_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \ + void oxarop_binary_ ## typ ## _vs_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \ switch (tag) { \ case BO_ADD: oxarop_op_add_ ## typ ## _sv_strided(rank, shape, out, y, strides, x); break; \ case BO_SUB: oxarop_op_sub_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ @@ -403,7 +392,7 @@ enum binop_tag_t { default: wrong_op("binary_vs_strided", tag); \ } \ } \ - void oxarop_binary_ ## typ ## _vv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ + void oxarop_binary_ ## typ ## _vv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ switch (tag) { \ case BO_ADD: oxarop_op_add_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ case BO_SUB: oxarop_op_sub_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ @@ -421,21 +410,21 @@ enum ibinop_tag_t { }; #define ENTRY_IBINARY_STRIDED_OPS(typ) \ - void oxarop_ibinary_ ## typ ## _sv_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \ + void oxarop_ibinary_ ## typ ## _sv_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \ switch (tag) { \ case IB_QUOT: oxarop_op_quot_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ case IB_REM: oxarop_op_rem_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ default: wrong_op("ibinary_sv_strided", tag); \ } \ } \ - void oxarop_ibinary_ ## typ ## _vs_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \ + void oxarop_ibinary_ ## typ ## _vs_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \ switch (tag) { \ case IB_QUOT: oxarop_op_quot_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ case IB_REM: oxarop_op_rem_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ default: wrong_op("ibinary_vs_strided", tag); \ } \ } \ - void oxarop_ibinary_ ## typ ## _vv_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ + void oxarop_ibinary_ ## typ ## _vv_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ switch (tag) { \ case IB_QUOT: oxarop_op_quot_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ case IB_REM: oxarop_op_rem_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ @@ -452,7 +441,7 @@ enum fbinop_tag_t { }; #define ENTRY_FBINARY_STRIDED_OPS(typ) \ - void oxarop_fbinary_ ## typ ## _sv_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \ + void oxarop_fbinary_ ## typ ## _sv_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \ switch (tag) { \ case FB_DIV: oxarop_op_fdiv_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ case FB_POW: oxarop_op_pow_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ @@ -461,7 +450,7 @@ enum fbinop_tag_t { default: wrong_op("fbinary_sv_strided", tag); \ } \ } \ - void oxarop_fbinary_ ## typ ## _vs_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \ + void oxarop_fbinary_ ## typ ## _vs_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \ switch (tag) { \ case FB_DIV: oxarop_op_fdiv_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ case FB_POW: oxarop_op_pow_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ @@ -470,7 +459,7 @@ enum fbinop_tag_t { default: wrong_op("fbinary_vs_strided", tag); \ } \ } \ - void oxarop_fbinary_ ## typ ## _vv_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ + void oxarop_fbinary_ ## typ ## _vv_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ switch (tag) { \ case FB_DIV: oxarop_op_fdiv_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ case FB_POW: oxarop_op_pow_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ @@ -489,7 +478,7 @@ enum unop_tag_t { }; #define ENTRY_UNARY_STRIDED_OPS(typ) \ - void oxarop_unary_ ## typ ## _strided(enum unop_tag_t tag, i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *x) { \ + void oxarop_unary_ ## typ ## _strided(enum unop_tag_t tag, i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *x) { \ switch (tag) { \ case UO_NEG: oxarop_op_neg_ ## typ ## _strided(rank, out, shape, strides, x); break; \ case UO_ABS: oxarop_op_abs_ ## typ ## _strided(rank, out, shape, strides, x); break; \ @@ -507,7 +496,7 @@ enum funop_tag_t { }; #define ENTRY_FUNARY_STRIDED_OPS(typ) \ - void oxarop_funary_ ## typ ## _strided(enum funop_tag_t tag, i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *x) { \ + void oxarop_funary_ ## typ ## _strided(enum funop_tag_t tag, i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *x) { \ switch (tag) { \ case FU_RECIP: oxarop_op_recip_ ## typ ## _strided(rank, out, shape, strides, x); break; \ case FU_EXP: oxarop_op_exp_ ## typ ## _strided(rank, out, shape, strides, x); break; \ @@ -542,7 +531,7 @@ enum redop_tag_t { }; #define ENTRY_REDUCE1_OPS(typ) \ - void oxarop_reduce1_ ## typ(enum redop_tag_t tag, i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *arr) { \ + void oxarop_reduce1_ ## typ(enum redop_tag_t tag, i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *arr) { \ switch (tag) { \ case RO_SUM: oxarop_op_sum1_ ## typ(rank, out, shape, strides, arr); break; \ case RO_PRODUCT: oxarop_op_product1_ ## typ(rank, out, shape, strides, arr); break; \ |