aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed/Internal/Arith.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Mixed/Internal/Arith.hs')
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs8
1 files changed, 7 insertions, 1 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index a57de1d..32ed355 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -183,6 +183,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride
. RS.fromVector @_ @lenFm1 (init shF)
<$> VS.unsafeFreeze outv
+-- TODO: test this function
-- | Find extremum (argmin or argmax) in full array
{-# NOINLINE vectorExtremumOp #-}
vectorExtremumOp :: forall a b n. Storable a
@@ -202,6 +203,11 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))
-- filter out replicated dimensions
(shF, stridesF) = unzip $ map fst $ filter (not . snd) (zip (zip sh strides) replDims)
ndimsF = length shF -- > 0, because not all strides were <=0
+ -- function to insert zeros in replicated-out dimensions
+ insertZeros [] idx = idx
+ insertZeros (True : repls) idx = 0 : insertZeros repls idx
+ insertZeros (False : repls) (i : idx) = i : insertZeros repls idx
+ insertZeros (_:_) [] = error "unreachable"
in unsafePerformIO $ do
outv <- VSM.unsafeNew (length shF)
VSM.unsafeWith outv $ \poutv ->
@@ -209,7 +215,7 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))
VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF ->
VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec ->
fextrem poutv (fromIntegral ndimsF) pshF pstridesF (ptrconv pvec)
- map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outv
+ insertZeros replDims . map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outv
flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())
-> Int64 -> Ptr a -> Ptr a -> a -> IO ()