aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-04-17 12:18:49 +0200
committerTom Smeding <tom@tomsmeding.com>2024-04-17 12:18:49 +0200
commit4fa4f193bdba187deb7ead0ff839c78c25125c7b (patch)
tree572de7a812eb69805da77bd7eecf3ed48356a626 /src/Data/Array/Nested/Internal.hs
parente29ab5d55be1c9cf60d4c795dc85388181a2e64b (diff)
Allow generating an empty array
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r--src/Data/Array/Nested/Internal.hs23
1 files changed, 18 insertions, 5 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 8edf5be..b0c0e56 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -192,10 +192,11 @@ class Elt a where
mshowShapeTree :: Proxy a -> ShapeTree a -> String
-- | Create uninitialised vectors for this array type, given the shape of
- -- this vector and an example for the contents. The shape must not have size
- -- zero; an error may be thrown otherwise.
+ -- this vector and an example for the contents.
mvecsUnsafeNew :: IxX sh -> a -> ST s (MixedVecs s sh a)
+ mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a)
+
-- | Given the shape of this array, an index and a value, write the value at
-- that index in the vectors.
mvecsWrite :: IxX sh -> IxX sh -> a -> MixedVecs s sh a -> ST s ()
@@ -239,6 +240,7 @@ instance Storable a => Elt (Primitive a) where
mshapeTreeEmpty _ () = False
mshowShapeTree _ () = "()"
mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh)
+ mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0
mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (X.toLinearIdx sh i) x
-- TODO: this use of toVector is suboptimal
@@ -271,6 +273,7 @@ instance (Elt a, Elt b) => Elt (a, b) where
mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2
mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")"
mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
+ mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b)
mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
mvecsWrite sh i x a
mvecsWrite sh i y b
@@ -342,12 +345,14 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
mvecsUnsafeNew sh example
- | X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example"
+ | X.shapeSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a))
| otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example))
(mindex example (X.zeroIxX (knownShapeX @sh')))
where
sh' = mshape example
+ mvecsNewEmpty _ = MV_Nest (X.zeroIxX (knownShapeX @sh')) <$> mvecsNewEmpty (Proxy @a)
+
mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs
mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2
@@ -391,8 +396,8 @@ mgenerate sh f
-- TODO: Do we need this checkBounds check elsewhere as well?
| not (checkBounds sh (knownShapeX @sh)) =
error $ "mgenerate: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh)
- -- We need to be very careful here to ensure that neither 'sh' nor
- -- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty.
+ -- If the shape is empty, there is no first element, so we should not try to
+ -- generate it.
| X.shapeSize sh == 0 = memptyArray sh
| otherwise =
let firstidx = X.zeroIxX' sh
@@ -580,6 +585,10 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where
| Dict <- lemKnownReplicate (Proxy @n)
= MV_Ranked <$> mvecsUnsafeNew idx arr
+ mvecsNewEmpty _
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))
+
mvecsWrite :: forall sh s. IxX sh -> IxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
mvecsWrite sh idx (Ranked arr) vecs
| Dict <- lemKnownReplicate (Proxy @n)
@@ -693,6 +702,10 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
| Dict <- lemKnownMapJust (Proxy @sh)
= MV_Shaped <$> mvecsUnsafeNew idx arr
+ mvecsNewEmpty _
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))
+
mvecsWrite :: forall sh' s. IxX sh' -> IxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
mvecsWrite sh idx (Shaped arr) vecs
| Dict <- lemKnownMapJust (Proxy @sh)