aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Nested/Mixed.hs3
-rw-r--r--src/Data/Array/XArray.hs13
2 files changed, 7 insertions, 9 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index b0d32db..f0dadb4 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -28,7 +28,6 @@ import Control.Monad.ST
import Data.Array.RankedS qualified as S
import Data.Bifunctor (bimap)
import Data.Coerce
-import Data.Foldable (toList)
import Data.Int
import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty(..))
@@ -423,7 +422,7 @@ instance Storable a => Elt (Primitive a) where
mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
mfromListOuterSN sn l@(arr1 :| _) =
let sh = mshape arr1
- in M_Primitive (SKnown sn :$% sh) (X.fromListOuterSN sn sh (map (\(M_Primitive _ a) -> a) (toList l)))
+ in M_Primitive (SKnown sn :$% sh) (X.fromListOuterSN sn sh ((\(M_Primitive _ a) -> a) <$> l))
mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr)
{-# INLINE mlift #-}
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs
index 77c0dc0..f779a53 100644
--- a/src/Data/Array/XArray.hs
+++ b/src/Data/Array/XArray.hs
@@ -26,7 +26,7 @@ import Data.Array.RankedS qualified as S
import Data.Coerce
import Data.Foldable (toList)
import Data.Kind
-import Data.List.NonEmpty (NonEmpty)
+import Data.List.NonEmpty (NonEmpty(..))
import Data.Proxy
import Data.Type.Equality
import Data.Type.Ord
@@ -312,22 +312,21 @@ sumOuter ssh ssh' arr
-- in the given shape to guess the shape of the empty array, in general).
{-# INLINE fromListOuterSN #-}
fromListOuterSN :: forall n sh a. Storable a
- => SNat n -> IShX sh -> [XArray sh a] -> XArray (Just n : sh) a
+ => SNat n -> IShX sh -> NonEmpty (XArray sh a) -> XArray (Just n : sh) a
fromListOuterSN m sh l
| Dict <- lemKnownNatRank sh
- , let l' = coerce @[XArray sh a] @[S.Array (Rank sh) a] l
+ , let l' = coerce @(NonEmpty (XArray sh a)) @(NonEmpty (S.Array (Rank sh) a)) l
= case sh of
- ZSX -> fromList1SN m (map S.unScalar l')
+ ZSX -> fromList1SN m (map S.unScalar (toList l'))
_ -> XArray (ravelOuterN (fromSNat' m) l')
-- | This checks that the list has the given length and that all shapes in the
-- list are equal. The list must be non-empty, and is streamed.
{-# INLINEABLE ravelOuterN #-}
ravelOuterN :: (KnownNat k, Storable a)
- => Int -> [S.Array k a] -> S.Array (1 + k) a
+ => Int -> NonEmpty (S.Array k a) -> S.Array (1 + k) a
ravelOuterN 0 _ = error "ravelOuterN: N == 0"
-ravelOuterN _ [] = error "ravelOuterN: empty list"
-ravelOuterN k as@(a0 : _) = runST $ do
+ravelOuterN k as@(a0 :| _) = runST $ do
let sh0 = S.shapeL a0
len = product sh0
vecSize = k * len