From 1f3d57e13441f86b97ee7ff213bb4a677e31f2db Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 9 Jun 2024 23:09:19 +0200 Subject: argmin and argmax --- cbits/arith.c | 40 +++++++++++++++- src/Data/Array/Mixed/Internal/Arith.hs | 63 ++++++++++++++++++++++++++ src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 7 +++ src/Data/Array/Nested/Internal/Mixed.hs | 10 ++++ src/Data/Array/Nested/Internal/Ranked.hs | 12 +++++ src/Data/Array/Nested/Internal/Shaped.hs | 8 ++++ 6 files changed, 139 insertions(+), 1 deletion(-) diff --git a/cbits/arith.c b/cbits/arith.c index e20578b..6ac49b8 100644 --- a/cbits/arith.c +++ b/cbits/arith.c @@ -164,6 +164,42 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); } } \ } +// preconditions +// - all strides are >0 +// - shape is everywhere >0 +// - rank is >= 1 +// Writes extreme index to outidx. If 'cmp' is '<', computes argmin; if '>', argmax. +#define EXTREMUM_OP(name, cmp, typ) \ + void oxarop_extremum_ ## name ## _ ## typ(i64 *outidx, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \ + typ best = arr[0]; \ + memset(outidx, 0, rank * sizeof(i64)); \ + if (strides[rank - 1] == 1) { \ + TARRAY_WALK_NOINNER(again1, rank, shape, strides, { \ + bool found = false; \ + for (i64 i = 0; i < shape[rank - 1]; i++) { \ + if (arr[arrlinidx + i] cmp best) { \ + best = arr[arrlinidx + i]; \ + found = true; \ + outidx[rank - 1] = i; \ + } \ + } \ + if (found) memcpy(outidx, idx, (rank - 1) * sizeof(i64)); \ + }); \ + } else { \ + TARRAY_WALK_NOINNER(again2, rank, shape, strides, { \ + bool found = false; \ + for (i64 i = 0; i < shape[rank - 1]; i++) { \ + if (arr[arrlinidx + i] cmp best) { \ + best = arr[arrlinidx + strides[rank - 1] * i]; \ + found = true; \ + outidx[rank - 1] = i; \ + } \ + } \ + if (found) memcpy(outidx, idx, (rank - 1) * sizeof(i64)); \ + }); \ + } \ + } + /***************************************************************************** * Entry point functions * @@ -332,7 +368,9 @@ enum redop_tag_t { REDUCE1_OP(product1, *, typ) \ ENTRY_BINARY_OPS(typ) \ ENTRY_UNARY_OPS(typ) \ - ENTRY_REDUCE_OPS(typ) + ENTRY_REDUCE_OPS(typ) \ + EXTREMUM_OP(min, <, typ) \ + EXTREMUM_OP(max, >, typ) NUM_TYPES_XLIST #undef X diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 91f994b..a57de1d 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -183,6 +183,34 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride . RS.fromVector @_ @lenFm1 (init shF) <$> VS.unsafeFreeze outv +-- | Find extremum (argmin or argmax) in full array +{-# NOINLINE vectorExtremumOp #-} +vectorExtremumOp :: forall a b n. Storable a + => (Ptr a -> Ptr b) + -> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ extremum kernel + -> RS.Array n a -> [Int] -- result length: n +vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec))) + | null sh = [] + | any (<= 0) sh = error "Extremum (argmin/argmax): empty array" + -- now the input array is nonempty + | all (<= 0) strides = 0 <$ sh + | otherwise = + let -- replicated dimensions: dimensions with zero stride. The extremum + -- kernel need not concern itself with those (and in fact has a + -- precondition that there are no such dimensions in its input). + replDims = map (== 0) strides + -- filter out replicated dimensions + (shF, stridesF) = unzip $ map fst $ filter (not . snd) (zip (zip sh strides) replDims) + ndimsF = length shF -- > 0, because not all strides were <=0 + in unsafePerformIO $ do + outv <- VSM.unsafeNew (length shF) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> + VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF -> + VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec -> + fextrem poutv (fromIntegral ndimsF) pshF pstridesF (ptrconv pvec) + map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outv + flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ()) -> Int64 -> Ptr a -> Ptr a -> a -> IO () flipOp f n out v s = f n out s v @@ -246,6 +274,16 @@ $(fmap concat . forM typesList $ \arithtype -> do ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |] return $ FunD name [Clause [] (NormalB body) []]]) +$(fmap concat . forM typesList $ \arithtype -> + fmap concat . forM ["min", "max"] $ \fname -> do + let ttyp = conT (atType arithtype) + name = mkName ("arg" ++ fname ++ "Vector" ++ nameBase (atType arithtype)) + c_op = varE (mkName ("c_extremum_" ++ fname ++ "_" ++ atCName arithtype)) + sequence [SigD name <$> + [t| forall n. RS.Array n $ttyp -> [Int] |] + ,do body <- [| vectorExtremumOp id $c_op |] + return $ FunD name [Clause [] (NormalB body) []]]) + -- This branch is ostensibly a runtime branch, but will (hopefully) be -- constant-folded away by GHC. intWidBranch1 :: forall i n. (FiniteBits i, Storable i) @@ -286,6 +324,17 @@ intWidBranchRed fsc32 fred32 fsc64 fred64 sn | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64 | otherwise = error "Unsupported Int width" +intWidBranchExtr :: forall i n. (FiniteBits i, Storable i, Integral i) + => -- int32 + (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ extremum kernel + -- int64 + -> (Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ extremum kernel + -> (RS.Array n i -> [Int]) +intWidBranchExtr fextr32 fextr64 + | finiteBitSize (undefined :: i) == 32 = vectorExtremumOp @i @Int32 castPtr fextr32 + | finiteBitSize (undefined :: i) == 64 = vectorExtremumOp @i @Int64 castPtr fextr64 + | otherwise = error "Unsupported Int width" + class NumElt a where numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a @@ -295,6 +344,8 @@ class NumElt a where numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a + numEltArgMin :: RS.Array n a -> [Int] + numEltArgMax :: RS.Array n a -> [Int] instance NumElt Int32 where numEltAdd = addVectorInt32 @@ -305,6 +356,8 @@ instance NumElt Int32 where numEltSignum = signumVectorInt32 numEltSum1Inner = sum1VectorInt32 numEltProduct1Inner = product1VectorInt32 + numEltArgMin = argminVectorInt32 + numEltArgMax = argmaxVectorInt32 instance NumElt Int64 where numEltAdd = addVectorInt64 @@ -315,6 +368,8 @@ instance NumElt Int64 where numEltSignum = signumVectorInt64 numEltSum1Inner = sum1VectorInt64 numEltProduct1Inner = product1VectorInt64 + numEltArgMin = argminVectorInt64 + numEltArgMax = argmaxVectorInt64 instance NumElt Float where numEltAdd = addVectorFloat @@ -325,6 +380,8 @@ instance NumElt Float where numEltSignum = signumVectorFloat numEltSum1Inner = sum1VectorFloat numEltProduct1Inner = product1VectorFloat + numEltArgMin = argminVectorFloat + numEltArgMax = argmaxVectorFloat instance NumElt Double where numEltAdd = addVectorDouble @@ -335,6 +392,8 @@ instance NumElt Double where numEltSignum = signumVectorDouble numEltSum1Inner = sum1VectorDouble numEltProduct1Inner = product1VectorDouble + numEltArgMin = argminVectorDouble + numEltArgMax = argmaxVectorDouble instance NumElt Int where numEltAdd = intWidBranch2 @Int (+) @@ -355,6 +414,8 @@ instance NumElt Int where numEltProduct1Inner = intWidBranchRed @Int (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) + numEltArgMin = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64 + numEltArgMax = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64 instance NumElt CInt where numEltAdd = intWidBranch2 @CInt (+) @@ -375,6 +436,8 @@ instance NumElt CInt where numEltProduct1Inner = intWidBranchRed @CInt (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1)) (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) + numEltArgMin = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64 + numEltArgMax = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64 class FloatElt a where floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs index 6fc7229..0bd72e8 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -53,3 +53,10 @@ $(fmap concat . forM typesList $ \arithtype -> do let base = "reduce_" ++ atCName arithtype pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]) + +$(fmap concat . forM typesList $ \arithtype -> + fmap concat . forM ["min", "max"] $ \fname -> do + let ttyp = conT (atType arithtype) + let base = "extremum_" ++ fname ++ "_" ++ atCName arithtype + pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> + [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index 4746f31..31c4e55 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -788,6 +788,16 @@ mreshape sh' arr = miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn) +-- | Throws if the array is empty. +margMinPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh +margMinPrim (toPrimitive -> M_Primitive sh (XArray arr)) = + ixxFromList (ssxFromShape sh) (numEltArgMin arr) + +-- | Throws if the array is empty. +margMaxPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh +margMaxPrim (toPrimitive -> M_Primitive sh (XArray arr)) = + ixxFromList (ssxFromShape sh) (numEltArgMax arr) + mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a) mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr) diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index 55ae59f..c16cfb7 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -449,6 +449,18 @@ rreshape sh' rarr@(Ranked arr) riota :: (Enum a, PrimElt a, Elt a) => Int -> Ranked 1 a riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota +-- | Throws if the array is empty. +rargMinPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n +rargMinPrim rarr@(Ranked arr) + | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) + = ixCvtXR (margMinPrim arr) + +-- | Throws if the array is empty. +rargMaxPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n +rargMaxPrim rarr@(Ranked arr) + | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) + = ixCvtXR (margMaxPrim arr) + rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a) rtoXArrayPrimP (Ranked arr) = first shCvtXR' (mtoXArrayPrimP arr) diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index 544a2fa..fae486b 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -373,6 +373,14 @@ sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr) siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a siota sn = Shaped (miota sn) +-- | Throws if the array is empty. +sargMinPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh +sargMinPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (margMinPrim arr) + +-- | Throws if the array is empty. +sargMaxPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh +sargMaxPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (margMaxPrim arr) + stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) stoXArrayPrimP (Shaped arr) = first shCvtXS' (mtoXArrayPrimP arr) -- cgit v1.2.3-70-g09d2