aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs10
1 files changed, 8 insertions, 2 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index 6417413..91f994b 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -40,6 +40,7 @@ liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
in RS.A (RG.A sh (OI.T strides 0 vec'))
| otherwise = RS.fromVector sh (f (RS.toVector arr))
+-- TODO: test all the cases of this thing with various input strides
liftVEltwise2 :: Storable a
=> SNat n
-> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a)
@@ -54,9 +55,14 @@ liftVEltwise2 SNat f
let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2))
in RS.A (RG.A sh1 (OI.T strides1 0 vec'))
(Just 1, Just n) -> -- scalar * dense
- RS.fromVector sh1 (f (Left (vec1 VS.! offset1)) (Right (VS.slice offset2 n vec2)))
+ RS.A (RG.A sh1 (OI.T strides2 0 (f (Left (vec1 VS.! offset1)) (Right (VS.slice offset2 n vec2)))))
(Just n, Just 1) -> -- dense * scalar
- RS.fromVector sh1 (f (Right (VS.slice offset1 n vec1)) (Left (vec2 VS.! offset2)))
+ RS.A (RG.A sh1 (OI.T strides1 0 (f (Right (VS.slice offset1 n vec1)) (Left (vec2 VS.! offset2)))))
+ (Just n, Just m)
+ | n == m -- not sure if this check is necessary
+ , strides1 == strides2
+ -> -- dense * dense but the strides match
+ RS.A (RG.A sh1 (OI.T strides1 0 (f (Right (VS.slice offset1 n vec1)) (Right (VS.slice offset2 n vec2)))))
(_, _) -> -- fallback case
RS.fromVector sh1 (f (Right (RS.toVector arr1)) (Right (RS.toVector arr2)))