From 9b0651bf19e889dfb28ba81b6ada25b27b0e6071 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Mon, 17 Jun 2024 13:08:13 +0200
Subject: sumAllPrim

---
 cbits/arith.c       | 49 ++++++++++++++++++++++++++++++++++++++++++++-----
 cbits/arith_lists.h |  4 ++--
 2 files changed, 46 insertions(+), 7 deletions(-)

(limited to 'cbits')

diff --git a/cbits/arith.c b/cbits/arith.c
index fb993c8..5d74c01 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -173,6 +173,33 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }
     } \
   }
 
+// preconditions
+// - all strides are >0
+// - shape is everywhere >0
+// - rank is >= 1
+#define REDUCEFULL_OP(name, op, typ) \
+  typ oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \
+    typ res = 0; \
+    if (strides[rank - 1] == 1) { \
+      TARRAY_WALK_NOINNER(again1, rank, shape, strides, { \
+        typ accum = arr[arrlinidx]; \
+        for (i64 i = 1; i < shape[rank - 1]; i++) { \
+          accum = accum op arr[arrlinidx + i]; \
+        } \
+        res = res op accum; \
+      }); \
+    } else { \
+      TARRAY_WALK_NOINNER(again2, rank, shape, strides, { \
+        typ accum = arr[arrlinidx]; \
+        for (i64 i = 1; i < shape[rank - 1]; i++) { \
+          accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \
+        } \
+        res = res op accum; \
+      }); \
+    } \
+    return res; \
+  }
+
 // preconditions
 // - all strides are >0
 // - shape is everywhere >0
@@ -394,11 +421,20 @@ enum redop_tag_t {
 #define LIST_REDOP(name, id, _)
 };
 
-#define ENTRY_REDUCE_OPS(typ) \
-  void oxarop_reduce_ ## typ(enum redop_tag_t tag, i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \
+#define ENTRY_REDUCE1_OPS(typ) \
+  void oxarop_reduce1_ ## typ(enum redop_tag_t tag, i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \
+    switch (tag) { \
+      case RO_SUM: oxarop_op_sum1_ ## typ(rank, shape, strides, out, arr); break; \
+      case RO_PRODUCT: oxarop_op_product1_ ## typ(rank, shape, strides, out, arr); break; \
+      default: wrong_op("reduce", tag); \
+    } \
+  }
+
+#define ENTRY_REDUCEFULL_OPS(typ) \
+  typ oxarop_reducefull_ ## typ(enum redop_tag_t tag, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \
     switch (tag) { \
-      case RO_SUM1: oxarop_op_sum1_ ## typ(rank, shape, strides, out, arr); break; \
-      case RO_PRODUCT1: oxarop_op_product1_ ## typ(rank, shape, strides, out, arr); break; \
+      case RO_SUM: return oxarop_op_sumfull_ ## typ(rank, shape, strides, arr); \
+      case RO_PRODUCT: return oxarop_op_productfull_ ## typ(rank, shape, strides, arr); \
       default: wrong_op("reduce", tag); \
     } \
   }
@@ -420,9 +456,12 @@ enum redop_tag_t {
   UNARY_OP(signum, GEN_SIGNUM, typ) \
   REDUCE1_OP(sum1, +, typ) \
   REDUCE1_OP(product1, *, typ) \
+  REDUCEFULL_OP(sumfull, +, typ) \
+  REDUCEFULL_OP(productfull, *, typ) \
   ENTRY_BINARY_OPS(typ) \
   ENTRY_UNARY_OPS(typ) \
-  ENTRY_REDUCE_OPS(typ) \
+  ENTRY_REDUCE1_OPS(typ) \
+  ENTRY_REDUCEFULL_OPS(typ) \
   EXTREMUM_OP(min, <, typ) \
   EXTREMUM_OP(max, >, typ) \
   DOTPROD_STRIDED_OP(typ)
diff --git a/cbits/arith_lists.h b/cbits/arith_lists.h
index 2e37575..58de65a 100644
--- a/cbits/arith_lists.h
+++ b/cbits/arith_lists.h
@@ -31,5 +31,5 @@ LIST_FUNOP(FU_EXPM1, 18,)
 LIST_FUNOP(FU_LOG1PEXP, 19,)
 LIST_FUNOP(FU_LOG1MEXP, 20,)
 
-LIST_REDOP(RO_SUM1, 1,)
-LIST_REDOP(RO_PRODUCT1, 2,)
+LIST_REDOP(RO_SUM, 1,)
+LIST_REDOP(RO_PRODUCT, 2,)
-- 
cgit v1.2.3-70-g09d2