aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-13 11:56:17 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-13 11:56:17 +0100
commit08e139de6bfeba885cacec1ad5600b85cd0f0947 (patch)
tree78dd39cf0e7774a8e794388e9b2572cabe3fccce
parent87b479d2d09eb7ef37100f883400f7bd366cdda7 (diff)
arith: Correct rank arguments to C wrapper functions
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs6
1 files changed, 3 insertions, 3 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index c940914..9c560d6 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -100,7 +100,7 @@ liftVEltwise2 sn@SNat valconv ptrconv f_ss f_sv f_vs f_vv
(Just (_, 1), Just (blockOff, blockSz)) -> -- scalar * dense
let arr2' = RS.fromVector [blockSz] (VS.slice blockOff blockSz vec2)
- RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinarySV sn valconv ptrconv f_sv (vec1 VS.! offset1) arr2'
+ RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinarySV (SNat @1) valconv ptrconv f_sv (vec1 VS.! offset1) arr2'
in RS.A (RG.A sh1 (OI.T strides2 (offset2 - blockOff) resvec))
(Just (_, 1), Nothing) -> -- scalar * array
@@ -108,7 +108,7 @@ liftVEltwise2 sn@SNat valconv ptrconv f_ss f_sv f_vs f_vv
(Just (blockOff, blockSz), Just (_, 1)) -> -- dense * scalar
let arr1' = RS.fromVector [blockSz] (VS.slice blockOff blockSz vec1)
- RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinaryVS sn valconv ptrconv f_vs arr1' (vec2 VS.! offset2)
+ RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinaryVS (SNat @1) valconv ptrconv f_vs arr1' (vec2 VS.! offset2)
in RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff) resvec))
(Nothing, Just (_, 1)) -> -- array * scalar
@@ -120,7 +120,7 @@ liftVEltwise2 sn@SNat valconv ptrconv f_ss f_sv f_vs f_vv
-> -- dense * dense but the strides match
let arr1' = RS.fromVector [blockSz1] (VS.slice blockOff1 blockSz1 vec1)
arr2' = RS.fromVector [blockSz1] (VS.slice blockOff2 blockSz2 vec2)
- RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinaryVV sn ptrconv f_vv arr1' arr2'
+ RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinaryVV (SNat @1) ptrconv f_vv arr1' arr2'
in RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff1) resvec))
(_, _) -> -- fallback case