aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-06-09 23:09:19 +0200
committerTom Smeding <tom@tomsmeding.com>2024-06-09 23:09:19 +0200
commit1f3d57e13441f86b97ee7ff213bb4a677e31f2db (patch)
treee72bfd568b032a9af611118038c2eeb6f347ea22 /src/Data/Array/Nested
parentc8f99847359a92289cf0ded280069794f6abae6a (diff)
argmin and argmax
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs10
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs12
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs8
3 files changed, 30 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index 4746f31..31c4e55 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -788,6 +788,16 @@ mreshape sh' arr =
miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a
miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn)
+-- | Throws if the array is empty.
+margMinPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
+margMinPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
+ ixxFromList (ssxFromShape sh) (numEltArgMin arr)
+
+-- | Throws if the array is empty.
+margMaxPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
+margMaxPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
+ ixxFromList (ssxFromShape sh) (numEltArgMax arr)
+
mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a)
mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr)
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index 55ae59f..c16cfb7 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -449,6 +449,18 @@ rreshape sh' rarr@(Ranked arr)
riota :: (Enum a, PrimElt a, Elt a) => Int -> Ranked 1 a
riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota
+-- | Throws if the array is empty.
+rargMinPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n
+rargMinPrim rarr@(Ranked arr)
+ | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
+ = ixCvtXR (margMinPrim arr)
+
+-- | Throws if the array is empty.
+rargMaxPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n
+rargMaxPrim rarr@(Ranked arr)
+ | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
+ = ixCvtXR (margMaxPrim arr)
+
rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a)
rtoXArrayPrimP (Ranked arr) = first shCvtXR' (mtoXArrayPrimP arr)
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index 544a2fa..fae486b 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -373,6 +373,14 @@ sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr)
siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a
siota sn = Shaped (miota sn)
+-- | Throws if the array is empty.
+sargMinPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
+sargMinPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (margMinPrim arr)
+
+-- | Throws if the array is empty.
+sargMaxPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
+sargMaxPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (margMaxPrim arr)
+
stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a)
stoXArrayPrimP (Shaped arr) = first shCvtXS' (mtoXArrayPrimP arr)