summaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-04-17 10:23:00 +0200
committerTom Smeding <tom@tomsmeding.com>2024-04-17 10:23:00 +0200
commitf88dfced066710e42e8e48d06b55e5661f73d617 (patch)
treed9b23c7724e1d3525e8ee14cb7713974f9ea50b5 /src/Data/Array/Nested/Internal.hs
parente7df04c53ee43603d4a75bf31c7bddae1575366a (diff)
Fix bug in mgenerate
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r--src/Data/Array/Nested/Internal.hs32
1 files changed, 30 insertions, 2 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index e470907..e6e9ab4 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -187,6 +187,10 @@ class Elt a where
mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool
+ mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool
+
+ mshowShapeTree :: Proxy a -> ShapeTree a -> String
+
-- | Create uninitialised vectors for this array type, given the shape of
-- this vector and an example for the contents. The shape must not have size
-- zero; an error may be thrown otherwise.
@@ -228,10 +232,12 @@ instance Storable a => Elt (Primitive a) where
, Refl <- X.lemAppNil @sh3
= M_Primitive (f Proxy a b)
- memptyArray sh = M_Primitive (X.generate sh (error "memptyArray Int: shape was not empty"))
+ memptyArray sh = M_Primitive (X.generate sh (error $ "memptyArray Int: shape was not empty (" ++ show sh ++ ")"))
mshapeTree _ = ()
mshapeTreeZero _ = ()
mshapeTreeEq _ () () = True
+ mshapeTreeEmpty _ () = False
+ mshowShapeTree _ () = "()"
mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh)
mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (X.toLinearIdx sh i) x
@@ -262,6 +268,8 @@ instance (Elt a, Elt b) => Elt (a, b) where
mshapeTree (x, y) = (mshapeTree x, mshapeTree y)
mshapeTreeZero _ = (mshapeTreeZero (Proxy @a), mshapeTreeZero (Proxy @b))
mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'
+ mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2
+ mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")"
mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
mvecsWrite sh i x a
@@ -329,6 +337,10 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
+ mshapeTreeEmpty _ (sh, t) = X.shapeSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+
+ mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
+
mvecsUnsafeNew sh example
| X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example"
| otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example))
@@ -386,7 +398,7 @@ mgenerate sh f
let firstidx = X.zeroIxX' sh
firstelem = f (X.zeroIxX' sh)
shapetree = mshapeTree firstelem
- in if mshapeTreeEq (Proxy @a) shapetree (mshapeTreeZero (Proxy @a))
+ in if mshapeTreeEmpty (Proxy @a) shapetree
then memptyArray sh
else runST $ do
vecs <- mvecsUnsafeNew sh firstelem
@@ -560,6 +572,10 @@ instance (KnownINat n, Elt a) => Elt (Ranked n a) where
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
+ mshapeTreeEmpty _ (sh, t) = shapeSizeR sh == 0 && mshapeTreeEmpty (Proxy @a) t
+
+ mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
+
mvecsUnsafeNew idx (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
= MV_Ranked <$> mvecsUnsafeNew idx arr
@@ -669,6 +685,10 @@ instance (KnownShape sh, Elt a) => Elt (Shaped sh a) where
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
+ mshapeTreeEmpty _ (sh, t) = shapeSizeS sh == 0 && mshapeTreeEmpty (Proxy @a) t
+
+ mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
+
mvecsUnsafeNew idx (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
= MV_Shaped <$> mvecsUnsafeNew idx arr
@@ -760,6 +780,10 @@ ixCvtRX :: IxR n -> IxX (Replicate n Nothing)
ixCvtRX IZR = IZX
ixCvtRX (n ::: idx) = n ::? ixCvtRX idx
+shapeSizeR :: IxR n -> Int
+shapeSizeR IZR = 1
+shapeSizeR (n ::: sh) = n * shapeSizeR sh
+
rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IxR n
rshape (Ranked arr)
@@ -883,6 +907,10 @@ ixCvtSX :: IxS sh -> IxX (MapJust sh)
ixCvtSX IZS = IZX
ixCvtSX (n ::$ sh) = n ::@ ixCvtSX sh
+shapeSizeS :: IxS sh -> Int
+shapeSizeS IZS = 1
+shapeSizeS (n ::$ sh) = n * shapeSizeS sh
+
-- | This does not touch the passed array, all information comes from 'KnownShape'.
sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IxS sh