aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-06-14 10:08:03 +0200
committerTom Smeding <tom@tomsmeding.com>2024-06-14 10:08:03 +0200
commit22f8f053f9ea2a3273d25f49ecd88a30ad506972 (patch)
treeda1b7879fbb87479874490d5a2d2680a7979f593
parentc6b912051ddac25c9d7efe2f8162eac9068a335c (diff)
Export full [mrs]{shape,rank,size} set
-rw-r--r--src/Data/Array/Nested.hs6
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs7
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs6
-rw-r--r--src/Data/Array/Nested/Internal/Shape.hs9
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs7
5 files changed, 30 insertions, 5 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 65d619a..911a525 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -6,7 +6,7 @@ module Data.Array.Nested (
ListR(ZR, (:::)),
IxR(.., ZIR, (:.:)), IIxR,
ShR(.., ZSR, (:$:)), IShR,
- rshape, rrank, rindex, rindexPartial, rgenerate, rsumOuter1,
+ rshape, rrank, rsize, rindex, rindexPartial, rgenerate, rsumOuter1,
rtranspose, rappend, rconcat, rscalar, rfromVector, rtoVector, runScalar,
rrerank,
rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1,
@@ -25,7 +25,7 @@ module Data.Array.Nested (
ListS(ZS, (::$)),
IxS(.., ZIS, (:.$)), IIxS,
ShS(.., ZSS, (:$$)), KnownShS(..),
- sshape, sindex, sindexPartial, sgenerate, ssumOuter1,
+ sshape, srank, ssize, sindex, sindexPartial, sgenerate, ssumOuter1,
stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar,
-- TODO: sconcat? What should its type be?
srerank,
@@ -44,7 +44,7 @@ module Data.Array.Nested (
Mixed,
IxX(..), IIxX,
KnownShX(..), StaticShX(..),
- mshape, mindex, mindexPartial, mgenerate, msumOuter1,
+ mshape, mrank, msize, mindex, mindexPartial, mgenerate, msumOuter1,
mtranspose, mappend, mconcat, mscalar, mfromVector, mtoVector, munScalar,
mrerank,
mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1,
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index e5d53e8..8421372 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -636,6 +636,13 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
+mrank :: Elt a => Mixed sh a -> SNat (Rank sh)
+mrank = shxRank . mshape
+
+-- | The total number of elements in the array.
+msize :: Elt a => Mixed sh a -> Int
+msize = shxSize . mshape
+
-- | Create an array given a size and a function that computes the element at a
-- given index.
--
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index 6a4db8e..306acc0 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -223,12 +223,16 @@ instance (FloatElt a, NumElt a, PrimElt a, Num a) => Floating (Ranked n a) where
log1mexp = arithPromoteRanked GHC.Float.log1mexp
-rshape :: forall n a. Elt a => Ranked n a -> IShR n
+rshape :: Elt a => Ranked n a -> IShR n
rshape (Ranked arr) = shCvtXR' (mshape arr)
rrank :: Elt a => Ranked n a -> SNat n
rrank = shrRank . rshape
+-- | The total number of elements in the array.
+rsize :: Elt a => Ranked n a -> Int
+rsize = shrSize . rshape
+
rindex :: Elt a => Ranked n a -> IIxR n -> a
rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs
index f449536..ca04840 100644
--- a/src/Data/Array/Nested/Internal/Shape.hs
+++ b/src/Data/Array/Nested/Internal/Shape.hs
@@ -94,7 +94,7 @@ listrIndex _ ZR = error "k + 1 <= 0"
listrRank :: ListR n i -> SNat n
listrRank ZR = SNat
-listrRank (_ ::: (sh :: ListR n i)) = snatSucc (listrRank sh)
+listrRank (_ ::: sh) = snatSucc (listrRank sh)
listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
listrPermutePrefix = \perm sh ->
@@ -320,6 +320,10 @@ listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f
listsAppend ZS idx' = idx'
listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'
+listsRank :: ListS sh i -> SNat (Rank sh)
+listsRank ZS = SNat
+listsRank (_ ::$ sh) = snatSucc (listsRank sh)
+
listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f
listsTakeLenPerm PNil _ = ZS
listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh
@@ -435,6 +439,9 @@ instance Show (ShS sh) where
shsLength :: ShS sh -> Int
shsLength (ShS l) = getSum (listsFold (\_ -> Sum 1) l)
+shsRank :: ShS sh -> SNat (Rank sh)
+shsRank (ShS l) = listsRank l
+
shsToList :: ShS sh -> [Int]
shsToList ZSS = []
shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index 5765595..9588017 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -225,6 +225,13 @@ instance (FloatElt a, NumElt a, PrimElt a, Num a) => Floating (Shaped sh a) wher
sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
sshape (Shaped arr) = shCvtXS' (mshape arr)
+srank :: Elt a => Shaped sh a -> SNat (Rank sh)
+srank = shsRank . sshape
+
+-- | The total number of elements in the array.
+ssize :: Elt a => Shaped sh a -> Int
+ssize = shsSize . sshape
+
sindex :: Elt a => Shaped sh a -> IIxS sh -> a
sindex (Shaped arr) idx = mindex arr (ixCvtSX idx)