aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-14 14:40:02 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-14 14:40:02 +0100
commit639acb0abed995400d984203684e178a11d91fa1 (patch)
tree48828391e23b2fb53578a1554864f0ab6af913d1
parent08e139de6bfeba885cacec1ad5600b85cd0f0947 (diff)
arith: Remove CASE1, add restrict
Turns out that GCC already splits generates separate code for an inner stride of 1 automatically, so no need to do fancy stuff in C. Also, GCC generated a whole bunch of superfluous code to correctly handle the case where output and input arrays overlap; since this never happens in our case, let's add `restrict` and save some binary size.
-rw-r--r--cbits/arith.c75
1 files changed, 32 insertions, 43 deletions
diff --git a/cbits/arith.c b/cbits/arith.c
index c984255..752fc1c 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -161,32 +161,21 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
} \
} while (false)
-// Same as TARRAY_WALK_NOINNER, except the body is specialised twice: once on
-// strides[rank-1] == 1 and a fallback case.
-#define TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, body) \
- do { \
- if ((strides)[(rank) - 1] == 1) { \
- TARRAY_WALK_NOINNER(tar_wa_again1, rank, shape, strides, body); \
- } else { \
- TARRAY_WALK_NOINNER(tar_wa_again2, rank, shape, strides, body); \
- } \
- } while (false)
-
/*****************************************************************************
* Kernel functions *
*****************************************************************************/
#define COMM_OP_STRIDED(name, op, typ) \
- static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \
- TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \
+ static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \
+ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
for (i64 i = 0; i < shape[rank - 1]; i++) { \
out[outlinidx * shape[rank - 1] + i] = x op y[arrlinidx + strides[rank - 1] * i]; \
} \
}); \
} \
- static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
- TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \
+ static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
+ TARRAY_WALK2_NOINNER(again, rank, shape, strides1, strides2, { \
for (i64 i = 0; i < shape[rank - 1]; i++) { \
out[outlinidx * shape[rank - 1] + i] = x[arrlinidx1 + strides1[rank - 1] * i] op y[arrlinidx2 + strides2[rank - 1] * i]; \
} \
@@ -195,8 +184,8 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
#define NONCOMM_OP_STRIDED(name, op, typ) \
COMM_OP_STRIDED(name, op, typ) \
- static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \
- TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \
+ static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \
+ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
for (i64 i = 0; i < shape[rank - 1]; i++) { \
out[outlinidx * shape[rank - 1] + i] = x[arrlinidx + strides[rank - 1] * i] op y; \
} \
@@ -204,22 +193,22 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
}
#define PREFIX_BINOP_STRIDED(name, op, typ) \
- static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \
- TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \
+ static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \
+ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
for (i64 i = 0; i < shape[rank - 1]; i++) { \
out[outlinidx * shape[rank - 1] + i] = op(x, y[arrlinidx + strides[rank - 1] * i]); \
} \
}); \
} \
- static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
- TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \
+ static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
+ TARRAY_WALK2_NOINNER(again, rank, shape, strides1, strides2, { \
for (i64 i = 0; i < shape[rank - 1]; i++) { \
out[outlinidx * shape[rank - 1] + i] = op(x[arrlinidx1 + strides1[rank - 1] * i], y[arrlinidx2 + strides2[rank - 1] * i]); \
} \
}); \
} \
- static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \
- TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \
+ static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \
+ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
for (i64 i = 0; i < shape[rank - 1]; i++) { \
out[outlinidx * shape[rank - 1] + i] = op(x[arrlinidx + strides[rank - 1] * i], y); \
} \
@@ -227,13 +216,13 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
}
#define UNARY_OP_STRIDED(name, op, typ) \
- static void oxarop_op_ ## name ## _ ## typ ## _strided(i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *arr) { \
+ static void oxarop_op_ ## name ## _ ## typ ## _strided(i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *arr) { \
/* fprintf(stderr, "oxarop_op_" #name "_" #typ "_strided: rank=%ld shape=", rank); \
print_shape(stderr, rank, shape); \
fprintf(stderr, " strides="); \
print_shape(stderr, rank, strides); \
fprintf(stderr, "\n"); */ \
- TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \
+ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
for (i64 i = 0; i < shape[rank - 1]; i++) { \
out[outlinidx * shape[rank - 1] + i] = op(arr[arrlinidx + strides[rank - 1] * i]); \
} \
@@ -248,8 +237,8 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
// Reduces along the innermost dimension.
// 'out' will be filled densely in linearisation order.
#define REDUCE1_OP(name, op, typ) \
- static void oxarop_op_ ## name ## _ ## typ(i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *arr) { \
- TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \
+ static void oxarop_op_ ## name ## _ ## typ(i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *arr) { \
+ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
typ accum = arr[arrlinidx]; \
for (i64 i = 1; i < shape[rank - 1]; i++) { \
accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \
@@ -265,7 +254,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
#define REDUCEFULL_OP(name, op, typ) \
typ oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \
typ res = 0; \
- TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \
+ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
typ accum = arr[arrlinidx]; \
for (i64 i = 1; i < shape[rank - 1]; i++) { \
accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \
@@ -281,10 +270,10 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
// - rank is >= 1
// Writes extreme index to outidx. If 'cmp' is '<', computes minindex ("argmin"); if '>', maxindex.
#define EXTREMUM_OP(name, cmp, typ) \
- void oxarop_extremum_ ## name ## _ ## typ(i64 *outidx, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \
+ void oxarop_extremum_ ## name ## _ ## typ(i64 *restrict outidx, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \
typ best = arr[0]; \
memset(outidx, 0, rank * sizeof(i64)); \
- TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \
+ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \
bool found = false; \
for (i64 i = 0; i < shape[rank - 1]; i++) { \
if (arr[arrlinidx + i] cmp best) { \
@@ -350,7 +339,7 @@ DOTPROD_OP(double)
// Reduces along the innermost dimension.
// 'out' will be filled densely in linearisation order.
#define DOTPROD_INNER_OP(typ) \
- void oxarop_dotprodinner_ ## typ(i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *arr1, const i64 *strides2, const typ *arr2) { \
+ void oxarop_dotprodinner_ ## typ(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *arr1, const i64 *strides2, const typ *arr2) { \
if (strides1[rank - 1] == 1 && strides2[rank - 1] == 1) { \
TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \
out[outlinidx] = oxarop_dotprod_ ## typ(shape[rank - 1], arr1 + arrlinidx1, arr2 + arrlinidx2); \
@@ -387,7 +376,7 @@ enum binop_tag_t {
};
#define ENTRY_BINARY_STRIDED_OPS(typ) \
- void oxarop_binary_ ## typ ## _sv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \
+ void oxarop_binary_ ## typ ## _sv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \
switch (tag) { \
case BO_ADD: oxarop_op_add_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
case BO_SUB: oxarop_op_sub_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
@@ -395,7 +384,7 @@ enum binop_tag_t {
default: wrong_op("binary_sv_strided", tag); \
} \
} \
- void oxarop_binary_ ## typ ## _vs_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \
+ void oxarop_binary_ ## typ ## _vs_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \
switch (tag) { \
case BO_ADD: oxarop_op_add_ ## typ ## _sv_strided(rank, shape, out, y, strides, x); break; \
case BO_SUB: oxarop_op_sub_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
@@ -403,7 +392,7 @@ enum binop_tag_t {
default: wrong_op("binary_vs_strided", tag); \
} \
} \
- void oxarop_binary_ ## typ ## _vv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
+ void oxarop_binary_ ## typ ## _vv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
switch (tag) { \
case BO_ADD: oxarop_op_add_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
case BO_SUB: oxarop_op_sub_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
@@ -421,21 +410,21 @@ enum ibinop_tag_t {
};
#define ENTRY_IBINARY_STRIDED_OPS(typ) \
- void oxarop_ibinary_ ## typ ## _sv_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \
+ void oxarop_ibinary_ ## typ ## _sv_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \
switch (tag) { \
case IB_QUOT: oxarop_op_quot_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
case IB_REM: oxarop_op_rem_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
default: wrong_op("ibinary_sv_strided", tag); \
} \
} \
- void oxarop_ibinary_ ## typ ## _vs_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \
+ void oxarop_ibinary_ ## typ ## _vs_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \
switch (tag) { \
case IB_QUOT: oxarop_op_quot_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
case IB_REM: oxarop_op_rem_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
default: wrong_op("ibinary_vs_strided", tag); \
} \
} \
- void oxarop_ibinary_ ## typ ## _vv_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
+ void oxarop_ibinary_ ## typ ## _vv_strided(enum ibinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
switch (tag) { \
case IB_QUOT: oxarop_op_quot_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
case IB_REM: oxarop_op_rem_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
@@ -452,7 +441,7 @@ enum fbinop_tag_t {
};
#define ENTRY_FBINARY_STRIDED_OPS(typ) \
- void oxarop_fbinary_ ## typ ## _sv_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \
+ void oxarop_fbinary_ ## typ ## _sv_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \
switch (tag) { \
case FB_DIV: oxarop_op_fdiv_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
case FB_POW: oxarop_op_pow_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
@@ -461,7 +450,7 @@ enum fbinop_tag_t {
default: wrong_op("fbinary_sv_strided", tag); \
} \
} \
- void oxarop_fbinary_ ## typ ## _vs_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \
+ void oxarop_fbinary_ ## typ ## _vs_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \
switch (tag) { \
case FB_DIV: oxarop_op_fdiv_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
case FB_POW: oxarop_op_pow_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
@@ -470,7 +459,7 @@ enum fbinop_tag_t {
default: wrong_op("fbinary_vs_strided", tag); \
} \
} \
- void oxarop_fbinary_ ## typ ## _vv_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
+ void oxarop_fbinary_ ## typ ## _vv_strided(enum fbinop_tag_t tag, i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
switch (tag) { \
case FB_DIV: oxarop_op_fdiv_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
case FB_POW: oxarop_op_pow_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
@@ -489,7 +478,7 @@ enum unop_tag_t {
};
#define ENTRY_UNARY_STRIDED_OPS(typ) \
- void oxarop_unary_ ## typ ## _strided(enum unop_tag_t tag, i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *x) { \
+ void oxarop_unary_ ## typ ## _strided(enum unop_tag_t tag, i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *x) { \
switch (tag) { \
case UO_NEG: oxarop_op_neg_ ## typ ## _strided(rank, out, shape, strides, x); break; \
case UO_ABS: oxarop_op_abs_ ## typ ## _strided(rank, out, shape, strides, x); break; \
@@ -507,7 +496,7 @@ enum funop_tag_t {
};
#define ENTRY_FUNARY_STRIDED_OPS(typ) \
- void oxarop_funary_ ## typ ## _strided(enum funop_tag_t tag, i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *x) { \
+ void oxarop_funary_ ## typ ## _strided(enum funop_tag_t tag, i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *x) { \
switch (tag) { \
case FU_RECIP: oxarop_op_recip_ ## typ ## _strided(rank, out, shape, strides, x); break; \
case FU_EXP: oxarop_op_exp_ ## typ ## _strided(rank, out, shape, strides, x); break; \
@@ -542,7 +531,7 @@ enum redop_tag_t {
};
#define ENTRY_REDUCE1_OPS(typ) \
- void oxarop_reduce1_ ## typ(enum redop_tag_t tag, i64 rank, typ *out, const i64 *shape, const i64 *strides, const typ *arr) { \
+ void oxarop_reduce1_ ## typ(enum redop_tag_t tag, i64 rank, typ *restrict out, const i64 *shape, const i64 *strides, const typ *arr) { \
switch (tag) { \
case RO_SUM: oxarop_op_sum1_ ## typ(rank, out, shape, strides, arr); break; \
case RO_PRODUCT: oxarop_op_product1_ ## typ(rank, out, shape, strides, arr); break; \