aboutsummaryrefslogtreecommitdiff
path: root/cbits/arith.c
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-06 00:08:40 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-11 21:53:08 +0100
commite347f70a3a65e1bb529d2df44289fb5fcf652d8a (patch)
tree6c7c3496b0679728e0c2b76e32330f57a1a0e5f7 /cbits/arith.c
parenta36d23048be6e2ad0e4516965f1e8b48756ef78b (diff)
WIP binary ops without normalisation
Diffstat (limited to 'cbits/arith.c')
-rw-r--r--cbits/arith.c138
1 files changed, 81 insertions, 57 deletions
diff --git a/cbits/arith.c b/cbits/arith.c
index 2788e41..9aed3b4 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -175,29 +175,53 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
* Kernel functions *
*****************************************************************************/
-#define COMM_OP(name, op, typ) \
- static void oxarop_op_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \
- for (i64 i = 0; i < n; i++) out[i] = x op y[i]; \
+#define COMM_OP_STRIDED(name, op, typ) \
+ static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \
+ TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \
+ for (i64 i = 0; i < shape[rank - 1]; i++) { \
+ out[outlinidx * shape[rank - 1] + i] = x op y[arrlinidx + strides[rank - 1] * i]; \
+ } \
+ }); \
} \
- static void oxarop_op_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \
- for (i64 i = 0; i < n; i++) out[i] = x[i] op y[i]; \
+ static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
+ TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \
+ for (i64 i = 0; i < shape[rank - 1]; i++) { \
+ out[outlinidx * shape[rank - 1] + i] = x[arrlinidx1 + strides1[rank - 1] * i] op y[arrlinidx2 + strides2[rank - 1] * i]; \
+ } \
+ }); \
}
-#define NONCOMM_OP(name, op, typ) \
- COMM_OP(name, op, typ) \
- static void oxarop_op_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \
- for (i64 i = 0; i < n; i++) out[i] = x[i] op y; \
+#define NONCOMM_OP_STRIDED(name, op, typ) \
+ COMM_OP_STRIDED(name, op, typ) \
+ static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \
+ TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \
+ for (i64 i = 0; i < shape[rank - 1]; i++) { \
+ out[outlinidx * shape[rank - 1] + i] = x[arrlinidx + strides[rank - 1] * i] op y; \
+ } \
+ }); \
}
-#define PREFIX_BINOP(name, op, typ) \
- static void oxarop_op_ ## name ## _ ## typ ## _sv(i64 n, typ *out, typ x, const typ *y) { \
- for (i64 i = 0; i < n; i++) out[i] = op(x, y[i]); \
+#define PREFIX_BINOP_STRIDED(name, op, typ) \
+ static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \
+ TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \
+ for (i64 i = 0; i < shape[rank - 1]; i++) { \
+ out[outlinidx * shape[rank - 1] + i] = op(x, y[arrlinidx + strides[rank - 1] * i]); \
+ } \
+ }); \
} \
- static void oxarop_op_ ## name ## _ ## typ ## _vv(i64 n, typ *out, const typ *x, const typ *y) { \
- for (i64 i = 0; i < n; i++) out[i] = op(x[i], y[i]); \
+ static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
+ TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \
+ for (i64 i = 0; i < shape[rank - 1]; i++) { \
+ out[outlinidx * shape[rank - 1] + i] = op(x[arrlinidx1 + strides1[rank - 1] * i], y[arrlinidx2 + strides2[rank - 1] * i]); \
+ } \
+ }); \
} \
- static void oxarop_op_ ## name ## _ ## typ ## _vs(i64 n, typ *out, const typ *x, typ y) { \
- for (i64 i = 0; i < n; i++) out[i] = op(x[i], y); \
+ static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \
+ TARRAY_WALK_NOINNER_CASE1(rank, shape, strides, { \
+ for (i64 i = 0; i < shape[rank - 1]; i++) { \
+ out[outlinidx * shape[rank - 1] + i] = op(x[arrlinidx + strides[rank - 1] * i], y); \
+ } \
+ }); \
}
#define UNARY_OP_STRIDED(name, op, typ) \
@@ -360,29 +384,29 @@ enum binop_tag_t {
#define LIST_BINOP(name, id, hsop)
};
-#define ENTRY_BINARY_OPS(typ) \
- void oxarop_binary_ ## typ ## _sv(enum binop_tag_t tag, i64 n, typ *out, typ x, const typ *y) { \
+#define ENTRY_BINARY_STRIDED_OPS(typ) \
+ void oxarop_binary_ ## typ ## _sv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \
switch (tag) { \
- case BO_ADD: oxarop_op_add_ ## typ ## _sv(n, out, x, y); break; \
- case BO_SUB: oxarop_op_sub_ ## typ ## _sv(n, out, x, y); break; \
- case BO_MUL: oxarop_op_mul_ ## typ ## _sv(n, out, x, y); break; \
- default: wrong_op("binary_sv", tag); \
+ case BO_ADD: oxarop_op_add_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
+ case BO_SUB: oxarop_op_sub_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
+ case BO_MUL: oxarop_op_mul_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
+ default: wrong_op("binary_sv_strided", tag); \
} \
} \
- void oxarop_binary_ ## typ ## _vs(enum binop_tag_t tag, i64 n, typ *out, const typ *x, typ y) { \
+ void oxarop_binary_ ## typ ## _vs_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \
switch (tag) { \
- case BO_ADD: oxarop_op_add_ ## typ ## _sv(n, out, y, x); break; \
- case BO_SUB: oxarop_op_sub_ ## typ ## _vs(n, out, x, y); break; \
- case BO_MUL: oxarop_op_mul_ ## typ ## _sv(n, out, y, x); break; \
- default: wrong_op("binary_vs", tag); \
+ case BO_ADD: oxarop_op_add_ ## typ ## _sv_strided(rank, shape, out, y, strides, x); break; \
+ case BO_SUB: oxarop_op_sub_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
+ case BO_MUL: oxarop_op_mul_ ## typ ## _sv_strided(rank, shape, out, y, strides, x); break; \
+ default: wrong_op("binary_vs_strided", tag); \
} \
} \
- void oxarop_binary_ ## typ ## _vv(enum binop_tag_t tag, i64 n, typ *out, const typ *x, const typ *y) { \
+ void oxarop_binary_ ## typ ## _vv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
switch (tag) { \
- case BO_ADD: oxarop_op_add_ ## typ ## _vv(n, out, x, y); break; \
- case BO_SUB: oxarop_op_sub_ ## typ ## _vv(n, out, x, y); break; \
- case BO_MUL: oxarop_op_mul_ ## typ ## _vv(n, out, x, y); break; \
- default: wrong_op("binary_vv", tag); \
+ case BO_ADD: oxarop_op_add_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
+ case BO_SUB: oxarop_op_sub_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
+ case BO_MUL: oxarop_op_mul_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
+ default: wrong_op("binary_vv_strided", tag); \
} \
}
@@ -394,29 +418,29 @@ enum fbinop_tag_t {
#define LIST_FBINOP(name, id, hsop)
};
-#define ENTRY_FBINARY_OPS(typ) \
- void oxarop_fbinary_ ## typ ## _sv(enum binop_tag_t tag, i64 n, typ *out, typ x, const typ *y) { \
+#define ENTRY_FBINARY_STRIDED_OPS(typ) \
+ void oxarop_fbinary_ ## typ ## _sv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \
switch (tag) { \
- case FB_DIV: oxarop_op_fdiv_ ## typ ## _sv(n, out, x, y); break; \
- case FB_POW: oxarop_op_pow_ ## typ ## _sv(n, out, x, y); break; \
- case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _sv(n, out, x, y); break; \
- default: wrong_op("binary_sv", tag); \
+ case FB_DIV: oxarop_op_fdiv_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
+ case FB_POW: oxarop_op_pow_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
+ case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \
+ default: wrong_op("fbinary_sv_strided", tag); \
} \
} \
- void oxarop_fbinary_ ## typ ## _vs(enum binop_tag_t tag, i64 n, typ *out, const typ *x, typ y) { \
+ void oxarop_fbinary_ ## typ ## _vs_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \
switch (tag) { \
- case FB_DIV: oxarop_op_fdiv_ ## typ ## _vs(n, out, x, y); break; \
- case FB_POW: oxarop_op_pow_ ## typ ## _vs(n, out, x, y); break; \
- case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vs(n, out, x, y); break; \
- default: wrong_op("binary_vs", tag); \
+ case FB_DIV: oxarop_op_fdiv_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
+ case FB_POW: oxarop_op_pow_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
+ case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \
+ default: wrong_op("fbinary_vs_strided", tag); \
} \
} \
- void oxarop_fbinary_ ## typ ## _vv(enum binop_tag_t tag, i64 n, typ *out, const typ *x, const typ *y) { \
+ void oxarop_fbinary_ ## typ ## _vv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \
switch (tag) { \
- case FB_DIV: oxarop_op_fdiv_ ## typ ## _vv(n, out, x, y); break; \
- case FB_POW: oxarop_op_pow_ ## typ ## _vv(n, out, x, y); break; \
- case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vv(n, out, x, y); break; \
- default: wrong_op("binary_vv", tag); \
+ case FB_DIV: oxarop_op_fdiv_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
+ case FB_POW: oxarop_op_pow_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
+ case FB_LOGBASE: oxarop_op_logbase_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \
+ default: wrong_op("fbinary_vv_strided", tag); \
} \
}
@@ -469,7 +493,7 @@ enum funop_tag_t {
case FU_EXPM1: oxarop_op_expm1_ ## typ ## _strided(rank, out, shape, strides, x); break; \
case FU_LOG1PEXP: oxarop_op_log1pexp_ ## typ ## _strided(rank, out, shape, strides, x); break; \
case FU_LOG1MEXP: oxarop_op_log1mexp_ ## typ ## _strided(rank, out, shape, strides, x); break; \
- default: wrong_op("unary", tag); \
+ default: wrong_op("funary_strided", tag); \
} \
}
@@ -508,9 +532,9 @@ enum redop_tag_t {
#define NUM_TYPES_XLIST X(i32) X(i64) FLOAT_TYPES_XLIST
#define X(typ) \
- COMM_OP(add, +, typ) \
- NONCOMM_OP(sub, -, typ) \
- COMM_OP(mul, *, typ) \
+ COMM_OP_STRIDED(add, +, typ) \
+ NONCOMM_OP_STRIDED(sub, -, typ) \
+ COMM_OP_STRIDED(mul, *, typ) \
UNARY_OP_STRIDED(neg, -, typ) \
UNARY_OP_STRIDED(abs, GEN_ABS, typ) \
UNARY_OP_STRIDED(signum, GEN_SIGNUM, typ) \
@@ -518,7 +542,7 @@ enum redop_tag_t {
REDUCE1_OP(product1, *, typ) \
REDUCEFULL_OP(sumfull, +, typ) \
REDUCEFULL_OP(productfull, *, typ) \
- ENTRY_BINARY_OPS(typ) \
+ ENTRY_BINARY_STRIDED_OPS(typ) \
ENTRY_UNARY_STRIDED_OPS(typ) \
ENTRY_REDUCE1_OPS(typ) \
ENTRY_REDUCEFULL_OPS(typ) \
@@ -530,9 +554,9 @@ NUM_TYPES_XLIST
#undef X
#define X(typ) \
- NONCOMM_OP(fdiv, /, typ) \
- PREFIX_BINOP(pow, GEN_POW, typ) \
- PREFIX_BINOP(logbase, GEN_LOGBASE, typ) \
+ NONCOMM_OP_STRIDED(fdiv, /, typ) \
+ PREFIX_BINOP_STRIDED(pow, GEN_POW, typ) \
+ PREFIX_BINOP_STRIDED(logbase, GEN_LOGBASE, typ) \
UNARY_OP_STRIDED(recip, 1.0/, typ) \
UNARY_OP_STRIDED(exp, GEN_EXP, typ) \
UNARY_OP_STRIDED(log, GEN_LOG, typ) \
@@ -553,7 +577,7 @@ NUM_TYPES_XLIST
UNARY_OP_STRIDED(expm1, GEN_EXPM1, typ) \
UNARY_OP_STRIDED(log1pexp, GEN_LOG1PEXP, typ) \
UNARY_OP_STRIDED(log1mexp, GEN_LOG1MEXP, typ) \
- ENTRY_FBINARY_OPS(typ) \
+ ENTRY_FBINARY_STRIDED_OPS(typ) \
ENTRY_FUNARY_STRIDED_OPS(typ)
FLOAT_TYPES_XLIST
#undef X