diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-06-03 21:29:53 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-06-03 21:29:53 +0200 |
commit | c5108efd1402dcb52beca27d13b4880eed35ef5b (patch) | |
tree | b25e4ee26c1f894671db2e68c0afdaf6a1378cb5 /src/Data/Array/Mixed | |
parent | 0fd727dcb3fe05816aa9c68be5ebac84a55fcf4b (diff) |
Properly test C reductions
Diffstat (limited to 'src/Data/Array/Mixed')
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 20 |
1 files changed, 15 insertions, 5 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index bb3ee4a..6417413 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -22,6 +22,7 @@ import Foreign.C.Types import Foreign.Ptr import Foreign.Storable (Storable) import GHC.TypeLits +import GHC.TypeNats qualified as TypeNats import Language.Haskell.TH import System.IO.Unsafe @@ -133,7 +134,6 @@ vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases VS.unsafeFreeze outv | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy) --- TODO: test all the weird cases of this function -- | Reduce along the inner dimension {-# NOINLINE vectorRedInnerOp #-} vectorRedInnerOp :: forall a b n. (Num a, Storable a) @@ -155,9 +155,15 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride (RS.A (RG.A (init sh) (OI.T (init strides) offset vec))) -- now there is useful work along the inner dimension | otherwise = - let -- filter out zero-stride dimensions; the reduction kernel need not concern itself with those - (shF, stridesF) = unzip $ filter ((/= 0) . snd) (zip sh strides) - ndimsF = length shF + let -- replicated dimensions: dimensions with zero stride. The reduction + -- kernel need not concern itself with those (and in fact has a + -- precondition that there are no such dimensions in its input). + replDims = map (== 0) strides + -- filter out replicated dimensions + (shF, stridesF) = unzip $ map fst $ filter (not . snd) (zip (zip sh strides) replDims) + -- replace replicated dimensions with ones + shOnes = zipWith (\n repl -> if repl then 1 else n) sh replDims + ndimsF = length shF -- > 0, otherwise `last strides == 0` in unsafePerformIO $ do outv <- VSM.unsafeNew (product (init shF)) VSM.unsafeWith outv $ \poutv -> @@ -165,7 +171,11 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF -> VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec -> fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec) - RS.fromVector (init sh) <$> VS.unsafeFreeze outv + TypeNats.withSomeSNat (fromIntegral (ndimsF - 1)) $ \(SNat :: SNat lenFm1) -> + RS.stretch (init sh) + . RS.reshape (init shOnes) + . RS.fromVector @_ @lenFm1 (init shF) + <$> VS.unsafeFreeze outv flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ()) -> Int64 -> Ptr a -> Ptr a -> a -> IO () |