diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-04-17 10:23:00 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-04-17 10:23:00 +0200 |
commit | f88dfced066710e42e8e48d06b55e5661f73d617 (patch) | |
tree | d9b23c7724e1d3525e8ee14cb7713974f9ea50b5 /src/Data/Array/Nested | |
parent | e7df04c53ee43603d4a75bf31c7bddae1575366a (diff) |
Fix bug in mgenerate
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 32 |
1 files changed, 30 insertions, 2 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index e470907..e6e9ab4 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -187,6 +187,10 @@ class Elt a where mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool + mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool + + mshowShapeTree :: Proxy a -> ShapeTree a -> String + -- | Create uninitialised vectors for this array type, given the shape of -- this vector and an example for the contents. The shape must not have size -- zero; an error may be thrown otherwise. @@ -228,10 +232,12 @@ instance Storable a => Elt (Primitive a) where , Refl <- X.lemAppNil @sh3 = M_Primitive (f Proxy a b) - memptyArray sh = M_Primitive (X.generate sh (error "memptyArray Int: shape was not empty")) + memptyArray sh = M_Primitive (X.generate sh (error $ "memptyArray Int: shape was not empty (" ++ show sh ++ ")")) mshapeTree _ = () mshapeTreeZero _ = () mshapeTreeEq _ () () = True + mshapeTreeEmpty _ () = False + mshowShapeTree _ () = "()" mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh) mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (X.toLinearIdx sh i) x @@ -262,6 +268,8 @@ instance (Elt a, Elt b) => Elt (a, b) where mshapeTree (x, y) = (mshapeTree x, mshapeTree y) mshapeTreeZero _ = (mshapeTreeZero (Proxy @a), mshapeTreeZero (Proxy @b)) mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' + mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 + mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y mvecsWrite sh i (x, y) (MV_Tup2 a b) = do mvecsWrite sh i x a @@ -329,6 +337,10 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + mshapeTreeEmpty _ (sh, t) = X.shapeSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + + mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + mvecsUnsafeNew sh example | X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example" | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example)) @@ -386,7 +398,7 @@ mgenerate sh f let firstidx = X.zeroIxX' sh firstelem = f (X.zeroIxX' sh) shapetree = mshapeTree firstelem - in if mshapeTreeEq (Proxy @a) shapetree (mshapeTreeZero (Proxy @a)) + in if mshapeTreeEmpty (Proxy @a) shapetree then memptyArray sh else runST $ do vecs <- mvecsUnsafeNew sh firstelem @@ -560,6 +572,10 @@ instance (KnownINat n, Elt a) => Elt (Ranked n a) where mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + mshapeTreeEmpty _ (sh, t) = shapeSizeR sh == 0 && mshapeTreeEmpty (Proxy @a) t + + mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + mvecsUnsafeNew idx (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = MV_Ranked <$> mvecsUnsafeNew idx arr @@ -669,6 +685,10 @@ instance (KnownShape sh, Elt a) => Elt (Shaped sh a) where mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + mshapeTreeEmpty _ (sh, t) = shapeSizeS sh == 0 && mshapeTreeEmpty (Proxy @a) t + + mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + mvecsUnsafeNew idx (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = MV_Shaped <$> mvecsUnsafeNew idx arr @@ -760,6 +780,10 @@ ixCvtRX :: IxR n -> IxX (Replicate n Nothing) ixCvtRX IZR = IZX ixCvtRX (n ::: idx) = n ::? ixCvtRX idx +shapeSizeR :: IxR n -> Int +shapeSizeR IZR = 1 +shapeSizeR (n ::: sh) = n * shapeSizeR sh + rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IxR n rshape (Ranked arr) @@ -883,6 +907,10 @@ ixCvtSX :: IxS sh -> IxX (MapJust sh) ixCvtSX IZS = IZX ixCvtSX (n ::$ sh) = n ::@ ixCvtSX sh +shapeSizeS :: IxS sh -> Int +shapeSizeS IZS = 1 +shapeSizeS (n ::$ sh) = n * shapeSizeS sh + -- | This does not touch the passed array, all information comes from 'KnownShape'. sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IxS sh |