aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-26 22:44:44 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-26 22:44:44 +0100
commit829109ba73211394691d5789f35a23120feaf3f6 (patch)
treedf523326d0a7f6c6698e2b1aae7d177c17de792a
parent2177f3e9cdb8a1f10529f678d5dad9d8c7d60d86 (diff)
Benchmark and improve ixxFromLinear
-rw-r--r--bench/Main.hs21
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs34
2 files changed, 35 insertions, 20 deletions
diff --git a/bench/Main.hs b/bench/Main.hs
index 8df09a9..01d9a3f 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -17,19 +17,12 @@ import Text.Show (showListWith)
import Data.Array.Nested
import Data.Array.Nested.Mixed (Mixed(M_Primitive), mliftPrim, mliftPrim2, toPrimitive)
+import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Ranked (liftRanked1, liftRanked2)
import Data.Array.Strided.Arith.Internal qualified as Arith
import Data.Array.XArray (XArray(..))
-enableMisc :: Bool
-enableMisc = False
-
-bgroupIf :: Bool -> String -> [Benchmark] -> Benchmark
-bgroupIf True = bgroup
-bgroupIf False = \name _ -> bgroup name []
-
-
main :: IO ()
main = do
let enable = False
@@ -104,7 +97,7 @@ main_tests = defaultMain
in nf (\a -> RS.normalize a)
(RS.rev [0] (RS.rev [0] (RS.iota @Double n)))
]
- ,bgroupIf enableMisc "misc"
+ ,bgroup "misc"
[let n = 1000
k = 1000
in bgroup ("fusion [" ++ show k ++ "]*" ++ show n)
@@ -148,6 +141,16 @@ main_tests = defaultMain
| ki <- [1::Int ..5]]
]
]
+ ,bench "ixxFromLinear 10000x" $
+ let n = 10000
+ sh0 = SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% ZSX
+ in nf (\sh -> [ixxFromLinear sh i | i <- [1..n]]) sh0
+ ,bench "ixxFromLinear 1x" $
+ let sh0 = SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% ZSX
+ in nf (\sh -> ixxFromLinear sh 1234) sh0
+ ,bench "shxEnum" $
+ let sh0 = SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% ZSX
+ in nf (\sh -> shxEnum sh) sh0
]
]
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index 1b008e5..f127e3a 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -262,18 +262,30 @@ ixxZipWith _ ZIX ZIX = ZIX
ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js
ixxFromLinear :: IShX sh -> Int -> IIxX sh
-ixxFromLinear = \sh i -> case go sh i of
- (idx, 0) -> idx
- _ -> error $ "ixxFromLinear: out of range (" ++ show i ++
- " in array of shape " ++ show sh ++ ")"
+ixxFromLinear = \sh -> -- give this function arity 1 so that suffixes is shared when it's called many times
+ let suffixes = drop 1 (scanr (*) 1 (shxToList sh))
+ in \i ->
+ if i < 0 then outrange sh i else
+ case (sh, suffixes) of -- unfold first iteration of fromLin to do the range check
+ (ZSX, _) | i > 0 -> outrange sh i
+ | otherwise -> ZIX
+ (n :$% sh', suff : suffs) ->
+ let (q, r) = i `quotRem` suff
+ in if q >= fromSMayNat' n then outrange sh i else
+ q :.% fromLin sh' suffs r
+ _ -> error "impossible"
where
- -- returns (index in subarray, remaining index in enclosing array)
- go :: IShX sh -> Int -> (IIxX sh, Int)
- go ZSX i = (ZIX, i)
- go (n :$% sh) i =
- let (idx, i') = go sh i
- (upi, locali) = i' `quotRem` fromSMayNat' n
- in (locali :.% idx, upi)
+ fromLin :: IShX sh -> [Int] -> Int -> IxX sh Int
+ fromLin ZSX _ !_ = ZIX
+ fromLin (_ :$% sh') (suff : suffs) i =
+ let (q, r) = i `quotRem` suff -- suff == shrSize sh'
+ in q :.% fromLin sh' suffs r
+ fromLin _ _ _ = error "impossible"
+
+ {-# NOINLINE outrange #-}
+ outrange :: IShX sh -> Int -> a
+ outrange sh i = error $ "ixxFromLinear: out of range (" ++ show i ++
+ " in array of shape " ++ show sh ++ ")"
ixxToLinear :: IShX sh -> IIxX sh -> Int
ixxToLinear = \sh i -> fst (go sh i)