diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-05 10:17:18 +0100 |
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-05 23:21:11 +0100 |
| commit | de92940e793ddd6f53837de80ff316398d3e1809 (patch) | |
| tree | c257958d40ea64c4061890b334286b34c25dbe4f | |
| parent | 13a0ad5e2938218dd97c8db49b3da6c5bdd5a5db (diff) | |
Improve runtime and streaming of fromListOuterfromVectorsNChecked
| -rw-r--r-- | src/Data/Array/XArray.hs | 44 |
1 files changed, 38 insertions, 6 deletions
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index ee83654..64bbb0c 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -1,9 +1,11 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE MultiWayIf #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneKindSignatures #-} @@ -15,10 +17,11 @@ module Data.Array.XArray where import Control.DeepSeq (NFData) +import Control.Monad (foldM) +import Control.Monad.ST import Data.Array.Internal qualified as OI import Data.Array.Internal.RankedG qualified as ORG import Data.Array.Internal.RankedS qualified as ORS -import Data.Array.Ranked qualified as ORB import Data.Array.RankedS qualified as S import Data.Coerce import Data.Foldable (toList) @@ -27,9 +30,9 @@ import Data.List.NonEmpty (NonEmpty) import Data.Proxy import Data.Type.Equality import Data.Type.Ord -import Data.Vector qualified as V import Data.Vector.Generic.Checked qualified as VGC import Data.Vector.Storable qualified as VS +import Data.Vector.Storable.Mutable qualified as VSM import Foreign.Storable (Storable) import GHC.Generics (Generic) import GHC.TypeLits @@ -237,6 +240,7 @@ transpose ssh perm (XArray arr) = XArray (unsafeCoerce (S.transpose (permToList' perm) arr)) #endif + -- | The list argument gives indices into the original dimension list. -- -- The permutation (the list) must have length <= @n@. If it is longer, this @@ -303,21 +307,49 @@ sumOuter ssh ssh' arr -- | If @n@ is an 'SKnown' dimension, the list is streamed. If @n@ is unknown, -- the list's spine must be fully materialised to compute its length before --- constructing the array. +-- constructing the array. The list can't be empty (not enough information +-- in the given shape to guess the shape of the empty array, in general). fromListOuter :: forall n sh a. Storable a => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a fromListOuter ssh l - | Dict <- lemKnownNatRankSSX ssh + | Dict <- lemKnownNatRankSSX (ssxTail ssh) , let l' = coerce @[XArray sh a] @[S.Array (Rank sh) a] l = case ssh of _ :!% ZKX -> fromList1 ssh (map S.unScalar l') SKnown m :!% _ -> let n = fromSNat' m - in XArray (S.ravel (ORB.fromVector [n] (VGC.fromListNChecked n l'))) + in XArray (ravelOuterN n l') _ -> let n = length l - in XArray (S.ravel (ORB.fromVector [n] (V.fromListN n l'))) + in XArray (ravelOuterN n l') + +-- This checks that the list has the given length and that each shape +-- on the list equals the given shape. The list is streamed. +-- The list can't be empty. +ravelOuterN :: (KnownNat k, Storable a) + => Int -> [S.Array k a] -> S.Array (1 + k) a +ravelOuterN 0 _ = error "ravelOuterN: N == 0" +ravelOuterN _ [] = error "ravelOuterN: empty list" +ravelOuterN k as@(a0 : _) = runST $ do + let sh0 = S.shapeL a0 + len = product sh0 + vecSize = k * len + vec <- VSM.unsafeNew vecSize + let f !n a = + if | n >= k -> + error $ "ravelOuterN: list too long " ++ show (n, k) + -- if we do this check just once at the end, we may + -- crash instead of producing an accurate error message + | S.shapeL a == sh0 -> do + VS.unsafeCopy (VSM.slice (n * len) len vec) (S.toVector a) + return $! n + 1 + | otherwise -> + error $ "ravelOuterN: unequal shapes " ++ show (S.shapeL a, sh0) + nFinal <- foldM f 0 as + if nFinal == k + then S.fromVector (k : sh0) <$> VS.unsafeFreeze vec + else error $ "ravelOuterN: list too short " ++ show (nFinal, k) toListOuter :: forall a n sh. Storable a => XArray (n : sh) a -> [XArray sh a] toListOuter (XArray arr@(ORS.A (ORG.A _ t))) = |
