From f42b2b139ed8377bcf63d7a28db237350d3cb773 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 21 Jul 2024 18:58:33 +0200 Subject: arith: Respect offsets in dotprodinner --- src/Data/Array/Mixed/Internal/Arith.hs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'src/Data/Array') diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index fc26633..0ee6708 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -23,7 +23,7 @@ import Data.Vector.Storable.Mutable qualified as VSM import Foreign.C.Types import Foreign.Marshal.Alloc (alloca) import Foreign.Ptr -import Foreign.Storable (Storable, peek, poke) +import Foreign.Storable (Storable(sizeOf), peek, poke) import GHC.TypeLits import GHC.TypeNats qualified as TypeNats import Language.Haskell.TH @@ -341,7 +341,9 @@ vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner VS.unsafeWith vec1 $ \pvec1 -> VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides2)) $ \pstrides2 -> VS.unsafeWith vec2 $ \pvec2 -> - fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv) pstrides1 (ptrconv pvec1) pstrides2 (ptrconv pvec2) + fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv) + pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1)) + pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2)) RS.fromVector @_ @n (init sh1) <$> VS.unsafeFreeze outv {-# NOINLINE dotScalarVector #-} -- cgit v1.2.3-70-g09d2