aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed/Internal/Arith.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Mixed/Internal/Arith.hs')
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs101
1 files changed, 100 insertions, 1 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index d2ad61f..6ecbbeb 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -19,8 +19,9 @@ import Data.List (sort)
import Data.Vector.Storable qualified as VS
import Data.Vector.Storable.Mutable qualified as VSM
import Foreign.C.Types
+import Foreign.Marshal.Alloc (alloca)
import Foreign.Ptr
-import Foreign.Storable (Storable)
+import Foreign.Storable (Storable, peek, poke)
import GHC.TypeLits
import GHC.TypeNats qualified as TypeNats
import Language.Haskell.TH
@@ -217,6 +218,69 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))
fextrem poutv (fromIntegral ndimsF) pshF pstridesF (ptrconv pvec)
insertZeros replDims . map (fromIntegral @Int64 @Int) . VS.toList <$> VS.unsafeFreeze outv
+vectorDotprodOp :: (Num a, Storable a)
+ => (b -> a)
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> (Int64 -> Ptr b -> Ptr b -> IO b) -- ^ dotprod kernel
+ -> (Int64 -> Int64 -> Int64 -> Ptr b -> Int64 -> Int64 -> Ptr b -> IO b) -- ^ strided dotprod kernel
+ -> RS.Array 1 a -> RS.Array 1 a -> a
+vectorDotprodOp valbackconv ptrconv fred fdot fdotstrided
+ (RS.A (RG.A [len1] (OI.T [stride1] offset1 vec1)))
+ (RS.A (RG.A [len2] (OI.T [stride2] offset2 vec2)))
+ | len1 /= len2 = error $ "vectorDotprodOp: lengths unequal: " ++ show len1 ++ " vs " ++ show len2
+ | len1 == 0 = 0 -- if the arrays are empty, just return zero
+ | otherwise = case (stride1, stride2) of
+ (0, 0) -> -- replicated scalar * replicated scalar
+ fromIntegral len1 * (vec1 VS.! offset1) * (vec2 VS.! offset2)
+ (0, 1) -> -- replicated scalar * dense
+ dotScalarVector len1 ptrconv fred (vec1 VS.! offset1) (VS.slice offset2 len1 vec2)
+ (1, 0) -> -- dense * replicated scalar
+ dotScalarVector len1 ptrconv fred (vec2 VS.! offset2) (VS.slice offset1 len1 vec1)
+ (1, 1) -> -- dense * dense
+ dotVectorVector len1 valbackconv ptrconv fdot (VS.slice offset1 len1 vec1) (VS.slice offset2 len1 vec2)
+ (_, _) -> -- fallback case
+ dotVectorVectorStrided len1 valbackconv ptrconv fdotstrided offset1 stride1 vec1 offset2 stride2 vec2
+vectorDotprodOp _ _ _ _ _ _ _ = error "vectorDotprodOp: not one-dimensional?"
+
+{-# NOINLINE dotScalarVector #-}
+dotScalarVector :: forall a b. (Num a, Storable a)
+ => Int -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> a -> VS.Vector a -> a
+dotScalarVector len ptrconv fred scalar vec = unsafePerformIO $ do
+ alloca @a $ \pout -> do
+ alloca @Int64 $ \pshape -> do
+ poke pshape (fromIntegral @Int @Int64 len)
+ alloca @Int64 $ \pstride -> do
+ poke pstride 1
+ VS.unsafeWith vec $ \pvec ->
+ fred 1 pshape pstride (ptrconv pout) (ptrconv pvec)
+ res <- peek pout
+ return (scalar * res)
+
+{-# NOINLINE dotVectorVector #-}
+dotVectorVector :: Storable a => Int -> (b -> a) -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr b -> Ptr b -> IO b) -- ^ dotprod kernel
+ -> VS.Vector a -> VS.Vector a -> a
+dotVectorVector len valbackconv ptrconv fdot vec1 vec2 = unsafePerformIO $ do
+ VS.unsafeWith vec1 $ \pvec1 ->
+ VS.unsafeWith vec2 $ \pvec2 ->
+ valbackconv <$> fdot (fromIntegral @Int @Int64 len) (ptrconv pvec1) (ptrconv pvec2)
+
+{-# NOINLINE dotVectorVectorStrided #-}
+dotVectorVectorStrided :: Storable a => Int -> (b -> a) -> (Ptr a -> Ptr b)
+ -> (Int64 -> Int64 -> Int64 -> Ptr b -> Int64 -> Int64 -> Ptr b -> IO b) -- ^ dotprod kernel
+ -> Int -> Int -> VS.Vector a
+ -> Int -> Int -> VS.Vector a
+ -> a
+dotVectorVectorStrided len valbackconv ptrconv fdot offset1 stride1 vec1 offset2 stride2 vec2 = unsafePerformIO $ do
+ VS.unsafeWith vec1 $ \pvec1 ->
+ VS.unsafeWith vec2 $ \pvec2 ->
+ valbackconv <$> fdot (fromIntegral @Int @Int64 len)
+ (fromIntegral offset1) (fromIntegral stride1) (ptrconv pvec1)
+ (fromIntegral offset2) (fromIntegral stride2) (ptrconv pvec2)
+
flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())
-> Int64 -> Ptr a -> Ptr a -> a -> IO ()
flipOp f n out v s = f n out s v
@@ -290,6 +354,17 @@ $(fmap concat . forM typesList $ \arithtype ->
,do body <- [| vectorExtremumOp id $c_op |]
return $ FunD name [Clause [] (NormalB body) []]])
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ name = mkName ("dotprodVector" ++ nameBase (atType arithtype))
+ c_op = varE (mkName ("c_dotprod_" ++ atCName arithtype))
+ c_op_strided = varE (mkName ("c_dotprod_" ++ atCName arithtype ++ "_strided"))
+ c_red_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM1)))
+ sequence [SigD name <$>
+ [t| RS.Array 1 $ttyp -> RS.Array 1 $ttyp -> $ttyp |]
+ ,do body <- [| vectorDotprodOp id id $c_red_op $c_op $c_op_strided |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
-- This branch is ostensibly a runtime branch, but will (hopefully) be
-- constant-folded away by GHC.
intWidBranch1 :: forall i n. (FiniteBits i, Storable i)
@@ -341,6 +416,21 @@ intWidBranchExtr fextr32 fextr64
| finiteBitSize (undefined :: i) == 64 = vectorExtremumOp @i @Int64 castPtr fextr64
| otherwise = error "Unsupported Int width"
+intWidBranchDotprod :: forall i. (FiniteBits i, Storable i, Integral i)
+ => -- int32
+ (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel
+ -> (Int64 -> Ptr Int32 -> Ptr Int32 -> IO Int32) -- ^ dotprod kernel
+ -> (Int64 -> Int64 -> Int64 -> Ptr Int32 -> Int64 -> Int64 -> Ptr Int32 -> IO Int32) -- ^ strided dotprod kernel
+ -- int64
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64) -- ^ dotprod kernel
+ -> (Int64 -> Int64 -> Int64 -> Ptr Int64 -> Int64 -> Int64 -> Ptr Int64 -> IO Int64) -- ^ strided dotprod kernel
+ -> (RS.Array 1 i -> RS.Array 1 i -> i)
+intWidBranchDotprod fred32 fdot32 fdot32strided fred64 fdot64 fdot64strided
+ | finiteBitSize (undefined :: i) == 32 = vectorDotprodOp @i @Int32 fromIntegral castPtr fred32 fdot32 fdot32strided
+ | finiteBitSize (undefined :: i) == 64 = vectorDotprodOp @i @Int64 fromIntegral castPtr fred64 fdot64 fdot64strided
+ | otherwise = error "Unsupported Int width"
+
class NumElt a where
numEltAdd :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
numEltSub :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
@@ -352,6 +442,7 @@ class NumElt a where
numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
numEltMinIndex :: RS.Array n a -> [Int]
numEltMaxIndex :: RS.Array n a -> [Int]
+ numEltDotprod :: RS.Array 1 a -> RS.Array 1 a -> a
instance NumElt Int32 where
numEltAdd = addVectorInt32
@@ -364,6 +455,7 @@ instance NumElt Int32 where
numEltProduct1Inner = product1VectorInt32
numEltMinIndex = minindexVectorInt32
numEltMaxIndex = maxindexVectorInt32
+ numEltDotprod = dotprodVectorInt32
instance NumElt Int64 where
numEltAdd = addVectorInt64
@@ -376,6 +468,7 @@ instance NumElt Int64 where
numEltProduct1Inner = product1VectorInt64
numEltMinIndex = minindexVectorInt64
numEltMaxIndex = maxindexVectorInt64
+ numEltDotprod = dotprodVectorInt64
instance NumElt Float where
numEltAdd = addVectorFloat
@@ -388,6 +481,7 @@ instance NumElt Float where
numEltProduct1Inner = product1VectorFloat
numEltMinIndex = minindexVectorFloat
numEltMaxIndex = maxindexVectorFloat
+ numEltDotprod = dotprodVectorFloat
instance NumElt Double where
numEltAdd = addVectorDouble
@@ -400,6 +494,7 @@ instance NumElt Double where
numEltProduct1Inner = product1VectorDouble
numEltMinIndex = minindexVectorDouble
numEltMaxIndex = maxindexVectorDouble
+ numEltDotprod = dotprodVectorDouble
instance NumElt Int where
numEltAdd = intWidBranch2 @Int (+)
@@ -422,6 +517,8 @@ instance NumElt Int where
(c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
numEltMinIndex = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64
numEltMaxIndex = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64
+ numEltDotprod = intWidBranchDotprod @Int (c_reduce_i32 (aroEnum RO_SUM1)) c_dotprod_i32 c_dotprod_i32_strided
+ (c_reduce_i64 (aroEnum RO_SUM1)) c_dotprod_i64 c_dotprod_i64_strided
instance NumElt CInt where
numEltAdd = intWidBranch2 @CInt (+)
@@ -444,6 +541,8 @@ instance NumElt CInt where
(c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
numEltMinIndex = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64
numEltMaxIndex = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64
+ numEltDotprod = intWidBranchDotprod @CInt (c_reduce_i32 (aroEnum RO_SUM1)) c_dotprod_i32 c_dotprod_i32_strided
+ (c_reduce_i64 (aroEnum RO_SUM1)) c_dotprod_i64 c_dotprod_i64_strided
class FloatElt a where
floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a