diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-16 13:24:25 +0100 |
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-16 19:35:33 +0100 |
| commit | 682c584b26e872b7613cbcd73e3d15fc39867713 (patch) | |
| tree | 778abd95adb8516c8f3b83883f37ddcd26787c3f /src/Data/Array | |
| parent | 6f2206b61ea05d4b1cd1fb6d0971484bbc820b02 (diff) | |
Define ix?FromLinear without TH
Diffstat (limited to 'src/Data/Array')
| -rw-r--r-- | src/Data/Array/Nested/Convert.hs | 7 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 41 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape/Internal.hs | 59 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 14 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped.hs | 6 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 12 |
6 files changed, 58 insertions, 81 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index 32248c4..91752c4 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -63,8 +63,7 @@ import Data.Array.Nested.Types ixrFromIxS :: IxS sh i -> IxR (Rank sh) i ixrFromIxS = unsafeCoerce -ixrFromIxX :: IxX sh i -> IxR (Rank sh) i -ixrFromIxX = unsafeCoerce +-- ixrFromIxX re-exported shrFromShS :: ShS sh -> IShR (Rank sh) shrFromShS ZSS = ZSR @@ -97,9 +96,7 @@ ixsFromIxR' ZSS ZIR = ZIS ixsFromIxR' (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR' sh idx ixsFromIxR' _ _ = error "ixsFromIxR': index rank does not match shape rank" --- TODO: remove the ShS now that no KnownNats is inside IxS. -ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i -ixsFromIxX _ = unsafeCoerce +-- ixsFromIxX re-exported -- TODO: if possible, remove the ShS now that no KnownNats is inside IxS. -- | Performs a runtime check that @Rank sh'@ match @Rank sh@. Equivalent to diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index 5ffd40c..ebf0a07 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -46,7 +46,6 @@ import GHC.TypeLits import GHC.TypeLits.Orphans () #endif -import Data.Array.Nested.Mixed.Shape.Internal import Data.Array.Nested.Types @@ -293,6 +292,41 @@ ixxToLinear = \sh i -> go sh i 0 go ZSX ZIX a = a go (n :$% sh) (i :.% ix) a = go sh ix (fromIntegral (fromSMayNat' n) * a + i) +{-# INLINEABLE ixxFromLinear #-} +ixxFromLinear :: Num i => IShX sh -> Int -> IxX sh i +ixxFromLinear = \sh -> -- give this function arity 1 so that suffixes is shared when it's called many times + let suffixes = drop 1 (scanr (*) 1 (shxToList sh)) + in fromLin0 sh suffixes + where + -- Unfold first iteration of fromLin to do the range check. + -- Don't inline this function at first to allow GHC to inline the outer + -- function and realise that 'suffixes' is shared. But then later inline it + -- anyway, to avoid the function call. Removing the pragma makes GHC + -- somehow unable to recognise that 'suffixes' can be shared in a loop. + {-# NOINLINE [0] fromLin0 #-} + fromLin0 :: Num i => IShX sh -> [Int] -> Int -> IxX sh i + fromLin0 sh suffixes i = + if i < 0 then outrange sh i else + case (sh, suffixes) of + (ZSX, _) | i > 0 -> outrange sh i + | otherwise -> ZIX + ((fromSMayNat' -> n) :$% sh', suff : suffs) -> + let (q, r) = i `quotRem` suff + in if q >= n then outrange sh i else + fromIntegral q :.% fromLin sh' suffs r + _ -> error "impossible" + + fromLin :: Num i => IShX sh -> [Int] -> Int -> IxX sh i + fromLin ZSX _ !_ = ZIX + fromLin (_ :$% sh') (suff : suffs) i = + let (q, r) = i `quotRem` suff -- suff == shrSize sh' + in fromIntegral q :.% fromLin sh' suffs r + fromLin _ _ _ = error "impossible" + + {-# NOINLINE outrange #-} + outrange :: IShX sh -> Int -> a + outrange sh i = error $ "ixxFromLinear: out of range (" ++ show i ++ + " in array of shape " ++ show sh ++ ")" -- * Mixed shape-like lists to be used for ShX and StaticShX @@ -798,8 +832,3 @@ instance KnownShX sh => IsList (IShX sh) where type Item (IShX sh) = Int fromList = shxFromList (knownShX @sh) toList = shxToList - --- This needs to be at the bottom of the file to not split the file into --- pieces; some of the shape/index stuff refers to StaticShX. -$(ixFromLinearStub "ixxFromLinear" [t| IShX |] [t| IxX |] [p| ZSX |] (\a b -> [p| (fromSMayNat' -> $a) :$% $b |]) [| ZIX |] [| (:.%) |] [| shxToList |]) -{-# INLINEABLE ixxFromLinear #-} diff --git a/src/Data/Array/Nested/Mixed/Shape/Internal.hs b/src/Data/Array/Nested/Mixed/Shape/Internal.hs deleted file mode 100644 index 2a86ac1..0000000 --- a/src/Data/Array/Nested/Mixed/Shape/Internal.hs +++ /dev/null @@ -1,59 +0,0 @@ -{-# LANGUAGE BangPatterns #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -module Data.Array.Nested.Mixed.Shape.Internal where - -import Language.Haskell.TH - - --- | A TH stub function to avoid having to write the same code three times for --- the three kinds of shapes. -ixFromLinearStub :: String - -> TypeQ -> TypeQ - -> PatQ -> (PatQ -> PatQ -> PatQ) - -> ExpQ -> ExpQ - -> ExpQ - -> DecsQ -ixFromLinearStub fname' ishty ixty zshC consshC ixz ixcons shtolist = do - let fname = mkName fname' - typesig <- [t| forall i sh. Num i => $ishty sh -> Int -> $ixty sh i |] - - locals <- [d| - -- Unfold first iteration of fromLin to do the range check. - -- Don't inline this function at first to allow GHC to inline the outer - -- function and realise that 'suffixes' is shared. But then later inline it - -- anyway, to avoid the function call. Removing the pragma makes GHC - -- somehow unable to recognise that 'suffixes' can be shared in a loop. - {-# NOINLINE [0] fromLin0 #-} - fromLin0 :: Num i => $ishty sh -> [Int] -> Int -> $ixty sh i - fromLin0 sh suffixes i = - if i < 0 then outrange sh i else - case (sh, suffixes) of - ($zshC, _) | i > 0 -> outrange sh i - | otherwise -> $ixz - ($(consshC (varP (mkName "n")) (varP (mkName "sh'"))), suff : suffs) -> - let (q, r) = i `quotRem` suff - in if q >= n then outrange sh i else - $ixcons (fromIntegral q) (fromLin sh' suffs r) - _ -> error "impossible" - - fromLin :: Num i => $ishty sh -> [Int] -> Int -> $ixty sh i - fromLin $zshC _ !_ = $ixz - fromLin ($(consshC wildP (varP (mkName "sh'")))) (suff : suffs) i = - let (q, r) = i `quotRem` suff -- suff == shrSize sh' - in $ixcons (fromIntegral q) (fromLin sh' suffs r) - fromLin _ _ _ = error "impossible" - - {-# NOINLINE outrange #-} - outrange :: $ishty sh -> Int -> a - outrange sh i = error $ fname' ++ ": out of range (" ++ show i ++ - " in array of shape " ++ show sh ++ ")" |] - - body <- [| - \sh -> -- give this function arity 1 so that 'suffixes' is shared when - -- it's called many times - let suffixes = drop 1 (scanr (*) 1 ($shtolist sh)) - in fromLin0 sh suffixes |] - - return [SigD fname typesig - ,FunD fname [Clause [] (NormalB body) locals]] diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index 6ce0f4f..59289fb 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -39,10 +39,10 @@ import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits import GHC.TypeNats qualified as TN +import Unsafe.Coerce (unsafeCoerce) import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed.Shape -import Data.Array.Nested.Mixed.Shape.Internal import Data.Array.Nested.Permutation import Data.Array.Nested.Types @@ -300,6 +300,15 @@ ixrToLinear = \sh i -> go sh i 0 go ZSR ZIR a = a go (n :$: sh) (i :.: ix) a = go sh ix (fromIntegral n * a + i) +{-# INLINEABLE ixrFromLinear #-} +ixrFromLinear :: forall i m. Num i => IShR m -> Int -> IxR m i +ixrFromLinear (ShR sh) i + | Refl <- lemRankReplicate (Proxy @m) + = ixrFromIxX $ ixxFromLinear sh i + +ixrFromIxX :: IxX sh i -> IxR (Rank sh) i +ixrFromIxX = unsafeCoerce + -- * Ranked shapes @@ -505,6 +514,3 @@ listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i listrCastWithName _ SZ ZR = ZR listrCastWithName name (SS n) (i ::: l) = i ::: listrCastWithName name n l listrCastWithName name _ _ = error $ name ++ ": ranks don't match" - -$(ixFromLinearStub "ixrFromLinear" [t| IShR |] [t| IxR |] [p| ZSR |] (\a b -> [p| $a :$: $b |]) [| ZIR |] [| (:.:) |] [| shrToList |]) -{-# INLINEABLE ixrFromLinear #-} diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index 85042f2..23a4fc8 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -70,7 +70,7 @@ sindexPartial sarr@(Shaped arr) idx = -- | __WARNING__: All values returned from the function must have equal shape. -- See the documentation of 'mgenerate' for more details. sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a -sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh)) +sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX)) -- | See 'mgeneratePrim'. {-# INLINE sgeneratePrim #-} @@ -253,11 +253,11 @@ siota sn = Shaped (miota sn) -- | Throws if the array is empty. sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh -sminIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mminIndexPrim arr) +sminIndexPrim (Shaped arr) = ixsFromIxX (mminIndexPrim arr) -- | Throws if the array is empty. smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh -smaxIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) +smaxIndexPrim (Shaped arr) = ixsFromIxX (mmaxIndexPrim arr) sdot1Inner :: forall sh n a. (PrimElt a, NumElt a) => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index 521ec2f..f57e7dd 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -41,9 +41,9 @@ import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits +import Unsafe.Coerce (unsafeCoerce) import Data.Array.Nested.Mixed.Shape -import Data.Array.Nested.Mixed.Shape.Internal import Data.Array.Nested.Permutation import Data.Array.Nested.Types @@ -318,6 +318,13 @@ ixsToLinear = \sh i -> go sh i 0 go ZSS ZIS a = a go (n :$$ sh) (i :.$ ix) a = go sh ix (fromIntegral (fromSNat' n) * a + i) +{-# INLINEABLE ixsFromLinear #-} +ixsFromLinear :: Num i => ShS sh -> Int -> IxS sh i +ixsFromLinear (ShS sh) i = ixsFromIxX $ ixxFromLinear sh i + +ixsFromIxX :: IxX (MapJust sh) i -> IxS sh i +ixsFromIxX = unsafeCoerce + -- * Shaped shapes @@ -533,6 +540,3 @@ instance KnownShS sh => IsList (ShS sh) where type Item (ShS sh) = Int fromList = shsFromList (knownShS @sh) toList = shsToList - -$(ixFromLinearStub "ixsFromLinear" [t| ShS |] [t| IxS |] [p| ZSS |] (\a b -> [p| (fromSNat' -> $a) :$$ $b |]) [| ZIS |] [| (:.$) |] [| shsToList |]) -{-# INLINEABLE ixsFromLinear #-} |
