diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-04-20 20:11:24 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-04-20 22:08:07 +0200 |
commit | 18139715c7e11e7d3dbb2cf769f64c2a725832e2 (patch) | |
tree | a99ce84925018cf26ff043c54e86cc519039a7b1 /src/Data/Array/Nested/Internal.hs | |
parent | 3e37091f172b846e93a268695aec72838cc1bdf3 (diff) |
fromList
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 71 |
1 files changed, 71 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index f769870..759094e 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -37,7 +37,9 @@ import Control.Monad.ST import qualified Data.Array.RankedS as S import Data.Bifunctor (first) import Data.Coerce (coerce, Coercible) +import Data.Foldable (toList) import Data.Kind +import Data.List.NonEmpty (NonEmpty) import Data.Proxy import Data.Type.Equality import qualified Data.Vector.Storable as VS @@ -50,6 +52,29 @@ import qualified Data.Array.Mixed as X import Data.INat +-- Invariant in the API +-- ==================== +-- +-- In the underlying XArray, there is some shape for elements of an empty +-- array. For example, for this array: +-- +-- arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float) +-- rshape arr == 0 ::: 0 ::: 0 ::: IZR +-- +-- the two underlying XArrays have a shape, and those shapes might be anything. +-- The invariant is that these element shapes are unobservable in the API. +-- (This is possible because you ought to not be able to get to such an element +-- without indexing out of bounds.) +-- +-- Note, though, that the converse situation may arise: the outer array might +-- be nonempty but then the inner arrays might. This is fine, an invariant only +-- applies if the _outer_ array is empty. +-- +-- TODO: can we enforce that the elements of an empty (nested) array have +-- all-zero shape? +-- -> no, because mlift and also any kind of internals probing from outsiders + + type family Replicate n a where Replicate Z a = '[] Replicate (S n) a = a : Replicate n a @@ -105,6 +130,9 @@ newtype Primitive a = Primitive a -- class. type Mixed :: [Maybe Nat] -> Type -> Type data family Mixed sh a +-- NOTE: When opening up the Mixed abstraction, you might see dimension sizes +-- that you're not supposed to see. In particular, an empty array may have +-- elements with nonempty sizes, but then the whole array is still empty. newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a) deriving (Show) @@ -167,6 +195,17 @@ class Elt a where mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a mscalar :: a -> Mixed '[] a + -- | All arrays in the list, even subarrays inside @a@, must have the same + -- shape; if they do not, a runtime error will be thrown. See the + -- documentation of 'mgenerate' for more information about this restriction. + -- Furthermore, the length of the list must correspond with @n@: if @n@ is + -- @Just m@ and @m@ does not equal the length of the list, a runtime error is + -- thrown. + -- + -- If you want a single-dimensional array from your list, map 'mscalar' + -- first. + mfromList :: forall n sh. KnownShapeX (n : sh) => NonEmpty (Mixed sh a) -> Mixed (n : sh) a + mlift :: forall sh1 sh2. KnownShapeX sh2 => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 a -> Mixed sh2 a @@ -215,6 +254,7 @@ instance Storable a => Elt (Primitive a) where mindex (M_Primitive a) i = Primitive (X.index a i) mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i) mscalar (Primitive x) = M_Primitive (X.scalar x) + mfromList l = M_Primitive (X.fromList knownShapeX [x | M_Primitive x <- toList l]) mlift :: forall sh1 sh2. (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) @@ -263,6 +303,8 @@ instance (Elt a, Elt b) => Elt (a, b) where mindex (M_Tup2 a b) i = (mindex a i, mindex b i) mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) + mfromList l = M_Tup2 (mfromList ((\(M_Tup2 x _) -> x) <$> l)) + (mfromList ((\(M_Tup2 _ y) -> y) <$> l)) mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b) mlift2 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 f a x) (mlift2 f b y) @@ -302,6 +344,12 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where mscalar x = M_Nest x + mfromList :: forall n sh. KnownShapeX (n : sh) + => NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (n : sh) (Mixed sh' a) + mfromList l + | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @(n : sh)) (knownShapeX @sh')) + = M_Nest (mfromList ((\(M_Nest x) -> x) <$> l)) + mlift :: forall sh1 sh2. KnownShapeX sh2 => (forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) @@ -546,6 +594,12 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where mscalar (Ranked x) = M_Ranked (M_Nest x) + mfromList :: forall m sh. KnownShapeX (m : sh) + => NonEmpty (Mixed sh (Ranked n a)) -> Mixed (m : sh) (Ranked n a) + mfromList l + | Dict <- lemKnownReplicate (Proxy @n) + = M_Ranked (mfromList ((\(M_Ranked x) -> x) <$> l)) + mlift :: forall sh1 sh2. KnownShapeX sh2 => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) @@ -664,6 +718,12 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where mscalar (Shaped x) = M_Shaped (M_Nest x) + mfromList :: forall n sh'. KnownShapeX (n : sh') + => NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (n : sh') (Shaped sh a) + mfromList l + | Dict <- lemKnownMapJust (Proxy @sh) + = M_Shaped (mfromList ((\(M_Shaped x) -> x) <$> l)) + mlift :: forall sh1 sh2. KnownShapeX sh2 => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) @@ -870,6 +930,11 @@ rconstant sh x | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mconstant (ixCvtRX sh) x) +rfromList :: forall n a. (KnownINat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (S n) a +rfromList l + | Dict <- lemKnownReplicate (Proxy @n) + = Ranked (mfromList ((\(Ranked x) -> x) <$> l)) + -- ====== API OF SHAPED ARRAYS ====== -- @@ -993,3 +1058,9 @@ sconstant :: forall sh a. (KnownShape sh, Storable a, Coercible (Mixed (MapJust sconstant x | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mconstant (ixCvtSX (cvtSShapeIxS (knownShape @sh))) x) + +sfromList :: forall n sh a. (KnownNat n, KnownShape sh, Elt a) + => NonEmpty (Shaped sh a) -> Shaped (n : sh) a +sfromList l + | Dict <- lemKnownMapJust (Proxy @sh) + = Shaped (mfromList ((\(Shaped x) -> x) <$> l)) |