diff options
Diffstat (limited to 'src/Data/Array/Nested/Ranked')
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 46 |
1 files changed, 31 insertions, 15 deletions
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index ea22a2b..9815c42 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -33,7 +33,6 @@ import Control.DeepSeq (NFData(..)) import Data.Coerce (coerce) import Data.Foldable qualified as Foldable import Data.Kind (Type) -import Data.List (genericLength) import Data.Proxy import Data.Type.Equality import GHC.Exts (Int(..), Int#, quotRemInt#) @@ -81,9 +80,7 @@ instance Foldable (ListR n) where {-# INLINE foldr #-} foldr _ z ZR = z foldr f z (x ::: xs) = f x (foldr f z xs) - {-# INLINEABLE toList #-} - toList ZR = [] - toList (i ::: is) = i : Foldable.toList is + toList = listrToList null ZR = False null _ = True @@ -137,6 +134,11 @@ listrFromList n l = error $ "listrFromList: Mismatched list length (type says " ++ show (fromSNat n) ++ ", list has length " ++ show (length l) ++ ")" +{-# INLINEABLE listrToList #-} +listrToList :: ListR n i -> [i] +listrToList ZR = [] +listrToList (i ::: is) = i : listrToList is + listrHead :: ListR (n + 1) i -> i listrHead (i ::: _) = i @@ -174,16 +176,16 @@ listrZipWith _ _ _ = error "listrZipWith: impossible pattern needlessly required" listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i -listrPermutePrefix = \perm sh -> withSomeSNat (genericLength perm) $ \case - Just permlen@SNat-> - let sperm = listrFromList permlen perm - in case listrRank sh of - shlen@SNat -> case cmpNat permlen shlen of - LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post - EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post - GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" - ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" - Nothing -> error "listrPermutePrefix: impossible negative list length" +listrPermutePrefix = \perm sh -> + TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat -> + case listrRank sh of { shlen@SNat -> + let sperm = listrFromList permlen perm in + case cmpNat permlen shlen of + LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post + EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post + GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" + ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" + } where listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) listrSplitAt SZ sh = (ZR, sh) @@ -245,6 +247,13 @@ ixrZero :: SNat n -> IIxR n ixrZero SZ = ZIR ixrZero (SS n) = 0 :.: ixrZero n +ixrFromList :: forall n i. SNat n -> [i] -> IxR n i +ixrFromList = coerce (listrFromList @_ @i) + +{-# INLINEABLE ixrToList #-} +ixrToList :: forall n i. IxR n i -> [i] +ixrToList = coerce (listrToList @_ @i) + ixrHead :: IxR (n + 1) i -> i ixrHead (IxR list) = listrHead list @@ -330,6 +339,13 @@ shrSize :: IShR n -> Int shrSize ZSR = 1 shrSize (n :$: sh) = n * shrSize sh +shrFromList :: forall n i. SNat n -> [i] -> ShR n i +shrFromList = coerce (listrFromList @_ @i) + +{-# INLINEABLE shrToList #-} +shrToList :: forall n i. ShR n i -> [i] +shrToList = coerce (listrToList @_ @i) + shrHead :: ShR (n + 1) i -> i shrHead (ShR list) = listrHead list @@ -366,7 +382,7 @@ shrEnum = shrEnum' shrEnum' :: Num i => IShR sh -> [IxR sh i] shrEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shrSize sh - 1]] where - suffixes = drop 1 (scanr (*) 1 (Foldable.toList sh)) + suffixes = drop 1 (scanr (*) 1 (shrToList sh)) fromLin :: Num i => IShR sh -> [Int] -> Int# -> IxR sh i fromLin ZSR _ _ = ZIR |
