diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-06-10 10:02:59 +0200 | 
|---|---|---|
| committer | Tom Smeding <t.j.smeding@uu.nl> | 2024-06-10 16:17:09 +0200 | 
| commit | 205a20fd581bb7c5728fd457a15e4f78fbee9e75 (patch) | |
| tree | f6669ea87b56dde0f6168c109b3d7d7fcbf06136 /src/Data/Array/Mixed | |
| parent | c211316a4ab43cf34d6567c6919a3922d5840ae0 (diff) | |
Dot product
Diffstat (limited to 'src/Data/Array/Mixed')
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 101 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith/Foreign.hs | 9 | 
2 files changed, 109 insertions, 1 deletions
| diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index d2ad61f..6ecbbeb 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -19,8 +19,9 @@ import Data.List (sort)  import Data.Vector.Storable qualified as VS  import Data.Vector.Storable.Mutable qualified as VSM  import Foreign.C.Types +import Foreign.Marshal.Alloc (alloca)  import Foreign.Ptr -import Foreign.Storable (Storable) +import Foreign.Storable (Storable, peek, poke)  import GHC.TypeLits  import GHC.TypeNats qualified as TypeNats  import Language.Haskell.TH @@ -217,6 +218,69 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))                     fextrem poutv (fromIntegral ndimsF) pshF pstridesF (ptrconv pvec)             insertZeros replDims . map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outv +vectorDotprodOp :: (Num a, Storable a) +                => (b -> a) +                -> (Ptr a -> Ptr b) +                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ())  -- ^ reduction kernel +                -> (Int64 -> Ptr b -> Ptr b -> IO b)  -- ^ dotprod kernel +                -> (Int64 -> Int64 -> Int64 -> Ptr b -> Int64 -> Int64 -> Ptr b -> IO b)  -- ^ strided dotprod kernel +                -> RS.Array 1 a -> RS.Array 1 a -> a +vectorDotprodOp valbackconv ptrconv fred fdot fdotstrided +    (RS.A (RG.A [len1] (OI.T [stride1] offset1 vec1))) +    (RS.A (RG.A [len2] (OI.T [stride2] offset2 vec2))) +  | len1 /= len2 = error $ "vectorDotprodOp: lengths unequal: " ++ show len1 ++ " vs " ++ show len2 +  | len1 == 0 = 0  -- if the arrays are empty, just return zero +  | otherwise = case (stride1, stride2) of +      (0, 0) ->  -- replicated scalar * replicated scalar +        fromIntegral len1 * (vec1 VS.! offset1) * (vec2 VS.! offset2) +      (0, 1) ->  -- replicated scalar * dense +        dotScalarVector len1 ptrconv fred (vec1 VS.! offset1) (VS.slice offset2 len1 vec2) +      (1, 0) ->  -- dense * replicated scalar +        dotScalarVector len1 ptrconv fred (vec2 VS.! offset2) (VS.slice offset1 len1 vec1) +      (1, 1) ->  -- dense * dense +        dotVectorVector len1 valbackconv ptrconv fdot (VS.slice offset1 len1 vec1) (VS.slice offset2 len1 vec2) +      (_, _) ->  -- fallback case +        dotVectorVectorStrided len1 valbackconv ptrconv fdotstrided offset1 stride1 vec1 offset2 stride2 vec2 +vectorDotprodOp _ _ _ _ _ _ _ = error "vectorDotprodOp: not one-dimensional?" + +{-# NOINLINE dotScalarVector #-} +dotScalarVector :: forall a b. (Num a, Storable a) +                => Int -> (Ptr a -> Ptr b) +                -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ())  -- ^ reduction kernel +                -> a -> VS.Vector a -> a +dotScalarVector len ptrconv fred scalar vec = unsafePerformIO $ do +  alloca @a $ \pout -> do +    alloca @Int64 $ \pshape -> do +      poke pshape (fromIntegral @Int @Int64 len) +      alloca @Int64 $ \pstride -> do +        poke pstride 1 +        VS.unsafeWith vec $ \pvec -> +          fred 1 pshape pstride (ptrconv pout) (ptrconv pvec) +    res <- peek pout +    return (scalar * res) + +{-# NOINLINE dotVectorVector #-} +dotVectorVector :: Storable a => Int -> (b -> a) -> (Ptr a -> Ptr b) +                -> (Int64 -> Ptr b -> Ptr b -> IO b)  -- ^ dotprod kernel +                -> VS.Vector a -> VS.Vector a -> a +dotVectorVector len valbackconv ptrconv fdot vec1 vec2 = unsafePerformIO $ do +  VS.unsafeWith vec1 $ \pvec1 -> +    VS.unsafeWith vec2 $ \pvec2 -> +      valbackconv <$> fdot (fromIntegral @Int @Int64 len) (ptrconv pvec1) (ptrconv pvec2) + +{-# NOINLINE dotVectorVectorStrided #-} +dotVectorVectorStrided :: Storable a => Int -> (b -> a) -> (Ptr a -> Ptr b) +                       -> (Int64 -> Int64 -> Int64 -> Ptr b -> Int64 -> Int64 -> Ptr b -> IO b)  -- ^ dotprod kernel +                       -> Int -> Int -> VS.Vector a +                       -> Int -> Int -> VS.Vector a +                       -> a +dotVectorVectorStrided len valbackconv ptrconv fdot offset1 stride1 vec1 offset2 stride2 vec2 = unsafePerformIO $ do +  VS.unsafeWith vec1 $ \pvec1 -> +    VS.unsafeWith vec2 $ \pvec2 -> +      valbackconv <$> fdot (fromIntegral @Int @Int64 len) +                           (fromIntegral offset1) (fromIntegral stride1) (ptrconv pvec1) +                           (fromIntegral offset2) (fromIntegral stride2) (ptrconv pvec2) +  flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())         ->  Int64 -> Ptr a -> Ptr a -> a -> IO ()  flipOp f n out v s = f n out s v @@ -290,6 +354,17 @@ $(fmap concat . forM typesList $ \arithtype ->                 ,do body <- [| vectorExtremumOp id $c_op |]                     return $ FunD name [Clause [] (NormalB body) []]]) +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +        name = mkName ("dotprodVector" ++ nameBase (atType arithtype)) +        c_op = varE (mkName ("c_dotprod_" ++ atCName arithtype)) +        c_op_strided = varE (mkName ("c_dotprod_" ++ atCName arithtype ++ "_strided")) +        c_red_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM1))) +    sequence [SigD name <$> +                   [t| RS.Array 1 $ttyp -> RS.Array 1 $ttyp -> $ttyp |] +             ,do body <- [| vectorDotprodOp id id $c_red_op $c_op $c_op_strided |] +                 return $ FunD name [Clause [] (NormalB body) []]]) +  -- This branch is ostensibly a runtime branch, but will (hopefully) be  -- constant-folded away by GHC.  intWidBranch1 :: forall i n. (FiniteBits i, Storable i) @@ -341,6 +416,21 @@ intWidBranchExtr fextr32 fextr64    | finiteBitSize (undefined :: i) == 64 = vectorExtremumOp @i @Int64 castPtr fextr64    | otherwise = error "Unsupported Int width" +intWidBranchDotprod :: forall i. (FiniteBits i, Storable i, Integral i) +                    => -- int32 +                       (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ())  -- ^ reduction kernel +                    -> (Int64 -> Ptr Int32 -> Ptr Int32 -> IO Int32)  -- ^ dotprod kernel +                    -> (Int64 -> Int64 -> Int64 -> Ptr Int32 -> Int64 -> Int64 -> Ptr Int32 -> IO Int32)  -- ^ strided dotprod kernel +                       -- int64 +                    -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())  -- ^ reduction kernel +                    -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64)  -- ^ dotprod kernel +                    -> (Int64 -> Int64 -> Int64 -> Ptr Int64 -> Int64 -> Int64 -> Ptr Int64 -> IO Int64)  -- ^ strided dotprod kernel +                    -> (RS.Array 1 i -> RS.Array 1 i -> i) +intWidBranchDotprod fred32 fdot32 fdot32strided fred64 fdot64 fdot64strided +  | finiteBitSize (undefined :: i) == 32 = vectorDotprodOp @i @Int32 fromIntegral castPtr fred32 fdot32 fdot32strided +  | finiteBitSize (undefined :: i) == 64 = vectorDotprodOp @i @Int64 fromIntegral castPtr fred64 fdot64 fdot64strided +  | otherwise = error "Unsupported Int width" +  class NumElt a where    numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a    numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a @@ -352,6 +442,7 @@ class NumElt a where    numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a    numEltMinIndex :: RS.Array n a -> [Int]    numEltMaxIndex :: RS.Array n a -> [Int] +  numEltDotprod :: RS.Array 1 a -> RS.Array 1 a -> a  instance NumElt Int32 where    numEltAdd = addVectorInt32 @@ -364,6 +455,7 @@ instance NumElt Int32 where    numEltProduct1Inner = product1VectorInt32    numEltMinIndex = minindexVectorInt32    numEltMaxIndex = maxindexVectorInt32 +  numEltDotprod = dotprodVectorInt32  instance NumElt Int64 where    numEltAdd = addVectorInt64 @@ -376,6 +468,7 @@ instance NumElt Int64 where    numEltProduct1Inner = product1VectorInt64    numEltMinIndex = minindexVectorInt64    numEltMaxIndex = maxindexVectorInt64 +  numEltDotprod = dotprodVectorInt64  instance NumElt Float where    numEltAdd = addVectorFloat @@ -388,6 +481,7 @@ instance NumElt Float where    numEltProduct1Inner = product1VectorFloat    numEltMinIndex = minindexVectorFloat    numEltMaxIndex = maxindexVectorFloat +  numEltDotprod = dotprodVectorFloat  instance NumElt Double where    numEltAdd = addVectorDouble @@ -400,6 +494,7 @@ instance NumElt Double where    numEltProduct1Inner = product1VectorDouble    numEltMinIndex = minindexVectorDouble    numEltMaxIndex = maxindexVectorDouble +  numEltDotprod = dotprodVectorDouble  instance NumElt Int where    numEltAdd = intWidBranch2 @Int (+) @@ -422,6 +517,8 @@ instance NumElt Int where                            (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))    numEltMinIndex = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64    numEltMaxIndex = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64 +  numEltDotprod = intWidBranchDotprod @Int (c_reduce_i32 (aroEnum RO_SUM1)) c_dotprod_i32 c_dotprod_i32_strided +                                           (c_reduce_i64 (aroEnum RO_SUM1)) c_dotprod_i64 c_dotprod_i64_strided  instance NumElt CInt where    numEltAdd = intWidBranch2 @CInt (+) @@ -444,6 +541,8 @@ instance NumElt CInt where                            (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))    numEltMinIndex = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64    numEltMaxIndex = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64 +  numEltDotprod = intWidBranchDotprod @CInt (c_reduce_i32 (aroEnum RO_SUM1)) c_dotprod_i32 c_dotprod_i32_strided +                                            (c_reduce_i64 (aroEnum RO_SUM1)) c_dotprod_i64 c_dotprod_i64_strided  class FloatElt a where    floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs index 0bd72e8..96a85d1 100644 --- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs +++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs @@ -60,3 +60,12 @@ $(fmap concat . forM typesList $ \arithtype ->        let base = "extremum_" ++ fname ++ "_" ++ atCName arithtype        pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>          [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |]) + +$(fmap concat . forM typesList $ \arithtype -> do +    let ttyp = conT (atType arithtype) +    let base = "dotprod_" ++ atCName arithtype +    sequence +      [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$> +         [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO $ttyp |] +      ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_strided") (mkName ("c_" ++ base ++ "_strided")) <$> +         [t| Int64 -> Int64 -> Int64 -> Ptr $ttyp -> Int64 -> Int64 -> Ptr $ttyp -> IO $ttyp |]]) | 
