From a25d4061e219cec153f066fddbf710abd63b5e48 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Mon, 3 Jun 2024 19:56:05 +0200
Subject: Move sh*Tail to main ox-arrays

---
 src/Data/Array/Mixed/Shape.hs           |  8 +++++++-
 src/Data/Array/Nested/Internal/Shape.hs | 19 +++++++++++++++++++
 test/Tests/C.hs                         |  5 +++--
 test/Util.hs                            |  5 -----
 4 files changed, 29 insertions(+), 8 deletions(-)

diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs
index a13a176..4ab3c26 100644
--- a/src/Data/Array/Mixed/Shape.hs
+++ b/src/Data/Array/Mixed/Shape.hs
@@ -96,6 +96,9 @@ listxToList :: ListX sh' (Const i) -> [i]
 listxToList ZX = []
 listxToList (Const i ::% is) = i : listxToList is
 
+listxTail :: ListX (n : sh) i -> ListX sh i
+listxTail (_ ::% sh) = sh
+
 listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f
 listxAppend ZX idx' = idx'
 listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx'
@@ -147,6 +150,9 @@ ixxZero' :: IShX sh -> IIxX sh
 ixxZero' ZSX = ZIX
 ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh
 
+ixxTail :: IxX (n : sh) i -> IxX sh i
+ixxTail (IxX list) = IxX (listxTail list)
+
 ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i
 ixxAppend = coerce (listxAppend @_ @(Const i))
 
@@ -273,7 +279,7 @@ shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i
 shxAppend = coerce (listxAppend @_ @(SMayNat i SNat))
 
 shxTail :: ShX (n : sh) i -> ShX sh i
-shxTail (_ :$% sh) = sh
+shxTail (ShX list) = ShX (listxTail list)
 
 shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i
 shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat))
diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs
index 1e0d7cc..4cc58dd 100644
--- a/src/Data/Array/Nested/Internal/Shape.hs
+++ b/src/Data/Array/Nested/Internal/Shape.hs
@@ -76,6 +76,10 @@ listrFromList :: [i] -> (forall n. ListR n i -> r) -> r
 listrFromList [] k = k ZR
 listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l)
 
+listrTail :: ListR (n + 1) i -> ListR n i
+listrTail (_ ::: sh) = sh
+listrTail ZR = error "unreachable"
+
 listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
 listrIndex SZ (x ::: _) = x
 listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
@@ -149,6 +153,9 @@ ixCvtRX (n :.: (idx :: IxR m Int)) =
   castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m))
     (n :.% ixCvtRX idx)
 
+ixrTail :: IxR (n + 1) i -> IxR n i
+ixrTail (IxR list) = IxR (listrTail list)
+
 ixrToSNat :: IxR n i -> SNat n
 ixrToSNat (IxR sh) = listrToSNat sh
 
@@ -209,6 +216,9 @@ shrSize :: IShR n -> Int
 shrSize ZSR = 1
 shrSize (n :$: sh) = n * shrSize sh
 
+shrTail :: ShR (n + 1) i -> ShR n i
+shrTail (ShR list) = ShR (listrTail list)
+
 shrToSNat :: ShR n i -> SNat n
 shrToSNat (ShR sh) = listrToSNat sh
 
@@ -278,6 +288,9 @@ listsToList :: ListS sh (Const i) -> [i]
 listsToList ZS = []
 listsToList (Const i ::$ is) = i : listsToList is
 
+listsTail :: ListS (n : sh) i -> ListS sh i
+listsTail (_ ::$ sh) = sh
+
 listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f
 listsAppend ZS idx' = idx'
 listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'
@@ -370,6 +383,9 @@ ixCvtSX :: IIxS sh -> IIxX (MapJust sh)
 ixCvtSX ZIS = ZIX
 ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh
 
+ixsTail :: IxS (n : sh) i -> IxS sh i
+ixsTail (IxS list) = IxS (listsTail list)
+
 
 -- | The shape of a shape-typed array given as a list of 'SNat' values.
 --
@@ -421,6 +437,9 @@ shCvtSX :: ShS sh -> IShX (MapJust sh)
 shCvtSX ZSS = ZSX
 shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh
 
+shsTail :: ShS (n : sh) -> ShS sh
+shsTail (ShS list) = ShS (listsTail list)
+
 shsSize :: ShS sh -> Int
 shsSize ZSS = 1
 shsSize (n :$$ sh) = fromSNat' n * shsSize sh
diff --git a/test/Tests/C.hs b/test/Tests/C.hs
index b98c23f..2a3949f 100644
--- a/test/Tests/C.hs
+++ b/test/Tests/C.hs
@@ -20,6 +20,7 @@ import Data.Array.Mixed.XArray qualified as X
 import Data.Array.Mixed.Lemmas
 import Data.Array.Nested
 import Data.Array.Nested.Internal.Mixed
+import Data.Array.Nested.Internal.Shape
 
 import Hedgehog
 import Hedgehog.Internal.Property (forAllT)
@@ -42,7 +43,7 @@ tests = testGroup "C"
       let inrank = SNat @(n + 1)
       sh <- forAll $ genShR inrank
       -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
-      guard (all (> 0) (toList (rshTail sh)))  -- only constrain the tail
+      guard (all (> 0) (toList (shrTail sh)))  -- only constrain the tail
       arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>
                genStorables (Range.singleton (product sh))
                             (\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
@@ -64,7 +65,7 @@ tests = testGroup "C"
         sht <- shuffleShR (0 :$: shtt)  -- n
         n <- Gen.int (Range.linear 0 20)
         return (n :$: sht)  -- n + 1
-      guard (any (== 0) (toList (rshTail sh)))
+      guard (any (== 0) (toList (shrTail sh)))
       -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
       let arr = OR.fromList @Double @(n + 1) (toList sh) []
       let rarr = rfromOrthotope inrank arr
diff --git a/test/Util.hs b/test/Util.hs
index a358d30..f377e5b 100644
--- a/test/Util.hs
+++ b/test/Util.hs
@@ -13,7 +13,6 @@ import Data.Array.RankedS qualified as OR
 import GHC.TypeLits
 
 import Data.Array.Mixed.Types (fromSNat')
-import Data.Array.Nested
 
 
 -- Returns highest value that satisfies the predicate, or `lo` if none does
@@ -33,7 +32,3 @@ orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n
 orSumOuter1 (sn@SNat :: SNat n) =
   let n = fromSNat' sn
   in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0])
-
-rshTail :: ShR (n + 1) i -> ShR n i
-rshTail (_ :$: sh) = sh
-rshTail ZSR = error "unreachable"
-- 
cgit v1.2.3-70-g09d2