aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal/Arith.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-23 13:47:18 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-23 13:47:18 +0200
commit4c86a3a4231cecc5b7c31491398f43b4ba667eea (patch)
tree2e06f293f1350b7dd712bf1ad0eccb7b9d7686b4 /src/Data/Array/Nested/Internal/Arith.hs
parent827a9ce7adc6cf1debc08d154e4c11b7b83bfdf0 (diff)
Fast sum
Also fast product, but that's currently unused
Diffstat (limited to 'src/Data/Array/Nested/Internal/Arith.hs')
-rw-r--r--src/Data/Array/Nested/Internal/Arith.hs118
1 files changed, 97 insertions, 21 deletions
diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs
index 4312cd5..bd582b7 100644
--- a/src/Data/Array/Nested/Internal/Arith.hs
+++ b/src/Data/Array/Nested/Internal/Arith.hs
@@ -1,8 +1,11 @@
+{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Nested.Internal.Arith where
import Control.Monad (forM, guard)
@@ -25,24 +28,24 @@ import Data.Array.Nested.Internal.Arith.Foreign
import Data.Array.Nested.Internal.Arith.Lists
-mliftVEltwise1 :: Storable a
- => SNat n
- -> (VS.Vector a -> VS.Vector a)
- -> RS.Array n a -> RS.Array n a
-mliftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
+liftVEltwise1 :: Storable a
+ => SNat n
+ -> (VS.Vector a -> VS.Vector a)
+ -> RS.Array n a -> RS.Array n a
+liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
| Just prefixSz <- stridesDense sh strides =
let vec' = f (VS.slice offset prefixSz vec)
in RS.A (RG.A sh (OI.T strides 0 vec'))
| otherwise = RS.fromVector sh (f (RS.toVector arr))
-mliftVEltwise2 :: Storable a
- => SNat n
- -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a)
- -> RS.Array n a -> RS.Array n a -> RS.Array n a
-mliftVEltwise2 SNat f
+liftVEltwise2 :: Storable a
+ => SNat n
+ -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a)
+ -> RS.Array n a -> RS.Array n a -> RS.Array n a
+liftVEltwise2 SNat f
arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1)))
arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2)))
- | sh1 /= sh2 = error $ "mliftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2
+ | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2
| product sh1 == 0 = arr1 -- if the arrays are empty, just return one of the empty inputs
| otherwise = case (stridesDense sh1 strides1, stridesDense sh2 strides2) of
(Just 1, Just 1) -> -- both are a (potentially replicated) scalar; just apply f to the scalars
@@ -62,11 +65,11 @@ mliftVEltwise2 SNat f
stridesDense :: [Int] -> [Int] -> Maybe Int
stridesDense sh _ | any (<= 0) sh = Just 0
stridesDense sh str =
- -- sort dimensions on their stride, ascending
- case sort (zip str sh) of
- [] -> Just 0
+ -- sort dimensions on their stride, ascending, dropping any zero strides
+ case dropWhile ((== 0) . fst) (sort (zip str sh)) of
+ [] -> Just 1
(1, n) : (unzip -> (str', sh')) -> checkCover n sh' str'
- _ -> error "Orthotope array's shape vector and stride vector have different lengths"
+ _ -> Nothing -- if the smallest stride is not 1, it will never be dense
where
-- Given size of currently densely covered region at beginning of the
-- array, the remaining shape vector and the corresponding remaining stride
@@ -77,6 +80,7 @@ stridesDense sh str =
checkCover block (n : sh') (s : str') = guard (s <= block) >> checkCover (max block (n * s)) sh' str'
checkCover _ _ _ = error "Orthotope array's shape vector and stride vector have different lengths"
+{-# NOINLINE vectorOp1 #-}
vectorOp1 :: forall a b. Storable a
=> (Ptr a -> Ptr b)
-> (Int64 -> Ptr b -> Ptr b -> IO ())
@@ -89,6 +93,7 @@ vectorOp1 ptrconv f v = unsafePerformIO $ do
VS.unsafeFreeze outv
-- | If two vectors are given, assumes that they have the same length.
+{-# NOINLINE vectorOp2 #-}
vectorOp2 :: forall a b. Storable a
=> (a -> b)
-> (Ptr a -> Ptr b)
@@ -127,6 +132,39 @@ vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases
VS.unsafeFreeze outv
| otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy)
+-- TODO: test all the weird cases of this function
+-- | Reduce along the inner dimension
+{-# NOINLINE vectorRedInnerOp #-}
+vectorRedInnerOp :: forall a b n. (Num a, Storable a)
+ => SNat n
+ -> (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> RS.Array (n + 1) a -> RS.Array n a
+vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec)))
+ | null sh = error "unreachable"
+ | any (<= 0) sh = RS.A (RG.A (init sh) (OI.T (map (const 0) (init strides)) 0 VS.empty))
+ -- now the input array is nonempty
+ | last sh == 1 = RS.A (RG.A (init sh) (OI.T (init strides) offset vec))
+ | last strides == 0 =
+ liftVEltwise1 sn
+ (vectorOp1 id (\n pout px -> fscale n (ptrconv pout) (valconv (fromIntegral (last sh))) (ptrconv px)))
+ (RS.A (RG.A (init sh) (OI.T (init strides) offset vec)))
+ -- now there is useful work along the inner dimension
+ | otherwise =
+ let -- filter out zero-stride dimensions; the reduction kernel need not concern itself with those
+ (shF, stridesF) = unzip $ filter ((/= 0) . snd) (zip sh strides)
+ ndimsF = length shF
+ in unsafePerformIO $ do
+ outv <- VSM.unsafeNew (product (init shF))
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF ->
+ VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF ->
+ VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec ->
+ fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec)
+ RS.fromVector (init sh) <$> VS.unsafeFreeze outv
+
flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())
-> Int64 -> Ptr a -> Ptr a -> a -> IO ()
flipOp f n out v s = f n out s v
@@ -138,6 +176,8 @@ class NumElt a where
numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a
numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a
numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a
+ numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
+ numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
@@ -151,7 +191,7 @@ $(fmap concat . forM typesList $ \arithtype -> do
c_vv = varE $ mkName (cnamebase ++ "_vv")
sequence [SigD name <$>
[t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> mliftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]
+ ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]
return $ FunD name [Clause [] (NormalB body) []]])
$(fmap concat . forM typesList $ \arithtype -> do
@@ -161,7 +201,18 @@ $(fmap concat . forM typesList $ \arithtype -> do
c_op = varE $ mkName ("c_" ++ auoName arithop ++ "_" ++ atCName arithtype)
sequence [SigD name <$>
[t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> mliftVEltwise1 sn (vectorOp1 id $c_op) |]
+ ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM redopsList $ \redop -> do
+ let name = mkName (aroName redop ++ "Vector" ++ nameBase (atType arithtype))
+ c_op = varE $ mkName ("c_" ++ aroName redop ++ "_" ++ atCName arithtype)
+ c_scale_op = varE $ mkName ("c_mul_" ++ atCName arithtype ++ "_sv")
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |]
+ ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |]
return $ FunD name [Clause [] (NormalB body) []]])
-- This branch is ostensibly a runtime branch, but will (hopefully) be
@@ -171,8 +222,8 @@ intWidBranch1 :: forall i n. (FiniteBits i, Storable i)
-> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())
-> (SNat n -> RS.Array n i -> RS.Array n i)
intWidBranch1 f32 f64 sn
- | finiteBitSize (undefined :: i) == 32 = mliftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32)
- | finiteBitSize (undefined :: i) == 64 = mliftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64)
+ | finiteBitSize (undefined :: i) == 32 = liftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32)
+ | finiteBitSize (undefined :: i) == 64 = liftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64)
| otherwise = error "Unsupported Int width"
intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i)
@@ -187,8 +238,21 @@ intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i)
-> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- vv
-> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i)
intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn
- | finiteBitSize (undefined :: i) == 32 = mliftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32)
- | finiteBitSize (undefined :: i) == 64 = mliftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64)
+ | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32)
+ | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64)
+ | otherwise = error "Unsupported Int width"
+
+intWidBranchRed :: forall i n. (FiniteBits i, Storable i, Integral i)
+ => -- int32
+ (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel
+ -- int64
+ -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel
+ -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i)
+intWidBranchRed fsc32 fred32 fsc64 fred64 sn
+ | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32
+ | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64
| otherwise = error "Unsupported Int width"
instance NumElt Int32 where
@@ -198,6 +262,8 @@ instance NumElt Int32 where
numEltNeg = negVectorInt32
numEltAbs = absVectorInt32
numEltSignum = signumVectorInt32
+ numEltSum1Inner = sum1VectorInt32
+ numEltProduct1Inner = product1VectorInt32
instance NumElt Int64 where
numEltAdd = addVectorInt64
@@ -206,6 +272,8 @@ instance NumElt Int64 where
numEltNeg = negVectorInt64
numEltAbs = absVectorInt64
numEltSignum = signumVectorInt64
+ numEltSum1Inner = sum1VectorInt64
+ numEltProduct1Inner = product1VectorInt64
instance NumElt Float where
numEltAdd = addVectorFloat
@@ -214,6 +282,8 @@ instance NumElt Float where
numEltNeg = negVectorFloat
numEltAbs = absVectorFloat
numEltSignum = signumVectorFloat
+ numEltSum1Inner = sum1VectorFloat
+ numEltProduct1Inner = product1VectorFloat
instance NumElt Double where
numEltAdd = addVectorDouble
@@ -222,6 +292,8 @@ instance NumElt Double where
numEltNeg = negVectorDouble
numEltAbs = absVectorDouble
numEltSignum = signumVectorDouble
+ numEltSum1Inner = sum1VectorDouble
+ numEltProduct1Inner = product1VectorDouble
instance NumElt Int where
numEltAdd = intWidBranch2 @Int (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv
@@ -230,6 +302,8 @@ instance NumElt Int where
numEltNeg = intWidBranch1 @Int c_neg_i32 c_neg_i64
numEltAbs = intWidBranch1 @Int c_abs_i32 c_abs_i64
numEltSignum = intWidBranch1 @Int c_signum_i32 c_signum_i64
+ numEltSum1Inner = intWidBranchRed @Int c_mul_i32_sv c_sum1_i32 c_mul_i64_sv c_sum1_i64
+ numEltProduct1Inner = intWidBranchRed @Int c_mul_i32_sv c_product1_i32 c_mul_i64_sv c_product1_i64
instance NumElt CInt where
numEltAdd = intWidBranch2 @CInt (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv
@@ -238,3 +312,5 @@ instance NumElt CInt where
numEltNeg = intWidBranch1 @CInt c_neg_i32 c_neg_i64
numEltAbs = intWidBranch1 @CInt c_abs_i32 c_abs_i64
numEltSignum = intWidBranch1 @CInt c_signum_i32 c_signum_i64
+ numEltSum1Inner = intWidBranchRed @CInt c_mul_i32_sv c_sum1_i32 c_mul_i64_sv c_sum1_i64
+ numEltProduct1Inner = intWidBranchRed @CInt c_mul_i32_sv c_product1_i32 c_mul_i64_sv c_product1_i64