diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 43 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Base.hs | 29 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 29 |
3 files changed, 89 insertions, 12 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 654ce3c..0b263e1 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -371,10 +371,18 @@ class Elt a where -- that index in the vectors. mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () + -- | Given a linear index and a value, write the value at + -- that index in the vectors. + mvecsWriteLinear :: Int -> a -> MixedVecs s sh a -> ST s () + -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () + -- | Given a linear index and a value, write the value at + -- that index in the vectors. + mvecsWritePartialLinear :: Proxy sh -> Int -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () + -- | Given the shape of this array, finalise the vectors into 'XArray's. mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) @@ -469,6 +477,7 @@ instance Storable a => Elt (Primitive a) where mshowShapeTree _ () = "()" marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr) mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x + mvecsWriteLinear i (Primitive x) (MV_Primitive v) = VSM.write v i x -- TODO: this use of toVector is suboptimal mvecsWritePartial @@ -479,6 +488,14 @@ instance Storable a => Elt (Primitive a) where offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) + mvecsWritePartialLinear + :: forall sh' sh s. + Proxy sh -> Int -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () + mvecsWritePartialLinear _ i (M_Primitive sh' arr) (MV_Primitive v) = do + let arrsh = X.shape (ssxFromShX sh') arr + offset = i * shxSize arrsh + VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) + mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v -- [PRIMITIVE ELEMENT TYPES LIST] @@ -550,9 +567,15 @@ instance (Elt a, Elt b) => Elt (a, b) where mvecsWrite sh i (x, y) (MV_Tup2 a b) = do mvecsWrite sh i x a mvecsWrite sh i y b + mvecsWriteLinear i (x, y) (MV_Tup2 a b) = do + mvecsWriteLinear i x a + mvecsWriteLinear i y b mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do mvecsWritePartial sh i x a mvecsWritePartial sh i y b + mvecsWritePartialLinear proxy i (M_Tup2 x y) (MV_Tup2 a b) = do + mvecsWritePartialLinear proxy i x a + mvecsWritePartialLinear proxy i y b mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b instance (KnownElt a, KnownElt b) => KnownElt (a, b) where @@ -683,15 +706,27 @@ instance Elt a => Elt (Mixed sh' a) where marrayStrides (M_Nest _ arr) = marrayStrides arr mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs + mvecsWriteLinear :: forall s sh. Int -> Mixed sh' a -> MixedVecs s sh (Mixed sh' a) -> ST s () + mvecsWriteLinear idx val (MV_Nest _ vecs) = mvecsWritePartialLinear (Proxy @sh) idx val vecs - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) - -> ST s () + mvecsWritePartial + :: forall sh1 sh2 s. + IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) + -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) + -> ST s () mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs + mvecsWritePartialLinear + :: forall sh1 sh2 s. + Proxy sh1 -> Int -> Mixed sh2 (Mixed sh' a) + -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) + -> ST s () + mvecsWritePartialLinear proxy idx (M_Nest _ arr) (MV_Nest _ vecs) + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') + = mvecsWritePartialLinear proxy idx arr vecs + mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs index ef3af31..1f12830 100644 --- a/src/Data/Array/Nested/Ranked/Base.hs +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -155,10 +155,17 @@ instance Elt a => Elt (Ranked n a) where (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) vecs) - mvecsWritePartial :: forall sh sh' s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) - -> MixedVecs s (sh ++ sh') (Ranked n a) - -> ST s () + mvecsWriteLinear :: forall sh s. Int -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () + mvecsWriteLinear idx (Ranked arr) vecs = + mvecsWriteLinear idx arr + (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) + vecs) + + mvecsWritePartial + :: forall sh sh' s. + IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) + -> MixedVecs s (sh ++ sh') (Ranked n a) + -> ST s () mvecsWritePartial sh idx arr vecs = mvecsWritePartial sh idx (coerce @(Mixed sh' (Ranked n a)) @@ -168,6 +175,20 @@ instance Elt a => Elt (Ranked n a) where @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) vecs) + mvecsWritePartialLinear + :: forall sh sh' s. + Proxy sh -> Int -> Mixed sh' (Ranked n a) + -> MixedVecs s (sh ++ sh') (Ranked n a) + -> ST s () + mvecsWritePartialLinear proxy idx arr vecs = + mvecsWritePartialLinear proxy idx + (coerce @(Mixed sh' (Ranked n a)) + @(Mixed sh' (Mixed (Replicate n Nothing) a)) + arr) + (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) + @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) + vecs) + mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) mvecsFreeze sh vecs = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index ded1175..9fc2c9a 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -148,10 +148,17 @@ instance Elt a => Elt (Shaped sh a) where (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) vecs) - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) - -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) - -> ST s () + mvecsWriteLinear :: forall sh' s. Int -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () + mvecsWriteLinear idx (Shaped arr) vecs = + mvecsWriteLinear idx arr + (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) + + mvecsWritePartial + :: forall sh1 sh2 s. + IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) + -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) + -> ST s () mvecsWritePartial sh idx arr vecs = mvecsWritePartial sh idx (coerce @(Mixed sh2 (Shaped sh a)) @@ -161,6 +168,20 @@ instance Elt a => Elt (Shaped sh a) where @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) vecs) + mvecsWritePartialLinear + :: forall sh1 sh2 s. + Proxy sh1 -> Int -> Mixed sh2 (Shaped sh a) + -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) + -> ST s () + mvecsWritePartialLinear proxy idx arr vecs = + mvecsWritePartialLinear proxy idx + (coerce @(Mixed sh2 (Shaped sh a)) + @(Mixed sh2 (Mixed (MapJust sh) a)) + arr) + (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) + @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) + vecs) + mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) mvecsFreeze sh vecs = coerce @(Mixed sh' (Mixed (MapJust sh) a)) |
