From 1b69f540b0c1fa8d45b80f452cab8e7ac02dffd9 Mon Sep 17 00:00:00 2001
From: Mikolaj Konarski <mikolaj.konarski@gmail.com>
Date: Tue, 19 Nov 2024 10:15:55 +0100
Subject: Add the criminally absent singletons to numEltMinIndex and
 numEltMaxIndex

---
 src/Data/Array/Mixed/Internal/Arith.hs  | 28 ++++++++++++++--------------
 src/Data/Array/Nested/Internal/Mixed.hs |  4 ++--
 2 files changed, 16 insertions(+), 16 deletions(-)

(limited to 'src/Data')

diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index 0ee6708..a24efd6 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -571,8 +571,8 @@ class NumElt a where
   numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
   numEltSumFull :: SNat n -> RS.Array n a -> a
   numEltProductFull :: SNat n -> RS.Array n a -> a
-  numEltMinIndex :: RS.Array n a -> [Int]
-  numEltMaxIndex :: RS.Array n a -> [Int]
+  numEltMinIndex :: SNat n -> RS.Array n a -> [Int]
+  numEltMaxIndex :: SNat n -> RS.Array n a -> [Int]
   numEltDotprodInner :: SNat n -> RS.Array (n + 1) a -> RS.Array (n + 1) a -> RS.Array n a
 
 instance NumElt Int32 where
@@ -586,8 +586,8 @@ instance NumElt Int32 where
   numEltProduct1Inner = product1VectorInt32
   numEltSumFull = sumFullVectorInt32
   numEltProductFull = productFullVectorInt32
-  numEltMinIndex = minindexVectorInt32
-  numEltMaxIndex = maxindexVectorInt32
+  numEltMinIndex _ = minindexVectorInt32
+  numEltMaxIndex _ = maxindexVectorInt32
   numEltDotprodInner = dotprodinnerVectorInt32
 
 instance NumElt Int64 where
@@ -601,8 +601,8 @@ instance NumElt Int64 where
   numEltProduct1Inner = product1VectorInt64
   numEltSumFull = sumFullVectorInt64
   numEltProductFull = productFullVectorInt64
-  numEltMinIndex = minindexVectorInt64
-  numEltMaxIndex = maxindexVectorInt64
+  numEltMinIndex _ = minindexVectorInt64
+  numEltMaxIndex _ = maxindexVectorInt64
   numEltDotprodInner = dotprodinnerVectorInt64
 
 instance NumElt Float where
@@ -616,8 +616,8 @@ instance NumElt Float where
   numEltProduct1Inner = product1VectorFloat
   numEltSumFull = sumFullVectorFloat
   numEltProductFull = productFullVectorFloat
-  numEltMinIndex = minindexVectorFloat
-  numEltMaxIndex = maxindexVectorFloat
+  numEltMinIndex _ = minindexVectorFloat
+  numEltMaxIndex _ = maxindexVectorFloat
   numEltDotprodInner = dotprodinnerVectorFloat
 
 instance NumElt Double where
@@ -631,8 +631,8 @@ instance NumElt Double where
   numEltProduct1Inner = product1VectorDouble
   numEltSumFull = sumFullVectorDouble
   numEltProductFull = productFullVectorDouble
-  numEltMinIndex = minindexVectorDouble
-  numEltMaxIndex = maxindexVectorDouble
+  numEltMinIndex _ = minindexVectorDouble
+  numEltMaxIndex _ = maxindexVectorDouble
   numEltDotprodInner = dotprodinnerVectorDouble
 
 instance NumElt Int where
@@ -656,8 +656,8 @@ instance NumElt Int where
                           (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_PRODUCT))
   numEltSumFull = intWidBranchRedFull @Int (*) (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM))
   numEltProductFull = intWidBranchRedFull @Int (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT))
-  numEltMinIndex = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64
-  numEltMaxIndex = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64
+  numEltMinIndex _ = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64
+  numEltMaxIndex _ = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64
   numEltDotprodInner = intWidBranchDotprod @Int (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32
                                                 (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64
 
@@ -682,8 +682,8 @@ instance NumElt CInt where
                           (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_PRODUCT))
   numEltSumFull = intWidBranchRedFull @CInt mulWithInt (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM))
   numEltProductFull = intWidBranchRedFull @CInt (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT))
-  numEltMinIndex = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64
-  numEltMaxIndex = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64
+  numEltMinIndex _ = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64
+  numEltMaxIndex _ = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64
   numEltDotprodInner = intWidBranchDotprod @CInt (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32
                                                  (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64
 
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index 023f6fa..0e4f5e6 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -849,12 +849,12 @@ miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn)
 -- | Throws if the array is empty.
 mminIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
 mminIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
-  ixxFromList (ssxFromShape sh) (numEltMinIndex arr)
+  ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) arr)
 
 -- | Throws if the array is empty.
 mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
 mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
-  ixxFromList (ssxFromShape sh) (numEltMaxIndex arr)
+  ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) arr)
 
 mdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
            => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a
-- 
cgit v1.2.3-70-g09d2