aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-02-18 00:40:51 +0100
committerTom Smeding <tom@tomsmeding.com>2025-02-18 00:41:05 +0100
commit00432bf95d1e3f756ff4b805389897b9dff2e169 (patch)
tree37b2ead601f8e14ffb9ea53c9f3aa5f49cc6e9bd
parent7abd6dd42ded4e18787464e5eff111c05ac659c6 (diff)
arith: Fix unary op stride bugs
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs21
1 files changed, 12 insertions, 9 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index 58108f2..a403d3c 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -58,20 +58,23 @@ liftOpEltwise1 :: (Storable a, Storable b)
-> RS.Array n a -> RS.Array n b
liftOpEltwise1 sn@SNat ptrconv1 ptrconv2 cf_strided (RS.A (RG.A sh (OI.T strides offset vec)))
-- TODO: less code duplication between these two branches
- | Just (blockOff, blockSz) <- stridesDense sh offset strides = unsafePerformIO $ do
- outv <- VSM.unsafeNew blockSz
- VSM.unsafeWith outv $ \poutv ->
- VS.unsafeWith (VS.singleton (fromIntegral blockSz)) $ \psh ->
- VS.unsafeWith (VS.singleton 1) $ \pstrides ->
- VS.unsafeWith (VS.slice blockOff blockSz vec) $ \pv ->
- cf_strided 1 (ptrconv2 poutv) psh pstrides (ptrconv1 pv)
- RS.A . RG.A sh . OI.T strides (offset - blockOff) <$> VS.unsafeFreeze outv
+ | Just (blockOff, blockSz) <- stridesDense sh offset strides =
+ if blockSz == 0
+ then RS.A (RG.A sh (OI.T (map (const 0) strides) 0 VS.empty))
+ else unsafePerformIO $ do
+ outv <- VSM.unsafeNew blockSz
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.singleton (fromIntegral blockSz)) $ \psh ->
+ VS.unsafeWith (VS.singleton 1) $ \pstrides ->
+ VS.unsafeWith (VS.slice blockOff blockSz vec) $ \pv ->
+ cf_strided 1 (ptrconv2 poutv) psh pstrides (ptrconv1 pv)
+ RS.A . RG.A sh . OI.T strides (offset - blockOff) <$> VS.unsafeFreeze outv
| otherwise = unsafePerformIO $ do
outv <- VSM.unsafeNew (product sh)
VSM.unsafeWith outv $ \poutv ->
VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh ->
VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides)) $ \pstrides ->
- VS.unsafeWith vec $ \pv ->
+ VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pv ->
cf_strided (fromIntegral (fromSNat sn)) (ptrconv2 poutv) psh pstrides (ptrconv1 pv)
RS.fromVector sh <$> VS.unsafeFreeze outv