From 103a4fa07ee16106261cfc627422353277667cf8 Mon Sep 17 00:00:00 2001 From: Mikolaj Konarski Date: Thu, 15 Jan 2026 14:05:05 +0100 Subject: Save some VS.concat by using toVectorListT --- src/Data/Array/Nested/Mixed.hs | 14 ++++++++++---- src/Data/Array/XArray.hs | 15 +++++++++------ 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index f0dadb4..39f00fa 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -23,8 +23,11 @@ module Data.Array.Nested.Mixed where import Prelude hiding (mconcat) import Control.DeepSeq (NFData(..)) -import Control.Monad (forM_, when) +import Control.Monad (foldM_, forM_, when) import Control.Monad.ST +import Data.Array.Internal qualified as OI +import Data.Array.Internal.RankedG qualified as ORG +import Data.Array.Internal.RankedS qualified as ORS import Data.Array.RankedS qualified as S import Data.Bifunctor (bimap) import Data.Coerce @@ -485,14 +488,17 @@ instance Storable a => Elt (Primitive a) where marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr) mvecsWriteLinear i (Primitive x) (MV_Primitive v) = VSM.write v i x - -- TODO: this use of toVector is suboptimal + -- TODO: this use of toVectorListT is suboptimal mvecsWritePartialLinear :: forall sh' sh s. Proxy sh -> Int -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () - mvecsWritePartialLinear _ i (M_Primitive sh' arr) (MV_Primitive v) = do + mvecsWritePartialLinear _ i (M_Primitive sh' arr@(XArray (ORS.A (ORG.A sht t)))) (MV_Primitive v) = do let arrsh = X.shape (ssxFromShX sh') arr offset = i * shxSize arrsh - VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) + f off el = do + VS.copy (VSM.slice off (VS.length el) v) el + return $! off + VS.length el + foldM_ f offset (OI.toVectorListT sht t) mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v mvecsUnsafeFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.unsafeFreeze v diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index f779a53..38ccee6 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -17,7 +17,7 @@ module Data.Array.XArray where import Control.DeepSeq (NFData) -import Control.Monad (foldM) +import Control.Monad (foldM_, foldM) import Control.Monad.ST import Data.Array.Internal qualified as OI import Data.Array.Internal.RankedG qualified as ORG @@ -331,16 +331,19 @@ ravelOuterN k as@(a0 :| _) = runST $ do len = product sh0 vecSize = k * len vec <- VSM.unsafeNew vecSize - let f !n a = + let f !n (ORS.A (ORG.A sht t)) = if | n >= k -> error $ "ravelOuterN: list too long " ++ show (n, k) -- if we do this check just once at the end, we may -- crash instead of producing an accurate error message - | S.shapeL a == sh0 -> do - VS.unsafeCopy (VSM.slice (n * len) len vec) (S.toVector a) - return $! n + 1 + | sht == sh0 -> do + let g off el = do + VS.unsafeCopy (VSM.slice off (VS.length el) vec) el + return $! off + VS.length el + foldM_ g (n * len) (OI.toVectorListT sht t) + return $! n + 1 | otherwise -> - error $ "ravelOuterN: unequal shapes " ++ show (S.shapeL a, sh0) + error $ "ravelOuterN: unequal shapes " ++ show (sht, sh0) nFinal <- foldM f 0 as if nFinal == k then S.fromVector (k : sh0) <$> VS.unsafeFreeze vec -- cgit v1.2.3-70-g09d2