diff options
Diffstat (limited to 'cbits/arith.c')
-rw-r--r-- | cbits/arith.c | 42 |
1 files changed, 41 insertions, 1 deletions
diff --git a/cbits/arith.c b/cbits/arith.c index 9aed3b4..4646ca4 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -18,6 +18,7 @@ // These are the wrapper macros used in arith_lists.h. Preset them to empty to // avoid having to touch macros unrelated to the particular operation set below. #define LIST_BINOP(name, id, hsop) +#define LIST_IBINOP(name, id, hsop) #define LIST_FBINOP(name, id, hsop) #define LIST_UNOP(name, id, _) #define LIST_FUNOP(name, id, _) @@ -410,6 +411,37 @@ enum binop_tag_t { } \ } +enum ibinop_tag_t { +#undef LIST_IBINOP +#define LIST_IBINOP(name, id, hsop) name = id, +#include "arith_lists.h" +#undef LIST_IBINOP +#define LIST_IBINOP(name, id, hsop) +}; + +#define ENTRY_IBINARY_STRIDED_OPS(typ) \ + void oxarop_ibinary_ ## typ ## _sv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, typ x, const i64 *strides, const typ *y) { \ + switch (tag) { \ + case IB_QUOT: oxarop_op_quot_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ + case IB_REM: oxarop_op_rem_ ## typ ## _sv_strided(rank, shape, out, x, strides, y); break; \ + default: wrong_op("ibinary_sv_strided", tag); \ + } \ + } \ + void oxarop_ibinary_ ## typ ## _vs_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides, const typ *x, typ y) { \ + switch (tag) { \ + case IB_QUOT: oxarop_op_quot_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ + case IB_REM: oxarop_op_rem_ ## typ ## _vs_strided(rank, shape, out, strides, x, y); break; \ + default: wrong_op("ibinary_vs_strided", tag); \ + } \ + } \ + void oxarop_ibinary_ ## typ ## _vv_strided(enum binop_tag_t tag, i64 rank, const i64 *shape, typ *out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ + switch (tag) { \ + case IB_QUOT: oxarop_op_quot_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ + case IB_REM: oxarop_op_rem_ ## typ ## _vv_strided(rank, shape, out, strides1, x, strides2, y); break; \ + default: wrong_op("ibinary_vv_strided", tag); \ + } \ + } + enum fbinop_tag_t { #undef LIST_FBINOP #define LIST_FBINOP(name, id, hsop) name = id, @@ -528,8 +560,9 @@ enum redop_tag_t { * Generate all the functions * *****************************************************************************/ +#define INT_TYPES_XLIST X(i32) X(i64) #define FLOAT_TYPES_XLIST X(double) X(float) -#define NUM_TYPES_XLIST X(i32) X(i64) FLOAT_TYPES_XLIST +#define NUM_TYPES_XLIST INT_TYPES_XLIST FLOAT_TYPES_XLIST #define X(typ) \ COMM_OP_STRIDED(add, +, typ) \ @@ -554,6 +587,13 @@ NUM_TYPES_XLIST #undef X #define X(typ) \ + NONCOMM_OP_STRIDED(quot, /, typ) \ + NONCOMM_OP_STRIDED(rem, %, typ) \ + ENTRY_IBINARY_STRIDED_OPS(typ) +INT_TYPES_XLIST +#undef X + +#define X(typ) \ NONCOMM_OP_STRIDED(fdiv, /, typ) \ PREFIX_BINOP_STRIDED(pow, GEN_POW, typ) \ PREFIX_BINOP_STRIDED(logbase, GEN_LOGBASE, typ) \ |