From 1b69f540b0c1fa8d45b80f452cab8e7ac02dffd9 Mon Sep 17 00:00:00 2001 From: Mikolaj Konarski Date: Tue, 19 Nov 2024 10:15:55 +0100 Subject: Add the criminally absent singletons to numEltMinIndex and numEltMaxIndex --- src/Data/Array/Mixed/Internal/Arith.hs | 28 ++++++++++++++-------------- src/Data/Array/Nested/Internal/Mixed.hs | 4 ++-- 2 files changed, 16 insertions(+), 16 deletions(-) (limited to 'src') diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 0ee6708..a24efd6 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -571,8 +571,8 @@ class NumElt a where numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a numEltSumFull :: SNat n -> RS.Array n a -> a numEltProductFull :: SNat n -> RS.Array n a -> a - numEltMinIndex :: RS.Array n a -> [Int] - numEltMaxIndex :: RS.Array n a -> [Int] + numEltMinIndex :: SNat n -> RS.Array n a -> [Int] + numEltMaxIndex :: SNat n -> RS.Array n a -> [Int] numEltDotprodInner :: SNat n -> RS.Array (n + 1) a -> RS.Array (n + 1) a -> RS.Array n a instance NumElt Int32 where @@ -586,8 +586,8 @@ instance NumElt Int32 where numEltProduct1Inner = product1VectorInt32 numEltSumFull = sumFullVectorInt32 numEltProductFull = productFullVectorInt32 - numEltMinIndex = minindexVectorInt32 - numEltMaxIndex = maxindexVectorInt32 + numEltMinIndex _ = minindexVectorInt32 + numEltMaxIndex _ = maxindexVectorInt32 numEltDotprodInner = dotprodinnerVectorInt32 instance NumElt Int64 where @@ -601,8 +601,8 @@ instance NumElt Int64 where numEltProduct1Inner = product1VectorInt64 numEltSumFull = sumFullVectorInt64 numEltProductFull = productFullVectorInt64 - numEltMinIndex = minindexVectorInt64 - numEltMaxIndex = maxindexVectorInt64 + numEltMinIndex _ = minindexVectorInt64 + numEltMaxIndex _ = maxindexVectorInt64 numEltDotprodInner = dotprodinnerVectorInt64 instance NumElt Float where @@ -616,8 +616,8 @@ instance NumElt Float where numEltProduct1Inner = product1VectorFloat numEltSumFull = sumFullVectorFloat numEltProductFull = productFullVectorFloat - numEltMinIndex = minindexVectorFloat - numEltMaxIndex = maxindexVectorFloat + numEltMinIndex _ = minindexVectorFloat + numEltMaxIndex _ = maxindexVectorFloat numEltDotprodInner = dotprodinnerVectorFloat instance NumElt Double where @@ -631,8 +631,8 @@ instance NumElt Double where numEltProduct1Inner = product1VectorDouble numEltSumFull = sumFullVectorDouble numEltProductFull = productFullVectorDouble - numEltMinIndex = minindexVectorDouble - numEltMaxIndex = maxindexVectorDouble + numEltMinIndex _ = minindexVectorDouble + numEltMaxIndex _ = maxindexVectorDouble numEltDotprodInner = dotprodinnerVectorDouble instance NumElt Int where @@ -656,8 +656,8 @@ instance NumElt Int where (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_PRODUCT)) numEltSumFull = intWidBranchRedFull @Int (*) (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM)) numEltProductFull = intWidBranchRedFull @Int (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT)) - numEltMinIndex = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64 - numEltMaxIndex = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64 + numEltMinIndex _ = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64 + numEltMaxIndex _ = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64 numEltDotprodInner = intWidBranchDotprod @Int (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32 (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64 @@ -682,8 +682,8 @@ instance NumElt CInt where (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_PRODUCT)) numEltSumFull = intWidBranchRedFull @CInt mulWithInt (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM)) numEltProductFull = intWidBranchRedFull @CInt (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT)) - numEltMinIndex = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64 - numEltMaxIndex = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64 + numEltMinIndex _ = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64 + numEltMaxIndex _ = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64 numEltDotprodInner = intWidBranchDotprod @CInt (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32 (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64 diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index 023f6fa..0e4f5e6 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -849,12 +849,12 @@ miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn) -- | Throws if the array is empty. mminIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh mminIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = - ixxFromList (ssxFromShape sh) (numEltMinIndex arr) + ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) arr) -- | Throws if the array is empty. mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = - ixxFromList (ssxFromShape sh) (numEltMaxIndex arr) + ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) arr) mdot1Inner :: forall sh n a. (PrimElt a, NumElt a) => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a -- cgit v1.2.3-70-g09d2