aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-06-09 20:47:09 +0200
committerTom Smeding <tom@tomsmeding.com>2024-06-09 20:47:30 +0200
commit5763bf70dc67c5437207ff8e9dd08585d2ea5384 (patch)
tree8b68dae165940368925a3cbe816a61a65eb23b68 /src/Data/Array
parentcb98a56767d50fe92790ae4f48a3efbb28aab90a (diff)
Concatenation of arrays for M and R
What should the type of sconcat be?
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Mixed/XArray.hs10
-rw-r--r--src/Data/Array/Nested.hs7
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs23
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs9
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs4
5 files changed, 48 insertions, 5 deletions
diff --git a/src/Data/Array/Mixed/XArray.hs b/src/Data/Array/Mixed/XArray.hs
index e93ffca..20f5c7a 100644
--- a/src/Data/Array/Mixed/XArray.hs
+++ b/src/Data/Array/Mixed/XArray.hs
@@ -18,7 +18,9 @@ import Control.DeepSeq (NFData(..))
import Data.Array.Ranked qualified as ORB
import Data.Array.RankedS qualified as S
import Data.Coerce
+import Data.Foldable (toList)
import Data.Kind
+import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
import Data.Type.Equality
import Data.Type.Ord
@@ -117,6 +119,14 @@ append ssh (XArray a) (XArray b)
| Dict <- lemKnownNatRankSSX ssh
= XArray (S.append a b)
+-- | All arrays must have the same shape, except possibly for the outermost
+-- dimension.
+concat :: Storable a
+ => StaticShX sh -> NonEmpty (XArray (Nothing : sh) a) -> XArray (Nothing : sh) a
+concat ssh l
+ | Dict <- lemKnownNatRankSSX ssh
+ = XArray (S.concatOuter (coerce (toList l)))
+
-- | If the prefix of the shape of the input array (@sh@) is empty (i.e.
-- contains a zero), then there is no way to deduce the full shape of the output
-- array (more precisely, the @sh2@ part): that could only come from calling
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 6a919e8..7cc1de3 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -7,7 +7,7 @@ module Data.Array.Nested (
IxR(.., ZIR, (:.:)), IIxR,
ShR(.., ZSR, (:$:)), IShR,
rshape, rrank, rindex, rindexPartial, rgenerate, rsumOuter1,
- rtranspose, rappend, rscalar, rfromVector, rtoVector, runScalar,
+ rtranspose, rappend, rconcat, rscalar, rfromVector, rtoVector, runScalar,
rrerank,
rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1,
rslice, rrev1, rreshape, riota,
@@ -25,6 +25,7 @@ module Data.Array.Nested (
ShS(.., ZSS, (:$$)), KnownShS(..),
sshape, sindex, sindexPartial, sgenerate, ssumOuter1,
stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar,
+ -- TODO: sconcat? What should its type be?
srerank,
sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1,
sslice, srev1, sreshape, siota,
@@ -39,7 +40,7 @@ module Data.Array.Nested (
IxX(..), IIxX,
KnownShX(..), StaticShX(..),
mshape, mindex, mindexPartial, mgenerate, msumOuter1,
- mtranspose, mappend, mscalar, mfromVector, mtoVector, munScalar,
+ mtranspose, mappend, mconcat, mscalar, mfromVector, mtoVector, munScalar,
mrerank,
mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1,
mslice, mrev1, mreshape, miota,
@@ -66,7 +67,7 @@ module Data.Array.Nested (
NumElt, FloatElt,
) where
-import Prelude hiding (mappend)
+import Prelude hiding (mappend, mconcat)
import Data.Array.Mixed.Internal.Arith
import Data.Array.Mixed.Permutation
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index dcd86d1..6d601b8 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -18,15 +18,19 @@
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Nested.Internal.Mixed where
+import Prelude hiding (mconcat)
+
import Control.DeepSeq (NFData)
import Control.Monad (forM_, when)
import Control.Monad.ST
import Data.Array.RankedS qualified as S
+import Data.Bifunctor (bimap)
import Data.Coerce
import Data.Foldable (toList)
import Data.Int
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty(..))
+import Data.List.NonEmpty qualified as NE
import Data.Proxy
import Data.Type.Equality
import Data.Vector.Storable qualified as VS
@@ -280,6 +284,10 @@ class Elt a where
mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
=> Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a
+ -- | All arrays in the input must have equal shapes, including subarrays
+ -- inside their elements.
+ mconcat :: NonEmpty (Mixed (Nothing : sh) a) -> Mixed (Nothing : sh) a
+
-- ====== PRIVATE METHODS ====== --
-- | Tree giving the shape of every array component.
@@ -366,6 +374,11 @@ instance Storable a => Elt (Primitive a) where
M_Primitive (shxPermutePrefix perm sh)
(X.transpose (ssxFromShape sh) perm arr)
+ mconcat :: forall sh. NonEmpty (Mixed (Nothing : sh) (Primitive a)) -> Mixed (Nothing : sh) (Primitive a)
+ mconcat l@(M_Primitive (_ :$% sh) _ :| _) =
+ let result = X.concat (ssxFromShape sh) (fmap (\(M_Primitive _ arr) -> arr) l)
+ in M_Primitive (X.shape (SUnknown () :!% ssxFromShape sh) result) result
+
type ShapeTree (Primitive a) = ()
mshapeTree _ = ()
mshapeTreeEq _ () () = True
@@ -424,6 +437,11 @@ instance (Elt a, Elt b) => Elt (a, b) where
M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b)
mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b)
+ mconcat =
+ let unzipT2l [] = ([], [])
+ unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2)
+ unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2)
+ in uncurry M_Tup2 . bimap mconcat mconcat . unzipT2
type ShapeTree (a, b) = (ShapeTree a, ShapeTree b)
mshapeTree (x, y) = (mshapeTree x, mshapeTree y)
@@ -526,6 +544,11 @@ instance Elt a => Elt (Mixed sh' a) where
= M_Nest (shxPermutePrefix perm sh)
(mtranspose perm arr)
+ mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
+ mconcat l@(M_Nest sh1 _ :| _) =
+ let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l)
+ in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape result))) result
+
type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a)
mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a)
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index 9383b08..3e911ac 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -18,7 +18,7 @@
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Nested.Internal.Ranked where
-import Prelude hiding (mappend)
+import Prelude hiding (mappend, mconcat)
import Control.DeepSeq (NFData)
import Control.Monad.ST
@@ -109,6 +109,8 @@ instance Elt a => Elt (Ranked n a) where
mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr)
+ mconcat l = M_Ranked (mconcat (coerce l))
+
type ShapeTree (Ranked n a) = (IShR n, ShapeTree a)
mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr)
@@ -272,6 +274,11 @@ rtranspose perm arr
| otherwise
= error "Data.Array.Nested.rtranspose: Permutation longer than rank of array"
+rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a
+rconcat
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = coerce mconcat
+
rappend :: forall n a. Elt a
=> Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a
rappend arr1 arr2
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index b4dc80d..7d523b0 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -17,7 +17,7 @@
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Nested.Internal.Shaped where
-import Prelude hiding (mappend)
+import Prelude hiding (mappend, mconcat)
import Control.DeepSeq (NFData)
import Control.Monad.ST
@@ -104,6 +104,8 @@ instance Elt a => Elt (Shaped sh a) where
mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr)
+ mconcat l = M_Shaped (mconcat (coerce l))
+
type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)
mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr)