aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/Data/Array/Nested.hs16
-rw-r--r--src/Data/Array/Nested/Mixed.hs105
-rw-r--r--src/Data/Array/Nested/Ranked.hs57
-rw-r--r--src/Data/Array/Nested/Ranked/Base.hs4
-rw-r--r--src/Data/Array/Nested/Shaped.hs40
-rw-r--r--src/Data/Array/Nested/Shaped/Base.hs4
-rw-r--r--src/Data/Array/Nested/Trace.hs2
-rw-r--r--src/Data/Array/XArray.hs34
-rw-r--r--src/Data/Vector/Generic/Checked.hs39
9 files changed, 225 insertions, 76 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index c3635e9..8fb3bd1 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -11,7 +11,11 @@ module Data.Array.Nested (
remptyArray,
rrerank,
rreplicate, rreplicateScal,
- rfromList1, rfromListOuter, rfromListLinear, rfromListPrim, rfromListPrimLinear,
+ rfromList1, rfromList1N,
+ rfromListOuter, rfromListOuterN,
+ rfromListLinear,
+ rfromList1Prim, rfromList1PrimN,
+ rfromListPrimLinear,
rtoList, rtoListOuter, rtoListLinear,
rslice, rrev1, rreshape, rflatten, riota,
rminIndexPrim, rmaxIndexPrim, rdot1Inner, rdot,
@@ -38,7 +42,7 @@ module Data.Array.Nested (
semptyArray,
srerank,
sreplicate, sreplicateScal,
- sfromList1, sfromListOuter, sfromListLinear, sfromListPrim, sfromListPrimLinear,
+ sfromList1, sfromListOuter, sfromListLinear, sfromList1Prim, sfromListPrimLinear,
stoList, stoListOuter, stoListLinear,
sslice, srev1, sreshape, sflatten, siota,
sminIndexPrim, smaxIndexPrim, sdot1Inner, sdot,
@@ -66,9 +70,13 @@ module Data.Array.Nested (
memptyArray,
mrerank,
mreplicate, mreplicateScal,
- mfromList1, mfromListOuter, mfromListLinear, mfromListPrim, mfromListPrimLinear,
+ mfromList1, mfromList1N, mfromList1SN,
+ mfromListOuter, mfromListOuterN, mfromListOuterSN,
+ mfromListLinear,
+ mfromList1Prim, mfromList1PrimN, mfromList1PrimSN,
+ mfromListPrimLinear,
mtoList, mtoListOuter, mtoListLinear,
- mslice, mrev1, mreshape, mflatten, miota,
+ msliceN, msliceSN, mrev1, mreshape, mflatten, miota,
mminIndexPrim, mmaxIndexPrim, mdot1Inner, mdot,
mnest, munNest, mzip, munzip,
-- ** Lifting orthotope operations to 'Mixed' arrays
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 6b152f7..a2787b8 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -7,6 +7,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
@@ -307,15 +308,9 @@ class Elt a where
mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a
mscalar :: a -> Mixed '[] a
- -- | All arrays in the list, even subarrays inside @a@, must have the same
- -- shape; if they do not, a runtime error will be thrown. See the
- -- documentation of 'mgenerate' for more information about this restriction.
- -- Furthermore, the length of the list must correspond with @n@: if @n@ is
- -- @Just m@ and @m@ does not equal the length of the list, a runtime error is
- -- thrown.
- --
- -- Consider also 'mfromListPrim', which can avoid intermediate arrays.
- mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
+ -- | See 'mfromListOuter'. If the list does not have the given length, a
+ -- runtime error is thrown. 'mfromListPrimSN' is faster if applicable.
+ mfromListOuterSN :: forall sh n. SNat n -> NonEmpty (Mixed sh a) -> Mixed (Just n : sh) a
mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a]
@@ -407,8 +402,8 @@ instance Storable a => Elt (Primitive a) where
mindex (M_Primitive _ a) i = Primitive (X.index a i)
mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i)
mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
- mfromListOuter l@(arr1 :| _) =
- let sh = SUnknown (length l) :$% mshape arr1
+ mfromListOuterSN sn l@(arr1 :| _) =
+ let sh = SKnown sn :$% mshape arr1
in M_Primitive sh (X.fromListOuter (ssxFromShX sh) (map (\(M_Primitive _ a) -> a) (toList l)))
mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr)
@@ -515,9 +510,9 @@ instance (Elt a, Elt b) => Elt (a, b) where
mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)
- mfromListOuter l =
- M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l))
- (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l))
+ mfromListOuterSN sn l =
+ M_Tup2 (mfromListOuterSN sn ((\(M_Tup2 x _) -> x) <$> l))
+ (mfromListOuterSN sn ((\(M_Tup2 _ y) -> y) <$> l))
mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b)
mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b)
mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y)
@@ -578,10 +573,9 @@ instance Elt a => Elt (Mixed sh' a) where
mscalar = M_Nest ZSX
- mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
- mfromListOuter l@(arr :| _) =
- M_Nest (SUnknown (length l) :$% mshape arr)
- (mfromListOuter ((\(M_Nest _ a) -> a) <$> l))
+ mfromListOuterSN sn l@(arr :| _) =
+ M_Nest (SKnown sn :$% mshape arr)
+ (mfromListOuterSN sn ((\(M_Nest _ a) -> a) <$> l))
mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr)
@@ -793,23 +787,76 @@ mtoVectorP (M_Primitive _ v) = X.toVector v
mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a
mtoVector arr = mtoVectorP (toPrimitive arr)
+-- | All arrays in the list, even subarrays inside @a@, must have the same
+-- shape; if they do not, a runtime error will be thrown. See the
+-- documentation of 'mgenerate' for more information about this restriction.
+--
+-- Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'mfromListOuterN' or 'mfromListOuterSN' to be able to
+-- stream the list.
+--
+-- If your array is 1-dimensional and contains scalars, use 'mfromList1Prim'.
+mfromListOuter :: Elt a => NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
+mfromListOuter l = mfromListOuterN (length l) l
+
+-- | See 'mfromListOuter'. If the list does not have the given length, a
+-- runtime error is thrown. 'mfromList1PrimN' is faster if applicable.
+mfromListOuterN :: Elt a => Int -> NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
+mfromListOuterN n l =
+ withSomeSNat (fromIntegral n) $ \case
+ Just sn -> mcastPartial (SKnown sn :!% ZKX) (SUnknown () :!% ZKX) Proxy (mfromListOuterSN sn l)
+ Nothing -> error $ "mfromListOuterN: length negative (" ++ show n ++ ")"
+
+-- | Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'mfromList1N' or 'mfromList1SN' to be able to stream the
+-- list.
+--
+-- If the elements are scalars, 'mfromList1Prim' is faster.
mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a
-mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise?
+mfromList1 = mfromListOuter . fmap mscalar
+
+-- | If the elements are scalars, 'mfromList1PrimN' is faster. A runtime error
+-- is thrown if the list length does not match the given length.
+mfromList1N :: Elt a => Int -> NonEmpty a -> Mixed '[Nothing] a
+mfromList1N n = mfromListOuterN n . fmap mscalar
+
+-- | If the elements are scalars, 'mfromList1PrimSN' is faster. A runtime error
+-- is thrown if the list length does not match the given length.
+mfromList1SN :: Elt a => SNat n -> NonEmpty a -> Mixed '[Just n] a
+mfromList1SN sn = mfromListOuterSN sn . fmap mscalar
-- This forall is there so that a simple type application can constrain the
-- shape, in case the user wants to use OverloadedLists for the shape.
+-- | If the elements are scalars, 'mfromListPrimLinear' is faster.
mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a
-mfromListLinear sh l = mreshape sh (mfromList1 l)
+mfromListLinear sh l = mreshape sh (mfromList1N (shxSize sh) l)
-mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a
-mfromListPrim l =
+-- | Because the length of the list is unknown, its spine must be materialised
+-- in memory in order to compute its length. If its length is already known,
+-- use 'mfromList1PrimN' or 'mfromList1PrimSN' to be able to stream the list.
+mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a
+mfromList1Prim l =
let ssh = SUnknown () :!% ZKX
xarr = X.fromList1 ssh l
in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
-mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a
+mfromList1PrimN :: PrimElt a => Int -> [a] -> Mixed '[Nothing] a
+mfromList1PrimN n l =
+ withSomeSNat (fromIntegral n) $ \case
+ Just sn -> mcastPartial (SKnown sn :!% ZKX) (SUnknown () :!% ZKX) Proxy (mfromList1PrimSN sn l)
+ Nothing -> error $ "mfromList1PrimN: length negative (" ++ show n ++ ")"
+
+mfromList1PrimSN :: PrimElt a => SNat n -> [a] -> Mixed '[Just n] a
+mfromList1PrimSN sn l =
+ let ssh = SKnown sn :!% ZKX
+ xarr = X.fromList1 ssh l
+ in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+
+mfromListPrimLinear :: forall sh a. PrimElt a => IShX sh -> [a] -> Mixed sh a
mfromListPrimLinear sh l =
- let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
+ let M_Primitive _ xarr = toPrimitive (mfromList1PrimN (shxSize sh) l)
in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr)
mtoList :: Elt a => Mixed '[n] a -> [a]
@@ -872,14 +919,14 @@ mreplicateScal :: forall sh a. PrimElt a
=> IShX sh -> a -> Mixed sh a
mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x)
-mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
-mslice i n arr =
+msliceN :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
+msliceN i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr
+
+msliceSN :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
+msliceSN i n arr =
let _ :$% sh = mshape arr
in mlift (SKnown n :!% ssxFromShX sh) (\_ -> X.slice i n) arr
-msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
-msliceU i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr
-
mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a
mrev1 arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.rev1) arr
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs
index 8b95d0f..5cda531 100644
--- a/src/Data/Array/Nested/Ranked.hs
+++ b/src/Data/Array/Nested/Ranked.hs
@@ -137,24 +137,55 @@ rtoVectorP = coerce mtoVectorP
rtoVector :: PrimElt a => Ranked n a -> VS.Vector a
rtoVector = coerce mtoVector
-rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a
-rfromList1 l = Ranked (mfromList1 l)
-
+-- | All arrays in the list, even subarrays inside @a@, must have the same
+-- shape; if they do not, a runtime error will be thrown. See the
+-- documentation of 'mgenerate' for more information about this restriction.
+--
+-- Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'rfromListOuterN' to be able to stream the list.
+--
+-- If your array is 1-dimensional and contains scalars, use 'rfromList1Prim'.
rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
rfromListOuter l
| Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
= Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
+-- | See 'rfromListOuter'. If the list does not have the given length, a
+-- runtime error is thrown. 'rfromList1PrimN' is faster if applicable.
+rfromListOuterN :: forall n a. Elt a => Int -> NonEmpty (Ranked n a) -> Ranked (n + 1) a
+rfromListOuterN n l
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = Ranked (mfromListOuterN n (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
+
+-- | Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'rfromList1N' to be able to stream the list.
+--
+-- If the elements are scalars, 'rfromList1Prim' is faster.
+rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a
+rfromList1 = coerce mfromList1
+
+-- | If the elements are scalars, 'rfromList1PrimN' is faster. A runtime error
+-- is thrown if the list length does not match the given length.
+rfromList1N :: Elt a => Int -> NonEmpty a -> Ranked 1 a
+rfromList1N = coerce mfromList1N
+
+-- | If the elements are scalars, 'rfromListPrimLinear' is faster.
rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a
-rfromListLinear sh l = rreshape sh (rfromList1 l)
+rfromListLinear sh l = Ranked (mfromListLinear (shxFromShR sh) l)
+
+-- | Because the length of the list is unknown, its spine must be materialised
+-- in memory in order to compute its length. If its length is already known,
+-- use 'rfromList1PrimN' to be able to stream the list.
+rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a
+rfromList1Prim = coerce mfromList1Prim
-rfromListPrim :: PrimElt a => [a] -> Ranked 1 a
-rfromListPrim l = Ranked (mfromListPrim l)
+rfromList1PrimN :: PrimElt a => Int -> [a] -> Ranked 1 a
+rfromList1PrimN = coerce mfromList1PrimN
-rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a
-rfromListPrimLinear sh l =
- let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
- in Ranked $ fromPrimitive $ M_Primitive (shxFromShR sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShR sh) xarr)
+rfromListPrimLinear :: forall n a. PrimElt a => IShR n -> [a] -> Ranked n a
+rfromListPrimLinear sh l = Ranked (mfromListPrimLinear (shxFromShR sh) l)
rtoList :: Elt a => Ranked 1 a -> [a]
rtoList = map runScalar . rtoListOuter
@@ -254,11 +285,9 @@ rreplicateScal :: forall n a. PrimElt a
rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x)
rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
-rslice i n arr
+rslice i n (Ranked arr)
| Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
- = rlift (rrank arr)
- (\_ -> X.sliceU i n)
- arr
+ = Ranked (msliceN i n arr)
rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a
rrev1 arr =
diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs
index 54baa32..e8aba3f 100644
--- a/src/Data/Array/Nested/Ranked/Base.hs
+++ b/src/Data/Array/Nested/Ranked/Base.hs
@@ -97,8 +97,8 @@ instance Elt a => Elt (Ranked n a) where
mscalar (Ranked x) = M_Ranked (M_Nest ZSX x)
- mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a)
- mfromListOuter l = M_Ranked (mfromListOuter (coerce l))
+ mfromListOuterSN :: SNat m -> NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Just m : sh) (Ranked n a)
+ mfromListOuterSN sn l = M_Ranked (mfromListOuterSN sn (coerce l))
mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)]
mtoListOuter (M_Ranked arr) =
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index 198a068..4a3ed8d 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -123,26 +123,38 @@ stoVectorP = coerce mtoVectorP
stoVector :: PrimElt a => Shaped sh a -> VS.Vector a
stoVector = coerce mtoVector
-sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
-sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1
-
+-- | All arrays in the list, even subarrays inside @a@, must have the same
+-- shape; if they do not, a runtime error will be thrown. See the
+-- documentation of 'mgenerate' for more information about this restriction.
+--
+-- Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'sfromListOuterSN' to be able to stream the list.
+--
+-- If your array is 1-dimensional and contains scalars, use 'sfromList1Prim'.
sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
-sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l))
+sfromListOuter = coerce mfromListOuterSN
+
+-- | Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'sfromList1SN' to be able to stream the list.
+--
+-- If the elements are scalars, 'sfromList1Prim' is faster.
+sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
+sfromList1 = coerce mfromList1SN
+-- | If the elements are scalars, 'sfromListPrimLinear' is faster.
sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a
sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l)
-sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a
-sfromListPrim sn l
- | Refl <- lemAppNil @'[Just n]
- = let ssh = SUnknown () :!% ZKX
- xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l)
- in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr
+-- | Because the length of the list is unknown, its spine must be materialised
+-- in memory in order to compute its length. If its length is already known,
+-- use 'sfromList1PrimN' to be able to stream the list.
+sfromList1Prim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a
+sfromList1Prim = coerce mfromList1PrimSN
-sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a
-sfromListPrimLinear sh l =
- let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
- in Shaped $ fromPrimitive $ M_Primitive (shxFromShS sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShS sh) xarr)
+sfromListPrimLinear :: forall sh a. PrimElt a => ShS sh -> [a] -> Shaped sh a
+sfromListPrimLinear sh l = Shaped (mfromListPrimLinear (shxFromShS sh) l)
stoList :: Elt a => Shaped '[n] a -> [a]
stoList = map sunScalar . stoListOuter
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs
index 75e6fcb..b313b2d 100644
--- a/src/Data/Array/Nested/Shaped/Base.hs
+++ b/src/Data/Array/Nested/Shaped/Base.hs
@@ -90,8 +90,8 @@ instance Elt a => Elt (Shaped sh a) where
mscalar (Shaped x) = M_Shaped (M_Nest ZSX x)
- mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a)
- mfromListOuter l = M_Shaped (mfromListOuter (coerce l))
+ mfromListOuterSN :: SNat n -> NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Just n : sh') (Shaped sh a)
+ mfromListOuterSN sn l = M_Shaped (mfromListOuterSN sn (coerce l))
mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)]
mtoListOuter (M_Shaped arr)
diff --git a/src/Data/Array/Nested/Trace.hs b/src/Data/Array/Nested/Trace.hs
index 8a29aa5..6a2890f 100644
--- a/src/Data/Array/Nested/Trace.hs
+++ b/src/Data/Array/Nested/Trace.hs
@@ -69,4 +69,4 @@ import Data.Array.Nested.Trace.TH
$(concat <$> mapM convertFun
- ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromList1, 'rfromListOuter, 'rfromListLinear, 'rfromListPrim, 'rfromListPrimLinear, 'rtoList, 'rtoListOuter, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromList1, 'sfromListOuter, 'sfromListLinear, 'sfromListPrim, 'sfromListPrimLinear, 'stoList, 'stoListOuter, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromList1, 'mfromListOuter, 'mfromListLinear, 'mfromListPrim, 'mfromListPrimLinear, 'mtoList, 'mtoListOuter, 'mtoListLinear, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'convert, 'mquotArray, 'mremArray, 'matan2Array])
+ ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromList1, 'rfromList1N, 'rfromListOuter, 'rfromListOuterN, 'rfromListLinear, 'rfromList1Prim, 'rfromList1PrimN, 'rfromListPrimLinear, 'rtoList, 'rtoListOuter, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromList1, 'sfromListOuter, 'sfromListLinear, 'sfromList1Prim, 'sfromListPrimLinear, 'stoList, 'stoListOuter, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromList1, 'mfromList1N, 'mfromList1SN, 'mfromListOuter, 'mfromListOuterN, 'mfromListOuterSN, 'mfromListLinear, 'mfromList1Prim, 'mfromList1PrimN, 'mfromList1PrimSN, 'mfromListPrimLinear, 'mtoList, 'mtoListOuter, 'mtoListLinear, 'msliceN, 'msliceSN, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'convert, 'mquotArray, 'mremArray, 'matan2Array])
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs
index 29154f1..948a50e 100644
--- a/src/Data/Array/XArray.hs
+++ b/src/Data/Array/XArray.hs
@@ -27,6 +27,8 @@ import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
import Data.Type.Equality
import Data.Type.Ord
+import Data.Vector qualified as V
+import Data.Vector.Generic.Checked qualified as VGC
import Data.Vector.Storable qualified as VS
import Foreign.Storable (Storable)
import GHC.Generics (Generic)
@@ -291,15 +293,23 @@ sumOuter ssh ssh' arr
reshapePartial ssh ssh' shF $
arr
+-- | If @n@ is an 'SKnown' dimension, the list is streamed. If @n@ is unknown,
+-- the list's spine must be fully materialised to compute its length before
+-- constructing the array.
fromListOuter :: forall n sh a. Storable a
=> StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
fromListOuter ssh l
| Dict <- lemKnownNatRankSSX ssh
+ , let l' = coerce @[XArray sh a] @[S.Array (Rank sh) a] l
= case ssh of
- SKnown m :!% _ | fromSNat' m /= length l ->
- error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++
- "does not match the type (" ++ show (fromSNat' m) ++ ")"
- _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l)))
+ _ :!% ZKX ->
+ fromList1 ssh (map S.unScalar l')
+ SKnown m :!% _ ->
+ let n = fromSNat' m
+ in XArray (S.ravel (ORB.fromVector [n] (VGC.fromListNChecked n l')))
+ _ ->
+ let n = length l
+ in XArray (S.ravel (ORB.fromVector [n] (V.fromListN n l')))
toListOuter :: forall a n sh. Storable a => XArray (n : sh) a -> [XArray sh a]
toListOuter (XArray arr@(ORS.A (ORG.A _ t))) =
@@ -310,14 +320,18 @@ toListOuter (XArray arr@(ORS.A (ORG.A _ t))) =
[_] | Refl <- (unsafeCoerceRefl :: sh :~: '[]) -> coerce (map S.scalar $ S.toList arr)
n : sh -> coerce $ map (ORG.A sh . OI.indexT t) [0 .. n - 1]
+-- | If @n@ is an 'SKnown' dimension, the list is streamed. If @n@ is unknown,
+-- the list's spine must be fully materialised to compute its length before
+-- constructing the array.
fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a
fromList1 ssh l =
- let n = length l
- in case ssh of
- SKnown m :!% _ | fromSNat' m /= n ->
- error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++
- "does not match the type (" ++ show (fromSNat' m) ++ ")"
- _ -> XArray (S.fromVector [n] (VS.fromListN n l))
+ case ssh of
+ SKnown m :!% _ ->
+ let n = fromSNat' m -- do length check and vector construction simultaneously so that l can be streamed
+ in XArray (S.fromVector [n] (VGC.fromListNChecked n l))
+ _ ->
+ let n = length l -- avoid S.fromList because it takes a length _and_ does another length check itself
+ in XArray (S.fromVector [n] (VS.fromListN n l))
toList1 :: Storable a => XArray '[n] a -> [a]
toList1 (XArray arr) = S.toList arr
diff --git a/src/Data/Vector/Generic/Checked.hs b/src/Data/Vector/Generic/Checked.hs
new file mode 100644
index 0000000..d173bbf
--- /dev/null
+++ b/src/Data/Vector/Generic/Checked.hs
@@ -0,0 +1,39 @@
+{-# LANGUAGE CPP #-}
+module Data.Vector.Generic.Checked (
+ fromListNChecked,
+) where
+
+import qualified Data.Stream.Monadic as Stream
+import qualified Data.Vector.Fusion.Bundle.Monadic as VBM
+import qualified Data.Vector.Fusion.Bundle.Size as VBS
+import qualified Data.Vector.Fusion.Util as VFU
+import qualified Data.Vector.Generic as VG
+
+-- for INLINE_FUSED and INLINE_INNER
+#include "vector.h"
+
+
+-- These functions are copied over and lightly edited from the vector and
+-- vector-stream packages, and thus inherit their BSD-3-Clause license with:
+-- Copyright (c) 2008-2012, Roman Leshchinskiy
+-- 2020-2022, Alexey Kuleshevich
+-- 2020-2022, Aleksey Khudyakov
+-- 2020-2022, Andrew Lelechenko
+
+fromListNChecked :: VG.Vector v a => Int -> [a] -> v a
+{-# INLINE fromListNChecked #-}
+fromListNChecked n = VG.unstream . bundleFromListNChecked n
+
+bundleFromListNChecked :: Int -> [a] -> VBM.Bundle VFU.Id v a
+{-# INLINE_FUSED bundleFromListNChecked #-}
+bundleFromListNChecked nTop xsTop
+ | nTop < 0 = error "fromListNChecked: length negative"
+ | otherwise =
+ VBM.fromStream (Stream.Stream step (xsTop, nTop)) (VBS.Max (VFU.delay_inline max nTop 0))
+ where
+ {-# INLINE_INNER step #-}
+ step (xs,n) | n == 0 = case xs of
+ [] -> return Stream.Done
+ _:_ -> error "fromListNChecked: list too long"
+ step (x:xs,n) = return (Stream.Yield x (xs,n-1))
+ step ([],_) = error "fromListNChecked: list too short"