diff options
| -rw-r--r-- | src/Data/Array/Mixed/Shape.hs | 8 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Shape.hs | 19 | ||||
| -rw-r--r-- | test/Tests/C.hs | 5 | ||||
| -rw-r--r-- | test/Util.hs | 5 | 
4 files changed, 29 insertions, 8 deletions
| diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs index a13a176..4ab3c26 100644 --- a/src/Data/Array/Mixed/Shape.hs +++ b/src/Data/Array/Mixed/Shape.hs @@ -96,6 +96,9 @@ listxToList :: ListX sh' (Const i) -> [i]  listxToList ZX = []  listxToList (Const i ::% is) = i : listxToList is +listxTail :: ListX (n : sh) i -> ListX sh i +listxTail (_ ::% sh) = sh +  listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f  listxAppend ZX idx' = idx'  listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' @@ -147,6 +150,9 @@ ixxZero' :: IShX sh -> IIxX sh  ixxZero' ZSX = ZIX  ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh +ixxTail :: IxX (n : sh) i -> IxX sh i +ixxTail (IxX list) = IxX (listxTail list) +  ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i  ixxAppend = coerce (listxAppend @_ @(Const i)) @@ -273,7 +279,7 @@ shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i  shxAppend = coerce (listxAppend @_ @(SMayNat i SNat))  shxTail :: ShX (n : sh) i -> ShX sh i -shxTail (_ :$% sh) = sh +shxTail (ShX list) = ShX (listxTail list)  shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i  shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs index 1e0d7cc..4cc58dd 100644 --- a/src/Data/Array/Nested/Internal/Shape.hs +++ b/src/Data/Array/Nested/Internal/Shape.hs @@ -76,6 +76,10 @@ listrFromList :: [i] -> (forall n. ListR n i -> r) -> r  listrFromList [] k = k ZR  listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l) +listrTail :: ListR (n + 1) i -> ListR n i +listrTail (_ ::: sh) = sh +listrTail ZR = error "unreachable" +  listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i  listrIndex SZ (x ::: _) = x  listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs @@ -149,6 +153,9 @@ ixCvtRX (n :.: (idx :: IxR m Int)) =    castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m))      (n :.% ixCvtRX idx) +ixrTail :: IxR (n + 1) i -> IxR n i +ixrTail (IxR list) = IxR (listrTail list) +  ixrToSNat :: IxR n i -> SNat n  ixrToSNat (IxR sh) = listrToSNat sh @@ -209,6 +216,9 @@ shrSize :: IShR n -> Int  shrSize ZSR = 1  shrSize (n :$: sh) = n * shrSize sh +shrTail :: ShR (n + 1) i -> ShR n i +shrTail (ShR list) = ShR (listrTail list) +  shrToSNat :: ShR n i -> SNat n  shrToSNat (ShR sh) = listrToSNat sh @@ -278,6 +288,9 @@ listsToList :: ListS sh (Const i) -> [i]  listsToList ZS = []  listsToList (Const i ::$ is) = i : listsToList is +listsTail :: ListS (n : sh) i -> ListS sh i +listsTail (_ ::$ sh) = sh +  listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f  listsAppend ZS idx' = idx'  listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' @@ -370,6 +383,9 @@ ixCvtSX :: IIxS sh -> IIxX (MapJust sh)  ixCvtSX ZIS = ZIX  ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh +ixsTail :: IxS (n : sh) i -> IxS sh i +ixsTail (IxS list) = IxS (listsTail list) +  -- | The shape of a shape-typed array given as a list of 'SNat' values.  -- @@ -421,6 +437,9 @@ shCvtSX :: ShS sh -> IShX (MapJust sh)  shCvtSX ZSS = ZSX  shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh +shsTail :: ShS (n : sh) -> ShS sh +shsTail (ShS list) = ShS (listsTail list) +  shsSize :: ShS sh -> Int  shsSize ZSS = 1  shsSize (n :$$ sh) = fromSNat' n * shsSize sh diff --git a/test/Tests/C.hs b/test/Tests/C.hs index b98c23f..2a3949f 100644 --- a/test/Tests/C.hs +++ b/test/Tests/C.hs @@ -20,6 +20,7 @@ import Data.Array.Mixed.XArray qualified as X  import Data.Array.Mixed.Lemmas  import Data.Array.Nested  import Data.Array.Nested.Internal.Mixed +import Data.Array.Nested.Internal.Shape  import Hedgehog  import Hedgehog.Internal.Property (forAllT) @@ -42,7 +43,7 @@ tests = testGroup "C"        let inrank = SNat @(n + 1)        sh <- forAll $ genShR inrank        -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) -      guard (all (> 0) (toList (rshTail sh)))  -- only constrain the tail +      guard (all (> 0) (toList (shrTail sh)))  -- only constrain the tail        arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$>                 genStorables (Range.singleton (product sh))                              (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) @@ -64,7 +65,7 @@ tests = testGroup "C"          sht <- shuffleShR (0 :$: shtt)  -- n          n <- Gen.int (Range.linear 0 20)          return (n :$: sht)  -- n + 1 -      guard (any (== 0) (toList (rshTail sh))) +      guard (any (== 0) (toList (shrTail sh)))        -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))        let arr = OR.fromList @Double @(n + 1) (toList sh) []        let rarr = rfromOrthotope inrank arr diff --git a/test/Util.hs b/test/Util.hs index a358d30..f377e5b 100644 --- a/test/Util.hs +++ b/test/Util.hs @@ -13,7 +13,6 @@ import Data.Array.RankedS qualified as OR  import GHC.TypeLits  import Data.Array.Mixed.Types (fromSNat') -import Data.Array.Nested  -- Returns highest value that satisfies the predicate, or `lo` if none does @@ -33,7 +32,3 @@ orSumOuter1 :: (OR.Unbox a, Num a) => SNat n -> OR.Array (n + 1) a -> OR.Array n  orSumOuter1 (sn@SNat :: SNat n) =    let n = fromSNat' sn    in OR.rerank @n @1 @0 (OR.scalar . OR.sumA) . OR.transpose ([1 .. n] ++ [0]) - -rshTail :: ShR (n + 1) i -> ShR n i -rshTail (_ :$: sh) = sh -rshTail ZSR = error "unreachable" | 
