diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-01 18:35:58 +0100 |
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-01 18:35:58 +0100 |
| commit | 45c429917c95713b339cc4d9210a842546e72a0d (patch) | |
| tree | c6f540f1478390c0874f4b566de480593db17e9b | |
| parent | 9faf7fb877119bd52d664940c4326d326b3326fa (diff) | |
Unify fromList functions for shapes
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 34 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 9 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 43 |
3 files changed, 39 insertions, 47 deletions
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index 4626481..066ae8e 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -131,14 +131,11 @@ listxShow f l = showString "[" . go "" l . showString "]" go prefix (x ::% xs) = showString prefix . f x . go "," xs listxFromList :: StaticShX sh -> [i] -> ListX sh (Const i) -listxFromList topssh topl = go topssh topl - where - go :: StaticShX sh' -> [i] -> ListX sh' (Const i) - go ZKX [] = ZX - go (_ :!% sh) (i : is) = Const i ::% go sh is - go _ _ = error $ "listxFromList: Mismatched list length (type says " - ++ show (ssxLength topssh) ++ ", list has length " - ++ show (length topl) ++ ")" +listxFromList ZKX [] = ZX +listxFromList (_ :!% sh) (i : is) = Const i ::% listxFromList sh is +listxFromList sh l = error $ "listxFromList: Mismatched list length (type says " + ++ show (ssxLength sh) ++ ", list has length " + ++ show (length l) ++ ")" {-# INLINEABLE listxToList #-} listxToList :: ListX sh' (Const i) -> [i] @@ -432,18 +429,15 @@ shxSize ZSX = 1 shxSize (n :$% sh) = fromSMayNat' n * shxSize sh shxFromList :: StaticShX sh -> [Int] -> IShX sh -shxFromList topssh topl = go topssh topl - where - go :: StaticShX sh' -> [Int] -> IShX sh' - go ZKX [] = ZSX - go (SKnown sn :!% sh) (i : is) - | i == fromSNat' sn = SKnown sn :$% go sh is - | otherwise = error $ "shxFromList: Value does not match typing (type says " - ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" - go (SUnknown () :!% sh) (i : is) = SUnknown i :$% go sh is - go _ _ = error $ "shxFromList: Mismatched list length (type says " - ++ show (ssxLength topssh) ++ ", list has length " - ++ show (length topl) ++ ")" +shxFromList ZKX [] = ZSX +shxFromList (SKnown sn :!% sh) (i : is) + | i == fromSNat' sn = SKnown sn :$% shxFromList sh is + | otherwise = error $ "shxFromList: Value does not match typing (type says " + ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" +shxFromList (SUnknown () :!% sh) (i : is) = SUnknown i :$% shxFromList sh is +shxFromList sh l = error $ "shxFromList: Mismatched list length (type says " + ++ show (ssxLength sh) ++ ", list has length " + ++ show (length l) ++ ")" {-# INLINEABLE shxToList #-} shxToList :: IShX sh -> [Int] diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index 4d581a0..ea22a2b 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -379,14 +379,7 @@ shrEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shrSize sh - 1]] -- | Untyped: length is checked at runtime. instance KnownNat n => IsList (ListR n i) where type Item (ListR n i) = i - fromList topl = go (SNat @n) topl - where - go :: SNat n' -> [i] -> ListR n' i - go SZ [] = ZR - go (SS n) (i : is) = i ::: go n is - go _ _ = error $ "IsList(ListR): Mismatched list length (type says " - ++ show (fromSNat (SNat @n)) ++ ", list has length " - ++ show (length topl) ++ ")" + fromList = listrFromList (SNat @n) toList = Foldable.toList -- | Untyped: length is checked at runtime. diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index bbcdbf9..eb8653d 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -126,6 +126,13 @@ listsRank :: ListS sh f -> SNat (Rank sh) listsRank ZS = SNat listsRank (_ ::$ sh) = snatSucc (listsRank sh) +listsFromList :: ShS sh -> [i] -> ListS sh (Const i) +listsFromList ZSS [] = ZS +listsFromList (_ :$$ sh) (i : is) = Const i ::$ listsFromList sh is +listsFromList sh l = error $ "listsFromList: Mismatched list length (type says " + ++ show (shsLength sh) ++ ", list has length " + ++ show (length l) ++ ")" + {-# INLINEABLE listsToList #-} listsToList :: ListS sh (Const i) -> [i] listsToList ZS = [] @@ -244,6 +251,9 @@ ixsLength (IxS l) = listsLength l ixsRank :: IxS sh i -> SNat (Rank sh) ixsRank (IxS l) = listsRank l +ixsFromList :: forall sh i. ShS sh -> [i] -> IxS sh i +ixsFromList = coerce (listsFromList @_ @i) + ixsZero :: ShS sh -> IIxS sh ixsZero ZSS = ZIS ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh @@ -339,6 +349,18 @@ shsSize :: ShS sh -> Int shsSize ZSS = 1 shsSize (n :$$ sh) = fromSNat' n * shsSize sh +-- This is a partial @const@ that fails when the second argument +-- doesn't match the first. +shsFromList :: ShS sh -> [Int] -> ShS sh +shsFromList sh0@ZSS [] = sh0 +shsFromList sh0@(sn :$$ sh) (i : is) + | i == fromSNat' sn = shsFromList sh is `seq` sh0 + | otherwise = error $ "shsFromList: Value does not match typing (type says " + ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" +shsFromList sh l = error $ "shsFromList: Mismatched list length (type says " + ++ show (shsLength sh) ++ ", list has length " + ++ show (length l) ++ ")" + {-# INLINEABLE shsToList #-} shsToList :: ShS sh -> [Int] shsToList ZSS = [] @@ -428,14 +450,7 @@ shsEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shsSize sh - 1]] -- | Untyped: length is checked at runtime. instance KnownShS sh => IsList (ListS sh (Const i)) where type Item (ListS sh (Const i)) = i - fromList topl = go (knownShS @sh) topl - where - go :: ShS sh' -> [i] -> ListS sh' (Const i) - go ZSS [] = ZS - go (_ :$$ sh) (i : is) = Const i ::$ go sh is - go _ _ = error $ "IsList(ListS): Mismatched list length (type says " - ++ show (shsLength (knownShS @sh)) ++ ", list has length " - ++ show (length topl) ++ ")" + fromList = listsFromList (knownShS @sh) toList = listsToList -- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. @@ -447,15 +462,5 @@ instance KnownShS sh => IsList (IxS sh i) where -- | Untyped: length and values are checked at runtime. instance KnownShS sh => IsList (ShS sh) where type Item (ShS sh) = Int - fromList topl = ShS (go (knownShS @sh) topl) - where - go :: ShS sh' -> [Int] -> ListS sh' SNat - go ZSS [] = ZS - go (sn :$$ sh) (i : is) - | i == fromSNat' sn = sn ::$ go sh is - | otherwise = error $ "IsList(ShS): Value does not match typing (type says " - ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" - go _ _ = error $ "IsList(ShS): Mismatched list length (type says " - ++ show (shsLength (knownShS @sh)) ++ ", list has length " - ++ show (length topl) ++ ")" + fromList = shsFromList (knownShS @sh) toList = shsToList |
