aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-17 17:02:42 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-17 17:02:42 +0100
commit0766e22df98179ce7debb179e544716bccfbca24 (patch)
treebe29bfb6847dc39465d4a8a88af81cdba795d4d1 /src/Data/Array/Nested
parent85f3e5b2c91dded98edae8f7d1e9a4026839b556 (diff)
mshapeTreeIsEmpty: allow partially-zero shapes for nested arrays
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r--src/Data/Array/Nested/Mixed.hs11
-rw-r--r--src/Data/Array/Nested/Ranked/Base.hs2
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs2
3 files changed, 8 insertions, 7 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 7bda08c..6b152f7 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -364,7 +364,7 @@ class Elt a where
mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool
- mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool
+ mshapeTreeIsEmpty :: Proxy a -> ShapeTree a -> Bool
mshowShapeTree :: Proxy a -> ShapeTree a -> String
@@ -464,7 +464,7 @@ instance Storable a => Elt (Primitive a) where
type ShapeTree (Primitive a) = ()
mshapeTree _ = ()
mshapeTreeEq _ () () = True
- mshapeTreeEmpty _ () = False
+ mshapeTreeIsEmpty _ () = False
mshowShapeTree _ () = "()"
marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr)
mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x
@@ -542,7 +542,7 @@ instance (Elt a, Elt b) => Elt (a, b) where
type ShapeTree (a, b) = (ShapeTree a, ShapeTree b)
mshapeTree (x, y) = (mshapeTree x, mshapeTree y)
mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'
- mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2
+ mshapeTreeIsEmpty _ (t1, t2) = mshapeTreeIsEmpty (Proxy @a) t1 && mshapeTreeIsEmpty (Proxy @b) t2
mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")"
marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b
mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
@@ -673,7 +673,8 @@ instance Elt a => Elt (Mixed sh' a) where
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
- mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+ -- the array is empty if either there are no subarrays, or the subarrays themselves are empty
+ mshapeTreeIsEmpty _ (sh, t) = shxSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t
mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
@@ -736,7 +737,7 @@ mgenerate sh f = case shxEnum sh of
firstidx : restidxs ->
let firstelem = f (ixxZero' sh)
shapetree = mshapeTree firstelem
- in if mshapeTreeEmpty (Proxy @a) shapetree
+ in if mshapeTreeIsEmpty (Proxy @a) shapetree
then memptyArrayUnsafe sh
else runST $ do
vecs <- mvecsUnsafeNew sh firstelem
diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs
index babc809..54baa32 100644
--- a/src/Data/Array/Nested/Ranked/Base.hs
+++ b/src/Data/Array/Nested/Ranked/Base.hs
@@ -143,7 +143,7 @@ instance Elt a => Elt (Ranked n a) where
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
- mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+ mshapeTreeIsEmpty _ (sh, t) = shrSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t
mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
index 879e6b5..75e6fcb 100644
--- a/src/Data/Array/Nested/Shaped/Base.hs
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -136,7 +136,7 @@ instance Elt a => Elt (Shaped sh a) where
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
- mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+ mshapeTreeIsEmpty _ (sh, t) = shsSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t
mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"