aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-23 13:47:18 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-23 13:47:18 +0200
commit4c86a3a4231cecc5b7c31491398f43b4ba667eea (patch)
tree2e06f293f1350b7dd712bf1ad0eccb7b9d7686b4 /src/Data/Array
parent827a9ce7adc6cf1debc08d154e4c11b7b83bfdf0 (diff)
Fast sum
Also fast product, but that's currently unused
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Mixed.hs64
-rw-r--r--src/Data/Array/Nested/Internal.hs12
-rw-r--r--src/Data/Array/Nested/Internal/Arith.hs118
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Foreign.hs9
-rw-r--r--src/Data/Array/Nested/Internal/Arith/Lists.hs10
5 files changed, 178 insertions, 35 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 9a77ccb..7293914 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -14,6 +14,7 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
@@ -44,6 +45,8 @@ import GHC.TypeLits
import qualified GHC.TypeNats as TypeNats
import Unsafe.Coerce (unsafeCoerce)
+import Data.Array.Nested.Internal.Arith
+
-- | Evidence for the constraint @c a@.
data Dict c a where
@@ -120,6 +123,10 @@ foldListX f (x ::% xs) = f x <> foldListX f xs
lengthListX :: ListX sh f -> Int
lengthListX = getSum . foldListX (\_ -> Sum 1)
+snatLengthListX :: ListX sh f -> SNat (Rank sh)
+snatLengthListX ZX = SNat
+snatLengthListX (_ ::% l) | SNat <- snatLengthListX l = SNat
+
showListX :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS
showListX f l = showString "[" . go "" l . showString "]"
where
@@ -419,6 +426,26 @@ ssxIotaFrom :: Int -> StaticShX sh -> [Int]
ssxIotaFrom _ ZKX = []
ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh
+type Flatten sh = Flatten' 1 sh
+
+type family Flatten' acc sh where
+ Flatten' acc '[] = Just acc
+ Flatten' acc (Nothing : sh) = Nothing
+ Flatten' acc (Just n : sh) = Flatten' (acc * n) sh
+
+flattenSh :: IShX sh -> SMayNat Int SNat (Flatten sh)
+flattenSh = go (SNat @1)
+ where
+ go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh)
+ go acc ZSX = SKnown acc
+ go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh)
+ go acc (SKnown sn :$% sh) = go (mulSNat acc sn) sh
+
+ goUnknown :: Int -> IShX sh -> Int
+ goUnknown acc ZSX = acc
+ goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh
+ goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh
+
staticShapeFrom :: IShX sh -> StaticShX sh
staticShapeFrom ZSX = ZKX
staticShapeFrom (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% staticShapeFrom sh
@@ -511,6 +538,10 @@ type family AddMaybe n m where
plusSNat :: SNat n -> SNat m -> SNat (n + m)
plusSNat n m = TypeNats.withSomeSNat (TypeNats.fromSNat n + TypeNats.fromSNat m) unsafeCoerce
+-- This should be a function in base
+mulSNat :: SNat n -> SNat m -> SNat (n * m)
+mulSNat n m = TypeNats.withSomeSNat (TypeNats.fromSNat n * TypeNats.fromSNat m) unsafeCoerce
+
smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m)
smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m)
smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m)
@@ -719,17 +750,36 @@ transpose2 ssh1 ssh2 (XArray arr)
sumFull :: (Storable a, Num a) => XArray sh a -> a
sumFull (XArray arr) = S.sumA arr
-sumInner :: forall sh sh' a. (Storable a, Num a)
+sumInner :: forall sh sh' a. (Storable a, NumElt a)
=> StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a
-sumInner ssh ssh'
+sumInner ssh ssh' arr
| Refl <- lemAppNil @sh
- = rerank ssh ssh' ZKX (scalar . sumFull)
-
-sumOuter :: forall sh sh' a. (Storable a, Num a)
+ = let (_, sh') = shAppSplit (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
+ sh'F = flattenSh sh' :$% ZSX
+ ssh'F = staticShapeFrom sh'F
+
+ go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a
+ go (XArray arr')
+ | Refl <- lemRankApp ssh ssh'F
+ , let sn = snatLengthListX (let StaticShX l = ssh in l)
+ = XArray (numEltSum1Inner sn arr')
+
+ in go $
+ transpose2 ssh'F ssh $
+ reshapePartial ssh' ssh sh'F $
+ transpose2 ssh ssh' $
+ arr
+
+sumOuter :: forall sh sh' a. (Storable a, NumElt a)
=> StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a
-sumOuter ssh ssh'
+sumOuter ssh ssh' arr
| Refl <- lemAppNil @sh
- = sumInner ssh' ssh . transpose2 ssh ssh'
+ = let (sh, _) = shAppSplit (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
+ shF = flattenSh sh :$% ZSX
+ in sumInner ssh' (staticShapeFrom shF) $
+ transpose2 (staticShapeFrom shF) ssh' $
+ reshapePartial ssh ssh' shF $
+ arr
fromListOuter :: forall n sh a. Storable a
=> StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index f3f8f7d..118612f 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -884,13 +884,13 @@ mgenerate sh f = case X.enumShape sh of
mvecsWrite sh idx val vecs
mvecsFreeze sh vecs
-msumOuter1P :: forall sh n a. (Storable a, Num a)
+msumOuter1P :: forall sh n a. (Storable a, NumElt a)
=> Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
msumOuter1P (M_Primitive (n :$% sh) arr) =
let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX
in M_Primitive sh (X.sumOuter nssh (X.staticShapeFrom sh) arr)
-msumOuter1 :: forall sh n a. (Num a, PrimElt a)
+msumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
=> Mixed (n : sh) a -> Mixed sh a
msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive
@@ -1466,13 +1466,13 @@ rlift2 :: forall n1 n2 n3 a. Elt a
rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2)
rsumOuter1P :: forall n a.
- (Storable a, Num a)
+ (Storable a, NumElt a)
=> Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
rsumOuter1P (Ranked arr)
| Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
= Ranked (msumOuter1P arr)
-rsumOuter1 :: forall n a. (Num a, PrimElt a)
+rsumOuter1 :: forall n a. (NumElt a, PrimElt a)
=> Ranked (n + 1) a -> Ranked n a
rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive
@@ -1748,11 +1748,11 @@ slift2 :: forall sh1 sh2 sh3 a. Elt a
-> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a
slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (X.staticShapeFrom (shCvtSX sh3)) f arr1 arr2)
-ssumOuter1P :: forall sh n a. (Storable a, Num a)
+ssumOuter1P :: forall sh n a. (Storable a, NumElt a)
=> Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr)
-ssumOuter1 :: forall sh n a. (Num a, PrimElt a)
+ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
=> Shaped (n : sh) a -> Shaped sh a
ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive
diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Nested/Internal/Arith.hs
index 4312cd5..bd582b7 100644
--- a/src/Data/Array/Nested/Internal/Arith.hs
+++ b/src/Data/Array/Nested/Internal/Arith.hs
@@ -1,8 +1,11 @@
+{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Nested.Internal.Arith where
import Control.Monad (forM, guard)
@@ -25,24 +28,24 @@ import Data.Array.Nested.Internal.Arith.Foreign
import Data.Array.Nested.Internal.Arith.Lists
-mliftVEltwise1 :: Storable a
- => SNat n
- -> (VS.Vector a -> VS.Vector a)
- -> RS.Array n a -> RS.Array n a
-mliftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
+liftVEltwise1 :: Storable a
+ => SNat n
+ -> (VS.Vector a -> VS.Vector a)
+ -> RS.Array n a -> RS.Array n a
+liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec)))
| Just prefixSz <- stridesDense sh strides =
let vec' = f (VS.slice offset prefixSz vec)
in RS.A (RG.A sh (OI.T strides 0 vec'))
| otherwise = RS.fromVector sh (f (RS.toVector arr))
-mliftVEltwise2 :: Storable a
- => SNat n
- -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a)
- -> RS.Array n a -> RS.Array n a -> RS.Array n a
-mliftVEltwise2 SNat f
+liftVEltwise2 :: Storable a
+ => SNat n
+ -> (Either a (VS.Vector a) -> Either a (VS.Vector a) -> VS.Vector a)
+ -> RS.Array n a -> RS.Array n a -> RS.Array n a
+liftVEltwise2 SNat f
arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1)))
arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2)))
- | sh1 /= sh2 = error $ "mliftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2
+ | sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2
| product sh1 == 0 = arr1 -- if the arrays are empty, just return one of the empty inputs
| otherwise = case (stridesDense sh1 strides1, stridesDense sh2 strides2) of
(Just 1, Just 1) -> -- both are a (potentially replicated) scalar; just apply f to the scalars
@@ -62,11 +65,11 @@ mliftVEltwise2 SNat f
stridesDense :: [Int] -> [Int] -> Maybe Int
stridesDense sh _ | any (<= 0) sh = Just 0
stridesDense sh str =
- -- sort dimensions on their stride, ascending
- case sort (zip str sh) of
- [] -> Just 0
+ -- sort dimensions on their stride, ascending, dropping any zero strides
+ case dropWhile ((== 0) . fst) (sort (zip str sh)) of
+ [] -> Just 1
(1, n) : (unzip -> (str', sh')) -> checkCover n sh' str'
- _ -> error "Orthotope array's shape vector and stride vector have different lengths"
+ _ -> Nothing -- if the smallest stride is not 1, it will never be dense
where
-- Given size of currently densely covered region at beginning of the
-- array, the remaining shape vector and the corresponding remaining stride
@@ -77,6 +80,7 @@ stridesDense sh str =
checkCover block (n : sh') (s : str') = guard (s <= block) >> checkCover (max block (n * s)) sh' str'
checkCover _ _ _ = error "Orthotope array's shape vector and stride vector have different lengths"
+{-# NOINLINE vectorOp1 #-}
vectorOp1 :: forall a b. Storable a
=> (Ptr a -> Ptr b)
-> (Int64 -> Ptr b -> Ptr b -> IO ())
@@ -89,6 +93,7 @@ vectorOp1 ptrconv f v = unsafePerformIO $ do
VS.unsafeFreeze outv
-- | If two vectors are given, assumes that they have the same length.
+{-# NOINLINE vectorOp2 #-}
vectorOp2 :: forall a b. Storable a
=> (a -> b)
-> (Ptr a -> Ptr b)
@@ -127,6 +132,39 @@ vectorOp2 valconv ptrconv fss fsv fvs fvv = \cases
VS.unsafeFreeze outv
| otherwise -> error $ "vectorOp: unequal lengths: " ++ show (VS.length vx) ++ " /= " ++ show (VS.length vy)
+-- TODO: test all the weird cases of this function
+-- | Reduce along the inner dimension
+{-# NOINLINE vectorRedInnerOp #-}
+vectorRedInnerOp :: forall a b n. (Num a, Storable a)
+ => SNat n
+ -> (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr b -> b -> Ptr b -> IO ()) -- ^ scale by constant
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> Ptr b -> IO ()) -- ^ reduction kernel
+ -> RS.Array (n + 1) a -> RS.Array n a
+vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T strides offset vec)))
+ | null sh = error "unreachable"
+ | any (<= 0) sh = RS.A (RG.A (init sh) (OI.T (map (const 0) (init strides)) 0 VS.empty))
+ -- now the input array is nonempty
+ | last sh == 1 = RS.A (RG.A (init sh) (OI.T (init strides) offset vec))
+ | last strides == 0 =
+ liftVEltwise1 sn
+ (vectorOp1 id (\n pout px -> fscale n (ptrconv pout) (valconv (fromIntegral (last sh))) (ptrconv px)))
+ (RS.A (RG.A (init sh) (OI.T (init strides) offset vec)))
+ -- now there is useful work along the inner dimension
+ | otherwise =
+ let -- filter out zero-stride dimensions; the reduction kernel need not concern itself with those
+ (shF, stridesF) = unzip $ filter ((/= 0) . snd) (zip sh strides)
+ ndimsF = length shF
+ in unsafePerformIO $ do
+ outv <- VSM.unsafeNew (product (init shF))
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF ->
+ VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesF)) $ \pstridesF ->
+ VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pvec ->
+ fred (fromIntegral ndimsF) pshF pstridesF (ptrconv poutv) (ptrconv pvec)
+ RS.fromVector (init sh) <$> VS.unsafeFreeze outv
+
flipOp :: (Int64 -> Ptr a -> a -> Ptr a -> IO ())
-> Int64 -> Ptr a -> Ptr a -> a -> IO ()
flipOp f n out v s = f n out s v
@@ -138,6 +176,8 @@ class NumElt a where
numEltNeg :: SNat n -> RS.Array n a -> RS.Array n a
numEltAbs :: SNat n -> RS.Array n a -> RS.Array n a
numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a
+ numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
+ numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
@@ -151,7 +191,7 @@ $(fmap concat . forM typesList $ \arithtype -> do
c_vv = varE $ mkName (cnamebase ++ "_vv")
sequence [SigD name <$>
[t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> mliftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]
+ ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]
return $ FunD name [Clause [] (NormalB body) []]])
$(fmap concat . forM typesList $ \arithtype -> do
@@ -161,7 +201,18 @@ $(fmap concat . forM typesList $ \arithtype -> do
c_op = varE $ mkName ("c_" ++ auoName arithop ++ "_" ++ atCName arithtype)
sequence [SigD name <$>
[t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> mliftVEltwise1 sn (vectorOp1 id $c_op) |]
+ ,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]
+ return $ FunD name [Clause [] (NormalB body) []]])
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ fmap concat . forM redopsList $ \redop -> do
+ let name = mkName (aroName redop ++ "Vector" ++ nameBase (atType arithtype))
+ c_op = varE $ mkName ("c_" ++ aroName redop ++ "_" ++ atCName arithtype)
+ c_scale_op = varE $ mkName ("c_mul_" ++ atCName arithtype ++ "_sv")
+ sequence [SigD name <$>
+ [t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |]
+ ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |]
return $ FunD name [Clause [] (NormalB body) []]])
-- This branch is ostensibly a runtime branch, but will (hopefully) be
@@ -171,8 +222,8 @@ intWidBranch1 :: forall i n. (FiniteBits i, Storable i)
-> (Int64 -> Ptr Int64 -> Ptr Int64 -> IO ())
-> (SNat n -> RS.Array n i -> RS.Array n i)
intWidBranch1 f32 f64 sn
- | finiteBitSize (undefined :: i) == 32 = mliftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32)
- | finiteBitSize (undefined :: i) == 64 = mliftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64)
+ | finiteBitSize (undefined :: i) == 32 = liftVEltwise1 sn (vectorOp1 @i @Int32 castPtr f32)
+ | finiteBitSize (undefined :: i) == 64 = liftVEltwise1 sn (vectorOp1 @i @Int64 castPtr f64)
| otherwise = error "Unsupported Int width"
intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i)
@@ -187,8 +238,21 @@ intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i)
-> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- vv
-> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i)
intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn
- | finiteBitSize (undefined :: i) == 32 = mliftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32)
- | finiteBitSize (undefined :: i) == 64 = mliftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64)
+ | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32)
+ | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64)
+ | otherwise = error "Unsupported Int width"
+
+intWidBranchRed :: forall i n. (FiniteBits i, Storable i, Integral i)
+ => -- int32
+ (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel
+ -- int64
+ -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel
+ -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i)
+intWidBranchRed fsc32 fred32 fsc64 fred64 sn
+ | finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32
+ | finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64
| otherwise = error "Unsupported Int width"
instance NumElt Int32 where
@@ -198,6 +262,8 @@ instance NumElt Int32 where
numEltNeg = negVectorInt32
numEltAbs = absVectorInt32
numEltSignum = signumVectorInt32
+ numEltSum1Inner = sum1VectorInt32
+ numEltProduct1Inner = product1VectorInt32
instance NumElt Int64 where
numEltAdd = addVectorInt64
@@ -206,6 +272,8 @@ instance NumElt Int64 where
numEltNeg = negVectorInt64
numEltAbs = absVectorInt64
numEltSignum = signumVectorInt64
+ numEltSum1Inner = sum1VectorInt64
+ numEltProduct1Inner = product1VectorInt64
instance NumElt Float where
numEltAdd = addVectorFloat
@@ -214,6 +282,8 @@ instance NumElt Float where
numEltNeg = negVectorFloat
numEltAbs = absVectorFloat
numEltSignum = signumVectorFloat
+ numEltSum1Inner = sum1VectorFloat
+ numEltProduct1Inner = product1VectorFloat
instance NumElt Double where
numEltAdd = addVectorDouble
@@ -222,6 +292,8 @@ instance NumElt Double where
numEltNeg = negVectorDouble
numEltAbs = absVectorDouble
numEltSignum = signumVectorDouble
+ numEltSum1Inner = sum1VectorDouble
+ numEltProduct1Inner = product1VectorDouble
instance NumElt Int where
numEltAdd = intWidBranch2 @Int (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv
@@ -230,6 +302,8 @@ instance NumElt Int where
numEltNeg = intWidBranch1 @Int c_neg_i32 c_neg_i64
numEltAbs = intWidBranch1 @Int c_abs_i32 c_abs_i64
numEltSignum = intWidBranch1 @Int c_signum_i32 c_signum_i64
+ numEltSum1Inner = intWidBranchRed @Int c_mul_i32_sv c_sum1_i32 c_mul_i64_sv c_sum1_i64
+ numEltProduct1Inner = intWidBranchRed @Int c_mul_i32_sv c_product1_i32 c_mul_i64_sv c_product1_i64
instance NumElt CInt where
numEltAdd = intWidBranch2 @CInt (+) c_add_i32_sv (flipOp c_add_i32_sv) c_add_i32_vv c_add_i64_sv (flipOp c_add_i64_sv) c_add_i64_vv
@@ -238,3 +312,5 @@ instance NumElt CInt where
numEltNeg = intWidBranch1 @CInt c_neg_i32 c_neg_i64
numEltAbs = intWidBranch1 @CInt c_abs_i32 c_abs_i64
numEltSignum = intWidBranch1 @CInt c_signum_i32 c_signum_i64
+ numEltSum1Inner = intWidBranchRed @CInt c_mul_i32_sv c_sum1_i32 c_mul_i64_sv c_sum1_i64
+ numEltProduct1Inner = intWidBranchRed @CInt c_mul_i32_sv c_product1_i32 c_mul_i64_sv c_product1_i64
diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Nested/Internal/Arith/Foreign.hs
index dbd9ddc..f84b1c5 100644
--- a/src/Data/Array/Nested/Internal/Arith/Foreign.hs
+++ b/src/Data/Array/Nested/Internal/Arith/Foreign.hs
@@ -22,7 +22,7 @@ $(fmap concat . forM typesList $ \arithtype -> do
[t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
,guard (aboComm arithop == NonComm) >>
Just (ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base ++ "_vs") (mkName ("c_" ++ base ++ "_vs")) <$>
- [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
+ [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
])
$(fmap concat . forM typesList $ \arithtype -> do
@@ -31,3 +31,10 @@ $(fmap concat . forM typesList $ \arithtype -> do
let base = auoName arithop ++ "_" ++ atCName arithtype
ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
[t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+
+$(fmap concat . forM typesList $ \arithtype -> do
+ let ttyp = conT (atType arithtype)
+ forM redopsList $ \redop -> do
+ let base = aroName redop ++ "_" ++ atCName arithtype
+ ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
+ [t| Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Nested/Internal/Arith/Lists.hs
index 1b29770..78fe24a 100644
--- a/src/Data/Array/Nested/Internal/Arith/Lists.hs
+++ b/src/Data/Array/Nested/Internal/Arith/Lists.hs
@@ -45,3 +45,13 @@ unopsList =
,ArithUOp "abs"
,ArithUOp "signum"
]
+
+data ArithRedOp = ArithRedOp
+ { aroName :: String -- "sum"
+ }
+
+redopsList :: [ArithRedOp]
+redopsList =
+ [ArithRedOp "sum1"
+ ,ArithRedOp "product1"
+ ]