diff options
| -rw-r--r-- | bench/Main.hs | 21 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 34 |
2 files changed, 35 insertions, 20 deletions
diff --git a/bench/Main.hs b/bench/Main.hs index 8df09a9..01d9a3f 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -17,19 +17,12 @@ import Text.Show (showListWith) import Data.Array.Nested import Data.Array.Nested.Mixed (Mixed(M_Primitive), mliftPrim, mliftPrim2, toPrimitive) +import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Ranked (liftRanked1, liftRanked2) import Data.Array.Strided.Arith.Internal qualified as Arith import Data.Array.XArray (XArray(..)) -enableMisc :: Bool -enableMisc = False - -bgroupIf :: Bool -> String -> [Benchmark] -> Benchmark -bgroupIf True = bgroup -bgroupIf False = \name _ -> bgroup name [] - - main :: IO () main = do let enable = False @@ -104,7 +97,7 @@ main_tests = defaultMain in nf (\a -> RS.normalize a) (RS.rev [0] (RS.rev [0] (RS.iota @Double n))) ] - ,bgroupIf enableMisc "misc" + ,bgroup "misc" [let n = 1000 k = 1000 in bgroup ("fusion [" ++ show k ++ "]*" ++ show n) @@ -148,6 +141,16 @@ main_tests = defaultMain | ki <- [1::Int ..5]] ] ] + ,bench "ixxFromLinear 10000x" $ + let n = 10000 + sh0 = SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% ZSX + in nf (\sh -> [ixxFromLinear sh i | i <- [1..n]]) sh0 + ,bench "ixxFromLinear 1x" $ + let sh0 = SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% ZSX + in nf (\sh -> ixxFromLinear sh 1234) sh0 + ,bench "shxEnum" $ + let sh0 = SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% SUnknown 10 :$% ZSX + in nf (\sh -> shxEnum sh) sh0 ] ] diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index 1b008e5..f127e3a 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -262,18 +262,30 @@ ixxZipWith _ ZIX ZIX = ZIX ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js ixxFromLinear :: IShX sh -> Int -> IIxX sh -ixxFromLinear = \sh i -> case go sh i of - (idx, 0) -> idx - _ -> error $ "ixxFromLinear: out of range (" ++ show i ++ - " in array of shape " ++ show sh ++ ")" +ixxFromLinear = \sh -> -- give this function arity 1 so that suffixes is shared when it's called many times + let suffixes = drop 1 (scanr (*) 1 (shxToList sh)) + in \i -> + if i < 0 then outrange sh i else + case (sh, suffixes) of -- unfold first iteration of fromLin to do the range check + (ZSX, _) | i > 0 -> outrange sh i + | otherwise -> ZIX + (n :$% sh', suff : suffs) -> + let (q, r) = i `quotRem` suff + in if q >= fromSMayNat' n then outrange sh i else + q :.% fromLin sh' suffs r + _ -> error "impossible" where - -- returns (index in subarray, remaining index in enclosing array) - go :: IShX sh -> Int -> (IIxX sh, Int) - go ZSX i = (ZIX, i) - go (n :$% sh) i = - let (idx, i') = go sh i - (upi, locali) = i' `quotRem` fromSMayNat' n - in (locali :.% idx, upi) + fromLin :: IShX sh -> [Int] -> Int -> IxX sh Int + fromLin ZSX _ !_ = ZIX + fromLin (_ :$% sh') (suff : suffs) i = + let (q, r) = i `quotRem` suff -- suff == shrSize sh' + in q :.% fromLin sh' suffs r + fromLin _ _ _ = error "impossible" + + {-# NOINLINE outrange #-} + outrange :: IShX sh -> Int -> a + outrange sh i = error $ "ixxFromLinear: out of range (" ++ show i ++ + " in array of shape " ++ show sh ++ ")" ixxToLinear :: IShX sh -> IIxX sh -> Int ixxToLinear = \sh i -> fst (go sh i) |
