From bb1deceb98b4c7bfcd35372e0289566cb593d8a9 Mon Sep 17 00:00:00 2001 From: Mikolaj Konarski Date: Thu, 15 Jan 2026 21:30:26 +0100 Subject: Speed up sumFull from 36ms to 82 microseconds --- bench/Main.hs | 9 +++++++++ src/Data/Array/XArray.hs | 6 +----- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/bench/Main.hs b/bench/Main.hs index 8fe0fdc..185bef0 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -176,6 +176,15 @@ tests_compare = ,bench "sum Double [1e6]" $ nf (\a -> runScalar (rsumOuter1Prim a)) (riota @Double n) + ,bench "sumAll iota [1e6]" $ + nf (\a -> rsumAllPrim a) + (riota @Double n) + ,bench "sumAll rev1(iota) [1e6]" $ + nf (\a -> rsumAllPrim a) + (rrev1 $ riota @Double n) + ,bench "sumAll reshape(iota) [1e6]" $ + nf (\a -> rsumAllPrim a) + (rreshape (1 :$: n :$: 1 :$: ZSR) $ riota @Double n) ] ,bgroup "NumElt" [bench "sum(+) Double [1e6]" $ diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index 38ccee6..4f5bb08 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -269,11 +269,7 @@ transpose2 ssh1 ssh2 (XArray arr) = XArray (S.transpose (ssxIotaFrom ssh2 n1 ++ ssxIotaFrom ssh1 0) arr) sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a -sumFull _ (XArray arr) = - S.unScalar $ - liftO1 (numEltSum1Inner (SNat @0)) $ - S.fromVector [product (S.shapeL arr)] $ - S.toVector arr +sumFull ssx (XArray arr) = numEltSumFull (ssxRank ssx) $ fromO arr sumInner :: forall sh sh' a. (Storable a, NumElt a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a -- cgit v1.2.3-70-g09d2