aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMikolaj Konarski <mikolaj.konarski@funktory.com>2026-01-15 14:05:05 +0100
committerMikolaj Konarski <mikolaj.konarski@funktory.com>2026-01-15 14:29:46 +0100
commit103a4fa07ee16106261cfc627422353277667cf8 (patch)
tree871cad9672a7a85d509f175c72c4583a4dd1d7de
parentd0f887c9045960b59ae40bb7d77f3f55cbb9ed02 (diff)
Save some VS.concat by using toVectorListT
-rw-r--r--src/Data/Array/Nested/Mixed.hs14
-rw-r--r--src/Data/Array/XArray.hs15
2 files changed, 19 insertions, 10 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index f0dadb4..39f00fa 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -23,8 +23,11 @@ module Data.Array.Nested.Mixed where
import Prelude hiding (mconcat)
import Control.DeepSeq (NFData(..))
-import Control.Monad (forM_, when)
+import Control.Monad (foldM_, forM_, when)
import Control.Monad.ST
+import Data.Array.Internal qualified as OI
+import Data.Array.Internal.RankedG qualified as ORG
+import Data.Array.Internal.RankedS qualified as ORS
import Data.Array.RankedS qualified as S
import Data.Bifunctor (bimap)
import Data.Coerce
@@ -485,14 +488,17 @@ instance Storable a => Elt (Primitive a) where
marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr)
mvecsWriteLinear i (Primitive x) (MV_Primitive v) = VSM.write v i x
- -- TODO: this use of toVector is suboptimal
+ -- TODO: this use of toVectorListT is suboptimal
mvecsWritePartialLinear
:: forall sh' sh s.
Proxy sh -> Int -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
- mvecsWritePartialLinear _ i (M_Primitive sh' arr) (MV_Primitive v) = do
+ mvecsWritePartialLinear _ i (M_Primitive sh' arr@(XArray (ORS.A (ORG.A sht t)))) (MV_Primitive v) = do
let arrsh = X.shape (ssxFromShX sh') arr
offset = i * shxSize arrsh
- VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr)
+ f off el = do
+ VS.copy (VSM.slice off (VS.length el) v) el
+ return $! off + VS.length el
+ foldM_ f offset (OI.toVectorListT sht t)
mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v
mvecsUnsafeFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.unsafeFreeze v
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs
index f779a53..38ccee6 100644
--- a/src/Data/Array/XArray.hs
+++ b/src/Data/Array/XArray.hs
@@ -17,7 +17,7 @@
module Data.Array.XArray where
import Control.DeepSeq (NFData)
-import Control.Monad (foldM)
+import Control.Monad (foldM_, foldM)
import Control.Monad.ST
import Data.Array.Internal qualified as OI
import Data.Array.Internal.RankedG qualified as ORG
@@ -331,16 +331,19 @@ ravelOuterN k as@(a0 :| _) = runST $ do
len = product sh0
vecSize = k * len
vec <- VSM.unsafeNew vecSize
- let f !n a =
+ let f !n (ORS.A (ORG.A sht t)) =
if | n >= k ->
error $ "ravelOuterN: list too long " ++ show (n, k)
-- if we do this check just once at the end, we may
-- crash instead of producing an accurate error message
- | S.shapeL a == sh0 -> do
- VS.unsafeCopy (VSM.slice (n * len) len vec) (S.toVector a)
- return $! n + 1
+ | sht == sh0 -> do
+ let g off el = do
+ VS.unsafeCopy (VSM.slice off (VS.length el) vec) el
+ return $! off + VS.length el
+ foldM_ g (n * len) (OI.toVectorListT sht t)
+ return $! n + 1
| otherwise ->
- error $ "ravelOuterN: unequal shapes " ++ show (S.shapeL a, sh0)
+ error $ "ravelOuterN: unequal shapes " ++ show (sht, sh0)
nFinal <- foldM f 0 as
if nFinal == k
then S.fromVector (k : sh0) <$> VS.unsafeFreeze vec