From 1f3d57e13441f86b97ee7ff213bb4a677e31f2db Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 9 Jun 2024 23:09:19 +0200 Subject: argmin and argmax --- src/Data/Array/Nested/Internal/Mixed.hs | 10 ++++++++++ src/Data/Array/Nested/Internal/Ranked.hs | 12 ++++++++++++ src/Data/Array/Nested/Internal/Shaped.hs | 8 ++++++++ 3 files changed, 30 insertions(+) (limited to 'src/Data/Array/Nested') diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index 4746f31..31c4e55 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -788,6 +788,16 @@ mreshape sh' arr = miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn) +-- | Throws if the array is empty. +margMinPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh +margMinPrim (toPrimitive -> M_Primitive sh (XArray arr)) = + ixxFromList (ssxFromShape sh) (numEltArgMin arr) + +-- | Throws if the array is empty. +margMaxPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh +margMaxPrim (toPrimitive -> M_Primitive sh (XArray arr)) = + ixxFromList (ssxFromShape sh) (numEltArgMax arr) + mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a) mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr) diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index 55ae59f..c16cfb7 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -449,6 +449,18 @@ rreshape sh' rarr@(Ranked arr) riota :: (Enum a, PrimElt a, Elt a) => Int -> Ranked 1 a riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota +-- | Throws if the array is empty. +rargMinPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n +rargMinPrim rarr@(Ranked arr) + | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) + = ixCvtXR (margMinPrim arr) + +-- | Throws if the array is empty. +rargMaxPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n +rargMaxPrim rarr@(Ranked arr) + | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) + = ixCvtXR (margMaxPrim arr) + rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a) rtoXArrayPrimP (Ranked arr) = first shCvtXR' (mtoXArrayPrimP arr) diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index 544a2fa..fae486b 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -373,6 +373,14 @@ sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr) siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a siota sn = Shaped (miota sn) +-- | Throws if the array is empty. +sargMinPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh +sargMinPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (margMinPrim arr) + +-- | Throws if the array is empty. +sargMaxPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh +sargMaxPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (margMaxPrim arr) + stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) stoXArrayPrimP (Shaped arr) = first shCvtXS' (mtoXArrayPrimP arr) -- cgit v1.2.3-70-g09d2