From b63642a41f3bddc991d92f2f59b9e3ad53c1f15e Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Tue, 2 Dec 2025 15:03:10 +0100 Subject: Provide ix*FromLinear for all three shape kinds This speeds up {r,s}generatePrim --- src/Data/Array/Nested/Mixed/Shape.hs | 34 ++++-------------- src/Data/Array/Nested/Mixed/Shape/Internal.hs | 50 +++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 27 deletions(-) create mode 100644 src/Data/Array/Nested/Mixed/Shape/Internal.hs (limited to 'src/Data/Array/Nested/Mixed') diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index 8aa5a77..5a45a09 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -16,6 +16,7 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} @@ -44,6 +45,7 @@ import GHC.TypeLits import GHC.TypeLits.Orphans () #endif +import Data.Array.Nested.Mixed.Shape.Internal import Data.Array.Nested.Types @@ -276,33 +278,6 @@ ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k ixxZipWith _ ZIX ZIX = ZIX ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js -{-# INLINEABLE ixxFromLinear #-} -ixxFromLinear :: Num i => IShX sh -> Int -> IxX sh i -ixxFromLinear = \sh -> -- give this function arity 1 so that suffixes is shared when it's called many times - let suffixes = drop 1 (scanr (*) 1 (shxToList sh)) - in \i -> - if i < 0 then outrange sh i else - case (sh, suffixes) of -- unfold first iteration of fromLin to do the range check - (ZSX, _) | i > 0 -> outrange sh i - | otherwise -> ZIX - (n :$% sh', suff : suffs) -> - let (q, r) = i `quotRem` suff - in if q >= fromSMayNat' n then outrange sh i else - fromIntegral q :.% fromLin sh' suffs r - _ -> error "impossible" - where - fromLin :: Num i => IShX sh -> [Int] -> Int -> IxX sh i - fromLin ZSX _ !_ = ZIX - fromLin (_ :$% sh') (suff : suffs) i = - let (q, r) = i `quotRem` suff -- suff == shrSize sh' - in fromIntegral q :.% fromLin sh' suffs r - fromLin _ _ _ = error "impossible" - - {-# NOINLINE outrange #-} - outrange :: IShX sh -> Int -> a - outrange sh i = error $ "ixxFromLinear: out of range (" ++ show i ++ - " in array of shape " ++ show sh ++ ")" - ixxToLinear :: IShX sh -> IIxX sh -> Int ixxToLinear = \sh i -> fst (go sh i) where @@ -684,3 +659,8 @@ instance KnownShX sh => IsList (ShX sh Int) where type Item (ShX sh Int) = Int fromList = shxFromList (knownShX @sh) toList = shxToList + +-- This needs to be at the bottom of the file to not split the file into +-- pieces; some of the shape/index stuff refers to StaticShX. +$(ixFromLinearStub "ixxFromLinear" [t| IShX |] [t| IxX |] [p| ZSX |] (\a b -> [p| (fromSMayNat' -> $a) :$% $b |]) [| ZIX |] [| (:.%) |] [| shxToList |]) +{-# INLINEABLE ixxFromLinear #-} diff --git a/src/Data/Array/Nested/Mixed/Shape/Internal.hs b/src/Data/Array/Nested/Mixed/Shape/Internal.hs new file mode 100644 index 0000000..cf44522 --- /dev/null +++ b/src/Data/Array/Nested/Mixed/Shape/Internal.hs @@ -0,0 +1,50 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +module Data.Array.Nested.Mixed.Shape.Internal where + +import Language.Haskell.TH + + +-- | A TH stub function to avoid having to write the same code three times for +-- the three kinds of shapes. +ixFromLinearStub :: String + -> TypeQ -> TypeQ + -> PatQ -> (PatQ -> PatQ -> PatQ) + -> ExpQ -> ExpQ + -> ExpQ + -> DecsQ +ixFromLinearStub fname' ishty ixty zshC consshC ixz ixcons shtolist = do + let fname = mkName fname' + typesig <- [t| forall i sh. Num i => $ishty sh -> Int -> $ixty sh i |] + + locals <- [d| + fromLin :: Num i => $ishty sh -> [Int] -> Int -> $ixty sh i + fromLin $zshC _ !_ = $ixz + fromLin ($(consshC wildP (varP (mkName "sh'")))) (suff : suffs) i = + let (q, r) = i `quotRem` suff -- suff == shrSize sh' + in $ixcons (fromIntegral q) (fromLin sh' suffs r) + fromLin _ _ _ = error "impossible" + + {-# NOINLINE outrange #-} + outrange :: $ishty sh -> Int -> a + outrange sh i = error $ fname' ++ ": out of range (" ++ show i ++ + " in array of shape " ++ show sh ++ ")" |] + + body <- [| + \sh -> -- give this function arity 1 so that 'suffixes' is shared when + -- it's called many times + let suffixes = drop 1 (scanr (*) 1 ($shtolist sh)) + in \i -> + if i < 0 then outrange sh i else + case (sh, suffixes) of -- unfold first iteration of fromLin to do the range check + ($zshC, _) | i > 0 -> outrange sh i + | otherwise -> $ixz + ($(consshC (varP (mkName "n")) (varP (mkName "sh'"))), suff : suffs) -> + let (q, r) = i `quotRem` suff + in if q >= n then outrange sh i else + $ixcons (fromIntegral q) (fromLin sh' suffs r) + _ -> error "impossible" |] + + return [SigD fname typesig + ,FunD fname [Clause [] (NormalB body) locals]] -- cgit v1.2.3-70-g09d2