aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Mixed.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
-rw-r--r--src/Data/Array/Nested/Mixed.hs46
1 files changed, 12 insertions, 34 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 0b263e1..fc1c108 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -367,18 +367,10 @@ class Elt a where
-- this mixed array.
marrayStrides :: Mixed sh a -> Bag [Int]
- -- | Given the shape of this array, an index and a value, write the value at
- -- that index in the vectors.
- mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
-
-- | Given a linear index and a value, write the value at
-- that index in the vectors.
mvecsWriteLinear :: Int -> a -> MixedVecs s sh a -> ST s ()
- -- | Given the shape of this array, an index and a value, write the value at
- -- that index in the vectors.
- mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
-
-- | Given a linear index and a value, write the value at
-- that index in the vectors.
mvecsWritePartialLinear :: Proxy sh -> Int -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
@@ -386,7 +378,6 @@ class Elt a where
-- | Given the shape of this array, finalise the vectors into 'XArray's.
mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
-
-- | Element types for which we have evidence of the (static part of the) shape
-- in a type class constraint. Compare the instance contexts of the instances
-- of this class with those of 'Elt': some instances have an additional
@@ -476,18 +467,9 @@ instance Storable a => Elt (Primitive a) where
mshapeTreeIsEmpty _ () = False
mshowShapeTree _ () = "()"
marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr)
- mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x
mvecsWriteLinear i (Primitive x) (MV_Primitive v) = VSM.write v i x
-- TODO: this use of toVector is suboptimal
- mvecsWritePartial
- :: forall sh' sh s.
- IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
- mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do
- let arrsh = X.shape (ssxFromShX sh') arr
- offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh))
- VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr)
-
mvecsWritePartialLinear
:: forall sh' sh s.
Proxy sh -> Int -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
@@ -564,15 +546,9 @@ instance (Elt a, Elt b) => Elt (a, b) where
mshapeTreeIsEmpty _ (t1, t2) = mshapeTreeIsEmpty (Proxy @a) t1 && mshapeTreeIsEmpty (Proxy @b) t2
mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")"
marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b
- mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
- mvecsWrite sh i x a
- mvecsWrite sh i y b
mvecsWriteLinear i (x, y) (MV_Tup2 a b) = do
mvecsWriteLinear i x a
mvecsWriteLinear i y b
- mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do
- mvecsWritePartial sh i x a
- mvecsWritePartial sh i y b
mvecsWritePartialLinear proxy i (M_Tup2 x y) (MV_Tup2 a b) = do
mvecsWritePartialLinear proxy i x a
mvecsWritePartialLinear proxy i y b
@@ -705,19 +681,9 @@ instance Elt a => Elt (Mixed sh' a) where
marrayStrides (M_Nest _ arr) = marrayStrides arr
- mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs
mvecsWriteLinear :: forall s sh. Int -> Mixed sh' a -> MixedVecs s sh (Mixed sh' a) -> ST s ()
mvecsWriteLinear idx val (MV_Nest _ vecs) = mvecsWritePartialLinear (Proxy @sh) idx val vecs
- mvecsWritePartial
- :: forall sh1 sh2 s.
- IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
- -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
- -> ST s ()
- mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs)
- | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs
-
mvecsWritePartialLinear
:: forall sh1 sh2 s.
Proxy sh1 -> Int -> Mixed sh2 (Mixed sh' a)
@@ -748,6 +714,18 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
+-- | Given the shape of this array, an index and a value, write the value at
+-- that index in the vectors.
+{-# INLINE mvecsWrite #-}
+mvecsWrite :: Elt a => IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
+mvecsWrite sh idx val vecs = mvecsWriteLinear (ixxToLinear sh idx) val vecs
+
+-- | Given the shape of this array, an index and a value, write the value at
+-- that index in the vectors.
+{-# INLINE mvecsWritePartial #-}
+mvecsWritePartial :: forall sh sh' s a. Elt a => IShX sh -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+mvecsWritePartial sh idx val vecs = mvecsWritePartialLinear (Proxy @sh) (ixxToLinear sh idx) val vecs
+
-- TODO: should we provide a function that's just memptyArrayUnsafe but with a size==0 check? That may save someone a transpose somewhere
memptyArray :: forall sh a. KnownElt a => IShX sh -> Mixed (Just 0 : sh) a
memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh)