aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-04-20 20:11:24 +0200
committerTom Smeding <tom@tomsmeding.com>2024-04-20 22:08:07 +0200
commit18139715c7e11e7d3dbb2cf769f64c2a725832e2 (patch)
treea99ce84925018cf26ff043c54e86cc519039a7b1 /src/Data/Array/Nested/Internal.hs
parent3e37091f172b846e93a268695aec72838cc1bdf3 (diff)
fromList
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r--src/Data/Array/Nested/Internal.hs71
1 files changed, 71 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index f769870..759094e 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -37,7 +37,9 @@ import Control.Monad.ST
import qualified Data.Array.RankedS as S
import Data.Bifunctor (first)
import Data.Coerce (coerce, Coercible)
+import Data.Foldable (toList)
import Data.Kind
+import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
import Data.Type.Equality
import qualified Data.Vector.Storable as VS
@@ -50,6 +52,29 @@ import qualified Data.Array.Mixed as X
import Data.INat
+-- Invariant in the API
+-- ====================
+--
+-- In the underlying XArray, there is some shape for elements of an empty
+-- array. For example, for this array:
+--
+-- arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float)
+-- rshape arr == 0 ::: 0 ::: 0 ::: IZR
+--
+-- the two underlying XArrays have a shape, and those shapes might be anything.
+-- The invariant is that these element shapes are unobservable in the API.
+-- (This is possible because you ought to not be able to get to such an element
+-- without indexing out of bounds.)
+--
+-- Note, though, that the converse situation may arise: the outer array might
+-- be nonempty but then the inner arrays might. This is fine, an invariant only
+-- applies if the _outer_ array is empty.
+--
+-- TODO: can we enforce that the elements of an empty (nested) array have
+-- all-zero shape?
+-- -> no, because mlift and also any kind of internals probing from outsiders
+
+
type family Replicate n a where
Replicate Z a = '[]
Replicate (S n) a = a : Replicate n a
@@ -105,6 +130,9 @@ newtype Primitive a = Primitive a
-- class.
type Mixed :: [Maybe Nat] -> Type -> Type
data family Mixed sh a
+-- NOTE: When opening up the Mixed abstraction, you might see dimension sizes
+-- that you're not supposed to see. In particular, an empty array may have
+-- elements with nonempty sizes, but then the whole array is still empty.
newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a)
deriving (Show)
@@ -167,6 +195,17 @@ class Elt a where
mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a
mscalar :: a -> Mixed '[] a
+ -- | All arrays in the list, even subarrays inside @a@, must have the same
+ -- shape; if they do not, a runtime error will be thrown. See the
+ -- documentation of 'mgenerate' for more information about this restriction.
+ -- Furthermore, the length of the list must correspond with @n@: if @n@ is
+ -- @Just m@ and @m@ does not equal the length of the list, a runtime error is
+ -- thrown.
+ --
+ -- If you want a single-dimensional array from your list, map 'mscalar'
+ -- first.
+ mfromList :: forall n sh. KnownShapeX (n : sh) => NonEmpty (Mixed sh a) -> Mixed (n : sh) a
+
mlift :: forall sh1 sh2. KnownShapeX sh2
=> (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
-> Mixed sh1 a -> Mixed sh2 a
@@ -215,6 +254,7 @@ instance Storable a => Elt (Primitive a) where
mindex (M_Primitive a) i = Primitive (X.index a i)
mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i)
mscalar (Primitive x) = M_Primitive (X.scalar x)
+ mfromList l = M_Primitive (X.fromList knownShapeX [x | M_Primitive x <- toList l])
mlift :: forall sh1 sh2.
(Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)
@@ -263,6 +303,8 @@ instance (Elt a, Elt b) => Elt (a, b) where
mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)
+ mfromList l = M_Tup2 (mfromList ((\(M_Tup2 x _) -> x) <$> l))
+ (mfromList ((\(M_Tup2 _ y) -> y) <$> l))
mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b)
mlift2 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 f a x) (mlift2 f b y)
@@ -302,6 +344,12 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
mscalar x = M_Nest x
+ mfromList :: forall n sh. KnownShapeX (n : sh)
+ => NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (n : sh) (Mixed sh' a)
+ mfromList l
+ | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @(n : sh)) (knownShapeX @sh'))
+ = M_Nest (mfromList ((\(M_Nest x) -> x) <$> l))
+
mlift :: forall sh1 sh2. KnownShapeX sh2
=> (forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)
-> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
@@ -546,6 +594,12 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where
mscalar (Ranked x) = M_Ranked (M_Nest x)
+ mfromList :: forall m sh. KnownShapeX (m : sh)
+ => NonEmpty (Mixed sh (Ranked n a)) -> Mixed (m : sh) (Ranked n a)
+ mfromList l
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = M_Ranked (mfromList ((\(M_Ranked x) -> x) <$> l))
+
mlift :: forall sh1 sh2. KnownShapeX sh2
=> (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
-> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a)
@@ -664,6 +718,12 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
mscalar (Shaped x) = M_Shaped (M_Nest x)
+ mfromList :: forall n sh'. KnownShapeX (n : sh')
+ => NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (n : sh') (Shaped sh a)
+ mfromList l
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = M_Shaped (mfromList ((\(M_Shaped x) -> x) <$> l))
+
mlift :: forall sh1 sh2. KnownShapeX sh2
=> (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
-> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a)
@@ -870,6 +930,11 @@ rconstant sh x
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked (mconstant (ixCvtRX sh) x)
+rfromList :: forall n a. (KnownINat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (S n) a
+rfromList l
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = Ranked (mfromList ((\(Ranked x) -> x) <$> l))
+
-- ====== API OF SHAPED ARRAYS ====== --
@@ -993,3 +1058,9 @@ sconstant :: forall sh a. (KnownShape sh, Storable a, Coercible (Mixed (MapJust
sconstant x
| Dict <- lemKnownMapJust (Proxy @sh)
= Shaped (mconstant (ixCvtSX (cvtSShapeIxS (knownShape @sh))) x)
+
+sfromList :: forall n sh a. (KnownNat n, KnownShape sh, Elt a)
+ => NonEmpty (Shaped sh a) -> Shaped (n : sh) a
+sfromList l
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = Shaped (mfromList ((\(Shaped x) -> x) <$> l))