diff options
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 78 |
1 files changed, 55 insertions, 23 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 182943d..6b96a15 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -367,17 +367,19 @@ class Elt a where -- this mixed array. marrayStrides :: Mixed sh a -> Bag [Int] - -- | Given the shape of this array, an index and a value, write the value at + -- | Given a linear index and a value, write the value at -- that index in the vectors. - mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () + mvecsWriteLinear :: Int -> a -> MixedVecs s sh a -> ST s () - -- | Given the shape of this array, an index and a value, write the value at + -- | Given a linear index and a value, write the value at -- that index in the vectors. - mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () + mvecsWritePartialLinear :: Proxy sh -> Int -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () -- | Given the shape of this array, finalise the vectors into 'XArray's. mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) + -- | Given the shape of this array, finalise the vectors into 'XArray's. + mvecsUnsafeFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) -- | Element types for which we have evidence of the (static part of the) shape -- in a type class constraint. Compare the instance contexts of the instances @@ -393,6 +395,10 @@ class Elt a => KnownElt a where -- this vector and an example for the contents. mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a) + -- | Create initialised vectors for this array type, given the shape of + -- this vector and the chosen element. + mvecsReplicate :: IShX sh -> a -> ST s (MixedVecs s sh a) + mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) @@ -464,18 +470,19 @@ instance Storable a => Elt (Primitive a) where mshapeTreeIsEmpty _ () = False mshowShapeTree _ () = "()" marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr) - mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x + mvecsWriteLinear i (Primitive x) (MV_Primitive v) = VSM.write v i x -- TODO: this use of toVector is suboptimal - mvecsWritePartial + mvecsWritePartialLinear :: forall sh' sh s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () - mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do + Proxy sh -> Int -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () + mvecsWritePartialLinear _ i (M_Primitive sh' arr) (MV_Primitive v) = do let arrsh = X.shape (ssxFromShX sh') arr - offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) + offset = i * shxSize arrsh VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v + mvecsUnsafeFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.unsafeFreeze v -- [PRIMITIVE ELEMENT TYPES LIST] deriving via Primitive Bool instance Elt Bool @@ -492,6 +499,7 @@ deriving via Primitive () instance Elt () instance Storable a => KnownElt (Primitive a) where memptyArrayUnsafe sh = M_Primitive sh (X.empty sh) mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh) + mvecsReplicate sh (Primitive a) = MV_Primitive <$> VSM.replicate (shxSize sh) a mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 -- [PRIMITIVE ELEMENT TYPES LIST] @@ -542,17 +550,19 @@ instance (Elt a, Elt b) => Elt (a, b) where mshapeTreeIsEmpty _ (t1, t2) = mshapeTreeIsEmpty (Proxy @a) t1 && mshapeTreeIsEmpty (Proxy @b) t2 mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b - mvecsWrite sh i (x, y) (MV_Tup2 a b) = do - mvecsWrite sh i x a - mvecsWrite sh i y b - mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do - mvecsWritePartial sh i x a - mvecsWritePartial sh i y b + mvecsWriteLinear i (x, y) (MV_Tup2 a b) = do + mvecsWriteLinear i x a + mvecsWriteLinear i y b + mvecsWritePartialLinear proxy i (M_Tup2 x y) (MV_Tup2 a b) = do + mvecsWritePartialLinear proxy i x a + mvecsWritePartialLinear proxy i y b mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b + mvecsUnsafeFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsUnsafeFreeze sh a <*> mvecsUnsafeFreeze sh b instance (KnownElt a, KnownElt b) => KnownElt (a, b) where memptyArrayUnsafe sh = M_Tup2 (memptyArrayUnsafe sh) (memptyArrayUnsafe sh) mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y + mvecsReplicate sh (x, y) = MV_Tup2 <$> mvecsReplicate sh x <*> mvecsReplicate sh y mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) -- Arrays of arrays are just arrays, but with more dimensions. @@ -676,17 +686,20 @@ instance Elt a => Elt (Mixed sh' a) where marrayStrides (M_Nest _ arr) = marrayStrides arr - mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs + mvecsWriteLinear :: forall s sh. Int -> Mixed sh' a -> MixedVecs s sh (Mixed sh' a) -> ST s () + mvecsWriteLinear idx val (MV_Nest _ vecs) = mvecsWritePartialLinear (Proxy @sh) idx val vecs - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) - -> ST s () - mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) + mvecsWritePartialLinear + :: forall sh1 sh2 s. + Proxy sh1 -> Int -> Mixed sh2 (Mixed sh' a) + -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) + -> ST s () + mvecsWritePartialLinear proxy idx (M_Nest _ arr) (MV_Nest _ vecs) | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs + = mvecsWritePartialLinear proxy idx arr vecs mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs + mvecsUnsafeFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsUnsafeFreeze (shxAppend sh sh') vecs instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where memptyArrayUnsafe sh = M_Nest sh (memptyArrayUnsafe (shxAppend sh (shxCompleteZeros (knownShX @sh')))) @@ -697,9 +710,28 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where where sh' = mshape example + mvecsReplicate sh example = do + vecs <- mvecsUnsafeNew sh example + forM_ [0 .. shxSize sh - 1] $ \idx -> mvecsWriteLinear idx example vecs + -- this is a slow case, but the alternative, mvecsUnsafeNew with manual + -- writing in a loop, leads to every case being as slow + return vecs + mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) +-- | Given the shape of this array, an index and a value, write the value at +-- that index in the vectors. +{-# INLINE mvecsWrite #-} +mvecsWrite :: Elt a => IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () +mvecsWrite sh idx val vecs = mvecsWriteLinear (ixxToLinear sh idx) val vecs + +-- | Given the shape of this array, an index and a value, write the value at +-- that index in the vectors. +{-# INLINE mvecsWritePartial #-} +mvecsWritePartial :: forall sh sh' s a. Elt a => IShX sh -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () +mvecsWritePartial sh idx val vecs = mvecsWritePartialLinear (Proxy @sh) (ixxToLinear sh idx) val vecs + -- TODO: should we provide a function that's just memptyArrayUnsafe but with a size==0 check? That may save someone a transpose somewhere memptyArray :: forall sh a. KnownElt a => IShX sh -> Mixed (Just 0 : sh) a memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh) @@ -746,7 +778,7 @@ mgenerate sh f = case shxEnum sh of when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ error "Data.Array.Nested mgenerate: generated values do not have equal shapes" mvecsWrite sh idx val vecs - mvecsFreeze sh vecs + mvecsUnsafeFreeze sh vecs -- | An optimized special case of 'mgenerate', where the function results -- are of a primitive type and so there's not need to check that all shapes |
