From 5763bf70dc67c5437207ff8e9dd08585d2ea5384 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 9 Jun 2024 20:47:09 +0200 Subject: Concatenation of arrays for M and R What should the type of sconcat be? --- src/Data/Array/Nested/Internal/Mixed.hs | 23 +++++++++++++++++++++++ src/Data/Array/Nested/Internal/Ranked.hs | 9 ++++++++- src/Data/Array/Nested/Internal/Shaped.hs | 4 +++- 3 files changed, 34 insertions(+), 2 deletions(-) (limited to 'src/Data/Array/Nested/Internal') diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index dcd86d1..6d601b8 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -18,15 +18,19 @@ {-# LANGUAGE ViewPatterns #-} module Data.Array.Nested.Internal.Mixed where +import Prelude hiding (mconcat) + import Control.DeepSeq (NFData) import Control.Monad (forM_, when) import Control.Monad.ST import Data.Array.RankedS qualified as S +import Data.Bifunctor (bimap) import Data.Coerce import Data.Foldable (toList) import Data.Int import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty(..)) +import Data.List.NonEmpty qualified as NE import Data.Proxy import Data.Type.Equality import Data.Vector.Storable qualified as VS @@ -280,6 +284,10 @@ class Elt a where mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a + -- | All arrays in the input must have equal shapes, including subarrays + -- inside their elements. + mconcat :: NonEmpty (Mixed (Nothing : sh) a) -> Mixed (Nothing : sh) a + -- ====== PRIVATE METHODS ====== -- -- | Tree giving the shape of every array component. @@ -366,6 +374,11 @@ instance Storable a => Elt (Primitive a) where M_Primitive (shxPermutePrefix perm sh) (X.transpose (ssxFromShape sh) perm arr) + mconcat :: forall sh. NonEmpty (Mixed (Nothing : sh) (Primitive a)) -> Mixed (Nothing : sh) (Primitive a) + mconcat l@(M_Primitive (_ :$% sh) _ :| _) = + let result = X.concat (ssxFromShape sh) (fmap (\(M_Primitive _ arr) -> arr) l) + in M_Primitive (X.shape (SUnknown () :!% ssxFromShape sh) result) result + type ShapeTree (Primitive a) = () mshapeTree _ = () mshapeTreeEq _ () () = True @@ -424,6 +437,11 @@ instance (Elt a, Elt b) => Elt (a, b) where M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b) mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b) + mconcat = + let unzipT2l [] = ([], []) + unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2) + unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2) + in uncurry M_Tup2 . bimap mconcat mconcat . unzipT2 type ShapeTree (a, b) = (ShapeTree a, ShapeTree b) mshapeTree (x, y) = (mshapeTree x, mshapeTree y) @@ -526,6 +544,11 @@ instance Elt a => Elt (Mixed sh' a) where = M_Nest (shxPermutePrefix perm sh) (mtranspose perm arr) + mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) + mconcat l@(M_Nest sh1 _ :| _) = + let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l) + in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape result))) result + type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a) mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index 9383b08..3e911ac 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -18,7 +18,7 @@ {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Nested.Internal.Ranked where -import Prelude hiding (mappend) +import Prelude hiding (mappend, mconcat) import Control.DeepSeq (NFData) import Control.Monad.ST @@ -109,6 +109,8 @@ instance Elt a => Elt (Ranked n a) where mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr) + mconcat l = M_Ranked (mconcat (coerce l)) + type ShapeTree (Ranked n a) = (IShR n, ShapeTree a) mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr) @@ -272,6 +274,11 @@ rtranspose perm arr | otherwise = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array" +rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a +rconcat + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = coerce mconcat + rappend :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a rappend arr1 arr2 diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index b4dc80d..7d523b0 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -17,7 +17,7 @@ {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Nested.Internal.Shaped where -import Prelude hiding (mappend) +import Prelude hiding (mappend, mconcat) import Control.DeepSeq (NFData) import Control.Monad.ST @@ -104,6 +104,8 @@ instance Elt a => Elt (Shaped sh a) where mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) + mconcat l = M_Shaped (mconcat (coerce l)) + type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr) -- cgit v1.2.3-70-g09d2