aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2024-06-10 13:23:38 +0200
committerTom Smeding <t.j.smeding@uu.nl>2024-06-10 13:28:51 +0200
commit596bce9b869cafc06d9b1567c3a3ed282f7441ba (patch)
tree6406ea65d4ac1f099ac8851d2db38a5fca4bc4e3
parent8274da734aba266e86ac722b6a9e73afeeae59e6 (diff)
Rename arg{min,max} to {min,max}Index
-rw-r--r--cbits/arith.c2
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs34
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs12
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs12
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs8
5 files changed, 34 insertions, 34 deletions
diff --git a/cbits/arith.c b/cbits/arith.c
index 6ac49b8..5594c80 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -168,7 +168,7 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }
// - all strides are >0
// - shape is everywhere >0
// - rank is >= 1
-// Writes extreme index to outidx. If 'cmp' is '<', computes argmin; if '>', argmax.
+// Writes extreme index to outidx. If 'cmp' is '<', computes minindex ("argmin"); if '>', maxindex.
#define EXTREMUM_OP(name, cmp, typ) \
void oxarop_extremum_ ## name ## _ ## typ(i64 *outidx, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \
typ best = arr[0]; \
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index 32ed355..d2ad61f 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -184,7 +184,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride
<$> VS.unsafeFreeze outv
-- TODO: test this function
--- | Find extremum (argmin or argmax) in full array
+-- | Find extremum (minindex ("argmin") or maxindex) in full array
{-# NOINLINE vectorExtremumOp #-}
vectorExtremumOp :: forall a b n. Storable a
=> (Ptr a -> Ptr b)
@@ -192,7 +192,7 @@ vectorExtremumOp :: forall a b n. Storable a
-> RS.Array n a -> [Int] -- result length: n
vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))
| null sh = []
- | any (<= 0) sh = error "Extremum (argmin/argmax): empty array"
+ | any (<= 0) sh = error "Extremum (minindex/maxindex): empty array"
-- now the input array is nonempty
| all (<= 0) strides = 0 <$ sh
| otherwise =
@@ -283,7 +283,7 @@ $(fmap concat . forM typesList $ \arithtype -> do
$(fmap concat . forM typesList $ \arithtype ->
fmap concat . forM ["min", "max"] $ \fname -> do
let ttyp = conT (atType arithtype)
- name = mkName ("arg" ++ fname ++ "Vector" ++ nameBase (atType arithtype))
+ name = mkName (fname ++ "indexVector" ++ nameBase (atType arithtype))
c_op = varE (mkName ("c_extremum_" ++ fname ++ "_" ++ atCName arithtype))
sequence [SigD name <$>
[t| forall n. RS.Array n $ttyp -> [Int] |]
@@ -350,8 +350,8 @@ class NumElt a where
numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a
numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
- numEltArgMin :: RS.Array n a -> [Int]
- numEltArgMax :: RS.Array n a -> [Int]
+ numEltMinIndex :: RS.Array n a -> [Int]
+ numEltMaxIndex :: RS.Array n a -> [Int]
instance NumElt Int32 where
numEltAdd = addVectorInt32
@@ -362,8 +362,8 @@ instance NumElt Int32 where
numEltSignum = signumVectorInt32
numEltSum1Inner = sum1VectorInt32
numEltProduct1Inner = product1VectorInt32
- numEltArgMin = argminVectorInt32
- numEltArgMax = argmaxVectorInt32
+ numEltMinIndex = minindexVectorInt32
+ numEltMaxIndex = maxindexVectorInt32
instance NumElt Int64 where
numEltAdd = addVectorInt64
@@ -374,8 +374,8 @@ instance NumElt Int64 where
numEltSignum = signumVectorInt64
numEltSum1Inner = sum1VectorInt64
numEltProduct1Inner = product1VectorInt64
- numEltArgMin = argminVectorInt64
- numEltArgMax = argmaxVectorInt64
+ numEltMinIndex = minindexVectorInt64
+ numEltMaxIndex = maxindexVectorInt64
instance NumElt Float where
numEltAdd = addVectorFloat
@@ -386,8 +386,8 @@ instance NumElt Float where
numEltSignum = signumVectorFloat
numEltSum1Inner = sum1VectorFloat
numEltProduct1Inner = product1VectorFloat
- numEltArgMin = argminVectorFloat
- numEltArgMax = argmaxVectorFloat
+ numEltMinIndex = minindexVectorFloat
+ numEltMaxIndex = maxindexVectorFloat
instance NumElt Double where
numEltAdd = addVectorDouble
@@ -398,8 +398,8 @@ instance NumElt Double where
numEltSignum = signumVectorDouble
numEltSum1Inner = sum1VectorDouble
numEltProduct1Inner = product1VectorDouble
- numEltArgMin = argminVectorDouble
- numEltArgMax = argmaxVectorDouble
+ numEltMinIndex = minindexVectorDouble
+ numEltMaxIndex = maxindexVectorDouble
instance NumElt Int where
numEltAdd = intWidBranch2 @Int (+)
@@ -420,8 +420,8 @@ instance NumElt Int where
numEltProduct1Inner = intWidBranchRed @Int
(c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))
(c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
- numEltArgMin = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64
- numEltArgMax = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64
+ numEltMinIndex = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64
+ numEltMaxIndex = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64
instance NumElt CInt where
numEltAdd = intWidBranch2 @CInt (+)
@@ -442,8 +442,8 @@ instance NumElt CInt where
numEltProduct1Inner = intWidBranchRed @CInt
(c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))
(c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
- numEltArgMin = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64
- numEltArgMax = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64
+ numEltMinIndex = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64
+ numEltMaxIndex = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64
class FloatElt a where
floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index 31c4e55..a0de08b 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -789,14 +789,14 @@ miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a
miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn)
-- | Throws if the array is empty.
-margMinPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
-margMinPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
- ixxFromList (ssxFromShape sh) (numEltArgMin arr)
+mminIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
+mminIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
+ ixxFromList (ssxFromShape sh) (numEltMinIndex arr)
-- | Throws if the array is empty.
-margMaxPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
-margMaxPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
- ixxFromList (ssxFromShape sh) (numEltArgMax arr)
+mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
+mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
+ ixxFromList (ssxFromShape sh) (numEltMaxIndex arr)
mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a)
mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr)
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index c16cfb7..589f0c1 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -450,16 +450,16 @@ riota :: (Enum a, PrimElt a, Elt a) => Int -> Ranked 1 a
riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota
-- | Throws if the array is empty.
-rargMinPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n
-rargMinPrim rarr@(Ranked arr)
+rminIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n
+rminIndexPrim rarr@(Ranked arr)
| Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
- = ixCvtXR (margMinPrim arr)
+ = ixCvtXR (mminIndexPrim arr)
-- | Throws if the array is empty.
-rargMaxPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n
-rargMaxPrim rarr@(Ranked arr)
+rmaxIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n
+rmaxIndexPrim rarr@(Ranked arr)
| Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
- = ixCvtXR (margMaxPrim arr)
+ = ixCvtXR (mmaxIndexPrim arr)
rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a)
rtoXArrayPrimP (Ranked arr) = first shCvtXR' (mtoXArrayPrimP arr)
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index fae486b..ca3fd45 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -374,12 +374,12 @@ siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a
siota sn = Shaped (miota sn)
-- | Throws if the array is empty.
-sargMinPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
-sargMinPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (margMinPrim arr)
+sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
+sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminIndexPrim arr)
-- | Throws if the array is empty.
-sargMaxPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
-sargMaxPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (margMaxPrim arr)
+smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
+smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr)
stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a)
stoXArrayPrimP (Shaped arr) = first shCvtXS' (mtoXArrayPrimP arr)