From d2e557efba3d7bee34dfbca9e7e791485294d0a2 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Sun, 9 Jun 2024 10:04:05 +0200
Subject: Fix stride handling of binary arith ops

liftVEltwise2 just completely ignored the existence of strides ._.
---
 src/Data/Array/Mixed/Internal/Arith.hs | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

(limited to 'src/Data/Array/Mixed/Internal')

diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index 6417413..91f994b 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -40,6 +40,7 @@ liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
       in RS.A (RG.A sh (OI.T strides 0 vec'))
   | otherwise = RS.fromVector sh (f (RS.toVector arr))
 
+-- TODO: test all the cases of this thing with various input strides
 liftVEltwise2 :: Storable a
               => SNat n
               -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a)
@@ -54,9 +55,14 @@ liftVEltwise2 SNat f
         let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2))
         in RS.A (RG.A sh1 (OI.T strides1 0 vec'))
       (Just 1, Just n) ->  -- scalar * dense
-        RS.fromVector sh1 (f (Left (vec1 VS.! offset1)) (Right (VS.slice offset2 n vec2)))
+        RS.A (RG.A sh1 (OI.T strides2 0 (f (Left (vec1 VS.! offset1)) (Right (VS.slice offset2 n vec2)))))
       (Just n, Just 1) ->  -- dense * scalar
-        RS.fromVector sh1 (f (Right (VS.slice offset1 n vec1)) (Left (vec2 VS.! offset2)))
+        RS.A (RG.A sh1 (OI.T strides1 0 (f (Right (VS.slice offset1 n vec1)) (Left (vec2 VS.! offset2)))))
+      (Just n, Just m)
+        | n == m  -- not sure if this check is necessary
+        , strides1 == strides2
+        ->  -- dense * dense but the strides match
+          RS.A (RG.A sh1 (OI.T strides1 0 (f (Right (VS.slice offset1 n vec1)) (Right (VS.slice offset2 n vec2)))))
       (_, _) ->  -- fallback case
         RS.fromVector sh1 (f (Right (RS.toVector arr1)) (Right (RS.toVector arr2)))
 
-- 
cgit v1.2.3-70-g09d2