aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
authorMikolaj Konarski <mikolaj.konarski@funktory.com>2026-01-16 19:30:31 +0100
committerMikolaj Konarski <mikolaj.konarski@funktory.com>2026-01-16 19:30:31 +0100
commita02528a1402238add2820d7203ccb38ed9b59f29 (patch)
tree7d6a94b004cc34f258ddf6983b4bd994a86f0e96 /src/Data
parent67eea51eb3cc205c2884de613f1102655276b191 (diff)
Add a bang not to overlap big allocations
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/XArray.hs15
1 files changed, 9 insertions, 6 deletions
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs
index 4f5bb08..a4cc997 100644
--- a/src/Data/Array/XArray.hs
+++ b/src/Data/Array/XArray.hs
@@ -302,10 +302,9 @@ sumOuter ssh ssh' arr
reshapePartial ssh ssh' shF $
arr
--- | If @n@ is an 'SKnown' dimension, the list is streamed. If @n@ is unknown,
--- the list's spine must be fully materialised to compute its length before
--- constructing the array. The list can't be empty (not enough information
--- in the given shape to guess the shape of the empty array, in general).
+-- | This creates an array from a list of arrays of one less dimension.
+-- The list is streamed, it's length is checked and it's verified
+-- that all arrays on the list have the same shape.
{-# INLINE fromListOuterSN #-}
fromListOuterSN :: forall n sh a. Storable a
=> SNat n -> IShX sh -> NonEmpty (XArray sh a) -> XArray (Just n : sh) a
@@ -317,12 +316,16 @@ fromListOuterSN m sh l
_ -> XArray (ravelOuterN (fromSNat' m) l')
-- | This checks that the list has the given length and that all shapes in the
--- list are equal. The list must be non-empty, and is streamed.
+-- list are equal. The list is streamed.
+-- We force the first array on the list early to free some previously used
+-- memory (a lot of memory if it triggers evaluation of a big tensor
+-- all list elements are made from) before @unsafeNew@ allocates
+-- a big chunk of memory again.
{-# INLINEABLE ravelOuterN #-}
ravelOuterN :: (KnownNat k, Storable a)
=> Int -> NonEmpty (S.Array k a) -> S.Array (1 + k) a
ravelOuterN 0 _ = error "ravelOuterN: N == 0"
-ravelOuterN k as@(a0 :| _) = runST $ do
+ravelOuterN k as@(!a0 :| _) = runST $ do
let sh0 = S.shapeL a0
len = product sh0
vecSize = k * len