From 8274da734aba266e86ac722b6a9e73afeeae59e6 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 10 Jun 2024 09:51:15 +0200 Subject: Fix extremum for replicated input arrays --- src/Data/Array/Mixed/Internal/Arith.hs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'src/Data/Array/Mixed') diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index a57de1d..32ed355 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -183,6 +183,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride . RS.fromVector @_ @lenFm1 (init shF) <$> VS.unsafeFreeze outv +-- TODO: test this function -- | Find extremum (argmin or argmax) in full array {-# NOINLINE vectorExtremumOp #-} vectorExtremumOp :: forall a b n. Storable a @@ -202,6 +203,11 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec))) -- filter out replicated dimensions (shF, stridesF) = unzip $ map fst $ filter (not . snd) (zip (zip sh strides) replDims) ndimsF = length shF -- > 0, because not all strides were <=0 + -- function to insert zeros in replicated-out dimensions + insertZeros [] idx = idx + insertZeros (True : repls) idx = 0 : insertZeros repls idx + insertZeros (False : repls) (i : idx) = i : insertZeros repls idx + insertZeros (_:_) [] = error "unreachable" in unsafePerformIO $ do outv <- VSM.unsafeNew (length shF) VSM.unsafeWith outv $ \poutv -> @@ -209,7 +215,7 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec))) VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF -> VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec -> fextrem poutv (fromIntegral ndimsF) pshF pstridesF (ptrconv pvec) - map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outv + insertZeros replDims . map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outv flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ()) -> Int64 -> Ptr a -> Ptr a -> a -> IO () -- cgit v1.2.3-70-g09d2