diff options
Diffstat (limited to 'src/Data/Array/Mixed/Internal')
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 20 | 
1 files changed, 15 insertions, 5 deletions
| diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index bb3ee4a..6417413 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -22,6 +22,7 @@ import Foreign.C.Types  import Foreign.Ptr  import Foreign.Storable (Storable)  import GHC.TypeLits +import GHC.TypeNats qualified as TypeNats  import Language.Haskell.TH  import System.IO.Unsafe @@ -133,7 +134,6 @@ vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases            VS.unsafeFreeze outv      | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy) --- TODO: test all the weird cases of this function  -- | Reduce along the inner dimension  {-# NOINLINE vectorRedInnerOp #-}  vectorRedInnerOp :: forall a b n. (Num a, Storable a) @@ -155,9 +155,15 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride          (RS.A (RG.A (init sh) (OI.T (init strides) offset vec)))    -- now there is useful work along the inner dimension    | otherwise = -      let -- filter out zero-stride dimensions; the reduction kernel need not concern itself with those -          (shF, stridesF) = unzip $ filter ((/= 0) . snd) (zip sh strides) -          ndimsF = length shF +      let -- replicated dimensions: dimensions with zero stride. The reduction +          -- kernel need not concern itself with those (and in fact has a +          -- precondition that there are no such dimensions in its input). +          replDims = map (== 0) strides +          -- filter out replicated dimensions +          (shF, stridesF) = unzip $ map fst $ filter (not . snd) (zip (zip sh strides) replDims) +          -- replace replicated dimensions with ones +          shOnes = zipWith (\n repl -> if repl then 1 else n) sh replDims +          ndimsF = length shF  -- > 0, otherwise `last strides == 0`        in unsafePerformIO $ do             outv <- VSM.unsafeNew (product (init shF))             VSM.unsafeWith outv $ \poutv -> @@ -165,7 +171,11 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride                 VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF ->                   VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec ->                     fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec) -           RS.fromVector (init sh) <$> VS.unsafeFreeze outv +           TypeNats.withSomeSNat (fromIntegral (ndimsF - 1)) $ \(SNat :: SNat lenFm1) -> +             RS.stretch (init sh) +               . RS.reshape (init shOnes) +               . RS.fromVector @_ @lenFm1 (init shF) +               <$> VS.unsafeFreeze outv  flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())         ->  Int64 -> Ptr a -> Ptr a -> a -> IO () | 
