aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/XArray.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/XArray.hs')
-rw-r--r--src/Data/Array/XArray.hs34
1 files changed, 24 insertions, 10 deletions
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs
index 29154f1..948a50e 100644
--- a/src/Data/Array/XArray.hs
+++ b/src/Data/Array/XArray.hs
@@ -27,6 +27,8 @@ import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
import Data.Type.Equality
import Data.Type.Ord
+import Data.Vector qualified as V
+import Data.Vector.Generic.Checked qualified as VGC
import Data.Vector.Storable qualified as VS
import Foreign.Storable (Storable)
import GHC.Generics (Generic)
@@ -291,15 +293,23 @@ sumOuter ssh ssh' arr
reshapePartial ssh ssh' shF $
arr
+-- | If @n@ is an 'SKnown' dimension, the list is streamed. If @n@ is unknown,
+-- the list's spine must be fully materialised to compute its length before
+-- constructing the array.
fromListOuter :: forall n sh a. Storable a
=> StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
fromListOuter ssh l
| Dict <- lemKnownNatRankSSX ssh
+ , let l' = coerce @[XArray sh a] @[S.Array (Rank sh) a] l
= case ssh of
- SKnown m :!% _ | fromSNat' m /= length l ->
- error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++
- "does not match the type (" ++ show (fromSNat' m) ++ ")"
- _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l)))
+ _ :!% ZKX ->
+ fromList1 ssh (map S.unScalar l')
+ SKnown m :!% _ ->
+ let n = fromSNat' m
+ in XArray (S.ravel (ORB.fromVector [n] (VGC.fromListNChecked n l')))
+ _ ->
+ let n = length l
+ in XArray (S.ravel (ORB.fromVector [n] (V.fromListN n l')))
toListOuter :: forall a n sh. Storable a => XArray (n : sh) a -> [XArray sh a]
toListOuter (XArray arr@(ORS.A (ORG.A _ t))) =
@@ -310,14 +320,18 @@ toListOuter (XArray arr@(ORS.A (ORG.A _ t))) =
[_] | Refl <- (unsafeCoerceRefl :: sh :~: '[]) -> coerce (map S.scalar $ S.toList arr)
n : sh -> coerce $ map (ORG.A sh . OI.indexT t) [0 .. n - 1]
+-- | If @n@ is an 'SKnown' dimension, the list is streamed. If @n@ is unknown,
+-- the list's spine must be fully materialised to compute its length before
+-- constructing the array.
fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a
fromList1 ssh l =
- let n = length l
- in case ssh of
- SKnown m :!% _ | fromSNat' m /= n ->
- error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++
- "does not match the type (" ++ show (fromSNat' m) ++ ")"
- _ -> XArray (S.fromVector [n] (VS.fromListN n l))
+ case ssh of
+ SKnown m :!% _ ->
+ let n = fromSNat' m -- do length check and vector construction simultaneously so that l can be streamed
+ in XArray (S.fromVector [n] (VGC.fromListNChecked n l))
+ _ ->
+ let n = length l -- avoid S.fromList because it takes a length _and_ does another length check itself
+ in XArray (S.fromVector [n] (VS.fromListN n l))
toList1 :: Storable a => XArray '[n] a -> [a]
toList1 (XArray arr) = S.toList arr