diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-04-20 20:11:24 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-04-20 22:08:07 +0200 |
commit | 18139715c7e11e7d3dbb2cf769f64c2a725832e2 (patch) | |
tree | a99ce84925018cf26ff043c54e86cc519039a7b1 /src/Data/Array | |
parent | 3e37091f172b846e93a268695aec72838cc1bdf3 (diff) |
fromList
Diffstat (limited to 'src/Data/Array')
-rw-r--r-- | src/Data/Array/Mixed.hs | 13 | ||||
-rw-r--r-- | src/Data/Array/Nested.hs | 6 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 71 |
3 files changed, 87 insertions, 3 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 17b0ab4..246f8fc 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -13,6 +13,8 @@ module Data.Array.Mixed where import qualified Data.Array.RankedS as S +import qualified Data.Array.Ranked as ORB +import Data.Coerce import Data.Kind import Data.Proxy import Data.Type.Equality @@ -347,3 +349,14 @@ sumOuter :: forall sh sh' a. (Storable a, Num a) sumOuter ssh ssh' | Refl <- lemAppNil @sh = sumInner ssh' ssh . transpose2 ssh ssh' + +fromList :: forall n sh a. Storable a + => StaticShapeX (n : sh) -> [XArray sh a] -> XArray (n : sh) a +fromList ssh l + | Dict <- lemKnownINatRankSSX ssh + , Dict <- knownNatFromINat (Proxy @(Rank (n : sh))) + = case ssh of + m@GHC_SNat :$@ _ | natVal m /= fromIntegral (length l) -> + error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++ + "does not match the type (" ++ show (natVal m) ++ ")" + _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (FromINat (Rank sh)) a] l))) diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index cd2dde7..9219a74 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -5,7 +5,7 @@ module Data.Array.Nested ( IxR(..), rshape, rindex, rindexPartial, rgenerate, rsumOuter1, rtranspose, rappend, rscalar, rfromVector, runScalar, - rconstant, + rconstant, rfromList, -- ** Lifting orthotope operations to 'Ranked' arrays rlift, @@ -15,7 +15,7 @@ module Data.Array.Nested ( KnownShape(..), SShape(..), sshape, sindex, sindexPartial, sgenerate, ssumOuter1, stranspose, sappend, sscalar, sfromVector, sunScalar, - sconstant, + sconstant, sfromList, -- ** Lifting orthotope operations to 'Shaped' arrays slift, @@ -27,7 +27,7 @@ module Data.Array.Nested ( mconstant, -- * Array elements - Elt(mshape, mindex, mindexPartial, mscalar, mlift, mlift2), + Elt(mshape, mindex, mindexPartial, mscalar, mfromList, mlift, mlift2), Primitive(..), -- * Inductive natural numbers diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index f769870..759094e 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -37,7 +37,9 @@ import Control.Monad.ST import qualified Data.Array.RankedS as S import Data.Bifunctor (first) import Data.Coerce (coerce, Coercible) +import Data.Foldable (toList) import Data.Kind +import Data.List.NonEmpty (NonEmpty) import Data.Proxy import Data.Type.Equality import qualified Data.Vector.Storable as VS @@ -50,6 +52,29 @@ import qualified Data.Array.Mixed as X import Data.INat +-- Invariant in the API +-- ==================== +-- +-- In the underlying XArray, there is some shape for elements of an empty +-- array. For example, for this array: +-- +-- arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float) +-- rshape arr == 0 ::: 0 ::: 0 ::: IZR +-- +-- the two underlying XArrays have a shape, and those shapes might be anything. +-- The invariant is that these element shapes are unobservable in the API. +-- (This is possible because you ought to not be able to get to such an element +-- without indexing out of bounds.) +-- +-- Note, though, that the converse situation may arise: the outer array might +-- be nonempty but then the inner arrays might. This is fine, an invariant only +-- applies if the _outer_ array is empty. +-- +-- TODO: can we enforce that the elements of an empty (nested) array have +-- all-zero shape? +-- -> no, because mlift and also any kind of internals probing from outsiders + + type family Replicate n a where Replicate Z a = '[] Replicate (S n) a = a : Replicate n a @@ -105,6 +130,9 @@ newtype Primitive a = Primitive a -- class. type Mixed :: [Maybe Nat] -> Type -> Type data family Mixed sh a +-- NOTE: When opening up the Mixed abstraction, you might see dimension sizes +-- that you're not supposed to see. In particular, an empty array may have +-- elements with nonempty sizes, but then the whole array is still empty. newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a) deriving (Show) @@ -167,6 +195,17 @@ class Elt a where mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a mscalar :: a -> Mixed '[] a + -- | All arrays in the list, even subarrays inside @a@, must have the same + -- shape; if they do not, a runtime error will be thrown. See the + -- documentation of 'mgenerate' for more information about this restriction. + -- Furthermore, the length of the list must correspond with @n@: if @n@ is + -- @Just m@ and @m@ does not equal the length of the list, a runtime error is + -- thrown. + -- + -- If you want a single-dimensional array from your list, map 'mscalar' + -- first. + mfromList :: forall n sh. KnownShapeX (n : sh) => NonEmpty (Mixed sh a) -> Mixed (n : sh) a + mlift :: forall sh1 sh2. KnownShapeX sh2 => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 a -> Mixed sh2 a @@ -215,6 +254,7 @@ instance Storable a => Elt (Primitive a) where mindex (M_Primitive a) i = Primitive (X.index a i) mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i) mscalar (Primitive x) = M_Primitive (X.scalar x) + mfromList l = M_Primitive (X.fromList knownShapeX [x | M_Primitive x <- toList l]) mlift :: forall sh1 sh2. (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) @@ -263,6 +303,8 @@ instance (Elt a, Elt b) => Elt (a, b) where mindex (M_Tup2 a b) i = (mindex a i, mindex b i) mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) + mfromList l = M_Tup2 (mfromList ((\(M_Tup2 x _) -> x) <$> l)) + (mfromList ((\(M_Tup2 _ y) -> y) <$> l)) mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b) mlift2 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 f a x) (mlift2 f b y) @@ -302,6 +344,12 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where mscalar x = M_Nest x + mfromList :: forall n sh. KnownShapeX (n : sh) + => NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (n : sh) (Mixed sh' a) + mfromList l + | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @(n : sh)) (knownShapeX @sh')) + = M_Nest (mfromList ((\(M_Nest x) -> x) <$> l)) + mlift :: forall sh1 sh2. KnownShapeX sh2 => (forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) @@ -546,6 +594,12 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where mscalar (Ranked x) = M_Ranked (M_Nest x) + mfromList :: forall m sh. KnownShapeX (m : sh) + => NonEmpty (Mixed sh (Ranked n a)) -> Mixed (m : sh) (Ranked n a) + mfromList l + | Dict <- lemKnownReplicate (Proxy @n) + = M_Ranked (mfromList ((\(M_Ranked x) -> x) <$> l)) + mlift :: forall sh1 sh2. KnownShapeX sh2 => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) @@ -664,6 +718,12 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where mscalar (Shaped x) = M_Shaped (M_Nest x) + mfromList :: forall n sh'. KnownShapeX (n : sh') + => NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (n : sh') (Shaped sh a) + mfromList l + | Dict <- lemKnownMapJust (Proxy @sh) + = M_Shaped (mfromList ((\(M_Shaped x) -> x) <$> l)) + mlift :: forall sh1 sh2. KnownShapeX sh2 => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) @@ -870,6 +930,11 @@ rconstant sh x | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mconstant (ixCvtRX sh) x) +rfromList :: forall n a. (KnownINat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (S n) a +rfromList l + | Dict <- lemKnownReplicate (Proxy @n) + = Ranked (mfromList ((\(Ranked x) -> x) <$> l)) + -- ====== API OF SHAPED ARRAYS ====== -- @@ -993,3 +1058,9 @@ sconstant :: forall sh a. (KnownShape sh, Storable a, Coercible (Mixed (MapJust sconstant x | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mconstant (ixCvtSX (cvtSShapeIxS (knownShape @sh))) x) + +sfromList :: forall n sh a. (KnownNat n, KnownShape sh, Elt a) + => NonEmpty (Shaped sh a) -> Shaped (n : sh) a +sfromList l + | Dict <- lemKnownMapJust (Proxy @sh) + = Shaped (mfromList ((\(Shaped x) -> x) <$> l)) |