aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cbits/arith.c49
-rw-r--r--cbits/arith_lists.h4
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs150
-rw-r--r--src/Data/Array/Mixed/Internal/Arith/Foreign.hs10
-rw-r--r--src/Data/Array/Mixed/XArray.hs4
-rw-r--r--src/Data/Array/Nested.hs6
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs3
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs3
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs3
9 files changed, 184 insertions, 48 deletions
diff --git a/cbits/arith.c b/cbits/arith.c
index fb993c8..5d74c01 100644
--- a/cbits/arith.c
+++ b/cbits/arith.c
@@ -177,6 +177,33 @@ static double log1pexp_double(double x) { LOG1PEXP_IMPL(x); }
// - all strides are >0
// - shape is everywhere >0
// - rank is >= 1
+#define REDUCEFULL_OP(name, op, typ) \
+ typ oxarop_op_ ## name ## _ ## typ(i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \
+ typ res = 0; \
+ if (strides[rank - 1] == 1) { \
+ TARRAY_WALK_NOINNER(again1, rank, shape, strides, { \
+ typ accum = arr[arrlinidx]; \
+ for (i64 i = 1; i < shape[rank - 1]; i++) { \
+ accum = accum op arr[arrlinidx + i]; \
+ } \
+ res = res op accum; \
+ }); \
+ } else { \
+ TARRAY_WALK_NOINNER(again2, rank, shape, strides, { \
+ typ accum = arr[arrlinidx]; \
+ for (i64 i = 1; i < shape[rank - 1]; i++) { \
+ accum = accum op arr[arrlinidx + strides[rank - 1] * i]; \
+ } \
+ res = res op accum; \
+ }); \
+ } \
+ return res; \
+ }
+
+// preconditions
+// - all strides are >0
+// - shape is everywhere >0
+// - rank is >= 1
// Writes extreme index to outidx. If 'cmp' is '<', computes minindex ("argmin"); if '>', maxindex.
#define EXTREMUM_OP(name, cmp, typ) \
void oxarop_extremum_ ## name ## _ ## typ(i64 *outidx, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \
@@ -394,11 +421,20 @@ enum redop_tag_t {
#define LIST_REDOP(name, id, _)
};
-#define ENTRY_REDUCE_OPS(typ) \
- void oxarop_reduce_ ## typ(enum redop_tag_t tag, i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \
+#define ENTRY_REDUCE1_OPS(typ) \
+ void oxarop_reduce1_ ## typ(enum redop_tag_t tag, i64 rank, const i64 *shape, const i64 *strides, typ *out, const typ *arr) { \
+ switch (tag) { \
+ case RO_SUM: oxarop_op_sum1_ ## typ(rank, shape, strides, out, arr); break; \
+ case RO_PRODUCT: oxarop_op_product1_ ## typ(rank, shape, strides, out, arr); break; \
+ default: wrong_op("reduce", tag); \
+ } \
+ }
+
+#define ENTRY_REDUCEFULL_OPS(typ) \
+ typ oxarop_reducefull_ ## typ(enum redop_tag_t tag, i64 rank, const i64 *shape, const i64 *strides, const typ *arr) { \
switch (tag) { \
- case RO_SUM1: oxarop_op_sum1_ ## typ(rank, shape, strides, out, arr); break; \
- case RO_PRODUCT1: oxarop_op_product1_ ## typ(rank, shape, strides, out, arr); break; \
+ case RO_SUM: return oxarop_op_sumfull_ ## typ(rank, shape, strides, arr); \
+ case RO_PRODUCT: return oxarop_op_productfull_ ## typ(rank, shape, strides, arr); \
default: wrong_op("reduce", tag); \
} \
}
@@ -420,9 +456,12 @@ enum redop_tag_t {
UNARY_OP(signum, GEN_SIGNUM, typ) \
REDUCE1_OP(sum1, +, typ) \
REDUCE1_OP(product1, *, typ) \
+ REDUCEFULL_OP(sumfull, +, typ) \
+ REDUCEFULL_OP(productfull, *, typ) \
ENTRY_BINARY_OPS(typ) \
ENTRY_UNARY_OPS(typ) \
- ENTRY_REDUCE_OPS(typ) \
+ ENTRY_REDUCE1_OPS(typ) \
+ ENTRY_REDUCEFULL_OPS(typ) \
EXTREMUM_OP(min, <, typ) \
EXTREMUM_OP(max, >, typ) \
DOTPROD_STRIDED_OP(typ)
diff --git a/cbits/arith_lists.h b/cbits/arith_lists.h
index 2e37575..58de65a 100644
--- a/cbits/arith_lists.h
+++ b/cbits/arith_lists.h
@@ -31,5 +31,5 @@ LIST_FUNOP(FU_EXPM1, 18,)
LIST_FUNOP(FU_LOG1PEXP, 19,)
LIST_FUNOP(FU_LOG1MEXP, 20,)
-LIST_REDOP(RO_SUM1, 1,)
-LIST_REDOP(RO_PRODUCT1, 2,)
+LIST_REDOP(RO_SUM, 1,)
+LIST_REDOP(RO_PRODUCT, 2,)
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index 579c0da..d547084 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -33,6 +33,9 @@ import Data.Array.Mixed.Internal.Arith.Foreign
import Data.Array.Mixed.Internal.Arith.Lists
+-- TODO: need to sort strides for reduction-like functions so that the C inner-loop specialisation has some chance of working even after transposition
+
+
-- TODO: test all the cases of this thing with various input strides
liftVEltwise1 :: (Storable a, Storable b)
=> SNat n
@@ -186,7 +189,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride
-- precondition that there are no such dimensions in its input).
replDims = map (== 0) strides
-- filter out replicated dimensions
- (shF, stridesF) = unzip $ map fst $ filter (not . snd) (zip (zip sh strides) replDims)
+ (shF, stridesF) = unzip [(n, s) | (n, s, False) <- zip3 sh strides replDims]
-- replace replicated dimensions with ones
shOnes = zipWith (\n repl -> if repl then 1 else n) sh replDims
ndimsF = length shF -- > 0, otherwise `last strides == 0`
@@ -213,6 +216,48 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred (RS.A (RG.A sh (OI.T stride
. RS.fromVector @_ @lenFm1 (init shF) -- the partially-reversed result array
<$> VS.unsafeFreeze outvR
+-- TODO: test handling of negative strides
+-- | Reduce full array
+{-# NOINLINE vectorRedFullOp #-}
+vectorRedFullOp :: forall a b n. (Num a, Storable a)
+ => SNat n
+ -> (a -> Int -> a)
+ -> (b -> a)
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr b -> IO b) -- ^ reduction kernel
+ -> RS.Array n a -> a
+vectorRedFullOp _ scaleval valbackconv ptrconv fred (RS.A (RG.A sh (OI.T strides offset vec)))
+ | null sh = vec VS.! offset -- 0D array has one element
+ | any (<= 0) sh = 0
+ -- now the input array is nonempty
+ | all (== 0) strides = fromIntegral (product sh) * vec VS.! offset
+ -- now there is at least one non-replicated dimension
+ | otherwise =
+ let -- replicated dimensions: dimensions with zero stride. The reduction
+ -- kernel need not concern itself with those (and in fact has a
+ -- precondition that there are no such dimensions in its input).
+ replDims = map (== 0) strides
+ -- filter out replicated dimensions
+ (shF, stridesF) = unzip [(n, s) | (n, s, False) <- zip3 sh strides replDims]
+ ndimsF = length shF -- > 0, otherwise `all (== 0) strides`
+ -- we should scale up the output this many times to account for the replicated dimensions
+ multiplier = product [n | (n, True) <- zip sh replDims]
+
+ -- reversed dimensions: dimensions with negative stride. Reversal is
+ -- irrelevant for a reduction, and indeed the kernel has a
+ -- precondition that there are no such dimensions.
+ revDims = map (< 0) stridesF
+ stridesR = map abs stridesF
+ offsetR = offset + sum (zipWith3 (\rev n s -> if rev then (n - 1) * s else 0) revDims shF stridesF)
+ -- The *R values give an array with strides all > 0, hence the
+ -- left-most element is at offsetR.
+ in unsafePerformIO $ do
+ VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral shF)) $ \pshF ->
+ VS.unsafeWith (VS.fromListN ndimsF (map fromIntegral stridesR)) $ \pstridesR ->
+ VS.unsafeWith (VS.slice offsetR (VS.length vec - offsetR) vec) $ \pvecR ->
+ (`scaleval` fromIntegral multiplier) . valbackconv
+ <$> fred (fromIntegral ndimsF) pshF pstridesR (ptrconv pvecR)
+
-- TODO: test this function
-- | Find extremum (minindex ("argmin") or maxindex) in full array
{-# NOINLINE vectorExtremumOp #-}
@@ -232,7 +277,7 @@ vectorExtremumOp ptrconv fextrem (RS.A (RG.A sh (OI.T strides offset vec)))
-- precondition that there are no such dimensions in its input).
replDims = map (== 0) strides
-- filter out replicated dimensions
- (shF, stridesF) = unzip $ map fst $ filter (not . snd) (zip (zip sh strides) replDims)
+ (shF, stridesF) = unzip [(n, s) | (n, s, False) <- zip3 sh strides replDims]
ndimsF = length shF -- > 0, because not all strides were <=0
-- un-reverse reversed dimensions
@@ -380,16 +425,29 @@ $(fmap concat . forM floatTypesList $ \arithtype -> do
,do body <- [| \sn -> liftVEltwise1 sn (vectorOp1 id $c_op) |]
return $ FunD name [Clause [] (NormalB body) []]])
+mulWithInt :: Num a => a -> Int -> a
+mulWithInt a i = a * fromIntegral i
+
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
fmap concat . forM [minBound..maxBound] $ \arithop -> do
- let name = mkName (aroName arithop ++ "Vector" ++ nameBase (atType arithtype))
- c_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))
+ let scaleVar = case arithop of
+ RO_SUM -> varE 'mulWithInt
+ RO_PRODUCT -> varE '(^)
+ let name1 = mkName (aroName arithop ++ "1Vector" ++ nameBase (atType arithtype))
+ namefull = mkName (aroName arithop ++ "FullVector" ++ nameBase (atType arithtype))
+ c_op1 = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))
+ c_opfull = varE (mkName ("c_reducefull_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum arithop)))
c_scale_op = varE (mkName ("c_binary_" ++ atCName arithtype ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum BO_MUL)))
- sequence [SigD name <$>
+ sequence [SigD name1 <$>
[t| forall n. SNat n -> RS.Array (n + 1) $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op |]
- return $ FunD name [Clause [] (NormalB body) []]])
+ ,do body <- [| \sn -> vectorRedInnerOp sn id id $c_scale_op $c_op1 |]
+ return $ FunD name1 [Clause [] (NormalB body) []]
+ ,SigD namefull <$>
+ [t| forall n. SNat n -> RS.Array n $ttyp -> $ttyp |]
+ ,do body <- [| \sn -> vectorRedFullOp sn $scaleVar id id $c_opfull |]
+ return $ FunD namefull [Clause [] (NormalB body) []]
+ ])
$(fmap concat . forM typesList $ \arithtype ->
fmap concat . forM ["min", "max"] $ \fname -> do
@@ -406,7 +464,7 @@ $(fmap concat . forM typesList $ \arithtype -> do
name = mkName ("dotprodVector" ++ nameBase (atType arithtype))
c_op = varE (mkName ("c_dotprod_" ++ atCName arithtype))
c_op_strided = varE (mkName ("c_dotprod_" ++ atCName arithtype ++ "_strided"))
- c_red_op = varE (mkName ("c_reduce_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM1)))
+ c_red_op = varE (mkName ("c_reduce1_" ++ atCName arithtype)) `appE` litE (integerL (fromIntegral (aroEnum RO_SUM)))
sequence [SigD name <$>
[t| RS.Array 1 $ttyp -> RS.Array 1 $ttyp -> $ttyp |]
,do body <- [| vectorDotprodOp id id $c_red_op $c_op $c_op_strided |]
@@ -439,19 +497,31 @@ intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn
| finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64)
| otherwise = error "Unsupported Int width"
-intWidBranchRed :: forall i n. (FiniteBits i, Storable i, Integral i)
- => -- int32
- (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel
- -- int64
- -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel
- -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i)
-intWidBranchRed fsc32 fred32 fsc64 fred64 sn
+intWidBranchRed1 :: forall i n. (FiniteBits i, Storable i, Integral i)
+ => -- int32
+ (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- ^ scale by constant
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- ^ reduction kernel
+ -- int64
+ -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- ^ scale by constant
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- ^ reduction kernel
+ -> (SNat n -> RS.Array (n + 1) i -> RS.Array n i)
+intWidBranchRed1 fsc32 fred32 fsc64 fred64 sn
| finiteBitSize (undefined :: i) == 32 = vectorRedInnerOp @i @Int32 sn fromIntegral castPtr fsc32 fred32
| finiteBitSize (undefined :: i) == 64 = vectorRedInnerOp @i @Int64 sn fromIntegral castPtr fsc64 fred64
| otherwise = error "Unsupported Int width"
+intWidBranchRedFull :: forall i n. (FiniteBits i, Storable i, Integral i)
+ => (i -> Int -> i) -- ^ scale op
+ -- int32
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO Int32) -- ^ reduction kernel
+ -- int64
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO Int64) -- ^ reduction kernel
+ -> (SNat n -> RS.Array n i -> i)
+intWidBranchRedFull fsc fred32 fred64 sn
+ | finiteBitSize (undefined :: i) == 32 = vectorRedFullOp @i @Int32 sn fsc fromIntegral castPtr fred32
+ | finiteBitSize (undefined :: i) == 64 = vectorRedFullOp @i @Int64 sn fsc fromIntegral castPtr fred64
+ | otherwise = error "Unsupported Int width"
+
intWidBranchExtr :: forall i n. (FiniteBits i, Storable i, Integral i)
=> -- int32
(Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- ^ extremum kernel
@@ -487,6 +557,8 @@ class NumElt a where
numEltSignum :: SNat n -> RS.Array n a -> RS.Array n a
numEltSum1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
numEltProduct1Inner :: SNat n -> RS.Array (n + 1) a -> RS.Array n a
+ numEltSumFull :: SNat n -> RS.Array n a -> a
+ numEltProductFull :: SNat n -> RS.Array n a -> a
numEltMinIndex :: RS.Array n a -> [Int]
numEltMaxIndex :: RS.Array n a -> [Int]
numEltDotprod :: RS.Array 1 a -> RS.Array 1 a -> a
@@ -500,6 +572,8 @@ instance NumElt Int32 where
numEltSignum = signumVectorInt32
numEltSum1Inner = sum1VectorInt32
numEltProduct1Inner = product1VectorInt32
+ numEltSumFull = sumFullVectorInt32
+ numEltProductFull = productFullVectorInt32
numEltMinIndex = minindexVectorInt32
numEltMaxIndex = maxindexVectorInt32
numEltDotprod = dotprodVectorInt32
@@ -513,6 +587,8 @@ instance NumElt Int64 where
numEltSignum = signumVectorInt64
numEltSum1Inner = sum1VectorInt64
numEltProduct1Inner = product1VectorInt64
+ numEltSumFull = sumFullVectorInt64
+ numEltProductFull = productFullVectorInt64
numEltMinIndex = minindexVectorInt64
numEltMaxIndex = maxindexVectorInt64
numEltDotprod = dotprodVectorInt64
@@ -526,6 +602,8 @@ instance NumElt Float where
numEltSignum = signumVectorFloat
numEltSum1Inner = sum1VectorFloat
numEltProduct1Inner = product1VectorFloat
+ numEltSumFull = sumFullVectorFloat
+ numEltProductFull = productFullVectorFloat
numEltMinIndex = minindexVectorFloat
numEltMaxIndex = maxindexVectorFloat
numEltDotprod = dotprodVectorFloat
@@ -539,6 +617,8 @@ instance NumElt Double where
numEltSignum = signumVectorDouble
numEltSum1Inner = sum1VectorDouble
numEltProduct1Inner = product1VectorDouble
+ numEltSumFull = sumFullVectorDouble
+ numEltProductFull = productFullVectorDouble
numEltMinIndex = minindexVectorDouble
numEltMaxIndex = maxindexVectorDouble
numEltDotprod = dotprodVectorDouble
@@ -556,16 +636,18 @@ instance NumElt Int where
numEltNeg = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG))
numEltAbs = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS))
numEltSignum = intWidBranch1 @Int (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM))
- numEltSum1Inner = intWidBranchRed @Int
- (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1))
- (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1))
- numEltProduct1Inner = intWidBranchRed @Int
- (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))
- (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
+ numEltSum1Inner = intWidBranchRed1 @Int
+ (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM))
+ numEltProduct1Inner = intWidBranchRed1 @Int
+ (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_PRODUCT))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_PRODUCT))
+ numEltSumFull = intWidBranchRedFull @Int (*) (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM))
+ numEltProductFull = intWidBranchRedFull @Int (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT))
numEltMinIndex = intWidBranchExtr @Int c_extremum_min_i32 c_extremum_min_i64
numEltMaxIndex = intWidBranchExtr @Int c_extremum_max_i32 c_extremum_max_i64
- numEltDotprod = intWidBranchDotprod @Int (c_reduce_i32 (aroEnum RO_SUM1)) c_dotprod_i32 c_dotprod_i32_strided
- (c_reduce_i64 (aroEnum RO_SUM1)) c_dotprod_i64 c_dotprod_i64_strided
+ numEltDotprod = intWidBranchDotprod @Int (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprod_i32 c_dotprod_i32_strided
+ (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprod_i64 c_dotprod_i64_strided
instance NumElt CInt where
numEltAdd = intWidBranch2 @CInt (+)
@@ -580,16 +662,18 @@ instance NumElt CInt where
numEltNeg = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_NEG)) (c_unary_i64 (auoEnum UO_NEG))
numEltAbs = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_ABS)) (c_unary_i64 (auoEnum UO_ABS))
numEltSignum = intWidBranch1 @CInt (c_unary_i32 (auoEnum UO_SIGNUM)) (c_unary_i64 (auoEnum UO_SIGNUM))
- numEltSum1Inner = intWidBranchRed @CInt
- (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_SUM1))
- (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_SUM1))
- numEltProduct1Inner = intWidBranchRed @CInt
- (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce_i32 (aroEnum RO_PRODUCT1))
- (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce_i64 (aroEnum RO_PRODUCT1))
+ numEltSum1Inner = intWidBranchRed1 @CInt
+ (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_SUM))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_SUM))
+ numEltProduct1Inner = intWidBranchRed1 @CInt
+ (c_binary_i32_sv (aboEnum BO_MUL)) (c_reduce1_i32 (aroEnum RO_PRODUCT))
+ (c_binary_i64_sv (aboEnum BO_MUL)) (c_reduce1_i64 (aroEnum RO_PRODUCT))
+ numEltSumFull = intWidBranchRedFull @CInt mulWithInt (c_reducefull_i32 (aroEnum RO_SUM)) (c_reducefull_i64 (aroEnum RO_SUM))
+ numEltProductFull = intWidBranchRedFull @CInt (^) (c_reducefull_i32 (aroEnum RO_PRODUCT)) (c_reducefull_i64 (aroEnum RO_PRODUCT))
numEltMinIndex = intWidBranchExtr @CInt c_extremum_min_i32 c_extremum_min_i64
numEltMaxIndex = intWidBranchExtr @CInt c_extremum_max_i32 c_extremum_max_i64
- numEltDotprod = intWidBranchDotprod @CInt (c_reduce_i32 (aroEnum RO_SUM1)) c_dotprod_i32 c_dotprod_i32_strided
- (c_reduce_i64 (aroEnum RO_SUM1)) c_dotprod_i64 c_dotprod_i64_strided
+ numEltDotprod = intWidBranchDotprod @CInt (c_reduce1_i32 (aroEnum RO_SUM)) c_dotprod_i32 c_dotprod_i32_strided
+ (c_reduce1_i64 (aroEnum RO_SUM)) c_dotprod_i64 c_dotprod_i64_strided
class FloatElt a where
floatEltDiv :: SNat n -> RS.Array n a -> RS.Array n a -> RS.Array n a
diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
index a406dab..ca96093 100644
--- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
+++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
@@ -49,9 +49,13 @@ $(fmap concat . forM floatTypesList $ \arithtype -> do
$(fmap concat . forM typesList $ \arithtype -> do
let ttyp = conT (atType arithtype)
- let base = "reduce_" ++ atCName arithtype
- pure . ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base) (mkName ("c_" ++ base)) <$>
- [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
+ let base1 = "reduce1_" ++ atCName arithtype
+ basefull = "reducefull_" ++ atCName arithtype
+ sequence
+ [ForeignD . ImportF CCall Unsafe ("oxarop_" ++ base1) (mkName ("c_" ++ base1)) <$>
+ [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO () |]
+ ,ForeignD . ImportF CCall Unsafe ("oxarop_" ++ basefull) (mkName ("c_" ++ basefull)) <$>
+ [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |]])
$(fmap concat . forM typesList $ \arithtype ->
fmap concat . forM ["min", "max"] $ \fname -> do
diff --git a/src/Data/Array/Mixed/XArray.hs b/src/Data/Array/Mixed/XArray.hs
index 08295cd..fa753bb 100644
--- a/src/Data/Array/Mixed/XArray.hs
+++ b/src/Data/Array/Mixed/XArray.hs
@@ -240,8 +240,8 @@ transpose2 ssh1 ssh2 (XArray arr)
, let n1 = ssxLength ssh1
= XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr)
-sumFull :: (Storable a, NumElt a) => XArray sh a -> a
-sumFull (XArray arr) =
+sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a
+sumFull _ (XArray arr) =
S.unScalar $
numEltSum1Inner (SNat @0) $
S.fromVector [product (S.shapeL arr)] $
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 3a60305..53417bd 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -6,7 +6,7 @@ module Data.Array.Nested (
ListR(ZR, (:::)),
IxR(.., ZIR, (:.:)), IIxR,
ShR(.., ZSR, (:$:)), IShR,
- rshape, rrank, rsize, rindex, rindexPartial, rgenerate, rsumOuter1,
+ rshape, rrank, rsize, rindex, rindexPartial, rgenerate, rsumOuter1, rsumAllPrim,
rtranspose, rappend, rconcat, rscalar, rfromVector, rtoVector, runScalar,
rrerank,
rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1,
@@ -26,7 +26,7 @@ module Data.Array.Nested (
ListS(ZS, (::$)),
IxS(.., ZIS, (:.$)), IIxS,
ShS(.., ZSS, (:$$)), KnownShS(..),
- sshape, srank, ssize, sindex, sindexPartial, sgenerate, ssumOuter1,
+ sshape, srank, ssize, sindex, sindexPartial, sgenerate, ssumOuter1, ssumAllPrim,
stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar,
-- TODO: sconcat? What should its type be?
srerank,
@@ -48,7 +48,7 @@ module Data.Array.Nested (
ShX(.., ZSX, (:$%)), KnownShX(..),
StaticShX(.., ZKX, (:!%)),
SMayNat(..),
- mshape, mrank, msize, mindex, mindexPartial, mgenerate, msumOuter1,
+ mshape, mrank, msize, mindex, mindexPartial, mgenerate, msumOuter1, msumAllPrim,
mtranspose, mappend, mconcat, mscalar, mfromVector, mtoVector, munScalar,
mrerank,
mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1,
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index 594383c..215313e 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -713,6 +713,9 @@ msumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
=> Mixed (n : sh) a -> Mixed sh a
msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive
+msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a
+msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShape sh) arr
+
mappend :: forall n m sh a. Elt a
=> Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a
mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index bd37e7a..74b2186 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -282,6 +282,9 @@ rsumOuter1 :: forall n a. (NumElt a, PrimElt a)
=> Ranked (n + 1) a -> Ranked n a
rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive
+rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a
+rsumAllPrim (Ranked arr) = msumAllPrim arr
+
rtranspose :: forall n a. Elt a => PermR -> Ranked n a -> Ranked n a
rtranspose perm arr
| sn@SNat <- rrank arr
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index f50ed28..ea979fa 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -277,6 +277,9 @@ ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
=> Shaped (n : sh) a -> Shaped sh a
ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive
+ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a
+ssumAllPrim (Shaped arr) = msumAllPrim arr
+
stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a)
=> Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a
stranspose perm sarr@(Shaped arr)