aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
authorMikolaj Konarski <mikolaj.konarski@funktory.com>2025-12-08 15:43:41 +0100
committerMikolaj Konarski <mikolaj.konarski@funktory.com>2025-12-08 18:06:03 +0100
commit9624e5d90d5f0815bec230cdd2f4e5b406805885 (patch)
tree44647d1b726f77a5ec1bf211f76e5786397efb73 /src/Data
parentd69d96270be6d91565a8461b10b10c72efdd955e (diff)
Add mvecsWriteLinear and mvecsWritePartialLinear (unused yet)
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/Nested/Mixed.hs43
-rw-r--r--src/Data/Array/Nested/Ranked/Base.hs29
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs29
3 files changed, 89 insertions, 12 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 654ce3c..0b263e1 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -371,10 +371,18 @@ class Elt a where
-- that index in the vectors.
mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
+ -- | Given a linear index and a value, write the value at
+ -- that index in the vectors.
+ mvecsWriteLinear :: Int -> a -> MixedVecs s sh a -> ST s ()
+
-- | Given the shape of this array, an index and a value, write the value at
-- that index in the vectors.
mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+ -- | Given a linear index and a value, write the value at
+ -- that index in the vectors.
+ mvecsWritePartialLinear :: Proxy sh -> Int -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+
-- | Given the shape of this array, finalise the vectors into 'XArray's.
mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
@@ -469,6 +477,7 @@ instance Storable a => Elt (Primitive a) where
mshowShapeTree _ () = "()"
marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr)
mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x
+ mvecsWriteLinear i (Primitive x) (MV_Primitive v) = VSM.write v i x
-- TODO: this use of toVector is suboptimal
mvecsWritePartial
@@ -479,6 +488,14 @@ instance Storable a => Elt (Primitive a) where
offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh))
VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr)
+ mvecsWritePartialLinear
+ :: forall sh' sh s.
+ Proxy sh -> Int -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
+ mvecsWritePartialLinear _ i (M_Primitive sh' arr) (MV_Primitive v) = do
+ let arrsh = X.shape (ssxFromShX sh') arr
+ offset = i * shxSize arrsh
+ VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr)
+
mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v
-- [PRIMITIVE ELEMENT TYPES LIST]
@@ -550,9 +567,15 @@ instance (Elt a, Elt b) => Elt (a, b) where
mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
mvecsWrite sh i x a
mvecsWrite sh i y b
+ mvecsWriteLinear i (x, y) (MV_Tup2 a b) = do
+ mvecsWriteLinear i x a
+ mvecsWriteLinear i y b
mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do
mvecsWritePartial sh i x a
mvecsWritePartial sh i y b
+ mvecsWritePartialLinear proxy i (M_Tup2 x y) (MV_Tup2 a b) = do
+ mvecsWritePartialLinear proxy i x a
+ mvecsWritePartialLinear proxy i y b
mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
instance (KnownElt a, KnownElt b) => KnownElt (a, b) where
@@ -683,15 +706,27 @@ instance Elt a => Elt (Mixed sh' a) where
marrayStrides (M_Nest _ arr) = marrayStrides arr
mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs
+ mvecsWriteLinear :: forall s sh. Int -> Mixed sh' a -> MixedVecs s sh (Mixed sh' a) -> ST s ()
+ mvecsWriteLinear idx val (MV_Nest _ vecs) = mvecsWritePartialLinear (Proxy @sh) idx val vecs
- mvecsWritePartial :: forall sh1 sh2 s.
- IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
- -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
- -> ST s ()
+ mvecsWritePartial
+ :: forall sh1 sh2 s.
+ IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
+ -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
+ -> ST s ()
mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs)
| Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
= mvecsWritePartial (shxAppend sh12 sh') idx arr vecs
+ mvecsWritePartialLinear
+ :: forall sh1 sh2 s.
+ Proxy sh1 -> Int -> Mixed sh2 (Mixed sh' a)
+ -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
+ -> ST s ()
+ mvecsWritePartialLinear proxy idx (M_Nest _ arr) (MV_Nest _ vecs)
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
+ = mvecsWritePartialLinear proxy idx arr vecs
+
mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs
instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs
index ef3af31..1f12830 100644
--- a/src/Data/Array/Nested/Ranked/Base.hs
+++ b/src/Data/Array/Nested/Ranked/Base.hs
@@ -155,10 +155,17 @@ instance Elt a => Elt (Ranked n a) where
(coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
vecs)
- mvecsWritePartial :: forall sh sh' s.
- IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
- -> MixedVecs s (sh ++ sh') (Ranked n a)
- -> ST s ()
+ mvecsWriteLinear :: forall sh s. Int -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
+ mvecsWriteLinear idx (Ranked arr) vecs =
+ mvecsWriteLinear idx arr
+ (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
+
+ mvecsWritePartial
+ :: forall sh sh' s.
+ IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
+ -> MixedVecs s (sh ++ sh') (Ranked n a)
+ -> ST s ()
mvecsWritePartial sh idx arr vecs =
mvecsWritePartial sh idx
(coerce @(Mixed sh' (Ranked n a))
@@ -168,6 +175,20 @@ instance Elt a => Elt (Ranked n a) where
@(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
vecs)
+ mvecsWritePartialLinear
+ :: forall sh sh' s.
+ Proxy sh -> Int -> Mixed sh' (Ranked n a)
+ -> MixedVecs s (sh ++ sh') (Ranked n a)
+ -> ST s ()
+ mvecsWritePartialLinear proxy idx arr vecs =
+ mvecsWritePartialLinear proxy idx
+ (coerce @(Mixed sh' (Ranked n a))
+ @(Mixed sh' (Mixed (Replicate n Nothing) a))
+ arr)
+ (coerce @(MixedVecs s (sh ++ sh') (Ranked n a))
+ @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
+ vecs)
+
mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
mvecsFreeze sh vecs =
coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
index ded1175..9fc2c9a 100644
--- a/src/Data/Array/Nested/Shaped/Base.hs
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -148,10 +148,17 @@ instance Elt a => Elt (Shaped sh a) where
(coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
vecs)
- mvecsWritePartial :: forall sh1 sh2 s.
- IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
- -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
- -> ST s ()
+ mvecsWriteLinear :: forall sh' s. Int -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
+ mvecsWriteLinear idx (Shaped arr) vecs =
+ mvecsWriteLinear idx arr
+ (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
+
+ mvecsWritePartial
+ :: forall sh1 sh2 s.
+ IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
+ -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
+ -> ST s ()
mvecsWritePartial sh idx arr vecs =
mvecsWritePartial sh idx
(coerce @(Mixed sh2 (Shaped sh a))
@@ -161,6 +168,20 @@ instance Elt a => Elt (Shaped sh a) where
@(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))
vecs)
+ mvecsWritePartialLinear
+ :: forall sh1 sh2 s.
+ Proxy sh1 -> Int -> Mixed sh2 (Shaped sh a)
+ -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
+ -> ST s ()
+ mvecsWritePartialLinear proxy idx arr vecs =
+ mvecsWritePartialLinear proxy idx
+ (coerce @(Mixed sh2 (Shaped sh a))
+ @(Mixed sh2 (Mixed (MapJust sh) a))
+ arr)
+ (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a))
+ @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))
+ vecs)
+
mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
mvecsFreeze sh vecs =
coerce @(Mixed sh' (Mixed (MapJust sh) a))