diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-06-10 09:51:15 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-06-10 09:51:15 +0200 |
commit | 8274da734aba266e86ac722b6a9e73afeeae59e6 (patch) | |
tree | ca8d246907527217d64d232de1c0d10ac21db26f | |
parent | 1f3d57e13441f86b97ee7ff213bb4a677e31f2db (diff) |
Fix extremum for replicated input arrays
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index a57de1d..32ed355 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -183,6 +183,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride . RS.fromVector @_ @lenFm1 (init shF) <$> VS.unsafeFreeze outv +-- TODO: test this function -- | Find extremum (argmin or argmax) in full array {-# NOINLINE vectorExtremumOp #-} vectorExtremumOp :: forall a b n. Storable a @@ -202,6 +203,11 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec))) -- filter out replicated dimensions (shF, stridesF) = unzip $ map fst $ filter (not . snd) (zip (zip sh strides) replDims) ndimsF = length shF -- > 0, because not all strides were <=0 + -- function to insert zeros in replicated-out dimensions + insertZeros [] idx = idx + insertZeros (True : repls) idx = 0 : insertZeros repls idx + insertZeros (False : repls) (i : idx) = i : insertZeros repls idx + insertZeros (_:_) [] = error "unreachable" in unsafePerformIO $ do outv <- VSM.unsafeNew (length shF) VSM.unsafeWith outv $ \poutv -> @@ -209,7 +215,7 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec))) VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF -> VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec -> fextrem poutv (fromIntegral ndimsF) pshF pstridesF (ptrconv pvec) - map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outv + insertZeros replDims . map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outv flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ()) -> Int64 -> Ptr a -> Ptr a -> a -> IO () |