diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-08 20:09:31 +0100 |
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-08 22:56:43 +0100 |
| commit | 4cd80d336f78b419e9e1b80c2c20e0e07ecf10a0 (patch) | |
| tree | 9ec08ed3adbc388e79d537959cf011d5e43fb49a | |
| parent | 9624e5d90d5f0815bec230cdd2f4e5b406805885 (diff) | |
Express mvecsWrite and mvecsWritePartial using the new methods
and change the type of the latter to make it possible.
This slightly improves performance of horde-ad tests,
before horde-ad even starts using the Linear methods, which improves
performance even more.
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 46 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Base.hs | 20 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 20 |
3 files changed, 12 insertions, 74 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 0b263e1..fc1c108 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -367,18 +367,10 @@ class Elt a where -- this mixed array. marrayStrides :: Mixed sh a -> Bag [Int] - -- | Given the shape of this array, an index and a value, write the value at - -- that index in the vectors. - mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () - -- | Given a linear index and a value, write the value at -- that index in the vectors. mvecsWriteLinear :: Int -> a -> MixedVecs s sh a -> ST s () - -- | Given the shape of this array, an index and a value, write the value at - -- that index in the vectors. - mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () - -- | Given a linear index and a value, write the value at -- that index in the vectors. mvecsWritePartialLinear :: Proxy sh -> Int -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () @@ -386,7 +378,6 @@ class Elt a where -- | Given the shape of this array, finalise the vectors into 'XArray's. mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) - -- | Element types for which we have evidence of the (static part of the) shape -- in a type class constraint. Compare the instance contexts of the instances -- of this class with those of 'Elt': some instances have an additional @@ -476,18 +467,9 @@ instance Storable a => Elt (Primitive a) where mshapeTreeIsEmpty _ () = False mshowShapeTree _ () = "()" marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr) - mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x mvecsWriteLinear i (Primitive x) (MV_Primitive v) = VSM.write v i x -- TODO: this use of toVector is suboptimal - mvecsWritePartial - :: forall sh' sh s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () - mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do - let arrsh = X.shape (ssxFromShX sh') arr - offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) - VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) - mvecsWritePartialLinear :: forall sh' sh s. Proxy sh -> Int -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () @@ -564,15 +546,9 @@ instance (Elt a, Elt b) => Elt (a, b) where mshapeTreeIsEmpty _ (t1, t2) = mshapeTreeIsEmpty (Proxy @a) t1 && mshapeTreeIsEmpty (Proxy @b) t2 mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b - mvecsWrite sh i (x, y) (MV_Tup2 a b) = do - mvecsWrite sh i x a - mvecsWrite sh i y b mvecsWriteLinear i (x, y) (MV_Tup2 a b) = do mvecsWriteLinear i x a mvecsWriteLinear i y b - mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do - mvecsWritePartial sh i x a - mvecsWritePartial sh i y b mvecsWritePartialLinear proxy i (M_Tup2 x y) (MV_Tup2 a b) = do mvecsWritePartialLinear proxy i x a mvecsWritePartialLinear proxy i y b @@ -705,19 +681,9 @@ instance Elt a => Elt (Mixed sh' a) where marrayStrides (M_Nest _ arr) = marrayStrides arr - mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs mvecsWriteLinear :: forall s sh. Int -> Mixed sh' a -> MixedVecs s sh (Mixed sh' a) -> ST s () mvecsWriteLinear idx val (MV_Nest _ vecs) = mvecsWritePartialLinear (Proxy @sh) idx val vecs - mvecsWritePartial - :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) - -> ST s () - mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) - | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs - mvecsWritePartialLinear :: forall sh1 sh2 s. Proxy sh1 -> Int -> Mixed sh2 (Mixed sh' a) @@ -748,6 +714,18 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) +-- | Given the shape of this array, an index and a value, write the value at +-- that index in the vectors. +{-# INLINE mvecsWrite #-} +mvecsWrite :: Elt a => IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () +mvecsWrite sh idx val vecs = mvecsWriteLinear (ixxToLinear sh idx) val vecs + +-- | Given the shape of this array, an index and a value, write the value at +-- that index in the vectors. +{-# INLINE mvecsWritePartial #-} +mvecsWritePartial :: forall sh sh' s a. Elt a => IShX sh -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () +mvecsWritePartial sh idx val vecs = mvecsWritePartialLinear (Proxy @sh) (ixxToLinear sh idx) val vecs + -- TODO: should we provide a function that's just memptyArrayUnsafe but with a size==0 check? That may save someone a transpose somewhere memptyArray :: forall sh a. KnownElt a => IShX sh -> Mixed (Just 0 : sh) a memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh) diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs index 1f12830..ed194a8 100644 --- a/src/Data/Array/Nested/Ranked/Base.hs +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -149,32 +149,12 @@ instance Elt a => Elt (Ranked n a) where marrayStrides (M_Ranked arr) = marrayStrides arr - mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () - mvecsWrite sh idx (Ranked arr) vecs = - mvecsWrite sh idx arr - (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) - vecs) - mvecsWriteLinear :: forall sh s. Int -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () mvecsWriteLinear idx (Ranked arr) vecs = mvecsWriteLinear idx arr (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) vecs) - mvecsWritePartial - :: forall sh sh' s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) - -> MixedVecs s (sh ++ sh') (Ranked n a) - -> ST s () - mvecsWritePartial sh idx arr vecs = - mvecsWritePartial sh idx - (coerce @(Mixed sh' (Ranked n a)) - @(Mixed sh' (Mixed (Replicate n Nothing) a)) - arr) - (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) - @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) - vecs) - mvecsWritePartialLinear :: forall sh sh' s. Proxy sh -> Int -> Mixed sh' (Ranked n a) diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index 9fc2c9a..e5dd852 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -142,32 +142,12 @@ instance Elt a => Elt (Shaped sh a) where marrayStrides (M_Shaped arr) = marrayStrides arr - mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () - mvecsWrite sh idx (Shaped arr) vecs = - mvecsWrite sh idx arr - (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) - vecs) - mvecsWriteLinear :: forall sh' s. Int -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () mvecsWriteLinear idx (Shaped arr) vecs = mvecsWriteLinear idx arr (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) vecs) - mvecsWritePartial - :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) - -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) - -> ST s () - mvecsWritePartial sh idx arr vecs = - mvecsWritePartial sh idx - (coerce @(Mixed sh2 (Shaped sh a)) - @(Mixed sh2 (Mixed (MapJust sh) a)) - arr) - (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) - @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) - vecs) - mvecsWritePartialLinear :: forall sh1 sh2 s. Proxy sh1 -> Int -> Mixed sh2 (Shaped sh a) |
