diff options
Diffstat (limited to 'src/Data/Array')
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 8 | 
1 files changed, 7 insertions, 1 deletions
| diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index a57de1d..32ed355 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -183,6 +183,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride                 . RS.fromVector @_ @lenFm1 (init shF)                 <$> VS.unsafeFreeze outv +-- TODO: test this function  -- | Find extremum (argmin or argmax) in full array  {-# NOINLINE vectorExtremumOp #-}  vectorExtremumOp :: forall a b n. Storable a @@ -202,6 +203,11 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))            -- filter out replicated dimensions            (shF, stridesF) = unzip $ map fst $ filter (not . snd) (zip (zip sh strides) replDims)            ndimsF = length shF  -- > 0, because not all strides were <=0 +          -- function to insert zeros in replicated-out dimensions +          insertZeros [] idx = idx +          insertZeros (True : repls) idx = 0 : insertZeros repls idx +          insertZeros (False : repls) (i : idx) = i : insertZeros repls idx +          insertZeros (_:_) [] = error "unreachable"        in unsafePerformIO $ do             outv <- VSM.unsafeNew (length shF)             VSM.unsafeWith outv $ \poutv -> @@ -209,7 +215,7 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))                 VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF ->                   VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec ->                     fextrem poutv (fromIntegral ndimsF) pshF pstridesF (ptrconv pvec) -           map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outv +           insertZeros replDims . map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outv  flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())         ->  Int64 -> Ptr a -> Ptr a -> a -> IO () | 
