diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2024-06-19 15:57:43 +0200 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2024-06-19 15:57:43 +0200 |
commit | aafe5f6b5fa772d0e2e9f9b4f91bc3e7cf696840 (patch) | |
tree | c0d0d81a9c40f72adf041b165819ab0c7daa44bf /src/Data/Array/Mixed/Internal/Arith.hs | |
parent | 97ab8502b9cd3f7d908160d13c7d85d23c99e203 (diff) |
Add {m,r,s}dot1Inner
Diffstat (limited to 'src/Data/Array/Mixed/Internal/Arith.hs')
-rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 118 |
1 files changed, 64 insertions, 54 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index 9f99c3b..fc26633 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -31,6 +31,7 @@ import System.IO.Unsafe import Data.Array.Mixed.Internal.Arith.Foreign import Data.Array.Mixed.Internal.Arith.Lists +import Data.Array.Mixed.Types (fromSNat') -- TODO: need to sort strides for reduction-like functions so that the C inner-loop specialisation has some chance of working even after transposition @@ -304,36 +305,44 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec))) . VS.toList <$> VS.unsafeFreeze outvR -vectorDotprodOp :: (Num a, Storable a) - => (b -> a) - -> (Ptr a -> Ptr b) - -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel - -> (Int64 -> Ptr b -> Ptr b -> IO b) -- ^ dotprod kernel - -> (Int64 -> Int64 -> Int64 -> Ptr b -> Int64 -> Int64 -> Ptr b -> IO b) -- ^ strided dotprod kernel - -> RS.Array 1 a -> RS.Array 1 a -> a -vectorDotprodOp valbackconv ptrconv fred fdot fdotstrided - (RS.A (RG.A [len1] (OI.T [stride1] offset1 vec1))) - (RS.A (RG.A [len2] (OI.T [stride2] offset2 vec2))) - | len1 /= len2 = error $ "vectorDotprodOp: lengths unequal: " ++ show len1 ++ " vs " ++ show len2 - | len1 == 0 = 0 -- if the arrays are empty, just return zero - | otherwise = case (stride1, stride2) of - (0, 0) -> -- replicated scalar * replicated scalar - fromIntegral len1 * (vec1 VS.! offset1) * (vec2 VS.! offset2) - (0, 1) -> -- replicated scalar * dense - dotScalarVector len1 ptrconv fred (vec1 VS.! offset1) (VS.slice offset2 len1 vec2) - (0, -1) -> -- replicated scalar * reversed dense - dotScalarVector len1 ptrconv fred (vec1 VS.! offset1) (VS.slice (offset2 - (len1 - 1)) len1 vec2) - (1, 0) -> -- dense * replicated scalar - dotScalarVector len1 ptrconv fred (vec2 VS.! offset2) (VS.slice offset1 len1 vec1) - (-1, 0) -> -- reversed dense * replicated scalar - dotScalarVector len1 ptrconv fred (vec2 VS.! offset2) (VS.slice (offset1 - (len1 - 1)) len1 vec1) - (1, 1) -> -- dense * dense - dotVectorVector len1 valbackconv ptrconv fdot (VS.slice offset1 len1 vec1) (VS.slice offset2 len1 vec2) - (-1, -1) -> -- reversed dense * reversed dense - dotVectorVector len1 valbackconv ptrconv fdot (VS.slice (offset1 - (len1 - 1)) len1 vec1) (VS.slice (offset2 - (len1 - 1)) len1 vec2) - (_, _) -> -- fallback case - dotVectorVectorStrided len1 valbackconv ptrconv fdotstrided offset1 stride1 vec1 offset2 stride2 vec2 -vectorDotprodOp _ _ _ _ _ _ _ = error "vectorDotprodOp: not one-dimensional?" +vectorDotprodInnerOp :: forall a b n. (Num a, Storable a) + => SNat n + -> (a -> b) + -> (Ptr a -> Ptr b) + -> (SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a) -- ^ elementwise multiplication + -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant + -> (Int64 -> Ptr b -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO ()) -- ^ reduction kernel + -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ dotprod kernel + -> RS.Array (n + 1) a -> RS.Array (n + 1) a -> RS.Array n a +vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner + arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1))) + arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) + | null sh1 || null sh2 = error "unreachable" + | sh1 /= sh2 = error $ "vectorDotprodInnerOp: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 + | last sh1 <= 0 = RS.stretch (init sh1) (RS.fromList (1 <$ init sh1) [0]) + | any (<= 0) (init sh1) = RS.A (RG.A (init sh1) (OI.T (0 <$ init strides1) 0 VS.empty)) + -- now the input arrays are nonempty + | last sh1 == 1 = fmul sn (RS.reshape (init sh1) arr1) (RS.reshape (init sh1) arr2) + | last strides1 == 0 = + fmul sn + (RS.A (RG.A (init sh1) (OI.T (init strides1) offset1 vec1))) + (vectorRedInnerOp sn valconv ptrconv fscale fred arr2) + | last strides2 == 0 = + fmul sn + (vectorRedInnerOp sn valconv ptrconv fscale fred arr1) + (RS.A (RG.A (init sh2) (OI.T (init strides2) offset2 vec2))) + -- now there is useful dotprod work along the inner dimension + | otherwise = unsafePerformIO $ do + let inrank = fromSNat' sn + 1 + outv <- VSM.unsafeNew (product (init sh1)) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith (VS.fromListN inrank (map fromIntegral sh1)) $ \psh -> + VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides1)) $ \pstrides1 -> + VS.unsafeWith vec1 $ \pvec1 -> + VS.unsafeWith (VS.fromListN inrank (map fromIntegral strides2)) $ \pstrides2 -> + VS.unsafeWith vec2 $ \pvec2 -> + fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv) pstrides1 (ptrconv pvec1) pstrides2 (ptrconv pvec2) + RS.fromVector @_ @n (init sh1) <$> VS.unsafeFreeze outv {-# NOINLINE dotScalarVector #-} dotScalarVector :: forall a b. (Num a, Storable a) @@ -461,13 +470,14 @@ $(fmap concat . forM typesList $ \arithtype -> $(fmap concat . forM typesList $ \arithtype -> do let ttyp = conT (atType arithtype) - name = mkName ("dotprodVector" ++ nameBase (atType arithtype)) - c_op = varE (mkName ("c_dotprod_" ++ atCName arithtype)) - c_op_strided = varE (mkName ("c_dotprod_" ++ atCName arithtype ++ "_strided")) + name = mkName ("dotprodinnerVector" ++ nameBase (atType arithtype)) + c_op = varE (mkName ("c_dotprodinner_" ++ atCName arithtype)) + mul_op = varE (mkName ("mulVector" ++ nameBase (atType arithtype))) + c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL))) c_red_op = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM))) sequence [SigD name <$> - [t| RS.Array 1 $ttyp -> RS.Array 1 $ttyp -> $ttyp |] - ,do body <- [| vectorDotprodOp id id $c_red_op $c_op $c_op_strided |] + [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> vectorDotprodInnerOp sn id id $mul_op $c_scale_op $c_red_op $c_op |] return $ FunD name [Clause [] (NormalB body) []]]) -- This branch is ostensibly a runtime branch, but will (hopefully) be @@ -533,19 +543,19 @@ intWidBranchExtr fextr32 fextr64 | finiteBitSize (undefined :: i) == 64 = vectorExtremumOp @i @Int64 castPtr fextr64 | otherwise = error "Unsupported Int width" -intWidBranchDotprod :: forall i. (FiniteBits i, Storable i, Integral i) +intWidBranchDotprod :: forall i n. (FiniteBits i, Storable i, Integral i, NumElt i) => -- int32 - (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ reduction kernel - -> (Int64 -> Ptr Int32 -> Ptr Int32 -> IO Int32) -- ^ dotprod kernel - -> (Int64 -> Int64 -> Int64 -> Ptr Int32 -> Int64 -> Int64 -> Ptr Int32 -> IO Int32) -- ^ strided dotprod kernel + (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant + -> (Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ reduction kernel + -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ dotprod kernel -- int64 + -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel - -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64) -- ^ dotprod kernel - -> (Int64 -> Int64 -> Int64 -> Ptr Int64 -> Int64 -> Int64 -> Ptr Int64 -> IO Int64) -- ^ strided dotprod kernel - -> (RS.Array 1 i -> RS.Array 1 i -> i) -intWidBranchDotprod fred32 fdot32 fdot32strided fred64 fdot64 fdot64strided - | finiteBitSize (undefined :: i) == 32 = vectorDotprodOp @i @Int32 fromIntegral castPtr fred32 fdot32 fdot32strided - | finiteBitSize (undefined :: i) == 64 = vectorDotprodOp @i @Int64 fromIntegral castPtr fred64 fdot64 fdot64strided + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ dotprod kernel + -> (SNat n -> RS.Array (n + 1) i -> RS.Array (n + 1) i -> RS.Array n i) +intWidBranchDotprod fsc32 fred32 fdot32 fsc64 fred64 fdot64 sn + | finiteBitSize (undefined :: i) == 32 = vectorDotprodInnerOp @i @Int32 sn fromIntegral castPtr numEltMul fsc32 fred32 fdot32 + | finiteBitSize (undefined :: i) == 64 = vectorDotprodInnerOp @i @Int64 sn fromIntegral castPtr numEltMul fsc64 fred64 fdot64 | otherwise = error "Unsupported Int width" class NumElt a where @@ -561,7 +571,7 @@ class NumElt a where numEltProductFull :: SNat n -> RS.Array n a -> a numEltMinIndex :: RS.Array n a -> [Int] numEltMaxIndex :: RS.Array n a -> [Int] - numEltDotprod :: RS.Array 1 a -> RS.Array 1 a -> a + numEltDotprodInner :: SNat n -> RS.Array (n + 1) a -> RS.Array (n + 1) a -> RS.Array n a instance NumElt Int32 where numEltAdd = addVectorInt32 @@ -576,7 +586,7 @@ instance NumElt Int32 where numEltProductFull = productFullVectorInt32 numEltMinIndex = minindexVectorInt32 numEltMaxIndex = maxindexVectorInt32 - numEltDotprod = dotprodVectorInt32 + numEltDotprodInner = dotprodinnerVectorInt32 instance NumElt Int64 where numEltAdd = addVectorInt64 @@ -591,7 +601,7 @@ instance NumElt Int64 where numEltProductFull = productFullVectorInt64 numEltMinIndex = minindexVectorInt64 numEltMaxIndex = maxindexVectorInt64 - numEltDotprod = dotprodVectorInt64 + numEltDotprodInner = dotprodinnerVectorInt64 instance NumElt Float where numEltAdd = addVectorFloat @@ -606,7 +616,7 @@ instance NumElt Float where numEltProductFull = productFullVectorFloat numEltMinIndex = minindexVectorFloat numEltMaxIndex = maxindexVectorFloat - numEltDotprod = dotprodVectorFloat + numEltDotprodInner = dotprodinnerVectorFloat instance NumElt Double where numEltAdd = addVectorDouble @@ -621,7 +631,7 @@ instance NumElt Double where numEltProductFull = productFullVectorDouble numEltMinIndex = minindexVectorDouble numEltMaxIndex = maxindexVectorDouble - numEltDotprod = dotprodVectorDouble + numEltDotprodInner = dotprodinnerVectorDouble instance NumElt Int where numEltAdd = intWidBranch2 @Int (+) @@ -646,8 +656,8 @@ instance NumElt Int where numEltProductFull = intWidBranchRedFull @Int (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT)) numEltMinIndex = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64 numEltMaxIndex = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64 - numEltDotprod = intWidBranchDotprod @Int (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprod_i32 c_dotprod_i32_strided - (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprod_i64 c_dotprod_i64_strided + numEltDotprodInner = intWidBranchDotprod @Int (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32 + (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64 instance NumElt CInt where numEltAdd = intWidBranch2 @CInt (+) @@ -672,8 +682,8 @@ instance NumElt CInt where numEltProductFull = intWidBranchRedFull @CInt (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT)) numEltMinIndex = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64 numEltMaxIndex = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64 - numEltDotprod = intWidBranchDotprod @CInt (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprod_i32 c_dotprod_i32_strided - (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprod_i64 c_dotprod_i64_strided + numEltDotprodInner = intWidBranchDotprod @CInt (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprodinner_i32 + (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprodinner_i64 class FloatElt a where floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a |