diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-02-18 00:40:51 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-02-18 00:41:05 +0100 | 
| commit | 00432bf95d1e3f756ff4b805389897b9dff2e169 (patch) | |
| tree | 37b2ead601f8e14ffb9ea53c9f3aa5f49cc6e9bd | |
| parent | 7abd6dd42ded4e18787464e5eff111c05ac659c6 (diff) | |
arith: Fix unary op stride bugs
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 21 | 
1 files changed, 12 insertions, 9 deletions
| diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 58108f2..a403d3c 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -58,20 +58,23 @@ liftOpEltwise1 :: (Storable a, Storable b)                 -> RS.Array n a -> RS.Array n b  liftOpEltwise1 sn@SNat ptrconv1 ptrconv2 cf_strided (RS.A (RG.A sh (OI.T strides offset vec)))    -- TODO: less code duplication between these two branches -  | Just (blockOff, blockSz) <- stridesDense sh offset strides = unsafePerformIO $ do -      outv <- VSM.unsafeNew blockSz -      VSM.unsafeWith outv $ \poutv -> -        VS.unsafeWith (VS.singleton (fromIntegral blockSz)) $ \psh -> -          VS.unsafeWith (VS.singleton 1) $ \pstrides -> -            VS.unsafeWith (VS.slice blockOff blockSz vec) $ \pv -> -              cf_strided 1 (ptrconv2 poutv) psh pstrides (ptrconv1 pv) -      RS.A . RG.A sh . OI.T strides (offset - blockOff) <$> VS.unsafeFreeze outv +  | Just (blockOff, blockSz) <- stridesDense sh offset strides = +      if blockSz == 0 +        then RS.A (RG.A sh (OI.T (map (const 0) strides) 0 VS.empty)) +        else unsafePerformIO $ do +               outv <- VSM.unsafeNew blockSz +               VSM.unsafeWith outv $ \poutv -> +                 VS.unsafeWith (VS.singleton (fromIntegral blockSz)) $ \psh -> +                   VS.unsafeWith (VS.singleton 1) $ \pstrides -> +                     VS.unsafeWith (VS.slice blockOff blockSz vec) $ \pv -> +                      cf_strided 1 (ptrconv2 poutv) psh pstrides (ptrconv1 pv) +               RS.A . RG.A sh . OI.T strides (offset - blockOff) <$> VS.unsafeFreeze outv    | otherwise = unsafePerformIO $ do        outv <- VSM.unsafeNew (product sh)        VSM.unsafeWith outv $ \poutv ->          VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh ->            VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides)) $ \pstrides -> -            VS.unsafeWith vec $ \pv -> +            VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pv ->                cf_strided (fromIntegral (fromSNat sn)) (ptrconv2 poutv) psh pstrides (ptrconv1 pv)        RS.fromVector sh <$> VS.unsafeFreeze outv | 
