aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMikolaj Konarski <mikolaj.konarski@funktory.com>2025-12-09 15:46:26 +0100
committerMikolaj Konarski <mikolaj.konarski@funktory.com>2025-12-09 15:46:26 +0100
commit14f4d3d2586cd9983c9b98fc9af591a25ae325b4 (patch)
tree8596400916afedd87bfb7db9fdd4ca74eb0e096a /src
parent0ac43767b7a95b208e4562688f6f74942994652d (diff)
Add method mvecsUnsafeFreeze and use itmvecsReplicate
Diffstat (limited to 'src')
-rw-r--r--src/Data/Array/Nested/Mixed.hs8
-rw-r--r--src/Data/Array/Nested/Ranked/Base.hs8
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs8
3 files changed, 23 insertions, 1 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 9f24aba..6b96a15 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -378,6 +378,9 @@ class Elt a where
-- | Given the shape of this array, finalise the vectors into 'XArray's.
mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
+ -- | Given the shape of this array, finalise the vectors into 'XArray's.
+ mvecsUnsafeFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
+
-- | Element types for which we have evidence of the (static part of the) shape
-- in a type class constraint. Compare the instance contexts of the instances
-- of this class with those of 'Elt': some instances have an additional
@@ -479,6 +482,7 @@ instance Storable a => Elt (Primitive a) where
VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr)
mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v
+ mvecsUnsafeFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.unsafeFreeze v
-- [PRIMITIVE ELEMENT TYPES LIST]
deriving via Primitive Bool instance Elt Bool
@@ -553,6 +557,7 @@ instance (Elt a, Elt b) => Elt (a, b) where
mvecsWritePartialLinear proxy i x a
mvecsWritePartialLinear proxy i y b
mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
+ mvecsUnsafeFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsUnsafeFreeze sh a <*> mvecsUnsafeFreeze sh b
instance (KnownElt a, KnownElt b) => KnownElt (a, b) where
memptyArrayUnsafe sh = M_Tup2 (memptyArrayUnsafe sh) (memptyArrayUnsafe sh)
@@ -694,6 +699,7 @@ instance Elt a => Elt (Mixed sh' a) where
= mvecsWritePartialLinear proxy idx arr vecs
mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs
+ mvecsUnsafeFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsUnsafeFreeze (shxAppend sh sh') vecs
instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
memptyArrayUnsafe sh = M_Nest sh (memptyArrayUnsafe (shxAppend sh (shxCompleteZeros (knownShX @sh'))))
@@ -772,7 +778,7 @@ mgenerate sh f = case shxEnum sh of
when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $
error "Data.Array.Nested mgenerate: generated values do not have equal shapes"
mvecsWrite sh idx val vecs
- mvecsFreeze sh vecs
+ mvecsUnsafeFreeze sh vecs
-- | An optimized special case of 'mgenerate', where the function results
-- are of a primitive type and so there's not need to check that all shapes
diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs
index ed194a8..97a5f6f 100644
--- a/src/Data/Array/Nested/Ranked/Base.hs
+++ b/src/Data/Array/Nested/Ranked/Base.hs
@@ -177,6 +177,14 @@ instance Elt a => Elt (Ranked n a) where
(coerce @(MixedVecs s sh (Ranked n a))
@(MixedVecs s sh (Mixed (Replicate n Nothing) a))
vecs)
+ mvecsUnsafeFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
+ mvecsUnsafeFreeze sh vecs =
+ coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
+ @(Mixed sh (Ranked n a))
+ <$> mvecsUnsafeFreeze sh
+ (coerce @(MixedVecs s sh (Ranked n a))
+ @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a)
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
index e5dd852..e2ec416 100644
--- a/src/Data/Array/Nested/Shaped/Base.hs
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -170,6 +170,14 @@ instance Elt a => Elt (Shaped sh a) where
(coerce @(MixedVecs s sh' (Shaped sh a))
@(MixedVecs s sh' (Mixed (MapJust sh) a))
vecs)
+ mvecsUnsafeFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
+ mvecsUnsafeFreeze sh vecs =
+ coerce @(Mixed sh' (Mixed (MapJust sh) a))
+ @(Mixed sh' (Shaped sh a))
+ <$> mvecsUnsafeFreeze sh
+ (coerce @(MixedVecs s sh' (Shaped sh a))
+ @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)