aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bench/Main.hs9
-rw-r--r--src/Data/Array/XArray.hs6
2 files changed, 10 insertions, 5 deletions
diff --git a/bench/Main.hs b/bench/Main.hs
index 8fe0fdc..185bef0 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -176,6 +176,15 @@ tests_compare =
,bench "sum Double [1e6]" $
nf (\a -> runScalar (rsumOuter1Prim a))
(riota @Double n)
+ ,bench "sumAll iota [1e6]" $
+ nf (\a -> rsumAllPrim a)
+ (riota @Double n)
+ ,bench "sumAll rev1(iota) [1e6]" $
+ nf (\a -> rsumAllPrim a)
+ (rrev1 $ riota @Double n)
+ ,bench "sumAll reshape(iota) [1e6]" $
+ nf (\a -> rsumAllPrim a)
+ (rreshape (1 :$: n :$: 1 :$: ZSR) $ riota @Double n)
]
,bgroup "NumElt"
[bench "sum(+) Double [1e6]" $
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs
index 38ccee6..4f5bb08 100644
--- a/src/Data/Array/XArray.hs
+++ b/src/Data/Array/XArray.hs
@@ -269,11 +269,7 @@ transpose2 ssh1 ssh2 (XArray arr)
= XArray (S.transpose (ssxIotaFrom ssh2 n1 ++ ssxIotaFrom ssh1 0) arr)
sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a
-sumFull _ (XArray arr) =
- S.unScalar $
- liftO1 (numEltSum1Inner (SNat @0)) $
- S.fromVector [product (S.shapeL arr)] $
- S.toVector arr
+sumFull ssx (XArray arr) = numEltSumFull (ssxRank ssx) $ fromO arr
sumInner :: forall sh sh' a. (Storable a, NumElt a)
=> StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a