diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-06-09 10:04:05 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-06-09 10:04:05 +0200 | 
| commit | d2e557efba3d7bee34dfbca9e7e791485294d0a2 (patch) | |
| tree | e0d587723f1d4a4137c2e00fb0aee937f8a56c2f /src/Data/Array/Mixed/Internal | |
| parent | 0c5d0ecb7f6a1fecc382badae60df27b4bf169a4 (diff) | |
Fix stride handling of binary arith ops
liftVEltwise2 just completely ignored the existence of strides ._.
Diffstat (limited to 'src/Data/Array/Mixed/Internal')
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 10 | 
1 files changed, 8 insertions, 2 deletions
| diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 6417413..91f994b 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -40,6 +40,7 @@ liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))        in RS.A (RG.A sh (OI.T strides 0 vec'))    | otherwise = RS.fromVector sh (f (RS.toVector arr)) +-- TODO: test all the cases of this thing with various input strides  liftVEltwise2 :: Storable a                => SNat n                -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a) @@ -54,9 +55,14 @@ liftVEltwise2 SNat f          let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2))          in RS.A (RG.A sh1 (OI.T strides1 0 vec'))        (Just 1, Just n) ->  -- scalar * dense -        RS.fromVector sh1 (f (Left (vec1 VS.! offset1)) (Right (VS.slice offset2 n vec2))) +        RS.A (RG.A sh1 (OI.T strides2 0 (f (Left (vec1 VS.! offset1)) (Right (VS.slice offset2 n vec2)))))        (Just n, Just 1) ->  -- dense * scalar -        RS.fromVector sh1 (f (Right (VS.slice offset1 n vec1)) (Left (vec2 VS.! offset2))) +        RS.A (RG.A sh1 (OI.T strides1 0 (f (Right (VS.slice offset1 n vec1)) (Left (vec2 VS.! offset2))))) +      (Just n, Just m) +        | n == m  -- not sure if this check is necessary +        , strides1 == strides2 +        ->  -- dense * dense but the strides match +          RS.A (RG.A sh1 (OI.T strides1 0 (f (Right (VS.slice offset1 n vec1)) (Right (VS.slice offset2 n vec2)))))        (_, _) ->  -- fallback case          RS.fromVector sh1 (f (Right (RS.toVector arr1)) (Right (RS.toVector arr2))) | 
