diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-17 17:02:42 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-17 17:02:42 +0100 |
| commit | 0766e22df98179ce7debb179e544716bccfbca24 (patch) | |
| tree | be29bfb6847dc39465d4a8a88af81cdba795d4d1 /src/Data/Array | |
| parent | 85f3e5b2c91dded98edae8f7d1e9a4026839b556 (diff) | |
mshapeTreeIsEmpty: allow partially-zero shapes for nested arrays
Diffstat (limited to 'src/Data/Array')
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 11 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Base.hs | 2 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 2 |
3 files changed, 8 insertions, 7 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 7bda08c..6b152f7 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -364,7 +364,7 @@ class Elt a where mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool - mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool + mshapeTreeIsEmpty :: Proxy a -> ShapeTree a -> Bool mshowShapeTree :: Proxy a -> ShapeTree a -> String @@ -464,7 +464,7 @@ instance Storable a => Elt (Primitive a) where type ShapeTree (Primitive a) = () mshapeTree _ = () mshapeTreeEq _ () () = True - mshapeTreeEmpty _ () = False + mshapeTreeIsEmpty _ () = False mshowShapeTree _ () = "()" marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr) mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x @@ -542,7 +542,7 @@ instance (Elt a, Elt b) => Elt (a, b) where type ShapeTree (a, b) = (ShapeTree a, ShapeTree b) mshapeTree (x, y) = (mshapeTree x, mshapeTree y) mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' - mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 + mshapeTreeIsEmpty _ (t1, t2) = mshapeTreeIsEmpty (Proxy @a) t1 && mshapeTreeIsEmpty (Proxy @b) t2 mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b mvecsWrite sh i (x, y) (MV_Tup2 a b) = do @@ -673,7 +673,8 @@ instance Elt a => Elt (Mixed sh' a) where mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + -- the array is empty if either there are no subarrays, or the subarrays themselves are empty + mshapeTreeIsEmpty _ (sh, t) = shxSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" @@ -736,7 +737,7 @@ mgenerate sh f = case shxEnum sh of firstidx : restidxs -> let firstelem = f (ixxZero' sh) shapetree = mshapeTree firstelem - in if mshapeTreeEmpty (Proxy @a) shapetree + in if mshapeTreeIsEmpty (Proxy @a) shapetree then memptyArrayUnsafe sh else runST $ do vecs <- mvecsUnsafeNew sh firstelem diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs index babc809..54baa32 100644 --- a/src/Data/Array/Nested/Ranked/Base.hs +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -143,7 +143,7 @@ instance Elt a => Elt (Ranked n a) where mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + mshapeTreeIsEmpty _ (sh, t) = shrSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index 879e6b5..75e6fcb 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -136,7 +136,7 @@ instance Elt a => Elt (Shaped sh a) where mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + mshapeTreeIsEmpty _ (sh, t) = shsSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" |
