From 1f3d57e13441f86b97ee7ff213bb4a677e31f2db Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Sun, 9 Jun 2024 23:09:19 +0200
Subject: argmin and argmax

---
 src/Data/Array/Mixed/Internal/Arith.hs         | 63 ++++++++++++++++++++++++++
 src/Data/Array/Mixed/Internal/Arith/Foreign.hs |  7 +++
 2 files changed, 70 insertions(+)

(limited to 'src/Data/Array/Mixed')

diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index 91f994b..a57de1d 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -183,6 +183,34 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride
                . RS.fromVector @_ @lenFm1 (init shF)
                <$> VS.unsafeFreeze outv
 
+-- | Find extremum (argmin or argmax) in full array
+{-# NOINLINE vectorExtremumOp #-}
+vectorExtremumOp :: forall a b n. Storable a
+                 => (Ptr a -> Ptr b)
+                 -> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ())  -- ^ extremum kernel
+                 -> RS.Array n a -> [Int]  -- result length: n
+vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))
+  | null sh = []
+  | any (<= 0) sh = error "Extremum (argmin/argmax): empty array"
+  -- now the input array is nonempty
+  | all (<= 0) strides = 0 <$ sh
+  | otherwise =
+      let -- replicated dimensions: dimensions with zero stride. The extremum
+          -- kernel need not concern itself with those (and in fact has a
+          -- precondition that there are no such dimensions in its input).
+          replDims = map (== 0) strides
+          -- filter out replicated dimensions
+          (shF, stridesF) = unzip $ map fst $ filter (not . snd) (zip (zip sh strides) replDims)
+          ndimsF = length shF  -- > 0, because not all strides were <=0
+      in unsafePerformIO $ do
+           outv <- VSM.unsafeNew (length shF)
+           VSM.unsafeWith outv $ \poutv ->
+             VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF ->
+               VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF ->
+                 VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec ->
+                   fextrem poutv (fromIntegral ndimsF) pshF pstridesF (ptrconv pvec)
+           map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outv
+
 flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())
        ->  Int64 -> Ptr a -> Ptr a -> a -> IO ()
 flipOp f n out v s = f n out s v
@@ -246,6 +274,16 @@ $(fmap concat . forM typesList $ \arithtype -> do
                ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |]
                    return $ FunD name [Clause [] (NormalB body) []]])
 
+$(fmap concat . forM typesList $ \arithtype ->
+    fmap concat . forM ["min", "max"] $ \fname -> do
+      let ttyp = conT (atType arithtype)
+          name = mkName ("arg" ++ fname ++ "Vector" ++ nameBase (atType arithtype))
+          c_op = varE (mkName ("c_extremum_" ++ fname ++ "_" ++ atCName arithtype))
+      sequence [SigD name <$>
+                     [t| forall n. RS.Array n $ttyp -> [Int] |]
+               ,do body <- [| vectorExtremumOp id $c_op |]
+                   return $ FunD name [Clause [] (NormalB body) []]])
+
 -- This branch is ostensibly a runtime branch, but will (hopefully) be
 -- constant-folded away by GHC.
 intWidBranch1 :: forall i n. (FiniteBits i, Storable i)
@@ -286,6 +324,17 @@ intWidBranchRed fsc32 fred32 fsc64 fred64 sn
   | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64
   | otherwise = error "Unsupported Int width"
 
+intWidBranchExtr :: forall i n. (FiniteBits i, Storable i, Integral i)
+                 => -- int32
+                    (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ())  -- ^ extremum kernel
+                    -- int64
+                 -> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- ^ extremum kernel
+                 -> (RS.Array n i -> [Int])
+intWidBranchExtr fextr32 fextr64
+  | finiteBitSize (undefined :: i) == 32 = vectorExtremumOp @i @Int32 castPtr fextr32
+  | finiteBitSize (undefined :: i) == 64 = vectorExtremumOp @i @Int64 castPtr fextr64
+  | otherwise = error "Unsupported Int width"
+
 class NumElt a where
   numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
   numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
@@ -295,6 +344,8 @@ class NumElt a where
   numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a
   numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
   numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
+  numEltArgMin :: RS.Array n a -> [Int]
+  numEltArgMax :: RS.Array n a -> [Int]
 
 instance NumElt Int32 where
   numEltAdd = addVectorInt32
@@ -305,6 +356,8 @@ instance NumElt Int32 where
   numEltSignum = signumVectorInt32
   numEltSum1Inner = sum1VectorInt32
   numEltProduct1Inner = product1VectorInt32
+  numEltArgMin = argminVectorInt32
+  numEltArgMax = argmaxVectorInt32
 
 instance NumElt Int64 where
   numEltAdd = addVectorInt64
@@ -315,6 +368,8 @@ instance NumElt Int64 where
   numEltSignum = signumVectorInt64
   numEltSum1Inner = sum1VectorInt64
   numEltProduct1Inner = product1VectorInt64
+  numEltArgMin = argminVectorInt64
+  numEltArgMax = argmaxVectorInt64
 
 instance NumElt Float where
   numEltAdd = addVectorFloat
@@ -325,6 +380,8 @@ instance NumElt Float where
   numEltSignum = signumVectorFloat
   numEltSum1Inner = sum1VectorFloat
   numEltProduct1Inner = product1VectorFloat
+  numEltArgMin = argminVectorFloat
+  numEltArgMax = argmaxVectorFloat
 
 instance NumElt Double where
   numEltAdd = addVectorDouble
@@ -335,6 +392,8 @@ instance NumElt Double where
   numEltSignum = signumVectorDouble
   numEltSum1Inner = sum1VectorDouble
   numEltProduct1Inner = product1VectorDouble
+  numEltArgMin = argminVectorDouble
+  numEltArgMax = argmaxVectorDouble
 
 instance NumElt Int where
   numEltAdd = intWidBranch2 @Int (+)
@@ -355,6 +414,8 @@ instance NumElt Int where
   numEltProduct1Inner = intWidBranchRed @Int
                           (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))
                           (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
+  numEltArgMin = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64
+  numEltArgMax = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64
 
 instance NumElt CInt where
   numEltAdd = intWidBranch2 @CInt (+)
@@ -375,6 +436,8 @@ instance NumElt CInt where
   numEltProduct1Inner = intWidBranchRed @CInt
                           (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))
                           (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
+  numEltArgMin = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64
+  numEltArgMax = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64
 
 class FloatElt a where
   floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
index 6fc7229..0bd72e8 100644
--- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
+++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
@@ -53,3 +53,10 @@ $(fmap concat . forM typesList $ \arithtype -> do
     let base = "reduce_" ++ atCName arithtype
     pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
       [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+
+$(fmap concat . forM typesList $ \arithtype ->
+    fmap concat . forM ["min", "max"] $ \fname -> do
+      let ttyp = conT (atType arithtype)
+      let base = "extremum_" ++ fname ++ "_" ++ atCName arithtype
+      pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
+        [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
-- 
cgit v1.2.3-70-g09d2