From 4c86a3a4231cecc5b7c31491398f43b4ba667eea Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 23 May 2024 13:47:18 +0200 Subject: Fast sum Also fast product, but that's currently unused --- src/Data/Array/Nested/Internal/Arith.hs | 118 ++++++++++++++++++++++++++------ 1 file changed, 97 insertions(+), 21 deletions(-) (limited to 'src/Data/Array/Nested/Internal/Arith.hs') diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs index 4312cd5..bd582b7 100644 --- a/src/Data/Array/Nested/Internal/Arith.hs +++ b/src/Data/Array/Nested/Internal/Arith.hs @@ -1,8 +1,11 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Nested.Internal.Arith where import Control.Monad (forM, guard) @@ -25,24 +28,24 @@ import Data.Array.Nested.Internal.Arith.Foreign import Data.Array.Nested.Internal.Arith.Lists -mliftVEltwise1 :: Storable a - => SNat n - -> (VS.Vector a -> VS.Vector a) - -> RS.Array n a -> RS.Array n a -mliftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) +liftVEltwise1 :: Storable a + => SNat n + -> (VS.Vector a -> VS.Vector a) + -> RS.Array n a -> RS.Array n a +liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) | Just prefixSz <- stridesDense sh strides = let vec' = f (VS.slice offset prefixSz vec) in RS.A (RG.A sh (OI.T strides 0 vec')) | otherwise = RS.fromVector sh (f (RS.toVector arr)) -mliftVEltwise2 :: Storable a - => SNat n - -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a) - -> RS.Array n a -> RS.Array n a -> RS.Array n a -mliftVEltwise2 SNat f +liftVEltwise2 :: Storable a + => SNat n + -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a) + -> RS.Array n a -> RS.Array n a -> RS.Array n a +liftVEltwise2 SNat f arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1))) arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2))) - | sh1 /= sh2 = error $ "mliftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 + | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2 | product sh1 == 0 = arr1 -- if the arrays are empty, just return one of the empty inputs | otherwise = case (stridesDense sh1 strides1, stridesDense sh2 strides2) of (Just 1, Just 1) -> -- both are a (potentially replicated) scalar; just apply f to the scalars @@ -62,11 +65,11 @@ mliftVEltwise2 SNat f stridesDense :: [Int] -> [Int] -> Maybe Int stridesDense sh _ | any (<= 0) sh = Just 0 stridesDense sh str = - -- sort dimensions on their stride, ascending - case sort (zip str sh) of - [] -> Just 0 + -- sort dimensions on their stride, ascending, dropping any zero strides + case dropWhile ((== 0) . fst) (sort (zip str sh)) of + [] -> Just 1 (1, n) : (unzip -> (str', sh')) -> checkCover n sh' str' - _ -> error "Orthotope array's shape vector and stride vector have different lengths" + _ -> Nothing -- if the smallest stride is not 1, it will never be dense where -- Given size of currently densely covered region at beginning of the -- array, the remaining shape vector and the corresponding remaining stride @@ -77,6 +80,7 @@ stridesDense sh str = checkCover block (n : sh') (s : str') = guard (s <= block) >> checkCover (max block (n * s)) sh' str' checkCover _ _ _ = error "Orthotope array's shape vector and stride vector have different lengths" +{-# NOINLINE vectorOp1 #-} vectorOp1 :: forall a b. Storable a => (Ptr a -> Ptr b) -> (Int64 -> Ptr b -> Ptr b -> IO ()) @@ -89,6 +93,7 @@ vectorOp1 ptrconv f v = unsafePerformIO $ do VS.unsafeFreeze outv -- | If two vectors are given, assumes that they have the same length. +{-# NOINLINE vectorOp2 #-} vectorOp2 :: forall a b. Storable a => (a -> b) -> (Ptr a -> Ptr b) @@ -127,6 +132,39 @@ vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases VS.unsafeFreeze outv | otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy) +-- TODO: test all the weird cases of this function +-- | Reduce along the inner dimension +{-# NOINLINE vectorRedInnerOp #-} +vectorRedInnerOp :: forall a b n. (Num a, Storable a) + => SNat n + -> (a -> b) + -> (Ptr a -> Ptr b) + -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel + -> RS.Array (n + 1) a -> RS.Array n a +vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec))) + | null sh = error "unreachable" + | any (<= 0) sh = RS.A (RG.A (init sh) (OI.T (map (const 0) (init strides)) 0 VS.empty)) + -- now the input array is nonempty + | last sh == 1 = RS.A (RG.A (init sh) (OI.T (init strides) offset vec)) + | last strides == 0 = + liftVEltwise1 sn + (vectorOp1 id (\n pout px -> fscale n (ptrconv pout) (valconv (fromIntegral (last sh))) (ptrconv px))) + (RS.A (RG.A (init sh) (OI.T (init strides) offset vec))) + -- now there is useful work along the inner dimension + | otherwise = + let -- filter out zero-stride dimensions; the reduction kernel need not concern itself with those + (shF, stridesF) = unzip $ filter ((/= 0) . snd) (zip sh strides) + ndimsF = length shF + in unsafePerformIO $ do + outv <- VSM.unsafeNew (product (init shF)) + VSM.unsafeWith outv $ \poutv -> + VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF -> + VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF -> + VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec -> + fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec) + RS.fromVector (init sh) <$> VS.unsafeFreeze outv + flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ()) -> Int64 -> Ptr a -> Ptr a -> a -> IO () flipOp f n out v s = f n out s v @@ -138,6 +176,8 @@ class NumElt a where numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a + numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a + numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a $(fmap concat . forM typesList $ \arithtype -> do let ttyp = conT (atType arithtype) @@ -151,7 +191,7 @@ $(fmap concat . forM typesList $ \arithtype -> do c_vv = varE $ mkName (cnamebase ++ "_vv") sequence [SigD name <$> [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> mliftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] + ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |] return $ FunD name [Clause [] (NormalB body) []]]) $(fmap concat . forM typesList $ \arithtype -> do @@ -161,7 +201,18 @@ $(fmap concat . forM typesList $ \arithtype -> do c_op = varE $ mkName ("c_" ++ auoName arithop ++ "_" ++ atCName arithtype) sequence [SigD name <$> [t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |] - ,do body <- [| \sn -> mliftVEltwise1 sn (vectorOp1 id $c_op) |] + ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |] + return $ FunD name [Clause [] (NormalB body) []]]) + +$(fmap concat . forM typesList $ \arithtype -> do + let ttyp = conT (atType arithtype) + fmap concat . forM redopsList $ \redop -> do + let name = mkName (aroName redop ++ "Vector" ++ nameBase (atType arithtype)) + c_op = varE $ mkName ("c_" ++ aroName redop ++ "_" ++ atCName arithtype) + c_scale_op = varE $ mkName ("c_mul_" ++ atCName arithtype ++ "_sv") + sequence [SigD name <$> + [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |] + ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |] return $ FunD name [Clause [] (NormalB body) []]]) -- This branch is ostensibly a runtime branch, but will (hopefully) be @@ -171,8 +222,8 @@ intWidBranch1 :: forall i n. (FiniteBits i, Storable i) -> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -> (SNat n -> RS.Array n i -> RS.Array n i) intWidBranch1 f32 f64 sn - | finiteBitSize (undefined :: i) == 32 = mliftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32) - | finiteBitSize (undefined :: i) == 64 = mliftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64) + | finiteBitSize (undefined :: i) == 32 = liftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32) + | finiteBitSize (undefined :: i) == 64 = liftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64) | otherwise = error "Unsupported Int width" intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i) @@ -187,8 +238,21 @@ intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i) -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- vv -> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i) intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn - | finiteBitSize (undefined :: i) == 32 = mliftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32) - | finiteBitSize (undefined :: i) == 64 = mliftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64) + | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32) + | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64) + | otherwise = error "Unsupported Int width" + +intWidBranchRed :: forall i n. (FiniteBits i, Storable i, Integral i) + => -- int32 + (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel + -- int64 + -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant + -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel + -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i) +intWidBranchRed fsc32 fred32 fsc64 fred64 sn + | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32 + | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64 | otherwise = error "Unsupported Int width" instance NumElt Int32 where @@ -198,6 +262,8 @@ instance NumElt Int32 where numEltNeg = negVectorInt32 numEltAbs = absVectorInt32 numEltSignum = signumVectorInt32 + numEltSum1Inner = sum1VectorInt32 + numEltProduct1Inner = product1VectorInt32 instance NumElt Int64 where numEltAdd = addVectorInt64 @@ -206,6 +272,8 @@ instance NumElt Int64 where numEltNeg = negVectorInt64 numEltAbs = absVectorInt64 numEltSignum = signumVectorInt64 + numEltSum1Inner = sum1VectorInt64 + numEltProduct1Inner = product1VectorInt64 instance NumElt Float where numEltAdd = addVectorFloat @@ -214,6 +282,8 @@ instance NumElt Float where numEltNeg = negVectorFloat numEltAbs = absVectorFloat numEltSignum = signumVectorFloat + numEltSum1Inner = sum1VectorFloat + numEltProduct1Inner = product1VectorFloat instance NumElt Double where numEltAdd = addVectorDouble @@ -222,6 +292,8 @@ instance NumElt Double where numEltNeg = negVectorDouble numEltAbs = absVectorDouble numEltSignum = signumVectorDouble + numEltSum1Inner = sum1VectorDouble + numEltProduct1Inner = product1VectorDouble instance NumElt Int where numEltAdd = intWidBranch2 @Int (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv @@ -230,6 +302,8 @@ instance NumElt Int where numEltNeg = intWidBranch1 @Int c_neg_i32 c_neg_i64 numEltAbs = intWidBranch1 @Int c_abs_i32 c_abs_i64 numEltSignum = intWidBranch1 @Int c_signum_i32 c_signum_i64 + numEltSum1Inner = intWidBranchRed @Int c_mul_i32_sv c_sum1_i32 c_mul_i64_sv c_sum1_i64 + numEltProduct1Inner = intWidBranchRed @Int c_mul_i32_sv c_product1_i32 c_mul_i64_sv c_product1_i64 instance NumElt CInt where numEltAdd = intWidBranch2 @CInt (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv @@ -238,3 +312,5 @@ instance NumElt CInt where numEltNeg = intWidBranch1 @CInt c_neg_i32 c_neg_i64 numEltAbs = intWidBranch1 @CInt c_abs_i32 c_abs_i64 numEltSignum = intWidBranch1 @CInt c_signum_i32 c_signum_i64 + numEltSum1Inner = intWidBranchRed @CInt c_mul_i32_sv c_sum1_i32 c_mul_i64_sv c_sum1_i64 + numEltProduct1Inner = intWidBranchRed @CInt c_mul_i32_sv c_product1_i32 c_mul_i64_sv c_product1_i64 -- cgit v1.2.3-70-g09d2