From 22f8f053f9ea2a3273d25f49ecd88a30ad506972 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Fri, 14 Jun 2024 10:08:03 +0200 Subject: Export full [mrs]{shape,rank,size} set --- src/Data/Array/Nested.hs | 6 +++--- src/Data/Array/Nested/Internal/Mixed.hs | 7 +++++++ src/Data/Array/Nested/Internal/Ranked.hs | 6 +++++- src/Data/Array/Nested/Internal/Shape.hs | 9 ++++++++- src/Data/Array/Nested/Internal/Shaped.hs | 7 +++++++ 5 files changed, 30 insertions(+), 5 deletions(-) diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 65d619a..911a525 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -6,7 +6,7 @@ module Data.Array.Nested ( ListR(ZR, (:::)), IxR(.., ZIR, (:.:)), IIxR, ShR(.., ZSR, (:$:)), IShR, - rshape, rrank, rindex, rindexPartial, rgenerate, rsumOuter1, + rshape, rrank, rsize, rindex, rindexPartial, rgenerate, rsumOuter1, rtranspose, rappend, rconcat, rscalar, rfromVector, rtoVector, runScalar, rrerank, rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1, @@ -25,7 +25,7 @@ module Data.Array.Nested ( ListS(ZS, (::$)), IxS(.., ZIS, (:.$)), IIxS, ShS(.., ZSS, (:$$)), KnownShS(..), - sshape, sindex, sindexPartial, sgenerate, ssumOuter1, + sshape, srank, ssize, sindex, sindexPartial, sgenerate, ssumOuter1, stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar, -- TODO: sconcat? What should its type be? srerank, @@ -44,7 +44,7 @@ module Data.Array.Nested ( Mixed, IxX(..), IIxX, KnownShX(..), StaticShX(..), - mshape, mindex, mindexPartial, mgenerate, msumOuter1, + mshape, mrank, msize, mindex, mindexPartial, mgenerate, msumOuter1, mtranspose, mappend, mconcat, mscalar, mfromVector, mtoVector, munScalar, mrerank, mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1, diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index e5d53e8..8421372 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -636,6 +636,13 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) +mrank :: Elt a => Mixed sh a -> SNat (Rank sh) +mrank = shxRank . mshape + +-- | The total number of elements in the array. +msize :: Elt a => Mixed sh a -> Int +msize = shxSize . mshape + -- | Create an array given a size and a function that computes the element at a -- given index. -- diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index 6a4db8e..306acc0 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -223,12 +223,16 @@ instance (FloatElt a, NumElt a, PrimElt a, Num a) => Floating (Ranked n a) where log1mexp = arithPromoteRanked GHC.Float.log1mexp -rshape :: forall n a. Elt a => Ranked n a -> IShR n +rshape :: Elt a => Ranked n a -> IShR n rshape (Ranked arr) = shCvtXR' (mshape arr) rrank :: Elt a => Ranked n a -> SNat n rrank = shrRank . rshape +-- | The total number of elements in the array. +rsize :: Elt a => Ranked n a -> Int +rsize = shrSize . rshape + rindex :: Elt a => Ranked n a -> IIxR n -> a rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs index f449536..ca04840 100644 --- a/src/Data/Array/Nested/Internal/Shape.hs +++ b/src/Data/Array/Nested/Internal/Shape.hs @@ -94,7 +94,7 @@ listrIndex _ ZR = error "k + 1 <= 0" listrRank :: ListR n i -> SNat n listrRank ZR = SNat -listrRank (_ ::: (sh :: ListR n i)) = snatSucc (listrRank sh) +listrRank (_ ::: sh) = snatSucc (listrRank sh) listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i listrPermutePrefix = \perm sh -> @@ -320,6 +320,10 @@ listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f listsAppend ZS idx' = idx' listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' +listsRank :: ListS sh i -> SNat (Rank sh) +listsRank ZS = SNat +listsRank (_ ::$ sh) = snatSucc (listsRank sh) + listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f listsTakeLenPerm PNil _ = ZS listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh @@ -435,6 +439,9 @@ instance Show (ShS sh) where shsLength :: ShS sh -> Int shsLength (ShS l) = getSum (listsFold (\_ -> Sum 1) l) +shsRank :: ShS sh -> SNat (Rank sh) +shsRank (ShS l) = listsRank l + shsToList :: ShS sh -> [Int] shsToList ZSS = [] shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index 5765595..9588017 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -225,6 +225,13 @@ instance (FloatElt a, NumElt a, PrimElt a, Num a) => Floating (Shaped sh a) wher sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh sshape (Shaped arr) = shCvtXS' (mshape arr) +srank :: Elt a => Shaped sh a -> SNat (Rank sh) +srank = shsRank . sshape + +-- | The total number of elements in the array. +ssize :: Elt a => Shaped sh a -> Int +ssize = shsSize . sshape + sindex :: Elt a => Shaped sh a -> IIxS sh -> a sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) -- cgit v1.2.3-70-g09d2