diff options
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 46 |
1 files changed, 12 insertions, 34 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 0b263e1..fc1c108 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -367,18 +367,10 @@ class Elt a where -- this mixed array. marrayStrides :: Mixed sh a -> Bag [Int] - -- | Given the shape of this array, an index and a value, write the value at - -- that index in the vectors. - mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () - -- | Given a linear index and a value, write the value at -- that index in the vectors. mvecsWriteLinear :: Int -> a -> MixedVecs s sh a -> ST s () - -- | Given the shape of this array, an index and a value, write the value at - -- that index in the vectors. - mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () - -- | Given a linear index and a value, write the value at -- that index in the vectors. mvecsWritePartialLinear :: Proxy sh -> Int -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () @@ -386,7 +378,6 @@ class Elt a where -- | Given the shape of this array, finalise the vectors into 'XArray's. mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) - -- | Element types for which we have evidence of the (static part of the) shape -- in a type class constraint. Compare the instance contexts of the instances -- of this class with those of 'Elt': some instances have an additional @@ -476,18 +467,9 @@ instance Storable a => Elt (Primitive a) where mshapeTreeIsEmpty _ () = False mshowShapeTree _ () = "()" marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr) - mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x mvecsWriteLinear i (Primitive x) (MV_Primitive v) = VSM.write v i x -- TODO: this use of toVector is suboptimal - mvecsWritePartial - :: forall sh' sh s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () - mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do - let arrsh = X.shape (ssxFromShX sh') arr - offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) - VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) - mvecsWritePartialLinear :: forall sh' sh s. Proxy sh -> Int -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () @@ -564,15 +546,9 @@ instance (Elt a, Elt b) => Elt (a, b) where mshapeTreeIsEmpty _ (t1, t2) = mshapeTreeIsEmpty (Proxy @a) t1 && mshapeTreeIsEmpty (Proxy @b) t2 mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b - mvecsWrite sh i (x, y) (MV_Tup2 a b) = do - mvecsWrite sh i x a - mvecsWrite sh i y b mvecsWriteLinear i (x, y) (MV_Tup2 a b) = do mvecsWriteLinear i x a mvecsWriteLinear i y b - mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do - mvecsWritePartial sh i x a - mvecsWritePartial sh i y b mvecsWritePartialLinear proxy i (M_Tup2 x y) (MV_Tup2 a b) = do mvecsWritePartialLinear proxy i x a mvecsWritePartialLinear proxy i y b @@ -705,19 +681,9 @@ instance Elt a => Elt (Mixed sh' a) where marrayStrides (M_Nest _ arr) = marrayStrides arr - mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs mvecsWriteLinear :: forall s sh. Int -> Mixed sh' a -> MixedVecs s sh (Mixed sh' a) -> ST s () mvecsWriteLinear idx val (MV_Nest _ vecs) = mvecsWritePartialLinear (Proxy @sh) idx val vecs - mvecsWritePartial - :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) - -> ST s () - mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs - mvecsWritePartialLinear :: forall sh1 sh2 s. Proxy sh1 -> Int -> Mixed sh2 (Mixed sh' a) @@ -748,6 +714,18 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) +-- | Given the shape of this array, an index and a value, write the value at +-- that index in the vectors. +{-# INLINE mvecsWrite #-} +mvecsWrite :: Elt a => IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () +mvecsWrite sh idx val vecs = mvecsWriteLinear (ixxToLinear sh idx) val vecs + +-- | Given the shape of this array, an index and a value, write the value at +-- that index in the vectors. +{-# INLINE mvecsWritePartial #-} +mvecsWritePartial :: forall sh sh' s a. Elt a => IShX sh -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () +mvecsWritePartial sh idx val vecs = mvecsWritePartialLinear (Proxy @sh) (ixxToLinear sh idx) val vecs + -- TODO: should we provide a function that's just memptyArrayUnsafe but with a size==0 check? That may save someone a transpose somewhere memptyArray :: forall sh a. KnownElt a => IShX sh -> Mixed (Just 0 : sh) a memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh) |
