From 0766e22df98179ce7debb179e544716bccfbca24 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 17 Nov 2025 17:02:42 +0100 Subject: mshapeTreeIsEmpty: allow partially-zero shapes for nested arrays --- src/Data/Array/Nested/Mixed.hs | 11 ++++++----- src/Data/Array/Nested/Ranked/Base.hs | 2 +- src/Data/Array/Nested/Shaped/Base.hs | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) (limited to 'src/Data') diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 7bda08c..6b152f7 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -364,7 +364,7 @@ class Elt a where mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool - mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool + mshapeTreeIsEmpty :: Proxy a -> ShapeTree a -> Bool mshowShapeTree :: Proxy a -> ShapeTree a -> String @@ -464,7 +464,7 @@ instance Storable a => Elt (Primitive a) where type ShapeTree (Primitive a) = () mshapeTree _ = () mshapeTreeEq _ () () = True - mshapeTreeEmpty _ () = False + mshapeTreeIsEmpty _ () = False mshowShapeTree _ () = "()" marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr) mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x @@ -542,7 +542,7 @@ instance (Elt a, Elt b) => Elt (a, b) where type ShapeTree (a, b) = (ShapeTree a, ShapeTree b) mshapeTree (x, y) = (mshapeTree x, mshapeTree y) mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' - mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 + mshapeTreeIsEmpty _ (t1, t2) = mshapeTreeIsEmpty (Proxy @a) t1 && mshapeTreeIsEmpty (Proxy @b) t2 mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b mvecsWrite sh i (x, y) (MV_Tup2 a b) = do @@ -673,7 +673,8 @@ instance Elt a => Elt (Mixed sh' a) where mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + -- the array is empty if either there are no subarrays, or the subarrays themselves are empty + mshapeTreeIsEmpty _ (sh, t) = shxSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" @@ -736,7 +737,7 @@ mgenerate sh f = case shxEnum sh of firstidx : restidxs -> let firstelem = f (ixxZero' sh) shapetree = mshapeTree firstelem - in if mshapeTreeEmpty (Proxy @a) shapetree + in if mshapeTreeIsEmpty (Proxy @a) shapetree then memptyArrayUnsafe sh else runST $ do vecs <- mvecsUnsafeNew sh firstelem diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs index babc809..54baa32 100644 --- a/src/Data/Array/Nested/Ranked/Base.hs +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -143,7 +143,7 @@ instance Elt a => Elt (Ranked n a) where mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + mshapeTreeIsEmpty _ (sh, t) = shrSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index 879e6b5..75e6fcb 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -136,7 +136,7 @@ instance Elt a => Elt (Shaped sh a) where mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + mshapeTreeIsEmpty _ (sh, t) = shsSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" -- cgit v1.2.3-70-g09d2