diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-06-10 10:02:59 +0200 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2024-06-10 16:17:09 +0200 |
commit | 205a20fd581bb7c5728fd457a15e4f78fbee9e75 (patch) | |
tree | f6669ea87b56dde0f6168c109b3d7d7fcbf06136 /src/Data/Array/Mixed/Internal/Arith.hs | |
parent | c211316a4ab43cf34d6567c6919a3922d5840ae0 (diff) |
Dot product
Diffstat (limited to 'src/Data/Array/Mixed/Internal/Arith.hs')
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 101 |
1 files changed, 100 insertions, 1 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index d2ad61f..6ecbbeb 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -19,8 +19,9 @@ import Data.List (sort) import Data.Vector.Storable qualified as VS import Data.Vector.Storable.Mutable qualified as VSM import Foreign.C.Types +import Foreign.Marshal.Alloc (alloca) import Foreign.Ptr -import Foreign.Storable (Storable) +import Foreign.Storable (Storable, peek, poke) import GHC.TypeLits import GHC.TypeNats qualified as TypeNats import Language.Haskell.TH @@ -217,6 +218,69 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec))) fextrem poutv (fromIntegral ndimsF) pshF pstridesF (ptrconv pvec) insertZeros replDims . map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outv +vectorDotprodOp :: (Num a, Storable a) + => (b -> a) + -> (Ptr a -> Ptr b) + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel + -> (Int64 -> Ptr b -> Ptr b -> IO b) -- ^ dotprod kernel + -> (Int64 -> Int64 -> Int64 -> Ptr b -> Int64 -> Int64 -> Ptr b -> IO b) -- ^ strided dotprod kernel + -> RS.Array 1 a -> RS.Array 1 a -> a +vectorDotprodOp valbackconv ptrconv fred fdot fdotstrided + (RS.A (RG.A [len1] (OI.T [stride1] offset1 vec1))) + (RS.A (RG.A [len2] (OI.T [stride2] offset2 vec2))) + | len1 /= len2 = error $ "vectorDotprodOp: lengths unequal: " ++ show len1 ++ " vs " ++ show len2 + | len1 == 0 = 0 -- if the arrays are empty, just return zero + | otherwise = case (stride1, stride2) of + (0, 0) -> -- replicated scalar * replicated scalar + fromIntegral len1 * (vec1 VS.! offset1) * (vec2 VS.! offset2) + (0, 1) -> -- replicated scalar * dense + dotScalarVector len1 ptrconv fred (vec1 VS.! offset1) (VS.slice offset2 len1 vec2) + (1, 0) -> -- dense * replicated scalar + dotScalarVector len1 ptrconv fred (vec2 VS.! offset2) (VS.slice offset1 len1 vec1) + (1, 1) -> -- dense * dense + dotVectorVector len1 valbackconv ptrconv fdot (VS.slice offset1 len1 vec1) (VS.slice offset2 len1 vec2) + (_, _) -> -- fallback case + dotVectorVectorStrided len1 valbackconv ptrconv fdotstrided offset1 stride1 vec1 offset2 stride2 vec2 +vectorDotprodOp _ _ _ _ _ _ _ = error "vectorDotprodOp: not one-dimensional?" + +{-# NOINLINE dotScalarVector #-} +dotScalarVector :: forall a b. (Num a, Storable a) + => Int -> (Ptr a -> Ptr b) + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel + -> a -> VS.Vector a -> a +dotScalarVector len ptrconv fred scalar vec = unsafePerformIO $ do + alloca @a $ \pout -> do + alloca @Int64 $ \pshape -> do + poke pshape (fromIntegral @Int @Int64 len) + alloca @Int64 $ \pstride -> do + poke pstride 1 + VS.unsafeWith vec $ \pvec -> + fred 1 pshape pstride (ptrconv pout) (ptrconv pvec) + res <- peek pout + return (scalar * res) + +{-# NOINLINE dotVectorVector #-} +dotVectorVector :: Storable a => Int -> (b -> a) -> (Ptr a -> Ptr b) + -> (Int64 -> Ptr b -> Ptr b -> IO b) -- ^ dotprod kernel + -> VS.Vector a -> VS.Vector a -> a +dotVectorVector len valbackconv ptrconv fdot vec1 vec2 = unsafePerformIO $ do + VS.unsafeWith vec1 $ \pvec1 -> + VS.unsafeWith vec2 $ \pvec2 -> + valbackconv <$> fdot (fromIntegral @Int @Int64 len) (ptrconv pvec1) (ptrconv pvec2) + +{-# NOINLINE dotVectorVectorStrided #-} +dotVectorVectorStrided :: Storable a => Int -> (b -> a) -> (Ptr a -> Ptr b) + -> (Int64 -> Int64 -> Int64 -> Ptr b -> Int64 -> Int64 -> Ptr b -> IO b) -- ^ dotprod kernel + -> Int -> Int -> VS.Vector a + -> Int -> Int -> VS.Vector a + -> a +dotVectorVectorStrided len valbackconv ptrconv fdot offset1 stride1 vec1 offset2 stride2 vec2 = unsafePerformIO $ do + VS.unsafeWith vec1 $ \pvec1 -> + VS.unsafeWith vec2 $ \pvec2 -> + valbackconv <$> fdot (fromIntegral @Int @Int64 len) + (fromIntegral offset1) (fromIntegral stride1) (ptrconv pvec1) + (fromIntegral offset2) (fromIntegral stride2) (ptrconv pvec2) + flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ()) -> Int64 -> Ptr a -> Ptr a -> a -> IO () flipOp f n out v s = f n out s v @@ -290,6 +354,17 @@ $(fmap concat . forM typesList $ \arithtype -> ,do body <- [| vectorExtremumOp id $c_op |] return $ FunD name [Clause [] (NormalB body) []]]) +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + name = mkName ("dotprodVector" ++ nameBase (atType arithtype)) + c_op = varE (mkName ("c_dotprod_" ++ atCName arithtype)) + c_op_strided = varE (mkName ("c_dotprod_" ++ atCName arithtype ++ "_strided")) + c_red_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM1))) + sequence [SigD name <$> + [t| RS.Array 1 $ttyp -> RS.Array 1 $ttyp -> $ttyp |] + ,do body <- [| vectorDotprodOp id id $c_red_op $c_op $c_op_strided |] + return $ FunD name [Clause [] (NormalB body) []]]) + -- This branch is ostensibly a runtime branch, but will (hopefully) be -- constant-folded away by GHC. intWidBranch1 :: forall i n. (FiniteBits i, Storable i) @@ -341,6 +416,21 @@ intWidBranchExtr fextr32 fextr64 | finiteBitSize (undefined :: i) == 64 = vectorExtremumOp @i @Int64 castPtr fextr64 | otherwise = error "Unsupported Int width" +intWidBranchDotprod :: forall i. (FiniteBits i, Storable i, Integral i) + => -- int32 + (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel + -> (Int64 -> Ptr Int32 -> Ptr Int32 -> IO Int32) -- ^ dotprod kernel + -> (Int64 -> Int64 -> Int64 -> Ptr Int32 -> Int64 -> Int64 -> Ptr Int32 -> IO Int32) -- ^ strided dotprod kernel + -- int64 + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64) -- ^ dotprod kernel + -> (Int64 -> Int64 -> Int64 -> Ptr Int64 -> Int64 -> Int64 -> Ptr Int64 -> IO Int64) -- ^ strided dotprod kernel + -> (RS.Array 1 i -> RS.Array 1 i -> i) +intWidBranchDotprod fred32 fdot32 fdot32strided fred64 fdot64 fdot64strided + | finiteBitSize (undefined :: i) == 32 = vectorDotprodOp @i @Int32 fromIntegral castPtr fred32 fdot32 fdot32strided + | finiteBitSize (undefined :: i) == 64 = vectorDotprodOp @i @Int64 fromIntegral castPtr fred64 fdot64 fdot64strided + | otherwise = error "Unsupported Int width" + class NumElt a where numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a @@ -352,6 +442,7 @@ class NumElt a where numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a numEltMinIndex :: RS.Array n a -> [Int] numEltMaxIndex :: RS.Array n a -> [Int] + numEltDotprod :: RS.Array 1 a -> RS.Array 1 a -> a instance NumElt Int32 where numEltAdd = addVectorInt32 @@ -364,6 +455,7 @@ instance NumElt Int32 where numEltProduct1Inner = product1VectorInt32 numEltMinIndex = minindexVectorInt32 numEltMaxIndex = maxindexVectorInt32 + numEltDotprod = dotprodVectorInt32 instance NumElt Int64 where numEltAdd = addVectorInt64 @@ -376,6 +468,7 @@ instance NumElt Int64 where numEltProduct1Inner = product1VectorInt64 numEltMinIndex = minindexVectorInt64 numEltMaxIndex = maxindexVectorInt64 + numEltDotprod = dotprodVectorInt64 instance NumElt Float where numEltAdd = addVectorFloat @@ -388,6 +481,7 @@ instance NumElt Float where numEltProduct1Inner = product1VectorFloat numEltMinIndex = minindexVectorFloat numEltMaxIndex = maxindexVectorFloat + numEltDotprod = dotprodVectorFloat instance NumElt Double where numEltAdd = addVectorDouble @@ -400,6 +494,7 @@ instance NumElt Double where numEltProduct1Inner = product1VectorDouble numEltMinIndex = minindexVectorDouble numEltMaxIndex = maxindexVectorDouble + numEltDotprod = dotprodVectorDouble instance NumElt Int where numEltAdd = intWidBranch2 @Int (+) @@ -422,6 +517,8 @@ instance NumElt Int where (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) numEltMinIndex = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64 numEltMaxIndex = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64 + numEltDotprod = intWidBranchDotprod @Int (c_reduce_i32 (aroEnum RO_SUM1)) c_dotprod_i32 c_dotprod_i32_strided + (c_reduce_i64 (aroEnum RO_SUM1)) c_dotprod_i64 c_dotprod_i64_strided instance NumElt CInt where numEltAdd = intWidBranch2 @CInt (+) @@ -444,6 +541,8 @@ instance NumElt CInt where (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1)) numEltMinIndex = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64 numEltMaxIndex = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64 + numEltDotprod = intWidBranchDotprod @CInt (c_reduce_i32 (aroEnum RO_SUM1)) c_dotprod_i32 c_dotprod_i32_strided + (c_reduce_i64 (aroEnum RO_SUM1)) c_dotprod_i64 c_dotprod_i64_strided class FloatElt a where floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a |