aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Shaped/Shape.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Shaped/Shape.hs')
-rw-r--r--src/Data/Array/Nested/Shaped/Shape.hs43
1 files changed, 24 insertions, 19 deletions
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs
index bbcdbf9..eb8653d 100644
--- a/src/Data/Array/Nested/Shaped/Shape.hs
+++ b/src/Data/Array/Nested/Shaped/Shape.hs
@@ -126,6 +126,13 @@ listsRank :: ListS sh f -> SNat (Rank sh)
listsRank ZS = SNat
listsRank (_ ::$ sh) = snatSucc (listsRank sh)
+listsFromList :: ShS sh -> [i] -> ListS sh (Const i)
+listsFromList ZSS [] = ZS
+listsFromList (_ :$$ sh) (i : is) = Const i ::$ listsFromList sh is
+listsFromList sh l = error $ "listsFromList: Mismatched list length (type says "
+ ++ show (shsLength sh) ++ ", list has length "
+ ++ show (length l) ++ ")"
+
{-# INLINEABLE listsToList #-}
listsToList :: ListS sh (Const i) -> [i]
listsToList ZS = []
@@ -244,6 +251,9 @@ ixsLength (IxS l) = listsLength l
ixsRank :: IxS sh i -> SNat (Rank sh)
ixsRank (IxS l) = listsRank l
+ixsFromList :: forall sh i. ShS sh -> [i] -> IxS sh i
+ixsFromList = coerce (listsFromList @_ @i)
+
ixsZero :: ShS sh -> IIxS sh
ixsZero ZSS = ZIS
ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh
@@ -339,6 +349,18 @@ shsSize :: ShS sh -> Int
shsSize ZSS = 1
shsSize (n :$$ sh) = fromSNat' n * shsSize sh
+-- This is a partial @const@ that fails when the second argument
+-- doesn't match the first.
+shsFromList :: ShS sh -> [Int] -> ShS sh
+shsFromList sh0@ZSS [] = sh0
+shsFromList sh0@(sn :$$ sh) (i : is)
+ | i == fromSNat' sn = shsFromList sh is `seq` sh0
+ | otherwise = error $ "shsFromList: Value does not match typing (type says "
+ ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
+shsFromList sh l = error $ "shsFromList: Mismatched list length (type says "
+ ++ show (shsLength sh) ++ ", list has length "
+ ++ show (length l) ++ ")"
+
{-# INLINEABLE shsToList #-}
shsToList :: ShS sh -> [Int]
shsToList ZSS = []
@@ -428,14 +450,7 @@ shsEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shsSize sh - 1]]
-- | Untyped: length is checked at runtime.
instance KnownShS sh => IsList (ListS sh (Const i)) where
type Item (ListS sh (Const i)) = i
- fromList topl = go (knownShS @sh) topl
- where
- go :: ShS sh' -> [i] -> ListS sh' (Const i)
- go ZSS [] = ZS
- go (_ :$$ sh) (i : is) = Const i ::$ go sh is
- go _ _ = error $ "IsList(ListS): Mismatched list length (type says "
- ++ show (shsLength (knownShS @sh)) ++ ", list has length "
- ++ show (length topl) ++ ")"
+ fromList = listsFromList (knownShS @sh)
toList = listsToList
-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
@@ -447,15 +462,5 @@ instance KnownShS sh => IsList (IxS sh i) where
-- | Untyped: length and values are checked at runtime.
instance KnownShS sh => IsList (ShS sh) where
type Item (ShS sh) = Int
- fromList topl = ShS (go (knownShS @sh) topl)
- where
- go :: ShS sh' -> [Int] -> ListS sh' SNat
- go ZSS [] = ZS
- go (sn :$$ sh) (i : is)
- | i == fromSNat' sn = sn ::$ go sh is
- | otherwise = error $ "IsList(ShS): Value does not match typing (type says "
- ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
- go _ _ = error $ "IsList(ShS): Mismatched list length (type says "
- ++ show (shsLength (knownShS @sh)) ++ ", list has length "
- ++ show (length topl) ++ ")"
+ fromList = shsFromList (knownShS @sh)
toList = shsToList