aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal/Shape.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Internal/Shape.hs')
-rw-r--r--src/Data/Array/Nested/Internal/Shape.hs48
1 files changed, 45 insertions, 3 deletions
diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs
index ca04840..7077053 100644
--- a/src/Data/Array/Nested/Internal/Shape.hs
+++ b/src/Data/Array/Nested/Internal/Shape.hs
@@ -87,6 +87,16 @@ listrTail :: ListR (n + 1) i -> ListR n i
listrTail (_ ::: sh) = sh
listrTail ZR = error "unreachable"
+listrInit :: ListR (n + 1) i -> ListR n i
+listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh
+listrInit (_ ::: ZR) = ZR
+listrInit ZR = error "unreachable"
+
+listrLast :: ListR (n + 1) i -> i
+listrLast (_ ::: sh@(_ ::: _)) = listrLast sh
+listrLast (n ::: ZR) = n
+listrLast ZR = error "unreachable"
+
listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
listrIndex SZ (x ::: _) = x
listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
@@ -166,6 +176,12 @@ ixrHead (IxR list) = listrHead list
ixrTail :: IxR (n + 1) i -> IxR n i
ixrTail (IxR list) = IxR (listrTail list)
+ixrInit :: IxR (n + 1) i -> IxR n i
+ixrInit (IxR list) = IxR (listrInit list)
+
+ixrLast :: IxR (n + 1) i -> i
+ixrLast (IxR list) = listrLast list
+
ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i
ixrAppend = coerce (listrAppend @_ @i)
@@ -235,6 +251,12 @@ shrHead (ShR list) = listrHead list
shrTail :: ShR (n + 1) i -> ShR n i
shrTail (ShR list) = ShR (listrTail list)
+shrInit :: ShR (n + 1) i -> ShR n i
+shrInit (ShR list) = ShR (listrInit list)
+
+shrLast :: ShR (n + 1) i -> i
+shrLast (ShR list) = listrLast list
+
shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i
shrAppend = coerce (listrAppend @_ @i)
@@ -310,17 +332,25 @@ listsToList :: ListS sh (Const i) -> [i]
listsToList ZS = []
listsToList (Const i ::$ is) = i : listsToList is
-listsHead :: ListS (n : sh) i -> i n
+listsHead :: ListS (n : sh) f -> f n
listsHead (i ::$ _) = i
-listsTail :: ListS (n : sh) i -> ListS sh i
+listsTail :: ListS (n : sh) f -> ListS sh f
listsTail (_ ::$ sh) = sh
+listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f
+listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh
+listsInit (_ ::$ ZS) = ZS
+
+listsLast :: ListS (n : sh) f -> f (Last (n : sh))
+listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh
+listsLast (n ::$ ZS) = n
+
listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f
listsAppend ZS idx' = idx'
listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'
-listsRank :: ListS sh i -> SNat (Rank sh)
+listsRank :: ListS sh f -> SNat (Rank sh)
listsRank ZS = SNat
listsRank (_ ::$ sh) = snatSucc (listsRank sh)
@@ -403,6 +433,12 @@ ixsHead (IxS list) = getConst (listsHead list)
ixsTail :: IxS (n : sh) i -> IxS sh i
ixsTail (IxS list) = IxS (listsTail list)
+ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i
+ixsInit (IxS list) = IxS (listsInit list)
+
+ixsLast :: IxS (n : sh) i -> i
+ixsLast (IxS list) = getConst (listsLast list)
+
ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i
ixsAppend = coerce (listsAppend @_ @(Const i))
@@ -469,6 +505,12 @@ shsHead (ShS list) = listsHead list
shsTail :: ShS (n : sh) -> ShS sh
shsTail (ShS list) = ShS (listsTail list)
+shsInit :: ShS (n : sh) -> ShS (Init (n : sh))
+shsInit (ShS list) = ShS (listsInit list)
+
+shsLast :: ShS (n : sh) -> SNat (Last (n : sh))
+shsLast (ShS list) = listsLast list
+
shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh')
shsAppend = coerce (listsAppend @_ @SNat)