diff options
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 23 | ||||
-rw-r--r-- | test/Main.hs | 7 |
2 files changed, 25 insertions, 5 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 8edf5be..b0c0e56 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -192,10 +192,11 @@ class Elt a where mshowShapeTree :: Proxy a -> ShapeTree a -> String -- | Create uninitialised vectors for this array type, given the shape of - -- this vector and an example for the contents. The shape must not have size - -- zero; an error may be thrown otherwise. + -- this vector and an example for the contents. mvecsUnsafeNew :: IxX sh -> a -> ST s (MixedVecs s sh a) + mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) + -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. mvecsWrite :: IxX sh -> IxX sh -> a -> MixedVecs s sh a -> ST s () @@ -239,6 +240,7 @@ instance Storable a => Elt (Primitive a) where mshapeTreeEmpty _ () = False mshowShapeTree _ () = "()" mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh) + mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (X.toLinearIdx sh i) x -- TODO: this use of toVector is suboptimal @@ -271,6 +273,7 @@ instance (Elt a, Elt b) => Elt (a, b) where mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y + mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) mvecsWrite sh i (x, y) (MV_Tup2 a b) = do mvecsWrite sh i x a mvecsWrite sh i y b @@ -342,12 +345,14 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" mvecsUnsafeNew sh example - | X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example" + | X.shapeSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example)) (mindex example (X.zeroIxX (knownShapeX @sh'))) where sh' = mshape example + mvecsNewEmpty _ = MV_Nest (X.zeroIxX (knownShapeX @sh')) <$> mvecsNewEmpty (Proxy @a) + mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 @@ -391,8 +396,8 @@ mgenerate sh f -- TODO: Do we need this checkBounds check elsewhere as well? | not (checkBounds sh (knownShapeX @sh)) = error $ "mgenerate: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) - -- We need to be very careful here to ensure that neither 'sh' nor - -- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty. + -- If the shape is empty, there is no first element, so we should not try to + -- generate it. | X.shapeSize sh == 0 = memptyArray sh | otherwise = let firstidx = X.zeroIxX' sh @@ -580,6 +585,10 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where | Dict <- lemKnownReplicate (Proxy @n) = MV_Ranked <$> mvecsUnsafeNew idx arr + mvecsNewEmpty _ + | Dict <- lemKnownReplicate (Proxy @n) + = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) + mvecsWrite :: forall sh s. IxX sh -> IxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () mvecsWrite sh idx (Ranked arr) vecs | Dict <- lemKnownReplicate (Proxy @n) @@ -693,6 +702,10 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where | Dict <- lemKnownMapJust (Proxy @sh) = MV_Shaped <$> mvecsUnsafeNew idx arr + mvecsNewEmpty _ + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) + mvecsWrite :: forall sh' s. IxX sh' -> IxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () mvecsWrite sh idx (Shaped arr) vecs | Dict <- lemKnownMapJust (Proxy @sh) diff --git a/test/Main.hs b/test/Main.hs index 1619c00..d29e4d5 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -15,8 +15,15 @@ arr = rgenerate (3 ::: 4 ::: IZR) $ \(i ::: j ::: IZR) -> foo :: (Double, Int) foo = arr `rindex` (2 ::: 1 ::: IZR) `sindex` (1 ::$ 1 ::$ IZS) +bad :: Ranked I2 (Ranked I1 Double) +bad = rgenerate (3 ::: 4 ::: IZR) $ \(i ::: j ::: IZR) -> + rgenerate (i ::: IZR) $ \(k ::: IZR) -> + let s = 24*i + 6*j + 3*k + in fromIntegral s + main :: IO () main = do print arr print foo print (rtranspose [1,0] arr) + -- print bad |