diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-13 11:56:17 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-13 11:56:17 +0100 |
commit | 08e139de6bfeba885cacec1ad5600b85cd0f0947 (patch) | |
tree | 78dd39cf0e7774a8e794388e9b2572cabe3fccce | |
parent | 87b479d2d09eb7ef37100f883400f7bd366cdda7 (diff) |
arith: Correct rank arguments to C wrapper functions
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index c940914..9c560d6 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -100,7 +100,7 @@ liftVEltwise2 sn@SNat valconv ptrconv f_ss f_sv f_vs f_vv (Just (_, 1), Just (blockOff, blockSz)) -> -- scalar * dense let arr2' = RS.fromVector [blockSz] (VS.slice blockOff blockSz vec2) - RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinarySV sn valconv ptrconv f_sv (vec1 VS.! offset1) arr2' + RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinarySV (SNat @1) valconv ptrconv f_sv (vec1 VS.! offset1) arr2' in RS.A (RG.A sh1 (OI.T strides2 (offset2 - blockOff) resvec)) (Just (_, 1), Nothing) -> -- scalar * array @@ -108,7 +108,7 @@ liftVEltwise2 sn@SNat valconv ptrconv f_ss f_sv f_vs f_vv (Just (blockOff, blockSz), Just (_, 1)) -> -- dense * scalar let arr1' = RS.fromVector [blockSz] (VS.slice blockOff blockSz vec1) - RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinaryVS sn valconv ptrconv f_vs arr1' (vec2 VS.! offset2) + RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinaryVS (SNat @1) valconv ptrconv f_vs arr1' (vec2 VS.! offset2) in RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff) resvec)) (Nothing, Just (_, 1)) -> -- array * scalar @@ -120,7 +120,7 @@ liftVEltwise2 sn@SNat valconv ptrconv f_ss f_sv f_vs f_vv -> -- dense * dense but the strides match let arr1' = RS.fromVector [blockSz1] (VS.slice blockOff1 blockSz1 vec1) arr2' = RS.fromVector [blockSz1] (VS.slice blockOff2 blockSz2 vec2) - RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinaryVV sn ptrconv f_vv arr1' arr2' + RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinaryVV (SNat @1) ptrconv f_vv arr1' arr2' in RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff1) resvec)) (_, _) -> -- fallback case |