From ba798129655503d7e69de271d956cceaef4cef56 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 17 Nov 2025 22:16:11 +0100 Subject: Provide explicit-length versions of fromList functions --- src/Data/Array/XArray.hs | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) (limited to 'src/Data/Array/XArray.hs') diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index 29154f1..948a50e 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -27,6 +27,8 @@ import Data.List.NonEmpty (NonEmpty) import Data.Proxy import Data.Type.Equality import Data.Type.Ord +import Data.Vector qualified as V +import Data.Vector.Generic.Checked qualified as VGC import Data.Vector.Storable qualified as VS import Foreign.Storable (Storable) import GHC.Generics (Generic) @@ -291,15 +293,23 @@ sumOuter ssh ssh' arr reshapePartial ssh ssh' shF $ arr +-- | If @n@ is an 'SKnown' dimension, the list is streamed. If @n@ is unknown, +-- the list's spine must be fully materialised to compute its length before +-- constructing the array. fromListOuter :: forall n sh a. Storable a => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a fromListOuter ssh l | Dict <- lemKnownNatRankSSX ssh + , let l' = coerce @[XArray sh a] @[S.Array (Rank sh) a] l = case ssh of - SKnown m :!% _ | fromSNat' m /= length l -> - error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++ - "does not match the type (" ++ show (fromSNat' m) ++ ")" - _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) + _ :!% ZKX -> + fromList1 ssh (map S.unScalar l') + SKnown m :!% _ -> + let n = fromSNat' m + in XArray (S.ravel (ORB.fromVector [n] (VGC.fromListNChecked n l'))) + _ -> + let n = length l + in XArray (S.ravel (ORB.fromVector [n] (V.fromListN n l'))) toListOuter :: forall a n sh. Storable a => XArray (n : sh) a -> [XArray sh a] toListOuter (XArray arr@(ORS.A (ORG.A _ t))) = @@ -310,14 +320,18 @@ toListOuter (XArray arr@(ORS.A (ORG.A _ t))) = [_] | Refl <- (unsafeCoerceRefl :: sh :~: '[]) -> coerce (map S.scalar $ S.toList arr) n : sh -> coerce $ map (ORG.A sh . OI.indexT t) [0 .. n - 1] +-- | If @n@ is an 'SKnown' dimension, the list is streamed. If @n@ is unknown, +-- the list's spine must be fully materialised to compute its length before +-- constructing the array. fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a fromList1 ssh l = - let n = length l - in case ssh of - SKnown m :!% _ | fromSNat' m /= n -> - error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++ - "does not match the type (" ++ show (fromSNat' m) ++ ")" - _ -> XArray (S.fromVector [n] (VS.fromListN n l)) + case ssh of + SKnown m :!% _ -> + let n = fromSNat' m -- do length check and vector construction simultaneously so that l can be streamed + in XArray (S.fromVector [n] (VGC.fromListNChecked n l)) + _ -> + let n = length l -- avoid S.fromList because it takes a length _and_ does another length check itself + in XArray (S.fromVector [n] (VS.fromListN n l)) toList1 :: Storable a => XArray '[n] a -> [a] toList1 (XArray arr) = S.toList arr -- cgit v1.2.3-70-g09d2