aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-06-13 13:09:04 +0200
committerTom Smeding <tom@tomsmeding.com>2024-06-13 13:09:04 +0200
commit20173c939486ed6e27b8170e94f666d8ae3df152 (patch)
tree36c02005c3f2a20567388c6291e54bc2e4a4e6db /src/Data/Array
parent275847827d7550436eaf8cd10969f1430dae821d (diff)
Rename *LengthSNat to *Rank
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Mixed/Permutation.hs8
-rw-r--r--src/Data/Array/Mixed/Shape.hs14
-rw-r--r--src/Data/Array/Mixed/XArray.hs2
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs26
-rw-r--r--src/Data/Array/Nested/Internal/Shape.hs16
5 files changed, 31 insertions, 35 deletions
diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs
index ca99b02..331d5e0 100644
--- a/src/Data/Array/Mixed/Permutation.hs
+++ b/src/Data/Array/Mixed/Permutation.hs
@@ -45,9 +45,9 @@ infixr 5 `PCons`
deriving instance Show (Perm list)
deriving instance Eq (Perm list)
-permLengthSNat :: Perm list -> SNat (Rank list)
-permLengthSNat PNil = SNat
-permLengthSNat (_ `PCons` l) | SNat <- permLengthSNat l = SNat
+permRank :: Perm list -> SNat (Rank list)
+permRank PNil = SNat
+permRank (_ `PCons` l) | SNat <- permRank l = SNat
permFromList :: [Int] -> (forall list. Perm list -> r) -> r
permFromList [] k = k PNil
@@ -67,7 +67,7 @@ permToList' = map fromIntegral . permToList
-- then @Nothing@ is returned.
permCheckPermutation :: forall r list. Perm list -> (IsPermutation list => r) -> Maybe r
permCheckPermutation = \p k ->
- let n = permLengthSNat p
+ let n = permRank p
in case (provePerm1 (Proxy @list) n p, provePerm2 (SNat @0) n p) of
(Just Refl, Just Refl) -> Just k
_ -> Nothing
diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs
index e46105d..95cd4ef 100644
--- a/src/Data/Array/Mixed/Shape.hs
+++ b/src/Data/Array/Mixed/Shape.hs
@@ -81,9 +81,9 @@ listxFold f (x ::% xs) = f x <> listxFold f xs
listxLength :: ListX sh f -> Int
listxLength = getSum . listxFold (\_ -> Sum 1)
-listxLengthSNat :: ListX sh f -> SNat (Rank sh)
-listxLengthSNat ZX = SNat
-listxLengthSNat (_ ::% l) | SNat <- listxLengthSNat l = SNat
+listxRank :: ListX sh f -> SNat (Rank sh)
+listxRank ZX = SNat
+listxRank (_ ::% l) | SNat <- listxRank l = SNat
listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS
listxShow f l = showString "[" . go "" l . showString "]"
@@ -265,8 +265,8 @@ instance NFData i => NFData (ShX sh i) where
shxLength :: ShX sh i -> Int
shxLength (ShX l) = listxLength l
-shxLengthSNat :: ShX sh f -> SNat (Rank sh)
-shxLengthSNat (ShX list) = listxLengthSNat list
+shxRank :: ShX sh f -> SNat (Rank sh)
+shxRank (ShX list) = listxRank list
-- | This is more than @geq@: it also checks that the integers (the unknown
-- dimensions) are the same.
@@ -344,10 +344,6 @@ shxEnum = \sh -> go sh id []
go ZSX f = (f ZIX :)
go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]]
-shxRank :: ShX sh f -> SNat (Rank sh)
-shxRank ZSX = SNat
-shxRank (_ :$% sh) | SNat <- shxRank sh = SNat
-
-- * Static mixed shapes
diff --git a/src/Data/Array/Mixed/XArray.hs b/src/Data/Array/Mixed/XArray.hs
index 20f5c7a..08295cd 100644
--- a/src/Data/Array/Mixed/XArray.hs
+++ b/src/Data/Array/Mixed/XArray.hs
@@ -258,7 +258,7 @@ sumInner ssh ssh' arr
go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a
go (XArray arr')
| Refl <- lemRankApp ssh ssh'F
- , let sn = listxLengthSNat (let StaticShX l = ssh in l)
+ , let sn = listxRank (let StaticShX l = ssh in l)
= XArray (numEltSum1Inner sn arr')
in go $
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index 1518791..6a4db8e 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -227,7 +227,7 @@ rshape :: forall n a. Elt a => Ranked n a -> IShR n
rshape (Ranked arr) = shCvtXR' (mshape arr)
rrank :: Elt a => Ranked n a -> SNat n
-rrank = shrLengthSNat . rshape
+rrank = shrRank . rshape
rindex :: Elt a => Ranked n a -> IIxR n -> a
rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
@@ -235,14 +235,14 @@ rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a
rindexPartial (Ranked arr) idx =
Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
- (castWith (subst2 (lemReplicatePlusApp (ixrLengthSNat idx) (Proxy @m) (Proxy @Nothing))) arr)
+ (castWith (subst2 (lemReplicatePlusApp (ixrRank idx) (Proxy @m) (Proxy @Nothing))) arr)
(ixCvtRX idx))
-- | __WARNING__: All values returned from the function must have equal shape.
-- See the documentation of 'mgenerate' for more details.
rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a
rgenerate sh f
- | sn@SNat <- shrLengthSNat sh
+ | sn@SNat <- shrRank sh
, Dict <- lemKnownReplicate sn
, Refl <- lemRankReplicate sn
= Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR))
@@ -274,7 +274,7 @@ rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive
rtranspose :: forall n a. Elt a => PermR -> Ranked n a -> Ranked n a
rtranspose perm arr
- | sn@SNat <- shrLengthSNat (rshape arr)
+ | sn@SNat <- rrank arr
, Dict <- lemKnownReplicate sn
, length perm <= fromIntegral (natVal (Proxy @n))
= rlift sn
@@ -291,7 +291,7 @@ rconcat
rappend :: forall n a. Elt a
=> Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a
rappend arr1 arr2
- | sn@SNat <- shrLengthSNat (rshape arr1)
+ | sn@SNat <- rrank arr1
, Dict <- lemKnownReplicate sn
, Refl <- lemReplicateSucc @(Nothing @Nat) @n
= coerce (mappend @Nothing @Nothing @(Replicate n Nothing))
@@ -302,7 +302,7 @@ rscalar x = Ranked (mscalar x)
rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a)
rfromVectorP sh v
- | Dict <- lemKnownReplicate (shrLengthSNat sh)
+ | Dict <- lemKnownReplicate (shrRank sh)
= Ranked (mfromVectorP (shCvtRX sh) v)
rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a
@@ -352,7 +352,7 @@ rfromOrthotope sn arr
rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a
rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr)))
- | Refl <- lemRankReplicate (shrLengthSNat $ shCvtXR' sh)
+ | Refl <- lemRankReplicate (shrRank $ shCvtXR' sh)
= arr
runScalar :: Elt a => Ranked 0 a -> a
@@ -412,12 +412,12 @@ rrerank sn sh2 f (rtoPrimitive -> arr) =
rreplicate :: forall n m a. Elt a
=> IShR n -> Ranked m a -> Ranked (n + m) a
rreplicate sh (Ranked arr)
- | Refl <- lemReplicatePlusApp (shrLengthSNat sh) (Proxy @m) (Proxy @(Nothing @Nat))
+ | Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat))
= Ranked (mreplicate (shCvtRX sh) arr)
rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
rreplicateScalP sh x
- | Dict <- lemKnownReplicate (shrLengthSNat sh)
+ | Dict <- lemKnownReplicate (shrRank sh)
= Ranked (mreplicateScalP (shCvtRX sh) x)
rreplicateScal :: forall n a. PrimElt a
@@ -427,13 +427,13 @@ rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x)
rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
rslice i n arr
| Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = rlift (shrLengthSNat (rshape arr))
+ = rlift (rrank arr)
(\_ -> X.sliceU i n)
arr
rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a
rrev1 arr =
- rlift (shrLengthSNat (rshape arr))
+ rlift (rrank arr)
(\(_ :: StaticShX sh') ->
case lemReplicateSucc @(Nothing @Nat) @n of
Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh'))
@@ -442,8 +442,8 @@ rrev1 arr =
rreshape :: forall n n' a. Elt a
=> IShR n' -> Ranked n a -> Ranked n' a
rreshape sh' rarr@(Ranked arr)
- | Dict <- lemKnownReplicate (shrLengthSNat (rshape rarr))
- , Dict <- lemKnownReplicate (shrLengthSNat sh')
+ | Dict <- lemKnownReplicate (rrank rarr)
+ , Dict <- lemKnownReplicate (shrRank sh')
= Ranked (mreshape (shCvtRX sh') arr)
rflatten :: Elt a => Ranked n a -> Ranked 1 a
diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs
index 7d95f61..4fa4284 100644
--- a/src/Data/Array/Nested/Internal/Shape.hs
+++ b/src/Data/Array/Nested/Internal/Shape.hs
@@ -91,14 +91,14 @@ listrIndex SZ (x ::: _) = x
listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
listrIndex _ ZR = error "k + 1 <= 0"
-listrLengthSNat :: ListR n i -> SNat n
-listrLengthSNat ZR = SNat
-listrLengthSNat (_ ::: (sh :: ListR n i)) = snatSucc (listrLengthSNat sh)
+listrRank :: ListR n i -> SNat n
+listrRank ZR = SNat
+listrRank (_ ::: (sh :: ListR n i)) = snatSucc (listrRank sh)
listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
listrPermutePrefix = \perm sh ->
listrFromList perm $ \sperm ->
- case (listrLengthSNat sperm, listrLengthSNat sh) of
+ case (listrRank sperm, listrRank sh) of
(permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of
LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
@@ -168,8 +168,8 @@ ixrTail (IxR list) = IxR (listrTail list)
ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i
ixrAppend = coerce (listrAppend @_ @i)
-ixrLengthSNat :: IxR n i -> SNat n
-ixrLengthSNat (IxR sh) = listrLengthSNat sh
+ixrRank :: IxR n i -> SNat n
+ixrRank (IxR sh) = listrRank sh
ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
ixrPermutePrefix = coerce (listrPermutePrefix @i)
@@ -237,8 +237,8 @@ shrTail (ShR list) = ShR (listrTail list)
shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i
shrAppend = coerce (listrAppend @_ @i)
-shrLengthSNat :: ShR n i -> SNat n
-shrLengthSNat (ShR sh) = listrLengthSNat sh
+shrRank :: ShR n i -> SNat n
+shrRank (ShR sh) = listrRank sh
shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i
shrPermutePrefix = coerce (listrPermutePrefix @i)