aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-06-10 17:50:23 +0200
committerTom Smeding <tom@tomsmeding.com>2024-06-10 17:50:23 +0200
commita31367cc657198237a8ff911c8c78a399d51e2b8 (patch)
tree34b4d11e474d0d157822c9125e5d80b586db69b1
parent890f4afd45ea416134ddfaf8a9115602316e17dc (diff)
Add head functions for the nested list types
-rw-r--r--src/Data/Array/Nested/Internal/Shape.hs19
1 files changed, 19 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs
index c66b467..9d718cc 100644
--- a/src/Data/Array/Nested/Internal/Shape.hs
+++ b/src/Data/Array/Nested/Internal/Shape.hs
@@ -76,6 +76,10 @@ listrFromList :: [i] -> (forall n. ListR n i -> r) -> r
listrFromList [] k = k ZR
listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l)
+listrHead :: ListR (n + 1) i -> i
+listrHead (i ::: _) = i
+listrHead ZR = error "unreachable"
+
listrTail :: ListR (n + 1) i -> ListR n i
listrTail (_ ::: sh) = sh
listrTail ZR = error "unreachable"
@@ -153,6 +157,9 @@ ixCvtRX (n :.: (idx :: IxR m Int)) =
castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m))
(n :.% ixCvtRX idx)
+ixrHead :: IxR (n + 1) i -> i
+ixrHead (IxR list) = listrHead list
+
ixrTail :: IxR (n + 1) i -> IxR n i
ixrTail (IxR list) = IxR (listrTail list)
@@ -219,6 +226,9 @@ shrSize :: IShR n -> Int
shrSize ZSR = 1
shrSize (n :$: sh) = n * shrSize sh
+shrHead :: ShR (n + 1) i -> i
+shrHead (ShR list) = listrHead list
+
shrTail :: ShR (n + 1) i -> ShR n i
shrTail (ShR list) = ShR (listrTail list)
@@ -294,6 +304,9 @@ listsToList :: ListS sh (Const i) -> [i]
listsToList ZS = []
listsToList (Const i ::$ is) = i : listsToList is
+listsHead :: ListS (n : sh) i -> i n
+listsHead (i ::$ _) = i
+
listsTail :: ListS (n : sh) i -> ListS sh i
listsTail (_ ::$ sh) = sh
@@ -374,6 +387,9 @@ ixCvtSX :: IIxS sh -> IIxX (MapJust sh)
ixCvtSX ZIS = ZIX
ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh
+ixsHead :: IxS (n : sh) i -> i
+ixsHead (IxS list) = getConst (listsHead list)
+
ixsTail :: IxS (n : sh) i -> IxS sh i
ixsTail (IxS list) = IxS (listsTail list)
@@ -434,6 +450,9 @@ shCvtSX :: ShS sh -> IShX (MapJust sh)
shCvtSX ZSS = ZSX
shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh
+shsHead :: ShS (n : sh) -> SNat n
+shsHead (ShS list) = listsHead list
+
shsTail :: ShS (n : sh) -> ShS sh
shsTail (ShS list) = ShS (listsTail list)