diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-06-10 09:51:15 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-06-10 09:51:15 +0200 | 
| commit | 8274da734aba266e86ac722b6a9e73afeeae59e6 (patch) | |
| tree | ca8d246907527217d64d232de1c0d10ac21db26f /src | |
| parent | 1f3d57e13441f86b97ee7ff213bb4a677e31f2db (diff) | |
Fix extremum for replicated input arrays
Diffstat (limited to 'src')
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 8 | 
1 files changed, 7 insertions, 1 deletions
| diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index a57de1d..32ed355 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -183,6 +183,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride                 . RS.fromVector @_ @lenFm1 (init shF)                 <$> VS.unsafeFreeze outv +-- TODO: test this function  -- | Find extremum (argmin or argmax) in full array  {-# NOINLINE vectorExtremumOp #-}  vectorExtremumOp :: forall a b n. Storable a @@ -202,6 +203,11 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))            -- filter out replicated dimensions            (shF, stridesF) = unzip $ map fst $ filter (not . snd) (zip (zip sh strides) replDims)            ndimsF = length shF  -- > 0, because not all strides were <=0 +          -- function to insert zeros in replicated-out dimensions +          insertZeros [] idx = idx +          insertZeros (True : repls) idx = 0 : insertZeros repls idx +          insertZeros (False : repls) (i : idx) = i : insertZeros repls idx +          insertZeros (_:_) [] = error "unreachable"        in unsafePerformIO $ do             outv <- VSM.unsafeNew (length shF)             VSM.unsafeWith outv $ \poutv -> @@ -209,7 +215,7 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))                 VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF ->                   VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec ->                     fextrem poutv (fromIntegral ndimsF) pshF pstridesF (ptrconv pvec) -           map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outv +           insertZeros replDims . map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outv  flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())         ->  Int64 -> Ptr a -> Ptr a -> a -> IO () | 
