From d0f887c9045960b59ae40bb7d77f3f55cbb9ed02 Mon Sep 17 00:00:00 2001 From: Mikolaj Konarski Date: Tue, 13 Jan 2026 12:11:18 +0100 Subject: Let X.fromListOuterSN and ravelOuterN take NonEmpty --- src/Data/Array/Nested/Mixed.hs | 3 +-- src/Data/Array/XArray.hs | 13 ++++++------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index b0d32db..f0dadb4 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -28,7 +28,6 @@ import Control.Monad.ST import Data.Array.RankedS qualified as S import Data.Bifunctor (bimap) import Data.Coerce -import Data.Foldable (toList) import Data.Int import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty(..)) @@ -423,7 +422,7 @@ instance Storable a => Elt (Primitive a) where mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) mfromListOuterSN sn l@(arr1 :| _) = let sh = mshape arr1 - in M_Primitive (SKnown sn :$% sh) (X.fromListOuterSN sn sh (map (\(M_Primitive _ a) -> a) (toList l))) + in M_Primitive (SKnown sn :$% sh) (X.fromListOuterSN sn sh ((\(M_Primitive _ a) -> a) <$> l)) mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) {-# INLINE mlift #-} diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index 77c0dc0..f779a53 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -26,7 +26,7 @@ import Data.Array.RankedS qualified as S import Data.Coerce import Data.Foldable (toList) import Data.Kind -import Data.List.NonEmpty (NonEmpty) +import Data.List.NonEmpty (NonEmpty(..)) import Data.Proxy import Data.Type.Equality import Data.Type.Ord @@ -312,22 +312,21 @@ sumOuter ssh ssh' arr -- in the given shape to guess the shape of the empty array, in general). {-# INLINE fromListOuterSN #-} fromListOuterSN :: forall n sh a. Storable a - => SNat n -> IShX sh -> [XArray sh a] -> XArray (Just n : sh) a + => SNat n -> IShX sh -> NonEmpty (XArray sh a) -> XArray (Just n : sh) a fromListOuterSN m sh l | Dict <- lemKnownNatRank sh - , let l' = coerce @[XArray sh a] @[S.Array (Rank sh) a] l + , let l' = coerce @(NonEmpty (XArray sh a)) @(NonEmpty (S.Array (Rank sh) a)) l = case sh of - ZSX -> fromList1SN m (map S.unScalar l') + ZSX -> fromList1SN m (map S.unScalar (toList l')) _ -> XArray (ravelOuterN (fromSNat' m) l') -- | This checks that the list has the given length and that all shapes in the -- list are equal. The list must be non-empty, and is streamed. {-# INLINEABLE ravelOuterN #-} ravelOuterN :: (KnownNat k, Storable a) - => Int -> [S.Array k a] -> S.Array (1 + k) a + => Int -> NonEmpty (S.Array k a) -> S.Array (1 + k) a ravelOuterN 0 _ = error "ravelOuterN: N == 0" -ravelOuterN _ [] = error "ravelOuterN: empty list" -ravelOuterN k as@(a0 : _) = runST $ do +ravelOuterN k as@(a0 :| _) = runST $ do let sh0 = S.shapeL a0 len = product sh0 vecSize = k * len -- cgit v1.2.3-70-g09d2