diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-06-10 10:02:59 +0200 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2024-06-10 16:17:09 +0200 |
commit | 205a20fd581bb7c5728fd457a15e4f78fbee9e75 (patch) | |
tree | f6669ea87b56dde0f6168c109b3d7d7fcbf06136 /src | |
parent | c211316a4ab43cf34d6567c6919a3922d5840ae0 (diff) |
Dot product
Diffstat (limited to 'src')
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 101 | ||||
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 9 | ||||
-rw-r--r-- | src/Data/Array/Nested.hs | 6 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 9 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 3 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 3 |
6 files changed, 127 insertions, 4 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index d2ad61f..6ecbbeb 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -19,8 +19,9 @@ import Data.List (sort) import Data.Vector.Storable qualified as VS import Data.Vector.Storable.Mutable qualified as VSM import Foreign.C.Types +import Foreign.Marshal.Alloc (alloca) import Foreign.Ptr -import Foreign.Storable (Storable) +import Foreign.Storable (Storable, peek, poke) import GHC.TypeLits import GHC.TypeNats qualified as TypeNats import Language.Haskell.TH @@ -217,6 +218,69 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec))) fextrem poutv (fromIntegral ndimsF) pshF pstridesF (ptrconv pvec) insertZeros replDims . map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outv +vectorDotprodOp :: (Num a, Storable a) + => (b -> a) + -> (Ptr a -> Ptr b) + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel + -> (Int64 -> Ptr b -> Ptr b -> IO b) -- ^ dotprod kernel + -> (Int64 -> Int64 -> Int64 -> Ptr b -> Int64 -> Int64 -> Ptr b -> IO b) -- ^ strided dotprod kernel + -> RS.Array 1 a -> RS.Array 1 a -> a +vectorDotprodOp valbackconv ptrconv fred fdot fdotstrided + (RS.A (RG.A [len1] (OI.T [stride1] offset1 vec1))) + (RS.A (RG.A [len2] (OI.T [stride2] offset2 vec2))) + | len1 /= len2 = error $ "vectorDotprodOp: lengths unequal: " ++ show len1 ++ " vs " ++ show len2 + | len1 == 0 = 0 -- if the arrays are empty, just return zero + | otherwise = case (stride1, stride2) of + (0, 0) -> -- replicated scalar * replicated scalar + fromIntegral len1 * (vec1 VS.! offset1) * (vec2 VS.! offset2) + (0, 1) -> -- replicated scalar * dense + dotScalarVector len1 ptrconv fred (vec1 VS.! offset1) (VS.slice offset2 len1 vec2) + (1, 0) -> -- dense * replicated scalar + dotScalarVector len1 ptrconv fred (vec2 VS.! offset2) (VS.slice offset1 len1 vec1) + (1, 1) -> -- dense * dense + dotVectorVector len1 valbackconv ptrconv fdot (VS.slice offset1 len1 vec1) (VS.slice offset2 len1 vec2) + (_, _) -> -- fallback case + dotVectorVectorStrided len1 valbackconv ptrconv fdotstrided offset1 stride1 vec1 offset2 stride2 vec2 +vectorDotprodOp _ _ _ _ _ _ _ = error "vectorDotprodOp: not one-dimensional?" + +{-# NOINLINE dotScalarVector #-} +dotScalarVector :: forall a b. (Num a, Storable a) + => Int -> (Ptr a -> Ptr b) + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel + -> a -> VS.Vector a -> a +dotScalarVector len ptrconv fred scalar vec = unsafePerformIO $ do + alloca @a $ \pout -> do + alloca @Int64 $ \pshape -> do + poke pshape (fromIntegral @Int @Int64 len) + alloca @Int64 $ \pstride -> do + poke pstride 1 + VS.unsafeWith vec $ \pvec -> + fred 1 pshape pstride (ptrconv pout) (ptrconv pvec) + res <- peek pout + return (scalar * res) + +{-# NOINLINE dotVectorVector #-} +dotVectorVector :: Storable a => Int -> (b -> a) -> (Ptr a -> Ptr b) + -> (Int64 -> Ptr b -> Ptr b -> IO b) -- ^ dotprod kernel + -> VS.Vector a -> VS.Vector a -> a +dotVectorVector len valbackconv ptrconv fdot vec1 vec2 = unsafePerformIO $ do + VS.unsafeWith vec1 $ \pvec1 -> + VS.unsafeWith vec2 $ \pvec2 -> + valbackconv <$> fdot (fromIntegral @Int @Int64 len) (ptrconv pvec1) (ptrconv pvec2) + +{-# NOINLINE dotVectorVectorStrided #-} +dotVectorVectorStrided :: Storable a => Int -> (b -> a) -> (Ptr a -> Ptr b) + -> (Int64 -> Int64 -> Int64 -> Ptr b -> Int64 -> Int64 -> Ptr b -> IO b) -- ^ dotprod kernel + -> Int -> Int -> VS.Vector a + -> Int -> Int -> VS.Vector a + -> a +dotVectorVectorStrided len valbackconv ptrconv fdot offset1 stride1 vec1 offset2 stride2 vec2 = unsafePerformIO $ do + VS.unsafeWith vec1 $ \pvec1 -> + VS.unsafeWith vec2 $ \pvec2 -> + valbackconv <$> fdot (fromIntegral @Int @Int64 len) + (fromIntegral offset1) (fromIntegral stride1) (ptrconv pvec1) + (fromIntegral offset2) (fromIntegral stride2) (ptrconv pvec2) + flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ()) -> Int64 -> Ptr a -> Ptr a -> a -> IO () flipOp f n out v s = f n out s v @@ -290,6 +354,17 @@ $(fmap concat . forM typesList $ \arithtype -> ,do body <- [| vectorExtremumOp id $c_op |] return $ FunD name [Clause [] (NormalB body) []]]) +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + name = mkName ("dotprodVector" ++ nameBase (atType arithtype)) + c_op = varE (mkName ("c_dotprod_" ++ atCName arithtype)) + c_op_strided = varE (mkName ("c_dotprod_" ++ atCName arithtype ++ "_strided")) + c_red_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM1))) + sequence [SigD name <$> + [t| RS.Array 1 $ttyp -> RS.Array 1 $ttyp -> $ttyp |] + ,do body <- [| vectorDotprodOp id id $c_red_op $c_op $c_op_strided |] + return $ FunD name [Clause [] (NormalB body) []]]) + -- This branch is ostensibly a runtime branch, but will (hopefully) be -- constant-folded away by GHC. intWidBranch1 :: forall i n. (FiniteBits i, Storable i) @@ -341,6 +416,21 @@ intWidBranchExtr fextr32 fextr64 | finiteBitSize (undefined :: i) == 64 = vectorExtremumOp @i @Int64 castPtr fextr64 | otherwise = error "Unsupported Int width" +intWidBranchDotprod :: forall i. (FiniteBits i, Storable i, Integral i) + => -- int32 + (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel + -> (Int64 -> Ptr Int32 -> Ptr Int32 -> IO Int32) -- ^ dotprod kernel + -> (Int64 -> Int64 -> Int64 -> Ptr Int32 -> Int64 -> Int64 -> Ptr Int32 -> IO Int32) -- ^ strided dotprod kernel + -- int64 + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64) -- ^ dotprod kernel + -> (Int64 -> Int64 -> Int64 -> Ptr Int64 -> Int64 -> Int64 -> Ptr Int64 -> IO Int64) -- ^ strided dotprod kernel + -> (RS.Array 1 i -> RS.Array 1 i -> i) +intWidBranchDotprod fred32 fdot32 fdot32strided fred64 fdot64 fdot64strided + | finiteBitSize (undefined :: i) == 32 = vectorDotprodOp @i @Int32 fromIntegral castPtr fred32 fdot32 fdot32strided + | finiteBitSize (undefined :: i) == 64 = vectorDotprodOp @i @Int64 fromIntegral castPtr fred64 fdot64 fdot64strided + | otherwise = error "Unsupported Int width" + class NumElt a where numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a @@ -352,6 +442,7 @@ class NumElt a where numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a numEltMinIndex :: RS.Array n a -> [Int] numEltMaxIndex :: RS.Array n a -> [Int] + numEltDotprod :: RS.Array 1 a -> RS.Array 1 a -> a instance NumElt Int32 where numEltAdd = addVectorInt32 @@ -364,6 +455,7 @@ instance NumElt Int32 where numEltProduct1Inner = product1VectorInt32 numEltMinIndex = minindexVectorInt32 numEltMaxIndex = maxindexVectorInt32 + numEltDotprod = dotprodVectorInt32 instance NumElt Int64 where numEltAdd = addVectorInt64 @@ -376,6 +468,7 @@ instance NumElt Int64 where numEltProduct1Inner = product1VectorInt64 numEltMinIndex = minindexVectorInt64 numEltMaxIndex = maxindexVectorInt64 + numEltDotprod = dotprodVectorInt64 instance NumElt Float where numEltAdd = addVectorFloat @@ -388,6 +481,7 @@ instance NumElt Float where numEltProduct1Inner = product1VectorFloat numEltMinIndex = minindexVectorFloat numEltMaxIndex = maxindexVectorFloat + numEltDotprod = dotprodVectorFloat instance NumElt Double where numEltAdd = addVectorDouble @@ -400,6 +494,7 @@ instance NumElt Double where numEltProduct1Inner = product1VectorDouble numEltMinIndex = minindexVectorDouble numEltMaxIndex = maxindexVectorDouble + numEltDotprod = dotprodVectorDouble instance NumElt Int where numEltAdd = intWidBranch2 @Int (+) @@ -422,6 +517,8 @@ instance NumElt Int where (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) numEltMinIndex = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64 numEltMaxIndex = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64 + numEltDotprod = intWidBranchDotprod @Int (c_reduce_i32 (aroEnum RO_SUM1)) c_dotprod_i32 c_dotprod_i32_strided + (c_reduce_i64 (aroEnum RO_SUM1)) c_dotprod_i64 c_dotprod_i64_strided instance NumElt CInt where numEltAdd = intWidBranch2 @CInt (+) @@ -444,6 +541,8 @@ instance NumElt CInt where (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) numEltMinIndex = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64 numEltMaxIndex = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64 + numEltDotprod = intWidBranchDotprod @CInt (c_reduce_i32 (aroEnum RO_SUM1)) c_dotprod_i32 c_dotprod_i32_strided + (c_reduce_i64 (aroEnum RO_SUM1)) c_dotprod_i64 c_dotprod_i64_strided class FloatElt a where floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs index 0bd72e8..96a85d1 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -60,3 +60,12 @@ $(fmap concat . forM typesList $ \arithtype -> let base = "extremum_" ++ fname ++ "_" ++ atCName arithtype pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + let base = "dotprod_" ++ atCName arithtype + sequence + [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> + [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO $ttyp |] + ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_strided") (mkName ("c_" ++ base ++ "_strided")) <$> + [t| Int64 -> Int64 -> Int64 -> Ptr $ttyp -> Int64 -> Int64 -> Ptr $ttyp -> IO $ttyp |]]) diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 4a23a39..51f9fc0 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -11,7 +11,7 @@ module Data.Array.Nested ( rrerank, rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1, rslice, rrev1, rreshape, riota, - rminIndexPrim, rmaxIndexPrim, + rminIndexPrim, rmaxIndexPrim, rdot, rnest, runNest, -- ** Lifting orthotope operations to 'Ranked' arrays rlift, rlift2, @@ -31,7 +31,7 @@ module Data.Array.Nested ( srerank, sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1, sslice, srev1, sreshape, siota, - sminIndexPrim, smaxIndexPrim, + sminIndexPrim, smaxIndexPrim, sdot, snest, sunNest, -- ** Lifting orthotope operations to 'Shaped' arrays slift, slift2, @@ -48,7 +48,7 @@ module Data.Array.Nested ( mrerank, mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1, mslice, mrev1, mreshape, miota, - mminIndexPrim, mmaxIndexPrim, + mminIndexPrim, mmaxIndexPrim, mdot, mnest, munNest, -- ** Lifting orthotope operations to 'Mixed' arrays mlift, mlift2, diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index a0de08b..2c99487 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -49,6 +49,11 @@ import Data.Array.Mixed.Types import Data.Array.Mixed.Permutation import Data.Array.Mixed.Lemmas +-- TODO: +-- dotprod, sumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a +-- After benchmarking: matmul and matvec + + -- Invariant in the API -- ==================== @@ -798,6 +803,10 @@ mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = ixxFromList (ssxFromShape sh) (numEltMaxIndex arr) +mdot :: (PrimElt a, NumElt a) => Mixed '[n] a -> Mixed '[n] a -> a +mdot (toPrimitive -> M_Primitive _ (XArray arr1)) (toPrimitive -> M_Primitive _ (XArray arr2)) = + numEltDotprod arr1 arr2 + mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a) mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr) diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index 589f0c1..c67e892 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -461,6 +461,9 @@ rmaxIndexPrim rarr@(Ranked arr) | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) = ixCvtXR (mmaxIndexPrim arr) +rdot :: (PrimElt a, NumElt a) => Ranked 1 a -> Ranked 1 a -> a +rdot = coerce mdot + rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a) rtoXArrayPrimP (Ranked arr) = first shCvtXR' (mtoXArrayPrimP arr) diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index ca3fd45..9320495 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -381,6 +381,9 @@ sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminInde smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) +sdot :: (PrimElt a, NumElt a) => Shaped '[n] a -> Shaped '[n] a -> a +sdot = coerce mdot + stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) stoXArrayPrimP (Shaped arr) = first shCvtXS' (mtoXArrayPrimP arr) |