aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Ranked
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Ranked')
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs46
1 files changed, 31 insertions, 15 deletions
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
index ea22a2b..9815c42 100644
--- a/src/Data/Array/Nested/Ranked/Shape.hs
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -33,7 +33,6 @@ import Control.DeepSeq (NFData(..))
import Data.Coerce (coerce)
import Data.Foldable qualified as Foldable
import Data.Kind (Type)
-import Data.List (genericLength)
import Data.Proxy
import Data.Type.Equality
import GHC.Exts (Int(..), Int#, quotRemInt#)
@@ -81,9 +80,7 @@ instance Foldable (ListR n) where
{-# INLINE foldr #-}
foldr _ z ZR = z
foldr f z (x ::: xs) = f x (foldr f z xs)
- {-# INLINEABLE toList #-}
- toList ZR = []
- toList (i ::: is) = i : Foldable.toList is
+ toList = listrToList
null ZR = False
null _ = True
@@ -137,6 +134,11 @@ listrFromList n l = error $ "listrFromList: Mismatched list length (type says "
++ show (fromSNat n) ++ ", list has length "
++ show (length l) ++ ")"
+{-# INLINEABLE listrToList #-}
+listrToList :: ListR n i -> [i]
+listrToList ZR = []
+listrToList (i ::: is) = i : listrToList is
+
listrHead :: ListR (n + 1) i -> i
listrHead (i ::: _) = i
@@ -174,16 +176,16 @@ listrZipWith _ _ _ =
error "listrZipWith: impossible pattern needlessly required"
listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
-listrPermutePrefix = \perm sh -> withSomeSNat (genericLength perm) $ \case
- Just permlen@SNat->
- let sperm = listrFromList permlen perm
- in case listrRank sh of
- shlen@SNat -> case cmpNat permlen shlen of
- LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
- EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
- GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
- ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
- Nothing -> error "listrPermutePrefix: impossible negative list length"
+listrPermutePrefix = \perm sh ->
+ TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat ->
+ case listrRank sh of { shlen@SNat ->
+ let sperm = listrFromList permlen perm in
+ case cmpNat permlen shlen of
+ LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
+ EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
+ GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
+ ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
+ }
where
listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
listrSplitAt SZ sh = (ZR, sh)
@@ -245,6 +247,13 @@ ixrZero :: SNat n -> IIxR n
ixrZero SZ = ZIR
ixrZero (SS n) = 0 :.: ixrZero n
+ixrFromList :: forall n i. SNat n -> [i] -> IxR n i
+ixrFromList = coerce (listrFromList @_ @i)
+
+{-# INLINEABLE ixrToList #-}
+ixrToList :: forall n i. IxR n i -> [i]
+ixrToList = coerce (listrToList @_ @i)
+
ixrHead :: IxR (n + 1) i -> i
ixrHead (IxR list) = listrHead list
@@ -330,6 +339,13 @@ shrSize :: IShR n -> Int
shrSize ZSR = 1
shrSize (n :$: sh) = n * shrSize sh
+shrFromList :: forall n i. SNat n -> [i] -> ShR n i
+shrFromList = coerce (listrFromList @_ @i)
+
+{-# INLINEABLE shrToList #-}
+shrToList :: forall n i. ShR n i -> [i]
+shrToList = coerce (listrToList @_ @i)
+
shrHead :: ShR (n + 1) i -> i
shrHead (ShR list) = listrHead list
@@ -366,7 +382,7 @@ shrEnum = shrEnum'
shrEnum' :: Num i => IShR sh -> [IxR sh i]
shrEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shrSize sh - 1]]
where
- suffixes = drop 1 (scanr (*) 1 (Foldable.toList sh))
+ suffixes = drop 1 (scanr (*) 1 (shrToList sh))
fromLin :: Num i => IShR sh -> [Int] -> Int# -> IxR sh i
fromLin ZSR _ _ = ZIR