diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-04-17 12:18:49 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-04-17 12:18:49 +0200 | 
| commit | 4fa4f193bdba187deb7ead0ff839c78c25125c7b (patch) | |
| tree | 572de7a812eb69805da77bd7eecf3ed48356a626 /src/Data | |
| parent | e29ab5d55be1c9cf60d4c795dc85388181a2e64b (diff) | |
Allow generating an empty array
Diffstat (limited to 'src/Data')
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 23 | 
1 files changed, 18 insertions, 5 deletions
| diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 8edf5be..b0c0e56 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -192,10 +192,11 @@ class Elt a where    mshowShapeTree :: Proxy a -> ShapeTree a -> String    -- | Create uninitialised vectors for this array type, given the shape of -  -- this vector and an example for the contents. The shape must not have size -  -- zero; an error may be thrown otherwise. +  -- this vector and an example for the contents.    mvecsUnsafeNew :: IxX sh -> a -> ST s (MixedVecs s sh a) +  mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) +    -- | Given the shape of this array, an index and a value, write the value at    -- that index in the vectors.    mvecsWrite :: IxX sh -> IxX sh -> a -> MixedVecs s sh a -> ST s () @@ -239,6 +240,7 @@ instance Storable a => Elt (Primitive a) where    mshapeTreeEmpty _ () = False    mshowShapeTree _ () = "()"    mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh) +  mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0    mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (X.toLinearIdx sh i) x    -- TODO: this use of toVector is suboptimal @@ -271,6 +273,7 @@ instance (Elt a, Elt b) => Elt (a, b) where    mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2    mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")"    mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y +  mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b)    mvecsWrite sh i (x, y) (MV_Tup2 a b) = do      mvecsWrite sh i x a      mvecsWrite sh i y b @@ -342,12 +345,14 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where    mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"    mvecsUnsafeNew sh example -    | X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example" +    | X.shapeSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a))      | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example))                                                   (mindex example (X.zeroIxX (knownShapeX @sh')))      where        sh' = mshape example +  mvecsNewEmpty _ = MV_Nest (X.zeroIxX (knownShapeX @sh')) <$> mvecsNewEmpty (Proxy @a) +    mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs    mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 @@ -391,8 +396,8 @@ mgenerate sh f    -- TODO: Do we need this checkBounds check elsewhere as well?    | not (checkBounds sh (knownShapeX @sh)) =        error $ "mgenerate: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) -  -- We need to be very careful here to ensure that neither 'sh' nor -  -- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty. +  -- If the shape is empty, there is no first element, so we should not try to +  -- generate it.    | X.shapeSize sh == 0 = memptyArray sh    | otherwise =        let firstidx = X.zeroIxX' sh @@ -580,6 +585,10 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where      | Dict <- lemKnownReplicate (Proxy @n)      = MV_Ranked <$> mvecsUnsafeNew idx arr +  mvecsNewEmpty _ +    | Dict <- lemKnownReplicate (Proxy @n) +    = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) +    mvecsWrite :: forall sh s. IxX sh -> IxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()    mvecsWrite sh idx (Ranked arr) vecs      | Dict <- lemKnownReplicate (Proxy @n) @@ -693,6 +702,10 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where      | Dict <- lemKnownMapJust (Proxy @sh)      = MV_Shaped <$> mvecsUnsafeNew idx arr +  mvecsNewEmpty _ +    | Dict <- lemKnownMapJust (Proxy @sh) +    = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) +    mvecsWrite :: forall sh' s. IxX sh' -> IxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()    mvecsWrite sh idx (Shaped arr) vecs      | Dict <- lemKnownMapJust (Proxy @sh) | 
