From 00432bf95d1e3f756ff4b805389897b9dff2e169 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Tue, 18 Feb 2025 00:40:51 +0100 Subject: arith: Fix unary op stride bugs --- src/Data/Array/Mixed/Internal/Arith.hs | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) (limited to 'src/Data') diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 58108f2..a403d3c 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -58,20 +58,23 @@ liftOpEltwise1 :: (Storable a, Storable b) -> RS.Array n a -> RS.Array n b liftOpEltwise1 sn@SNat ptrconv1 ptrconv2 cf_strided (RS.A (RG.A sh (OI.T strides offset vec))) -- TODO: less code duplication between these two branches - | Just (blockOff, blockSz) <- stridesDense sh offset strides = unsafePerformIO $ do - outv <- VSM.unsafeNew blockSz - VSM.unsafeWith outv $ \poutv -> - VS.unsafeWith (VS.singleton (fromIntegral blockSz)) $ \psh -> - VS.unsafeWith (VS.singleton 1) $ \pstrides -> - VS.unsafeWith (VS.slice blockOff blockSz vec) $ \pv -> - cf_strided 1 (ptrconv2 poutv) psh pstrides (ptrconv1 pv) - RS.A . RG.A sh . OI.T strides (offset - blockOff) <$> VS.unsafeFreeze outv + | Just (blockOff, blockSz) <- stridesDense sh offset strides = + if blockSz == 0 + then RS.A (RG.A sh (OI.T (map (const 0) strides) 0 VS.empty)) + else unsafePerformIO $ do + outv <- VSM.unsafeNew blockSz + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith (VS.singleton (fromIntegral blockSz)) $ \psh -> + VS.unsafeWith (VS.singleton 1) $ \pstrides -> + VS.unsafeWith (VS.slice blockOff blockSz vec) $ \pv -> + cf_strided 1 (ptrconv2 poutv) psh pstrides (ptrconv1 pv) + RS.A . RG.A sh . OI.T strides (offset - blockOff) <$> VS.unsafeFreeze outv | otherwise = unsafePerformIO $ do outv <- VSM.unsafeNew (product sh) VSM.unsafeWith outv $ \poutv -> VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh -> VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides)) $ \pstrides -> - VS.unsafeWith vec $ \pv -> + VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pv -> cf_strided (fromIntegral (fromSNat sn)) (ptrconv2 poutv) psh pstrides (ptrconv1 pv) RS.fromVector sh <$> VS.unsafeFreeze outv -- cgit v1.2.3-70-g09d2