diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-04-17 12:18:49 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-04-17 12:18:49 +0200 |
commit | 4fa4f193bdba187deb7ead0ff839c78c25125c7b (patch) | |
tree | 572de7a812eb69805da77bd7eecf3ed48356a626 /src/Data/Array | |
parent | e29ab5d55be1c9cf60d4c795dc85388181a2e64b (diff) |
Allow generating an empty array
Diffstat (limited to 'src/Data/Array')
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 23 |
1 files changed, 18 insertions, 5 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 8edf5be..b0c0e56 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -192,10 +192,11 @@ class Elt a where mshowShapeTree :: Proxy a -> ShapeTree a -> String -- | Create uninitialised vectors for this array type, given the shape of - -- this vector and an example for the contents. The shape must not have size - -- zero; an error may be thrown otherwise. + -- this vector and an example for the contents. mvecsUnsafeNew :: IxX sh -> a -> ST s (MixedVecs s sh a) + mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) + -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. mvecsWrite :: IxX sh -> IxX sh -> a -> MixedVecs s sh a -> ST s () @@ -239,6 +240,7 @@ instance Storable a => Elt (Primitive a) where mshapeTreeEmpty _ () = False mshowShapeTree _ () = "()" mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh) + mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (X.toLinearIdx sh i) x -- TODO: this use of toVector is suboptimal @@ -271,6 +273,7 @@ instance (Elt a, Elt b) => Elt (a, b) where mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y + mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) mvecsWrite sh i (x, y) (MV_Tup2 a b) = do mvecsWrite sh i x a mvecsWrite sh i y b @@ -342,12 +345,14 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" mvecsUnsafeNew sh example - | X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example" + | X.shapeSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example)) (mindex example (X.zeroIxX (knownShapeX @sh'))) where sh' = mshape example + mvecsNewEmpty _ = MV_Nest (X.zeroIxX (knownShapeX @sh')) <$> mvecsNewEmpty (Proxy @a) + mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 @@ -391,8 +396,8 @@ mgenerate sh f -- TODO: Do we need this checkBounds check elsewhere as well? | not (checkBounds sh (knownShapeX @sh)) = error $ "mgenerate: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) - -- We need to be very careful here to ensure that neither 'sh' nor - -- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty. + -- If the shape is empty, there is no first element, so we should not try to + -- generate it. | X.shapeSize sh == 0 = memptyArray sh | otherwise = let firstidx = X.zeroIxX' sh @@ -580,6 +585,10 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where | Dict <- lemKnownReplicate (Proxy @n) = MV_Ranked <$> mvecsUnsafeNew idx arr + mvecsNewEmpty _ + | Dict <- lemKnownReplicate (Proxy @n) + = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) + mvecsWrite :: forall sh s. IxX sh -> IxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () mvecsWrite sh idx (Ranked arr) vecs | Dict <- lemKnownReplicate (Proxy @n) @@ -693,6 +702,10 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where | Dict <- lemKnownMapJust (Proxy @sh) = MV_Shaped <$> mvecsUnsafeNew idx arr + mvecsNewEmpty _ + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) + mvecsWrite :: forall sh' s. IxX sh' -> IxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () mvecsWrite sh idx (Shaped arr) vecs | Dict <- lemKnownMapJust (Proxy @sh) |