diff options
Diffstat (limited to 'src/Data/Array')
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 14 | ||||
| -rw-r--r-- | src/Data/Array/XArray.hs | 48 |
2 files changed, 32 insertions, 30 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 667edd7..113a6e3 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -312,7 +312,7 @@ class Elt a where -- | See 'mfromListOuter'. If the list does not have the given length, a -- runtime error is thrown. 'mfromListPrimSN' is faster if applicable. - mfromListOuterSN :: forall sh n. SNat n -> NonEmpty (Mixed sh a) -> Mixed (Just n : sh) a + mfromListOuterSN :: forall sh n. SNat n -> NonEmpty (Mixed sh a) -> Mixed (Just n : sh) a mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a] @@ -421,8 +421,8 @@ instance Storable a => Elt (Primitive a) where mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i) mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) mfromListOuterSN sn l@(arr1 :| _) = - let sh = SKnown sn :$% mshape arr1 - in M_Primitive sh (X.fromListOuter (ssxFromShX sh) (map (\(M_Primitive _ a) -> a) (toList l))) + let sh = mshape arr1 + in M_Primitive (SKnown sn :$% sh) (X.fromListOuterSN sn sh (map (\(M_Primitive _ a) -> a) (toList l))) mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) {-# INLINE mlift #-} @@ -922,7 +922,7 @@ mfromListLinear sh l = mreshape sh (mfromList1N (shxSize sh) l) mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a mfromList1Prim l = let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l + xarr = X.fromList1 l in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr mfromList1PrimN :: PrimElt a => Int -> [a] -> Mixed '[Nothing] a @@ -933,9 +933,9 @@ mfromList1PrimN n l = mfromList1PrimSN :: PrimElt a => SNat n -> [a] -> Mixed '[Just n] a mfromList1PrimSN sn l = - let ssh = SKnown sn :!% ZKX - xarr = X.fromList1 ssh l - in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr + let ssh = SKnown sn :$% ZSX + xarr = X.fromList1SN sn l + in fromPrimitive $ M_Primitive ssh xarr mfromListPrimLinear :: forall sh a. PrimElt a => IShX sh -> [a] -> Mixed sh a mfromListPrimLinear sh l = diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index 42aed6e..77c0dc0 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -310,22 +310,19 @@ sumOuter ssh ssh' arr -- the list's spine must be fully materialised to compute its length before -- constructing the array. The list can't be empty (not enough information -- in the given shape to guess the shape of the empty array, in general). -fromListOuter :: forall n sh a. Storable a - => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a -fromListOuter ssh l - | Dict <- lemKnownNatRankSSX (ssxTail ssh) +{-# INLINE fromListOuterSN #-} +fromListOuterSN :: forall n sh a. Storable a + => SNat n -> IShX sh -> [XArray sh a] -> XArray (Just n : sh) a +fromListOuterSN m sh l + | Dict <- lemKnownNatRank sh , let l' = coerce @[XArray sh a] @[S.Array (Rank sh) a] l - = case ssh of - _ :!% ZKX -> - fromList1 ssh (map S.unScalar l') - _ -> - let n = case ssh of - SKnown m :!% _ -> fromSNat' m - _ -> length l - in XArray (ravelOuterN n l') + = case sh of + ZSX -> fromList1SN m (map S.unScalar l') + _ -> XArray (ravelOuterN (fromSNat' m) l') -- | This checks that the list has the given length and that all shapes in the -- list are equal. The list must be non-empty, and is streamed. +{-# INLINEABLE ravelOuterN #-} ravelOuterN :: (KnownNat k, Storable a) => Int -> [S.Array k a] -> S.Array (1 + k) a ravelOuterN 0 _ = error "ravelOuterN: N == 0" @@ -351,8 +348,8 @@ ravelOuterN k as@(a0 : _) = runST $ do else error $ "ravelOuterN: list too short " ++ show (nFinal, k) toListOuter :: forall a n sh. Storable a => XArray (n : sh) a -> [XArray sh a] -toListOuter (XArray arr@(ORS.A (ORG.A _ t))) = - case S.shapeL arr of +toListOuter (XArray arr@(ORS.A (ORG.A shArr t))) = + case shArr of [] -> error "impossible" 0 : _ -> [] -- using orthotope's functions here would entail using rerank, which is slow, so we don't @@ -362,15 +359,20 @@ toListOuter (XArray arr@(ORS.A (ORG.A _ t))) = -- | If @n@ is an 'SKnown' dimension, the list is streamed. If @n@ is unknown, -- the list's spine must be fully materialised to compute its length before -- constructing the array. -fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a -fromList1 ssh l = - case ssh of - SKnown m :!% _ -> - let n = fromSNat' m -- do length check and vector construction simultaneously so that l can be streamed - in XArray (S.fromVector [n] (VGC.fromListNChecked n l)) - _ -> - let n = length l -- avoid S.fromList because it takes a length _and_ does another length check itself - in XArray (S.fromVector [n] (VS.fromListN n l)) +{-# INLINE fromList1 #-} +fromList1 :: Storable a => [a] -> XArray '[Nothing] a +fromList1 l = + let n = length l -- avoid S.fromList because it takes a length _and_ does another length check itself + in XArray (S.fromVector [n] (VS.fromListN n l)) + +-- | If @n@ is an 'SKnown' dimension, the list is streamed. If @n@ is unknown, +-- the list's spine must be fully materialised to compute its length before +-- constructing the array. +{-# INLINE fromList1SN #-} +fromList1SN :: Storable a => SNat n -> [a] -> XArray '[Just n] a +fromList1SN m l = + let n = fromSNat' m -- do length check and vector construction simultaneously so that l can be streamed + in XArray (S.fromVector [n] (VGC.fromListNChecked n l)) toList1 :: Storable a => XArray '[n] a -> [a] toList1 (XArray arr) = S.toList arr |
