diff options
Diffstat (limited to 'src/Data')
| -rw-r--r-- | src/Data/Array/Mixed/XArray.hs | 10 | ||||
| -rw-r--r-- | src/Data/Array/Nested.hs | 7 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 23 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 9 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 4 | 
5 files changed, 48 insertions, 5 deletions
| diff --git a/src/Data/Array/Mixed/XArray.hs b/src/Data/Array/Mixed/XArray.hs index e93ffca..20f5c7a 100644 --- a/src/Data/Array/Mixed/XArray.hs +++ b/src/Data/Array/Mixed/XArray.hs @@ -18,7 +18,9 @@ import Control.DeepSeq (NFData(..))  import Data.Array.Ranked qualified as ORB  import Data.Array.RankedS qualified as S  import Data.Coerce +import Data.Foldable (toList)  import Data.Kind +import Data.List.NonEmpty (NonEmpty)  import Data.Proxy  import Data.Type.Equality  import Data.Type.Ord @@ -117,6 +119,14 @@ append ssh (XArray a) (XArray b)    | Dict <- lemKnownNatRankSSX ssh    = XArray (S.append a b) +-- | All arrays must have the same shape, except possibly for the outermost +-- dimension. +concat :: Storable a +       => StaticShX sh -> NonEmpty (XArray (Nothing : sh) a) -> XArray (Nothing : sh) a +concat ssh l +  | Dict <- lemKnownNatRankSSX ssh +  = XArray (S.concatOuter (coerce (toList l))) +  -- | If the prefix of the shape of the input array (@sh@) is empty (i.e.  -- contains a zero), then there is no way to deduce the full shape of the output  -- array (more precisely, the @sh2@ part): that could only come from calling diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 6a919e8..7cc1de3 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -7,7 +7,7 @@ module Data.Array.Nested (    IxR(.., ZIR, (:.:)), IIxR,    ShR(.., ZSR, (:$:)), IShR,    rshape, rrank, rindex, rindexPartial, rgenerate, rsumOuter1, -  rtranspose, rappend, rscalar, rfromVector, rtoVector, runScalar, +  rtranspose, rappend, rconcat, rscalar, rfromVector, rtoVector, runScalar,    rrerank,    rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1,    rslice, rrev1, rreshape, riota, @@ -25,6 +25,7 @@ module Data.Array.Nested (    ShS(.., ZSS, (:$$)), KnownShS(..),    sshape, sindex, sindexPartial, sgenerate, ssumOuter1,    stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar, +  -- TODO: sconcat? What should its type be?    srerank,    sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1,    sslice, srev1, sreshape, siota, @@ -39,7 +40,7 @@ module Data.Array.Nested (    IxX(..), IIxX,    KnownShX(..), StaticShX(..),    mshape, mindex, mindexPartial, mgenerate, msumOuter1, -  mtranspose, mappend, mscalar, mfromVector, mtoVector, munScalar, +  mtranspose, mappend, mconcat, mscalar, mfromVector, mtoVector, munScalar,    mrerank,    mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1,    mslice, mrev1, mreshape, miota, @@ -66,7 +67,7 @@ module Data.Array.Nested (    NumElt, FloatElt,  ) where -import Prelude hiding (mappend) +import Prelude hiding (mappend, mconcat)  import Data.Array.Mixed.Internal.Arith  import Data.Array.Mixed.Permutation diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index dcd86d1..6d601b8 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -18,15 +18,19 @@  {-# LANGUAGE ViewPatterns #-}  module Data.Array.Nested.Internal.Mixed where +import Prelude hiding (mconcat) +  import Control.DeepSeq (NFData)  import Control.Monad (forM_, when)  import Control.Monad.ST  import Data.Array.RankedS qualified as S +import Data.Bifunctor (bimap)  import Data.Coerce  import Data.Foldable (toList)  import Data.Int  import Data.Kind (Type)  import Data.List.NonEmpty (NonEmpty(..)) +import Data.List.NonEmpty qualified as NE  import Data.Proxy  import Data.Type.Equality  import Data.Vector.Storable qualified as VS @@ -280,6 +284,10 @@ class Elt a where    mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)               => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a +  -- | All arrays in the input must have equal shapes, including subarrays +  -- inside their elements. +  mconcat :: NonEmpty (Mixed (Nothing : sh) a) -> Mixed (Nothing : sh) a +    -- ====== PRIVATE METHODS ====== --    -- | Tree giving the shape of every array component. @@ -366,6 +374,11 @@ instance Storable a => Elt (Primitive a) where      M_Primitive (shxPermutePrefix perm sh)                  (X.transpose (ssxFromShape sh) perm arr) +  mconcat :: forall sh. NonEmpty (Mixed (Nothing : sh) (Primitive a)) -> Mixed (Nothing : sh) (Primitive a) +  mconcat l@(M_Primitive (_ :$% sh) _ :| _) = +    let result = X.concat (ssxFromShape sh) (fmap (\(M_Primitive _ arr) -> arr) l) +    in M_Primitive (X.shape (SUnknown () :!% ssxFromShape sh) result) result +    type ShapeTree (Primitive a) = ()    mshapeTree _ = ()    mshapeTreeEq _ () () = True @@ -424,6 +437,11 @@ instance (Elt a, Elt b) => Elt (a, b) where      M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b)    mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b) +  mconcat = +    let unzipT2l [] = ([], []) +        unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2) +        unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2) +    in uncurry M_Tup2 . bimap mconcat mconcat . unzipT2    type ShapeTree (a, b) = (ShapeTree a, ShapeTree b)    mshapeTree (x, y) = (mshapeTree x, mshapeTree y) @@ -526,6 +544,11 @@ instance Elt a => Elt (Mixed sh' a) where      = M_Nest (shxPermutePrefix perm sh)               (mtranspose perm arr) +  mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) +  mconcat l@(M_Nest sh1 _ :| _) = +    let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l) +    in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape result))) result +    type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a)    mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index 9383b08..3e911ac 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -18,7 +18,7 @@  {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}  module Data.Array.Nested.Internal.Ranked where -import Prelude hiding (mappend) +import Prelude hiding (mappend, mconcat)  import Control.DeepSeq (NFData)  import Control.Monad.ST @@ -109,6 +109,8 @@ instance Elt a => Elt (Ranked n a) where    mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr) +  mconcat l = M_Ranked (mconcat (coerce l)) +    type ShapeTree (Ranked n a) = (IShR n, ShapeTree a)    mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr) @@ -272,6 +274,11 @@ rtranspose perm arr    | otherwise    = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array" +rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a +rconcat +  | Refl <- lemReplicateSucc @(Nothing @Nat) @n +  = coerce mconcat +  rappend :: forall n a. Elt a          => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a  rappend arr1 arr2 diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index b4dc80d..7d523b0 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -17,7 +17,7 @@  {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}  module Data.Array.Nested.Internal.Shaped where -import Prelude hiding (mappend) +import Prelude hiding (mappend, mconcat)  import Control.DeepSeq (NFData)  import Control.Monad.ST @@ -104,6 +104,8 @@ instance Elt a => Elt (Shaped sh a) where    mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) +  mconcat l = M_Shaped (mconcat (coerce l)) +    type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)    mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr) | 
