diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2026-01-13 12:11:18 +0100 |
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2026-01-13 12:11:18 +0100 |
| commit | d0f887c9045960b59ae40bb7d77f3f55cbb9ed02 (patch) | |
| tree | 3de1d53bc941e8e798e970195faea96a3298236d /src/Data | |
| parent | 6959b7e4769289983f008d558a71fe0dd2e3d279 (diff) | |
Let X.fromListOuterSN and ravelOuterN take NonEmpty
Diffstat (limited to 'src/Data')
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 3 | ||||
| -rw-r--r-- | src/Data/Array/XArray.hs | 13 |
2 files changed, 7 insertions, 9 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index b0d32db..f0dadb4 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -28,7 +28,6 @@ import Control.Monad.ST import Data.Array.RankedS qualified as S import Data.Bifunctor (bimap) import Data.Coerce -import Data.Foldable (toList) import Data.Int import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty(..)) @@ -423,7 +422,7 @@ instance Storable a => Elt (Primitive a) where mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) mfromListOuterSN sn l@(arr1 :| _) = let sh = mshape arr1 - in M_Primitive (SKnown sn :$% sh) (X.fromListOuterSN sn sh (map (\(M_Primitive _ a) -> a) (toList l))) + in M_Primitive (SKnown sn :$% sh) (X.fromListOuterSN sn sh ((\(M_Primitive _ a) -> a) <$> l)) mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) {-# INLINE mlift #-} diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index 77c0dc0..f779a53 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -26,7 +26,7 @@ import Data.Array.RankedS qualified as S import Data.Coerce import Data.Foldable (toList) import Data.Kind -import Data.List.NonEmpty (NonEmpty) +import Data.List.NonEmpty (NonEmpty(..)) import Data.Proxy import Data.Type.Equality import Data.Type.Ord @@ -312,22 +312,21 @@ sumOuter ssh ssh' arr -- in the given shape to guess the shape of the empty array, in general). {-# INLINE fromListOuterSN #-} fromListOuterSN :: forall n sh a. Storable a - => SNat n -> IShX sh -> [XArray sh a] -> XArray (Just n : sh) a + => SNat n -> IShX sh -> NonEmpty (XArray sh a) -> XArray (Just n : sh) a fromListOuterSN m sh l | Dict <- lemKnownNatRank sh - , let l' = coerce @[XArray sh a] @[S.Array (Rank sh) a] l + , let l' = coerce @(NonEmpty (XArray sh a)) @(NonEmpty (S.Array (Rank sh) a)) l = case sh of - ZSX -> fromList1SN m (map S.unScalar l') + ZSX -> fromList1SN m (map S.unScalar (toList l')) _ -> XArray (ravelOuterN (fromSNat' m) l') -- | This checks that the list has the given length and that all shapes in the -- list are equal. The list must be non-empty, and is streamed. {-# INLINEABLE ravelOuterN #-} ravelOuterN :: (KnownNat k, Storable a) - => Int -> [S.Array k a] -> S.Array (1 + k) a + => Int -> NonEmpty (S.Array k a) -> S.Array (1 + k) a ravelOuterN 0 _ = error "ravelOuterN: N == 0" -ravelOuterN _ [] = error "ravelOuterN: empty list" -ravelOuterN k as@(a0 : _) = runST $ do +ravelOuterN k as@(a0 :| _) = runST $ do let sh0 = S.shapeL a0 len = product sh0 vecSize = k * len |
