diff options
Diffstat (limited to 'cbits/arith.c')
-rw-r--r-- | cbits/arith.c | 40 |
1 files changed, 21 insertions, 19 deletions
diff --git a/cbits/arith.c b/cbits/arith.c index ca0af51..b574d54 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -24,6 +24,17 @@ typedef int32_t i32; typedef int64_t i64; +// PRECONDITIONS +// +// All strided array operations in this file assume that none of the shape +// components are zero -- that is, the input arrays are non-empty. This must +// be arranged on the Haskell side. +// +// Furthermore, note that while the Haskell side has an offset into the backing +// vector, the C side assumes that the offset is zero. Shift the pointer if +// necessary. + + /***************************************************************************** * Performance statistics * *****************************************************************************/ @@ -370,6 +381,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { #define COMM_OP_STRIDED(name, op, typ) \ static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \ + if (rank == 0) { out[0] = x op y[0]; return; } \ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ for (i64 i = 0; i < shape[rank - 1]; i++) { \ out[outlinidx * shape[rank - 1] + i] = x op y[arrlinidx + strides[rank - 1] * i]; \ @@ -377,6 +389,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { }); \ } \ static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ + if (rank == 0) { out[0] = x[0] op y[0]; return; } \ TARRAY_WALK2_NOINNER(again, rank, shape, strides1, strides2, { \ for (i64 i = 0; i < shape[rank - 1]; i++) { \ out[outlinidx * shape[rank - 1] + i] = x[arrlinidx1 + strides1[rank - 1] * i] op y[arrlinidx2 + strides2[rank - 1] * i]; \ @@ -387,6 +400,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { #define NONCOMM_OP_STRIDED(name, op, typ) \ COMM_OP_STRIDED(name, op, typ) \ static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \ + if (rank == 0) { out[0] = x[0] op y; return; } \ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ for (i64 i = 0; i < shape[rank - 1]; i++) { \ out[outlinidx * shape[rank - 1] + i] = x[arrlinidx + strides[rank - 1] * i] op y; \ @@ -396,6 +410,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { #define PREFIX_BINOP_STRIDED(name, op, typ) \ static void oxarop_op_ ## name ## _ ## typ ## _sv_strided(i64 rank, const i64 *shape, typ *restrict out, typ x, const i64 *strides, const typ *y) { \ + if (rank == 0) { out[0] = op(x, y[0]); return; } \ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ for (i64 i = 0; i < shape[rank - 1]; i++) { \ out[outlinidx * shape[rank - 1] + i] = op(x, y[arrlinidx + strides[rank - 1] * i]); \ @@ -403,6 +418,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { }); \ } \ static void oxarop_op_ ## name ## _ ## typ ## _vv_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *x, const i64 *strides2, const typ *y) { \ + if (rank == 0) { out[0] = op(x[0], y[0]); return; } \ TARRAY_WALK2_NOINNER(again, rank, shape, strides1, strides2, { \ for (i64 i = 0; i < shape[rank - 1]; i++) { \ out[outlinidx * shape[rank - 1] + i] = op(x[arrlinidx1 + strides1[rank - 1] * i], y[arrlinidx2 + strides2[rank - 1] * i]); \ @@ -410,6 +426,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { }); \ } \ static void oxarop_op_ ## name ## _ ## typ ## _vs_strided(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides, const typ *x, typ y) { \ + if (rank == 0) { out[0] = op(x[0], y); return; } \ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ for (i64 i = 0; i < shape[rank - 1]; i++) { \ out[outlinidx * shape[rank - 1] + i] = op(x[arrlinidx + strides[rank - 1] * i], y); \ @@ -424,6 +441,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { fprintf(stderr, " strides="); \ print_shape(stderr, rank, strides); \ fprintf(stderr, "\n"); */ \ + if (rank == 0) { out[0] = op(arr[0]); return; } \ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ for (i64 i = 0; i < shape[rank - 1]; i++) { \ out[outlinidx * shape[rank - 1] + i] = op(arr[arrlinidx + strides[rank - 1] * i]); \ @@ -434,7 +452,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { // Used for reduction and dot product kernels below #define MANUAL_VECT_WID 8 -// Used in REDUCE1_OP and REDUCEFULL_OP below; requires the same preconditions +// Used in REDUCE1_OP and REDUCEFULL_OP below #define REDUCE_BODY_CODE(op, typ, innerLen, innerStride, arr, arrlinidx, destination) \ do { \ const i64 n = innerLen; const i64 s = innerStride; \ @@ -458,11 +476,6 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { } \ } while (0) -// preconditions: -// - all strides are >0 -// - shape is everywhere >0 -// - rank is >= 1 -// - out has capacity for (shape[0] * ... * shape[rank - 2]) elements // Reduces along the innermost dimension. // 'out' will be filled densely in linearisation order. #define REDUCE1_OP(name, op, typ) \ @@ -472,12 +485,9 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { }); \ } -// preconditions -// - all strides are >0 -// - shape is everywhere >0 -// - rank is >= 1 #define REDUCEFULL_OP(name, op, typ) \ typ oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \ + if (rank == 0) return arr[0]; \ typ result = 0; \ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ REDUCE_BODY_CODE(op, typ, shape[rank - 1], strides[rank - 1], arr, arrlinidx, result); \ @@ -485,13 +495,10 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { return result; \ } -// preconditions -// - all strides are >0 -// - shape is everywhere >0 -// - rank is >= 1 // Writes extreme index to outidx. If 'cmp' is '<', computes minindex ("argmin"); if '>', maxindex. #define EXTREMUM_OP(name, cmp, typ) \ void oxarop_extremum_ ## name ## _ ## typ(i64 *restrict outidx, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \ + if (rank == 0) return; /* output index vector has length 0 anyways */ \ typ best = arr[0]; \ memset(outidx, 0, rank * sizeof(i64)); \ TARRAY_WALK_NOINNER(again, rank, shape, strides, { \ @@ -527,11 +534,6 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) { } \ } -// preconditions: -// - all strides are >0 -// - shape is everywhere >0 -// - rank is >= 1 -// - out has capacity for (shape[0] * ... * shape[rank - 2]) elements // Reduces along the innermost dimension. // 'out' will be filled densely in linearisation order. #define DOTPROD_INNER_OP(typ) \ |