aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-06-03 21:27:55 +0200
committerTom Smeding <tom@tomsmeding.com>2024-06-03 21:27:55 +0200
commit3286a65fe6e4735aaadef5addecbe3c3f7ed3468 (patch)
treefb8787a3bc2b2740733a110c21c580d0c56b9381 /src
parentac061cf450b1c8e153de06f7b12256914c496788 (diff)
Rename *ToSNat to *LengthSNat
For consistency with existing functions
Diffstat (limited to 'src')
-rw-r--r--src/Data/Array/Mixed/Types.hs1
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs26
-rw-r--r--src/Data/Array/Nested/Internal/Shape.hs16
3 files changed, 22 insertions, 21 deletions
diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Mixed/Types.hs
index 35e6fd3..b8f0824 100644
--- a/src/Data/Array/Mixed/Types.hs
+++ b/src/Data/Array/Mixed/Types.hs
@@ -20,6 +20,7 @@ module Data.Array.Mixed.Types (
pattern SZ, pattern SS,
fromSNat', sameNat',
snatPlus, snatMul,
+ snatSucc,
-- * Type-level lists
type (++),
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index 894ed0d..6b1547d 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -216,7 +216,7 @@ rshape :: forall n a. Elt a => Ranked n a -> IShR n
rshape (Ranked arr) = shCvtXR' (mshape arr)
rrank :: Elt a => Ranked n a -> SNat n
-rrank = shrToSNat . rshape
+rrank = shrLengthSNat . rshape
rindex :: Elt a => Ranked n a -> IIxR n -> a
rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
@@ -224,14 +224,14 @@ rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a
rindexPartial (Ranked arr) idx =
Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
- (castWith (subst2 (lemReplicatePlusApp (ixrToSNat idx) (Proxy @m) (Proxy @Nothing))) arr)
+ (castWith (subst2 (lemReplicatePlusApp (ixrLengthSNat idx) (Proxy @m) (Proxy @Nothing))) arr)
(ixCvtRX idx))
-- | __WARNING__: All values returned from the function must have equal shape.
-- See the documentation of 'mgenerate' for more details.
rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a
rgenerate sh f
- | sn@SNat <- shrToSNat sh
+ | sn@SNat <- shrLengthSNat sh
, Dict <- lemKnownReplicate sn
, Refl <- lemRankReplicate sn
= Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR))
@@ -263,7 +263,7 @@ rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive
rtranspose :: forall n a. Elt a => PermR -> Ranked n a -> Ranked n a
rtranspose perm arr
- | sn@SNat <- shrToSNat (rshape arr)
+ | sn@SNat <- shrLengthSNat (rshape arr)
, Dict <- lemKnownReplicate sn
, length perm <= fromIntegral (natVal (Proxy @n))
= rlift sn
@@ -275,7 +275,7 @@ rtranspose perm arr
rappend :: forall n a. Elt a
=> Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a
rappend arr1 arr2
- | sn@SNat <- shrToSNat (rshape arr1)
+ | sn@SNat <- shrLengthSNat (rshape arr1)
, Dict <- lemKnownReplicate sn
, Refl <- lemReplicateSucc @(Nothing @Nat) @n
= coerce (mappend @Nothing @Nothing @(Replicate n Nothing))
@@ -286,7 +286,7 @@ rscalar x = Ranked (mscalar x)
rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a)
rfromVectorP sh v
- | Dict <- lemKnownReplicate (shrToSNat sh)
+ | Dict <- lemKnownReplicate (shrLengthSNat sh)
= Ranked (mfromVectorP (shCvtRX sh) v)
rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a
@@ -336,7 +336,7 @@ rfromOrthotope sn arr
rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a
rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr)))
- | Refl <- lemRankReplicate (shrToSNat $ shCvtXR' sh)
+ | Refl <- lemRankReplicate (shrLengthSNat $ shCvtXR' sh)
= arr
runScalar :: Elt a => Ranked 0 a -> a
@@ -386,12 +386,12 @@ rrerank ssh sh2 f (rtoPrimitive -> arr) =
rreplicate :: forall n m a. Elt a
=> IShR n -> Ranked m a -> Ranked (n + m) a
rreplicate sh (Ranked arr)
- | Refl <- lemReplicatePlusApp (shrToSNat sh) (Proxy @m) (Proxy @(Nothing @Nat))
+ | Refl <- lemReplicatePlusApp (shrLengthSNat sh) (Proxy @m) (Proxy @(Nothing @Nat))
= Ranked (mreplicate (shCvtRX sh) arr)
rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
rreplicateScalP sh x
- | Dict <- lemKnownReplicate (shrToSNat sh)
+ | Dict <- lemKnownReplicate (shrLengthSNat sh)
= Ranked (mreplicateScalP (shCvtRX sh) x)
rreplicateScal :: forall n a. PrimElt a
@@ -401,13 +401,13 @@ rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x)
rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
rslice i n arr
| Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = rlift (shrToSNat (rshape arr))
+ = rlift (shrLengthSNat (rshape arr))
(\_ -> X.sliceU i n)
arr
rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a
rrev1 arr =
- rlift (shrToSNat (rshape arr))
+ rlift (shrLengthSNat (rshape arr))
(\(_ :: StaticShX sh') ->
case lemReplicateSucc @(Nothing @Nat) @n of
Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh'))
@@ -416,8 +416,8 @@ rrev1 arr =
rreshape :: forall n n' a. Elt a
=> IShR n' -> Ranked n a -> Ranked n' a
rreshape sh' rarr@(Ranked arr)
- | Dict <- lemKnownReplicate (shrToSNat (rshape rarr))
- , Dict <- lemKnownReplicate (shrToSNat sh')
+ | Dict <- lemKnownReplicate (shrLengthSNat (rshape rarr))
+ , Dict <- lemKnownReplicate (shrLengthSNat sh')
= Ranked (mreshape (shCvtRX sh') arr)
riota :: (Enum a, PrimElt a, Elt a) => Int -> Ranked 1 a
diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs
index 4cc58dd..bce85b0 100644
--- a/src/Data/Array/Nested/Internal/Shape.hs
+++ b/src/Data/Array/Nested/Internal/Shape.hs
@@ -85,14 +85,14 @@ listrIndex SZ (x ::: _) = x
listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
listrIndex _ ZR = error "k + 1 <= 0"
-listrToSNat :: ListR n i -> SNat n
-listrToSNat ZR = SNat
-listrToSNat (_ ::: (l :: ListR n i)) | SNat <- listrToSNat l, Dict <- lemKnownNatSucc @n = SNat
+listrLengthSNat :: ListR n i -> SNat n
+listrLengthSNat ZR = SNat
+listrLengthSNat (_ ::: (sh :: ListR n i)) = snatSucc (listrLengthSNat sh)
listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
listrPermutePrefix = \perm sh ->
listrFromList perm $ \sperm ->
- case (listrToSNat sperm, listrToSNat sh) of
+ case (listrLengthSNat sperm, listrLengthSNat sh) of
(permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of
LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
@@ -156,8 +156,8 @@ ixCvtRX (n :.: (idx :: IxR m Int)) =
ixrTail :: IxR (n + 1) i -> IxR n i
ixrTail (IxR list) = IxR (listrTail list)
-ixrToSNat :: IxR n i -> SNat n
-ixrToSNat (IxR sh) = listrToSNat sh
+ixrLengthSNat :: IxR n i -> SNat n
+ixrLengthSNat (IxR sh) = listrLengthSNat sh
ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
ixrPermutePrefix = coerce (listrPermutePrefix @i)
@@ -219,8 +219,8 @@ shrSize (n :$: sh) = n * shrSize sh
shrTail :: ShR (n + 1) i -> ShR n i
shrTail (ShR list) = ShR (listrTail list)
-shrToSNat :: ShR n i -> SNat n
-shrToSNat (ShR sh) = listrToSNat sh
+shrLengthSNat :: ShR n i -> SNat n
+shrLengthSNat (ShR sh) = listrLengthSNat sh
shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i
shrPermutePrefix = coerce (listrPermutePrefix @i)