From a78ddeaa5d34fa8b6fa52eee42977cc46e8c36a5 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Tue, 25 Mar 2025 17:09:20 +0100
Subject: Dotprod: Optimise reversed and replicated dimensions

---
 cbits/arith.c | 62 ++++++++++++++++++++++-------------------------------------
 1 file changed, 23 insertions(+), 39 deletions(-)

(limited to 'cbits/arith.c')

diff --git a/cbits/arith.c b/cbits/arith.c
index b574d54..3659f6c 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -326,7 +326,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
 // Walk a orthotope-style strided array, except for the inner dimension. The
 // body is run for every "inner vector".
 // Provides idx, outlinidx, arrlinidx.
-#define TARRAY_WALK_NOINNER(again_label_name, rank, shape, strides, body) \
+#define TARRAY_WALK_NOINNER(again_label_name, rank, shape, strides, ...) \
   do { \
     i64 idx[(rank) /* - 1 */]; /* Note: [zero-length VLA] */ \
     memset(idx, 0, ((rank) - 1) * sizeof(idx[0])); \
@@ -334,7 +334,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
     i64 outlinidx = 0; \
   again_label_name: \
     { \
-      body \
+      __VA_ARGS__ \
     } \
     for (i64 dim = (rank) - 2; dim >= 0; dim--) { \
       if (++idx[dim] < (shape)[dim]) { \
@@ -351,7 +351,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
 // inner dimension. The arrays must have the same shape, but may have different
 // strides. The body is run for every pair of "inner vectors".
 // Provides idx, outlinidx, arrlinidx1, arrlinidx2.
-#define TARRAY_WALK2_NOINNER(again_label_name, rank, shape, strides1, strides2, body) \
+#define TARRAY_WALK2_NOINNER(again_label_name, rank, shape, strides1, strides2, ...) \
   do { \
     i64 idx[(rank) /* - 1 */]; /* Note: [zero-length VLA] */ \
     memset(idx, 0, ((rank) - 1) * sizeof(idx[0])); \
@@ -359,7 +359,7 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
     i64 outlinidx = 0; \
   again_label_name: \
     { \
-      body \
+      __VA_ARGS__ \
     } \
     for (i64 dim = (rank) - 2; dim >= 0; dim--) { \
       if (++idx[dim] < (shape)[dim]) { \
@@ -514,45 +514,30 @@ static void print_shape(FILE *stream, i64 rank, const i64 *shape) {
     }); \
   }
 
-#define DOTPROD_STRIDED_OP(typ) \
-  typ oxarop_dotprod_ ## typ ## _strided(i64 length, i64 stride1, const typ *arr1, i64 stride2, const typ *arr2) { \
-    if (length < MANUAL_VECT_WID) { \
-      typ res = 0; \
-      for (i64 i = 0; i < length; i++) res += arr1[stride1 * i] * arr2[stride2 * i]; \
-      return res; \
-    } else { \
-      typ accum[MANUAL_VECT_WID]; \
-      for (i64 j = 0; j < MANUAL_VECT_WID; j++) accum[j] = arr1[stride1 * j] * arr2[stride2 * j]; \
-      for (i64 i = 1; i < length / MANUAL_VECT_WID; i++) \
-        for (i64 j = 0; j < MANUAL_VECT_WID; j++) \
-          accum[j] += arr1[stride1 * (MANUAL_VECT_WID * i + j)] * arr2[stride2 * (MANUAL_VECT_WID * i + j)]; \
-      typ res = accum[0]; \
-      for (i64 j = 1; j < MANUAL_VECT_WID; j++) res += accum[j]; \
-      for (i64 i = length / MANUAL_VECT_WID * MANUAL_VECT_WID; i < length; i++) \
-        res += arr1[stride1 * i] * arr2[stride2 * i]; \
-      return res; \
-    } \
-  }
-
 // Reduces along the innermost dimension.
 // 'out' will be filled densely in linearisation order.
 #define DOTPROD_INNER_OP(typ) \
   void oxarop_dotprodinner_ ## typ(i64 rank, const i64 *shape, typ *restrict out, const i64 *strides1, const typ *arr1, const i64 *strides2, const typ *arr2) { \
     TIME_START(tm); \
-    if (strides1[rank - 1] == 1 && strides2[rank - 1] == 1) { \
-      TARRAY_WALK2_NOINNER(again1, rank, shape, strides1, strides2, { \
-        out[outlinidx] = oxarop_dotprod_ ## typ ## _strided(shape[rank - 1], 1, arr1 + arrlinidx1, 1, arr2 + arrlinidx2); \
-      }); \
-    } else if (strides1[rank - 1] == -1 && strides2[rank - 1] == -1) { \
-      TARRAY_WALK2_NOINNER(again2, rank, shape, strides1, strides2, { \
-        const i64 len = shape[rank - 1]; \
-        out[outlinidx] = oxarop_dotprod_ ## typ ## _strided(len, 1, arr1 + arrlinidx1 - (len - 1), 1, arr2 + arrlinidx2 - (len - 1)); \
-      }); \
-    } else { \
-      TARRAY_WALK2_NOINNER(again3, rank, shape, strides1, strides2, { \
-        out[outlinidx] = oxarop_dotprod_ ## typ ## _strided(shape[rank - 1], strides1[rank - 1], arr1 + arrlinidx1, strides2[rank - 1], arr2 + arrlinidx2); \
-      }); \
-    } \
+    TARRAY_WALK2_NOINNER(again3, rank, shape, strides1, strides2, { \
+      const i64 length = shape[rank - 1], stride1 = strides1[rank - 1], stride2 = strides2[rank - 1]; \
+      if (length < MANUAL_VECT_WID) { \
+        typ res = 0; \
+        for (i64 i = 0; i < length; i++) res += arr1[arrlinidx1 + stride1 * i] * arr2[arrlinidx2 + stride2 * i]; \
+        out[outlinidx] = res; \
+      } else { \
+        typ accum[MANUAL_VECT_WID]; \
+        for (i64 j = 0; j < MANUAL_VECT_WID; j++) accum[j] = arr1[arrlinidx1 + stride1 * j] * arr2[arrlinidx2 + stride2 * j]; \
+        for (i64 i = 1; i < length / MANUAL_VECT_WID; i++) \
+          for (i64 j = 0; j < MANUAL_VECT_WID; j++) \
+            accum[j] += arr1[arrlinidx1 + stride1 * (MANUAL_VECT_WID * i + j)] * arr2[arrlinidx2 + stride2 * (MANUAL_VECT_WID * i + j)]; \
+        typ res = accum[0]; \
+        for (i64 j = 1; j < MANUAL_VECT_WID; j++) res += accum[j]; \
+        for (i64 i = length / MANUAL_VECT_WID * MANUAL_VECT_WID; i < length; i++) \
+          res += arr1[arrlinidx1 + stride1 * i] * arr2[arrlinidx2 + stride2 * i]; \
+        out[outlinidx] = res; \
+      } \
+    }); \
     stats_record_binary(sbi_dotprod, rank, shape, strides1, strides2, TIME_END(tm)); \
   }
 
@@ -774,7 +759,6 @@ enum redop_tag_t {
   ENTRY_REDUCEFULL_OPS(typ) \
   EXTREMUM_OP(min, <, typ) \
   EXTREMUM_OP(max, >, typ) \
-  DOTPROD_STRIDED_OP(typ) \
   DOTPROD_INNER_OP(typ)
 NUM_TYPES_XLIST
 #undef X
-- 
cgit v1.2.3-70-g09d2