aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-06-10 10:02:59 +0200
committerTom Smeding <t.j.smeding@uu.nl>2024-06-10 16:17:09 +0200
commit205a20fd581bb7c5728fd457a15e4f78fbee9e75 (patch)
treef6669ea87b56dde0f6168c109b3d7d7fcbf06136 /src/Data
parentc211316a4ab43cf34d6567c6919a3922d5840ae0 (diff)
Dot product
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs101
-rw-r--r--src/Data/Array/Mixed/Internal/Arith/Foreign.hs9
-rw-r--r--src/Data/Array/Nested.hs6
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs9
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs3
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs3
6 files changed, 127 insertions, 4 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index d2ad61f..6ecbbeb 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -19,8 +19,9 @@ import Data.List (sort)
import Data.Vector.Storable qualified as VS
import Data.Vector.Storable.Mutable qualified as VSM
import Foreign.C.Types
+import Foreign.Marshal.Alloc (alloca)
import Foreign.Ptr
-import Foreign.Storable (Storable)
+import Foreign.Storable (Storable, peek, poke)
import GHC.TypeLits
import GHC.TypeNats qualified as TypeNats
import Language.Haskell.TH
@@ -217,6 +218,69 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))
fextrem poutv (fromIntegral ndimsF) pshF pstridesF (ptrconv pvec)
insertZeros replDims . map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outv
+vectorDotprodOp :: (Num a, Storable a)
+ => (b -> a)
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> (Int64 -> Ptr b -> Ptr b -> IO b) -- ^ dotprod kernel
+ -> (Int64 -> Int64 -> Int64 -> Ptr b -> Int64 -> Int64 -> Ptr b -> IO b) -- ^ strided dotprod kernel
+ -> RS.Array 1 a -> RS.Array 1 a -> a
+vectorDotprodOp valbackconv ptrconv fred fdot fdotstrided
+ (RS.A (RG.A [len1] (OI.T [stride1] offset1 vec1)))
+ (RS.A (RG.A [len2] (OI.T [stride2] offset2 vec2)))
+ | len1 /= len2 = error $ "vectorDotprodOp: lengths unequal: " ++ show len1 ++ " vs " ++ show len2
+ | len1 == 0 = 0 -- if the arrays are empty, just return zero
+ | otherwise = case (stride1, stride2) of
+ (0, 0) -> -- replicated scalar * replicated scalar
+ fromIntegral len1 * (vec1 VS.! offset1) * (vec2 VS.! offset2)
+ (0, 1) -> -- replicated scalar * dense
+ dotScalarVector len1 ptrconv fred (vec1 VS.! offset1) (VS.slice offset2 len1 vec2)
+ (1, 0) -> -- dense * replicated scalar
+ dotScalarVector len1 ptrconv fred (vec2 VS.! offset2) (VS.slice offset1 len1 vec1)
+ (1, 1) -> -- dense * dense
+ dotVectorVector len1 valbackconv ptrconv fdot (VS.slice offset1 len1 vec1) (VS.slice offset2 len1 vec2)
+ (_, _) -> -- fallback case
+ dotVectorVectorStrided len1 valbackconv ptrconv fdotstrided offset1 stride1 vec1 offset2 stride2 vec2
+vectorDotprodOp _ _ _ _ _ _ _ = error "vectorDotprodOp: not one-dimensional?"
+
+{-# NOINLINE dotScalarVector #-}
+dotScalarVector :: forall a b. (Num a, Storable a)
+ => Int -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> a -> VS.Vector a -> a
+dotScalarVector len ptrconv fred scalar vec = unsafePerformIO $ do
+ alloca @a $ \pout -> do
+ alloca @Int64 $ \pshape -> do
+ poke pshape (fromIntegral @Int @Int64 len)
+ alloca @Int64 $ \pstride -> do
+ poke pstride 1
+ VS.unsafeWith vec $ \pvec ->
+ fred 1 pshape pstride (ptrconv pout) (ptrconv pvec)
+ res <- peek pout
+ return (scalar * res)
+
+{-# NOINLINE dotVectorVector #-}
+dotVectorVector :: Storable a => Int -> (b -> a) -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr b -> Ptr b -> IO b) -- ^ dotprod kernel
+ -> VS.Vector a -> VS.Vector a -> a
+dotVectorVector len valbackconv ptrconv fdot vec1 vec2 = unsafePerformIO $ do
+ VS.unsafeWith vec1 $ \pvec1 ->
+ VS.unsafeWith vec2 $ \pvec2 ->
+ valbackconv <$> fdot (fromIntegral @Int @Int64 len) (ptrconv pvec1) (ptrconv pvec2)
+
+{-# NOINLINE dotVectorVectorStrided #-}
+dotVectorVectorStrided :: Storable a => Int -> (b -> a) -> (Ptr a -> Ptr b)
+ -> (Int64 -> Int64 -> Int64 -> Ptr b -> Int64 -> Int64 -> Ptr b -> IO b) -- ^ dotprod kernel
+ -> Int -> Int -> VS.Vector a
+ -> Int -> Int -> VS.Vector a
+ -> a
+dotVectorVectorStrided len valbackconv ptrconv fdot offset1 stride1 vec1 offset2 stride2 vec2 = unsafePerformIO $ do
+ VS.unsafeWith vec1 $ \pvec1 ->
+ VS.unsafeWith vec2 $ \pvec2 ->
+ valbackconv <$> fdot (fromIntegral @Int @Int64 len)
+ (fromIntegral offset1) (fromIntegral stride1) (ptrconv pvec1)
+ (fromIntegral offset2) (fromIntegral stride2) (ptrconv pvec2)
+
flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())
-> Int64 -> Ptr a -> Ptr a -> a -> IO ()
flipOp f n out v s = f n out s v
@@ -290,6 +354,17 @@ $(fmap concat . forM typesList $ \arithtype ->
,do body <- [| vectorExtremumOp id $c_op |]
return $ FunD name [Clause [] (NormalB body) []]])
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ name = mkName ("dotprodVector" ++ nameBase (atType arithtype))
+ c_op = varE (mkName ("c_dotprod_" ++ atCName arithtype))
+ c_op_strided = varE (mkName ("c_dotprod_" ++ atCName arithtype ++ "_strided"))
+ c_red_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM1)))
+ sequence [SigD name <$>
+ [t| RS.Array 1 $ttyp -> RS.Array 1 $ttyp -> $ttyp |]
+ ,do body <- [| vectorDotprodOp id id $c_red_op $c_op $c_op_strided |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
-- This branch is ostensibly a runtime branch, but will (hopefully) be
-- constant-folded away by GHC.
intWidBranch1 :: forall i n. (FiniteBits i, Storable i)
@@ -341,6 +416,21 @@ intWidBranchExtr fextr32 fextr64
| finiteBitSize (undefined :: i) == 64 = vectorExtremumOp @i @Int64 castPtr fextr64
| otherwise = error "Unsupported Int width"
+intWidBranchDotprod :: forall i. (FiniteBits i, Storable i, Integral i)
+ => -- int32
+ (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel
+ -> (Int64 -> Ptr Int32 -> Ptr Int32 -> IO Int32) -- ^ dotprod kernel
+ -> (Int64 -> Int64 -> Int64 -> Ptr Int32 -> Int64 -> Int64 -> Ptr Int32 -> IO Int32) -- ^ strided dotprod kernel
+ -- int64
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64) -- ^ dotprod kernel
+ -> (Int64 -> Int64 -> Int64 -> Ptr Int64 -> Int64 -> Int64 -> Ptr Int64 -> IO Int64) -- ^ strided dotprod kernel
+ -> (RS.Array 1 i -> RS.Array 1 i -> i)
+intWidBranchDotprod fred32 fdot32 fdot32strided fred64 fdot64 fdot64strided
+ | finiteBitSize (undefined :: i) == 32 = vectorDotprodOp @i @Int32 fromIntegral castPtr fred32 fdot32 fdot32strided
+ | finiteBitSize (undefined :: i) == 64 = vectorDotprodOp @i @Int64 fromIntegral castPtr fred64 fdot64 fdot64strided
+ | otherwise = error "Unsupported Int width"
+
class NumElt a where
numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
@@ -352,6 +442,7 @@ class NumElt a where
numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
numEltMinIndex :: RS.Array n a -> [Int]
numEltMaxIndex :: RS.Array n a -> [Int]
+ numEltDotprod :: RS.Array 1 a -> RS.Array 1 a -> a
instance NumElt Int32 where
numEltAdd = addVectorInt32
@@ -364,6 +455,7 @@ instance NumElt Int32 where
numEltProduct1Inner = product1VectorInt32
numEltMinIndex = minindexVectorInt32
numEltMaxIndex = maxindexVectorInt32
+ numEltDotprod = dotprodVectorInt32
instance NumElt Int64 where
numEltAdd = addVectorInt64
@@ -376,6 +468,7 @@ instance NumElt Int64 where
numEltProduct1Inner = product1VectorInt64
numEltMinIndex = minindexVectorInt64
numEltMaxIndex = maxindexVectorInt64
+ numEltDotprod = dotprodVectorInt64
instance NumElt Float where
numEltAdd = addVectorFloat
@@ -388,6 +481,7 @@ instance NumElt Float where
numEltProduct1Inner = product1VectorFloat
numEltMinIndex = minindexVectorFloat
numEltMaxIndex = maxindexVectorFloat
+ numEltDotprod = dotprodVectorFloat
instance NumElt Double where
numEltAdd = addVectorDouble
@@ -400,6 +494,7 @@ instance NumElt Double where
numEltProduct1Inner = product1VectorDouble
numEltMinIndex = minindexVectorDouble
numEltMaxIndex = maxindexVectorDouble
+ numEltDotprod = dotprodVectorDouble
instance NumElt Int where
numEltAdd = intWidBranch2 @Int (+)
@@ -422,6 +517,8 @@ instance NumElt Int where
(c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
numEltMinIndex = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64
numEltMaxIndex = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64
+ numEltDotprod = intWidBranchDotprod @Int (c_reduce_i32 (aroEnum RO_SUM1)) c_dotprod_i32 c_dotprod_i32_strided
+ (c_reduce_i64 (aroEnum RO_SUM1)) c_dotprod_i64 c_dotprod_i64_strided
instance NumElt CInt where
numEltAdd = intWidBranch2 @CInt (+)
@@ -444,6 +541,8 @@ instance NumElt CInt where
(c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
numEltMinIndex = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64
numEltMaxIndex = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64
+ numEltDotprod = intWidBranchDotprod @CInt (c_reduce_i32 (aroEnum RO_SUM1)) c_dotprod_i32 c_dotprod_i32_strided
+ (c_reduce_i64 (aroEnum RO_SUM1)) c_dotprod_i64 c_dotprod_i64_strided
class FloatElt a where
floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
index 0bd72e8..96a85d1 100644
--- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
+++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
@@ -60,3 +60,12 @@ $(fmap concat . forM typesList $ \arithtype ->
let base = "extremum_" ++ fname ++ "_" ++ atCName arithtype
pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
[t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ let base = "dotprod_" ++ atCName arithtype
+ sequence
+ [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
+ [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO $ttyp |]
+ ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_strided") (mkName ("c_" ++ base ++ "_strided")) <$>
+ [t| Int64 -> Int64 -> Int64 -> Ptr $ttyp -> Int64 -> Int64 -> Ptr $ttyp -> IO $ttyp |]])
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 4a23a39..51f9fc0 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -11,7 +11,7 @@ module Data.Array.Nested (
rrerank,
rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1,
rslice, rrev1, rreshape, riota,
- rminIndexPrim, rmaxIndexPrim,
+ rminIndexPrim, rmaxIndexPrim, rdot,
rnest, runNest,
-- ** Lifting orthotope operations to 'Ranked' arrays
rlift, rlift2,
@@ -31,7 +31,7 @@ module Data.Array.Nested (
srerank,
sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1,
sslice, srev1, sreshape, siota,
- sminIndexPrim, smaxIndexPrim,
+ sminIndexPrim, smaxIndexPrim, sdot,
snest, sunNest,
-- ** Lifting orthotope operations to 'Shaped' arrays
slift, slift2,
@@ -48,7 +48,7 @@ module Data.Array.Nested (
mrerank,
mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1,
mslice, mrev1, mreshape, miota,
- mminIndexPrim, mmaxIndexPrim,
+ mminIndexPrim, mmaxIndexPrim, mdot,
mnest, munNest,
-- ** Lifting orthotope operations to 'Mixed' arrays
mlift, mlift2,
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index a0de08b..2c99487 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -49,6 +49,11 @@ import Data.Array.Mixed.Types
import Data.Array.Mixed.Permutation
import Data.Array.Mixed.Lemmas
+-- TODO:
+-- dotprod, sumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a
+-- After benchmarking: matmul and matvec
+
+
-- Invariant in the API
-- ====================
@@ -798,6 +803,10 @@ mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
ixxFromList (ssxFromShape sh) (numEltMaxIndex arr)
+mdot :: (PrimElt a, NumElt a) => Mixed '[n] a -> Mixed '[n] a -> a
+mdot (toPrimitive -> M_Primitive _ (XArray arr1)) (toPrimitive -> M_Primitive _ (XArray arr2)) =
+ numEltDotprod arr1 arr2
+
mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a)
mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr)
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index 589f0c1..c67e892 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -461,6 +461,9 @@ rmaxIndexPrim rarr@(Ranked arr)
| Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
= ixCvtXR (mmaxIndexPrim arr)
+rdot :: (PrimElt a, NumElt a) => Ranked 1 a -> Ranked 1 a -> a
+rdot = coerce mdot
+
rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a)
rtoXArrayPrimP (Ranked arr) = first shCvtXR' (mtoXArrayPrimP arr)
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index ca3fd45..9320495 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -381,6 +381,9 @@ sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminInde
smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh
smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr)
+sdot :: (PrimElt a, NumElt a) => Shaped '[n] a -> Shaped '[n] a -> a
+sdot = coerce mdot
+
stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a)
stoXArrayPrimP (Shaped arr) = first shCvtXS' (mtoXArrayPrimP arr)