diff options
Diffstat (limited to 'src/Data/Array/XArray.hs')
| -rw-r--r-- | src/Data/Array/XArray.hs | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index f779a53..38ccee6 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -17,7 +17,7 @@ module Data.Array.XArray where import Control.DeepSeq (NFData) -import Control.Monad (foldM) +import Control.Monad (foldM_, foldM) import Control.Monad.ST import Data.Array.Internal qualified as OI import Data.Array.Internal.RankedG qualified as ORG @@ -331,16 +331,19 @@ ravelOuterN k as@(a0 :| _) = runST $ do len = product sh0 vecSize = k * len vec <- VSM.unsafeNew vecSize - let f !n a = + let f !n (ORS.A (ORG.A sht t)) = if | n >= k -> error $ "ravelOuterN: list too long " ++ show (n, k) -- if we do this check just once at the end, we may -- crash instead of producing an accurate error message - | S.shapeL a == sh0 -> do - VS.unsafeCopy (VSM.slice (n * len) len vec) (S.toVector a) - return $! n + 1 + | sht == sh0 -> do + let g off el = do + VS.unsafeCopy (VSM.slice off (VS.length el) vec) el + return $! off + VS.length el + foldM_ g (n * len) (OI.toVectorListT sht t) + return $! n + 1 | otherwise -> - error $ "ravelOuterN: unequal shapes " ++ show (S.shapeL a, sh0) + error $ "ravelOuterN: unequal shapes " ++ show (sht, sh0) nFinal <- foldM f 0 as if nFinal == k then S.fromVector (k : sh0) <$> VS.unsafeFreeze vec |
