aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed/Internal
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Mixed/Internal')
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs20
1 files changed, 15 insertions, 5 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index bb3ee4a..6417413 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -22,6 +22,7 @@ import Foreign.C.Types
import Foreign.Ptr
import Foreign.Storable (Storable)
import GHC.TypeLits
+import GHC.TypeNats qualified as TypeNats
import Language.Haskell.TH
import System.IO.Unsafe
@@ -133,7 +134,6 @@ vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases
VS.unsafeFreeze outv
| otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy)
--- TODO: test all the weird cases of this function
-- | Reduce along the inner dimension
{-# NOINLINE vectorRedInnerOp #-}
vectorRedInnerOp :: forall a b n. (Num a, Storable a)
@@ -155,9 +155,15 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride
(RS.A (RG.A (init sh) (OI.T (init strides) offset vec)))
-- now there is useful work along the inner dimension
| otherwise =
- let -- filter out zero-stride dimensions; the reduction kernel need not concern itself with those
- (shF, stridesF) = unzip $ filter ((/= 0) . snd) (zip sh strides)
- ndimsF = length shF
+ let -- replicated dimensions: dimensions with zero stride. The reduction
+ -- kernel need not concern itself with those (and in fact has a
+ -- precondition that there are no such dimensions in its input).
+ replDims = map (== 0) strides
+ -- filter out replicated dimensions
+ (shF, stridesF) = unzip $ map fst $ filter (not . snd) (zip (zip sh strides) replDims)
+ -- replace replicated dimensions with ones
+ shOnes = zipWith (\n repl -> if repl then 1 else n) sh replDims
+ ndimsF = length shF -- > 0, otherwise `last strides == 0`
in unsafePerformIO $ do
outv <- VSM.unsafeNew (product (init shF))
VSM.unsafeWith outv $ \poutv ->
@@ -165,7 +171,11 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride
VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF ->
VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec ->
fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec)
- RS.fromVector (init sh) <$> VS.unsafeFreeze outv
+ TypeNats.withSomeSNat (fromIntegral (ndimsF - 1)) $ \(SNat :: SNat lenFm1) ->
+ RS.stretch (init sh)
+ . RS.reshape (init shOnes)
+ . RS.fromVector @_ @lenFm1 (init shF)
+ <$> VS.unsafeFreeze outv
flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())
-> Int64 -> Ptr a -> Ptr a -> a -> IO ()