diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-08 15:43:41 +0100 |
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-08 18:06:03 +0100 |
| commit | 9624e5d90d5f0815bec230cdd2f4e5b406805885 (patch) | |
| tree | 44647d1b726f77a5ec1bf211f76e5786397efb73 /src/Data/Array/Nested/Mixed.hs | |
| parent | d69d96270be6d91565a8461b10b10c72efdd955e (diff) | |
Add mvecsWriteLinear and mvecsWritePartialLinear (unused yet)
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 43 |
1 files changed, 39 insertions, 4 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 654ce3c..0b263e1 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -371,10 +371,18 @@ class Elt a where -- that index in the vectors. mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () + -- | Given a linear index and a value, write the value at + -- that index in the vectors. + mvecsWriteLinear :: Int -> a -> MixedVecs s sh a -> ST s () + -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () + -- | Given a linear index and a value, write the value at + -- that index in the vectors. + mvecsWritePartialLinear :: Proxy sh -> Int -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () + -- | Given the shape of this array, finalise the vectors into 'XArray's. mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) @@ -469,6 +477,7 @@ instance Storable a => Elt (Primitive a) where mshowShapeTree _ () = "()" marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr) mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x + mvecsWriteLinear i (Primitive x) (MV_Primitive v) = VSM.write v i x -- TODO: this use of toVector is suboptimal mvecsWritePartial @@ -479,6 +488,14 @@ instance Storable a => Elt (Primitive a) where offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) + mvecsWritePartialLinear + :: forall sh' sh s. + Proxy sh -> Int -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () + mvecsWritePartialLinear _ i (M_Primitive sh' arr) (MV_Primitive v) = do + let arrsh = X.shape (ssxFromShX sh') arr + offset = i * shxSize arrsh + VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) + mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v -- [PRIMITIVE ELEMENT TYPES LIST] @@ -550,9 +567,15 @@ instance (Elt a, Elt b) => Elt (a, b) where mvecsWrite sh i (x, y) (MV_Tup2 a b) = do mvecsWrite sh i x a mvecsWrite sh i y b + mvecsWriteLinear i (x, y) (MV_Tup2 a b) = do + mvecsWriteLinear i x a + mvecsWriteLinear i y b mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do mvecsWritePartial sh i x a mvecsWritePartial sh i y b + mvecsWritePartialLinear proxy i (M_Tup2 x y) (MV_Tup2 a b) = do + mvecsWritePartialLinear proxy i x a + mvecsWritePartialLinear proxy i y b mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b instance (KnownElt a, KnownElt b) => KnownElt (a, b) where @@ -683,15 +706,27 @@ instance Elt a => Elt (Mixed sh' a) where marrayStrides (M_Nest _ arr) = marrayStrides arr mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs + mvecsWriteLinear :: forall s sh. Int -> Mixed sh' a -> MixedVecs s sh (Mixed sh' a) -> ST s () + mvecsWriteLinear idx val (MV_Nest _ vecs) = mvecsWritePartialLinear (Proxy @sh) idx val vecs - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) - -> ST s () + mvecsWritePartial + :: forall sh1 sh2 s. + IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) + -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) + -> ST s () mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs + mvecsWritePartialLinear + :: forall sh1 sh2 s. + Proxy sh1 -> Int -> Mixed sh2 (Mixed sh' a) + -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) + -> ST s () + mvecsWritePartialLinear proxy idx (M_Nest _ arr) (MV_Nest _ vecs) + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') + = mvecsWritePartialLinear proxy idx arr vecs + mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where |
