aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs63
-rw-r--r--src/Data/Array/Mixed/Internal/Arith/Foreign.hs7
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs10
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs12
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs8
5 files changed, 100 insertions, 0 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index 91f994b..a57de1d 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -183,6 +183,34 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride
. RS.fromVector @_ @lenFm1 (init shF)
<$> VS.unsafeFreeze outv
+-- | Find extremum (argmin or argmax) in full array
+{-# NOINLINE vectorExtremumOp #-}
+vectorExtremumOp :: forall a b n. Storable a
+ => (Ptr a -> Ptr b)
+ -> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ extremum kernel
+ -> RS.Array n a -> [Int] -- result length: n
+vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))
+ | null sh = []
+ | any (<= 0) sh = error "Extremum (argmin/argmax): empty array"
+ -- now the input array is nonempty
+ | all (<= 0) strides = 0 <$ sh
+ | otherwise =
+ let -- replicated dimensions: dimensions with zero stride. The extremum
+ -- kernel need not concern itself with those (and in fact has a
+ -- precondition that there are no such dimensions in its input).
+ replDims = map (== 0) strides
+ -- filter out replicated dimensions
+ (shF, stridesF) = unzip $ map fst $ filter (not . snd) (zip (zip sh strides) replDims)
+ ndimsF = length shF -- > 0, because not all strides were <=0
+ in unsafePerformIO $ do
+ outv <- VSM.unsafeNew (length shF)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF ->
+ VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF ->
+ VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec ->
+ fextrem poutv (fromIntegral ndimsF) pshF pstridesF (ptrconv pvec)
+ map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outv
+
flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())
-> Int64 -> Ptr a -> Ptr a -> a -> IO ()
flipOp f n out v s = f n out s v
@@ -246,6 +274,16 @@ $(fmap concat . forM typesList $ \arithtype -> do
,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |]
return $ FunD name [Clause [] (NormalB body) []]])
+$(fmap concat . forM typesList $ \arithtype ->
+ fmap concat . forM ["min", "max"] $ \fname -> do
+ let ttyp = conT (atType arithtype)
+ name = mkName ("arg" ++ fname ++ "Vector" ++ nameBase (atType arithtype))
+ c_op = varE (mkName ("c_extremum_" ++ fname ++ "_" ++ atCName arithtype))
+ sequence [SigD name <$>
+ [t| forall n. RS.Array n $ttyp -> [Int] |]
+ ,do body <- [| vectorExtremumOp id $c_op |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
-- This branch is ostensibly a runtime branch, but will (hopefully) be
-- constant-folded away by GHC.
intWidBranch1 :: forall i n. (FiniteBits i, Storable i)
@@ -286,6 +324,17 @@ intWidBranchRed fsc32 fred32 fsc64 fred64 sn
| finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64
| otherwise = error "Unsupported Int width"
+intWidBranchExtr :: forall i n. (FiniteBits i, Storable i, Integral i)
+ => -- int32
+ (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ extremum kernel
+ -- int64
+ -> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ extremum kernel
+ -> (RS.Array n i -> [Int])
+intWidBranchExtr fextr32 fextr64
+ | finiteBitSize (undefined :: i) == 32 = vectorExtremumOp @i @Int32 castPtr fextr32
+ | finiteBitSize (undefined :: i) == 64 = vectorExtremumOp @i @Int64 castPtr fextr64
+ | otherwise = error "Unsupported Int width"
+
class NumElt a where
numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
@@ -295,6 +344,8 @@ class NumElt a where
numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a
numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
+ numEltArgMin :: RS.Array n a -> [Int]
+ numEltArgMax :: RS.Array n a -> [Int]
instance NumElt Int32 where
numEltAdd = addVectorInt32
@@ -305,6 +356,8 @@ instance NumElt Int32 where
numEltSignum = signumVectorInt32
numEltSum1Inner = sum1VectorInt32
numEltProduct1Inner = product1VectorInt32
+ numEltArgMin = argminVectorInt32
+ numEltArgMax = argmaxVectorInt32
instance NumElt Int64 where
numEltAdd = addVectorInt64
@@ -315,6 +368,8 @@ instance NumElt Int64 where
numEltSignum = signumVectorInt64
numEltSum1Inner = sum1VectorInt64
numEltProduct1Inner = product1VectorInt64
+ numEltArgMin = argminVectorInt64
+ numEltArgMax = argmaxVectorInt64
instance NumElt Float where
numEltAdd = addVectorFloat
@@ -325,6 +380,8 @@ instance NumElt Float where
numEltSignum = signumVectorFloat
numEltSum1Inner = sum1VectorFloat
numEltProduct1Inner = product1VectorFloat
+ numEltArgMin = argminVectorFloat
+ numEltArgMax = argmaxVectorFloat
instance NumElt Double where
numEltAdd = addVectorDouble
@@ -335,6 +392,8 @@ instance NumElt Double where
numEltSignum = signumVectorDouble
numEltSum1Inner = sum1VectorDouble
numEltProduct1Inner = product1VectorDouble
+ numEltArgMin = argminVectorDouble
+ numEltArgMax = argmaxVectorDouble
instance NumElt Int where
numEltAdd = intWidBranch2 @Int (+)
@@ -355,6 +414,8 @@ instance NumElt Int where
numEltProduct1Inner = intWidBranchRed @Int
(c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))
(c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
+ numEltArgMin = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64
+ numEltArgMax = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64
instance NumElt CInt where
numEltAdd = intWidBranch2 @CInt (+)
@@ -375,6 +436,8 @@ instance NumElt CInt where
numEltProduct1Inner = intWidBranchRed @CInt
(c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))
(c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
+ numEltArgMin = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64
+ numEltArgMax = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64
class FloatElt a where
floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
index 6fc7229..0bd72e8 100644
--- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
+++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
@@ -53,3 +53,10 @@ $(fmap concat . forM typesList $ \arithtype -> do
let base = "reduce_" ++ atCName arithtype
pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
[t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+
+$(fmap concat . forM typesList $ \arithtype ->
+ fmap concat . forM ["min", "max"] $ \fname -> do
+ let ttyp = conT (atType arithtype)
+ let base = "extremum_" ++ fname ++ "_" ++ atCName arithtype
+ pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
+ [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index 4746f31..31c4e55 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -788,6 +788,16 @@ mreshape sh' arr =
miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a
miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn)
+-- | Throws if the array is empty.
+margMinPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
+margMinPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
+ ixxFromList (ssxFromShape sh) (numEltArgMin arr)
+
+-- | Throws if the array is empty.
+margMaxPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
+margMaxPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
+ ixxFromList (ssxFromShape sh) (numEltArgMax arr)
+
mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a)
mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr)
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index 55ae59f..c16cfb7 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -449,6 +449,18 @@ rreshape sh' rarr@(Ranked arr)
riota :: (Enum a, PrimElt a, Elt a) => Int -> Ranked 1 a
riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota
+-- | Throws if the array is empty.
+rargMinPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n
+rargMinPrim rarr@(Ranked arr)
+ | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
+ = ixCvtXR (margMinPrim arr)
+
+-- | Throws if the array is empty.
+rargMaxPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n
+rargMaxPrim rarr@(Ranked arr)
+ | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
+ = ixCvtXR (margMaxPrim arr)
+
rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a)
rtoXArrayPrimP (Ranked arr) = first shCvtXR' (mtoXArrayPrimP arr)
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index 544a2fa..fae486b 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -373,6 +373,14 @@ sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr)
siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a
siota sn = Shaped (miota sn)
+-- | Throws if the array is empty.
+sargMinPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
+sargMinPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (margMinPrim arr)
+
+-- | Throws if the array is empty.
+sargMaxPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
+sargMaxPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (margMaxPrim arr)
+
stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a)
stoXArrayPrimP (Shaped arr) = first shCvtXS' (mtoXArrayPrimP arr)