aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Data/Array/Nested/Internal.hs23
-rw-r--r--test/Main.hs7
2 files changed, 25 insertions, 5 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 8edf5be..b0c0e56 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -192,10 +192,11 @@ class Elt a where
mshowShapeTree :: Proxy a -> ShapeTree a -> String
-- | Create uninitialised vectors for this array type, given the shape of
- -- this vector and an example for the contents. The shape must not have size
- -- zero; an error may be thrown otherwise.
+ -- this vector and an example for the contents.
mvecsUnsafeNew :: IxX sh -> a -> ST s (MixedVecs s sh a)
+ mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a)
+
-- | Given the shape of this array, an index and a value, write the value at
-- that index in the vectors.
mvecsWrite :: IxX sh -> IxX sh -> a -> MixedVecs s sh a -> ST s ()
@@ -239,6 +240,7 @@ instance Storable a => Elt (Primitive a) where
mshapeTreeEmpty _ () = False
mshowShapeTree _ () = "()"
mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh)
+ mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0
mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (X.toLinearIdx sh i) x
-- TODO: this use of toVector is suboptimal
@@ -271,6 +273,7 @@ instance (Elt a, Elt b) => Elt (a, b) where
mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2
mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")"
mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
+ mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b)
mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
mvecsWrite sh i x a
mvecsWrite sh i y b
@@ -342,12 +345,14 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
mvecsUnsafeNew sh example
- | X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example"
+ | X.shapeSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a))
| otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example))
(mindex example (X.zeroIxX (knownShapeX @sh')))
where
sh' = mshape example
+ mvecsNewEmpty _ = MV_Nest (X.zeroIxX (knownShapeX @sh')) <$> mvecsNewEmpty (Proxy @a)
+
mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs
mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2
@@ -391,8 +396,8 @@ mgenerate sh f
-- TODO: Do we need this checkBounds check elsewhere as well?
| not (checkBounds sh (knownShapeX @sh)) =
error $ "mgenerate: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh)
- -- We need to be very careful here to ensure that neither 'sh' nor
- -- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty.
+ -- If the shape is empty, there is no first element, so we should not try to
+ -- generate it.
| X.shapeSize sh == 0 = memptyArray sh
| otherwise =
let firstidx = X.zeroIxX' sh
@@ -580,6 +585,10 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where
| Dict <- lemKnownReplicate (Proxy @n)
= MV_Ranked <$> mvecsUnsafeNew idx arr
+ mvecsNewEmpty _
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))
+
mvecsWrite :: forall sh s. IxX sh -> IxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
mvecsWrite sh idx (Ranked arr) vecs
| Dict <- lemKnownReplicate (Proxy @n)
@@ -693,6 +702,10 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
| Dict <- lemKnownMapJust (Proxy @sh)
= MV_Shaped <$> mvecsUnsafeNew idx arr
+ mvecsNewEmpty _
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))
+
mvecsWrite :: forall sh' s. IxX sh' -> IxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
mvecsWrite sh idx (Shaped arr) vecs
| Dict <- lemKnownMapJust (Proxy @sh)
diff --git a/test/Main.hs b/test/Main.hs
index 1619c00..d29e4d5 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -15,8 +15,15 @@ arr = rgenerate (3 ::: 4 ::: IZR) $ \(i ::: j ::: IZR) ->
foo :: (Double, Int)
foo = arr `rindex` (2 ::: 1 ::: IZR) `sindex` (1 ::$ 1 ::$ IZS)
+bad :: Ranked I2 (Ranked I1 Double)
+bad = rgenerate (3 ::: 4 ::: IZR) $ \(i ::: j ::: IZR) ->
+ rgenerate (i ::: IZR) $ \(k ::: IZR) ->
+ let s = 24*i + 6*j + 3*k
+ in fromIntegral s
+
main :: IO ()
main = do
print arr
print foo
print (rtranspose [1,0] arr)
+ -- print bad