aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-12-02 15:01:26 +0100
committerTom Smeding <tom@tomsmeding.com>2025-12-02 15:01:26 +0100
commitf76d781c75105b7b04ed2e602f0139d35846ab92 (patch)
tree2ae7357a52f8c2e302332cd8d7754c23fe5be511 /src
parent9f47aa6a2bcd772388a5d5150ca7254e4eb95bc2 (diff)
Style and uniformity of shape/index/list functions
Diffstat (limited to 'src')
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs8
-rw-r--r--src/Data/Array/Nested/Permutation.hs2
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs46
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs9
4 files changed, 42 insertions, 23 deletions
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index ed03310..8aa5a77 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -217,9 +217,7 @@ instance Foldable (IxX sh) where
{-# INLINE foldr #-}
foldr _ z ZIX = z
foldr f z (x :.% xs) = f x (foldr f z xs)
- {-# INLINEABLE toList #-}
- toList ZIX = []
- toList (i :.% is) = i : Foldable.toList is
+ toList = ixxToList
null ZIX = False
null _ = True
@@ -242,6 +240,10 @@ ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh
ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i
ixxFromList = coerce (listxFromList @_ @i)
+{-# INLINEABLE ixxToList #-}
+ixxToList :: forall sh i. IxX sh i -> [i]
+ixxToList = coerce (listxToList @_ @i)
+
ixxHead :: IxX (n : sh) i -> i
ixxHead (IxX list) = getConst (listxHead list)
diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs
index 9eae73d..065c9fd 100644
--- a/src/Data/Array/Nested/Permutation.hs
+++ b/src/Data/Array/Nested/Permutation.hs
@@ -61,7 +61,7 @@ permFromListCont :: [Int] -> (forall list. Perm list -> r) -> r
permFromListCont [] k = k PNil
permFromListCont (x : xs) k = withSomeSNat (fromIntegral x) $ \case
Just sn -> permFromListCont xs $ \list -> k (sn `PCons` list)
- Nothing -> error $ "Data.Array.Mixed.permFromListCont: negative number in list: " ++ show x
+ Nothing -> error $ "Data.Array.Nested.Permutation.permFromListCont: negative number in list: " ++ show x
permToList :: Perm list -> [Natural]
permToList PNil = mempty
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
index ea22a2b..9815c42 100644
--- a/src/Data/Array/Nested/Ranked/Shape.hs
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -33,7 +33,6 @@ import Control.DeepSeq (NFData(..))
import Data.Coerce (coerce)
import Data.Foldable qualified as Foldable
import Data.Kind (Type)
-import Data.List (genericLength)
import Data.Proxy
import Data.Type.Equality
import GHC.Exts (Int(..), Int#, quotRemInt#)
@@ -81,9 +80,7 @@ instance Foldable (ListR n) where
{-# INLINE foldr #-}
foldr _ z ZR = z
foldr f z (x ::: xs) = f x (foldr f z xs)
- {-# INLINEABLE toList #-}
- toList ZR = []
- toList (i ::: is) = i : Foldable.toList is
+ toList = listrToList
null ZR = False
null _ = True
@@ -137,6 +134,11 @@ listrFromList n l = error $ "listrFromList: Mismatched list length (type says "
++ show (fromSNat n) ++ ", list has length "
++ show (length l) ++ ")"
+{-# INLINEABLE listrToList #-}
+listrToList :: ListR n i -> [i]
+listrToList ZR = []
+listrToList (i ::: is) = i : listrToList is
+
listrHead :: ListR (n + 1) i -> i
listrHead (i ::: _) = i
@@ -174,16 +176,16 @@ listrZipWith _ _ _ =
error "listrZipWith: impossible pattern needlessly required"
listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
-listrPermutePrefix = \perm sh -> withSomeSNat (genericLength perm) $ \case
- Just permlen@SNat->
- let sperm = listrFromList permlen perm
- in case listrRank sh of
- shlen@SNat -> case cmpNat permlen shlen of
- LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
- EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
- GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
- ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
- Nothing -> error "listrPermutePrefix: impossible negative list length"
+listrPermutePrefix = \perm sh ->
+ TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat ->
+ case listrRank sh of { shlen@SNat ->
+ let sperm = listrFromList permlen perm in
+ case cmpNat permlen shlen of
+ LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
+ EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
+ GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
+ ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
+ }
where
listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
listrSplitAt SZ sh = (ZR, sh)
@@ -245,6 +247,13 @@ ixrZero :: SNat n -> IIxR n
ixrZero SZ = ZIR
ixrZero (SS n) = 0 :.: ixrZero n
+ixrFromList :: forall n i. SNat n -> [i] -> IxR n i
+ixrFromList = coerce (listrFromList @_ @i)
+
+{-# INLINEABLE ixrToList #-}
+ixrToList :: forall n i. IxR n i -> [i]
+ixrToList = coerce (listrToList @_ @i)
+
ixrHead :: IxR (n + 1) i -> i
ixrHead (IxR list) = listrHead list
@@ -330,6 +339,13 @@ shrSize :: IShR n -> Int
shrSize ZSR = 1
shrSize (n :$: sh) = n * shrSize sh
+shrFromList :: forall n i. SNat n -> [i] -> ShR n i
+shrFromList = coerce (listrFromList @_ @i)
+
+{-# INLINEABLE shrToList #-}
+shrToList :: forall n i. ShR n i -> [i]
+shrToList = coerce (listrToList @_ @i)
+
shrHead :: ShR (n + 1) i -> i
shrHead (ShR list) = listrHead list
@@ -366,7 +382,7 @@ shrEnum = shrEnum'
shrEnum' :: Num i => IShR sh -> [IxR sh i]
shrEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shrSize sh - 1]]
where
- suffixes = drop 1 (scanr (*) 1 (Foldable.toList sh))
+ suffixes = drop 1 (scanr (*) 1 (shrToList sh))
fromLin :: Num i => IShR sh -> [Int] -> Int# -> IxR sh i
fromLin ZSR _ _ = ZIR
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index eb8653d..0a4c1b9 100644
--- a/src/Data/Array/Nested/Shaped/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -237,9 +237,7 @@ instance Foldable (IxS sh) where
{-# INLINE foldr #-}
foldr _ z ZIS = z
foldr f z (x :.$ xs) = f x (foldr f z xs)
- {-# INLINEABLE toList #-}
- toList ZIS = []
- toList (i :.$ is) = i : Foldable.toList is
+ toList = ixsToList
null ZIS = False
null _ = True
@@ -254,6 +252,9 @@ ixsRank (IxS l) = listsRank l
ixsFromList :: forall sh i. ShS sh -> [i] -> IxS sh i
ixsFromList = coerce (listsFromList @_ @i)
+ixsToList :: forall sh i. IxS sh i -> [i]
+ixsToList = coerce (listsToList @_ @i)
+
ixsZero :: ShS sh -> IIxS sh
ixsZero ZSS = ZIS
ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh
@@ -349,7 +350,7 @@ shsSize :: ShS sh -> Int
shsSize ZSS = 1
shsSize (n :$$ sh) = fromSNat' n * shsSize sh
--- This is a partial @const@ that fails when the second argument
+-- | This is a partial @const@ that fails when the second argument
-- doesn't match the first.
shsFromList :: ShS sh -> [Int] -> ShS sh
shsFromList sh0@ZSS [] = sh0