aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Nested/Mixed.hs14
-rw-r--r--src/Data/Array/XArray.hs48
2 files changed, 32 insertions, 30 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 667edd7..113a6e3 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -312,7 +312,7 @@ class Elt a where
-- | See 'mfromListOuter'. If the list does not have the given length, a
-- runtime error is thrown. 'mfromListPrimSN' is faster if applicable.
- mfromListOuterSN :: forall sh n. SNat n -> NonEmpty (Mixed sh a) -> Mixed (Just n : sh) a
+ mfromListOuterSN :: forall sh n. SNat n -> NonEmpty (Mixed sh a) -> Mixed (Just n : sh) a
mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a]
@@ -421,8 +421,8 @@ instance Storable a => Elt (Primitive a) where
mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i)
mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
mfromListOuterSN sn l@(arr1 :| _) =
- let sh = SKnown sn :$% mshape arr1
- in M_Primitive sh (X.fromListOuter (ssxFromShX sh) (map (\(M_Primitive _ a) -> a) (toList l)))
+ let sh = mshape arr1
+ in M_Primitive (SKnown sn :$% sh) (X.fromListOuterSN sn sh (map (\(M_Primitive _ a) -> a) (toList l)))
mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr)
{-# INLINE mlift #-}
@@ -922,7 +922,7 @@ mfromListLinear sh l = mreshape sh (mfromList1N (shxSize sh) l)
mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a
mfromList1Prim l =
let ssh = SUnknown () :!% ZKX
- xarr = X.fromList1 ssh l
+ xarr = X.fromList1 l
in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
mfromList1PrimN :: PrimElt a => Int -> [a] -> Mixed '[Nothing] a
@@ -933,9 +933,9 @@ mfromList1PrimN n l =
mfromList1PrimSN :: PrimElt a => SNat n -> [a] -> Mixed '[Just n] a
mfromList1PrimSN sn l =
- let ssh = SKnown sn :!% ZKX
- xarr = X.fromList1 ssh l
- in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+ let ssh = SKnown sn :$% ZSX
+ xarr = X.fromList1SN sn l
+ in fromPrimitive $ M_Primitive ssh xarr
mfromListPrimLinear :: forall sh a. PrimElt a => IShX sh -> [a] -> Mixed sh a
mfromListPrimLinear sh l =
diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs
index 42aed6e..77c0dc0 100644
--- a/src/Data/Array/XArray.hs
+++ b/src/Data/Array/XArray.hs
@@ -310,22 +310,19 @@ sumOuter ssh ssh' arr
-- the list's spine must be fully materialised to compute its length before
-- constructing the array. The list can't be empty (not enough information
-- in the given shape to guess the shape of the empty array, in general).
-fromListOuter :: forall n sh a. Storable a
- => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
-fromListOuter ssh l
- | Dict <- lemKnownNatRankSSX (ssxTail ssh)
+{-# INLINE fromListOuterSN #-}
+fromListOuterSN :: forall n sh a. Storable a
+ => SNat n -> IShX sh -> [XArray sh a] -> XArray (Just n : sh) a
+fromListOuterSN m sh l
+ | Dict <- lemKnownNatRank sh
, let l' = coerce @[XArray sh a] @[S.Array (Rank sh) a] l
- = case ssh of
- _ :!% ZKX ->
- fromList1 ssh (map S.unScalar l')
- _ ->
- let n = case ssh of
- SKnown m :!% _ -> fromSNat' m
- _ -> length l
- in XArray (ravelOuterN n l')
+ = case sh of
+ ZSX -> fromList1SN m (map S.unScalar l')
+ _ -> XArray (ravelOuterN (fromSNat' m) l')
-- | This checks that the list has the given length and that all shapes in the
-- list are equal. The list must be non-empty, and is streamed.
+{-# INLINEABLE ravelOuterN #-}
ravelOuterN :: (KnownNat k, Storable a)
=> Int -> [S.Array k a] -> S.Array (1 + k) a
ravelOuterN 0 _ = error "ravelOuterN: N == 0"
@@ -351,8 +348,8 @@ ravelOuterN k as@(a0 : _) = runST $ do
else error $ "ravelOuterN: list too short " ++ show (nFinal, k)
toListOuter :: forall a n sh. Storable a => XArray (n : sh) a -> [XArray sh a]
-toListOuter (XArray arr@(ORS.A (ORG.A _ t))) =
- case S.shapeL arr of
+toListOuter (XArray arr@(ORS.A (ORG.A shArr t))) =
+ case shArr of
[] -> error "impossible"
0 : _ -> []
-- using orthotope's functions here would entail using rerank, which is slow, so we don't
@@ -362,15 +359,20 @@ toListOuter (XArray arr@(ORS.A (ORG.A _ t))) =
-- | If @n@ is an 'SKnown' dimension, the list is streamed. If @n@ is unknown,
-- the list's spine must be fully materialised to compute its length before
-- constructing the array.
-fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a
-fromList1 ssh l =
- case ssh of
- SKnown m :!% _ ->
- let n = fromSNat' m -- do length check and vector construction simultaneously so that l can be streamed
- in XArray (S.fromVector [n] (VGC.fromListNChecked n l))
- _ ->
- let n = length l -- avoid S.fromList because it takes a length _and_ does another length check itself
- in XArray (S.fromVector [n] (VS.fromListN n l))
+{-# INLINE fromList1 #-}
+fromList1 :: Storable a => [a] -> XArray '[Nothing] a
+fromList1 l =
+ let n = length l -- avoid S.fromList because it takes a length _and_ does another length check itself
+ in XArray (S.fromVector [n] (VS.fromListN n l))
+
+-- | If @n@ is an 'SKnown' dimension, the list is streamed. If @n@ is unknown,
+-- the list's spine must be fully materialised to compute its length before
+-- constructing the array.
+{-# INLINE fromList1SN #-}
+fromList1SN :: Storable a => SNat n -> [a] -> XArray '[Just n] a
+fromList1SN m l =
+ let n = fromSNat' m -- do length check and vector construction simultaneously so that l can be streamed
+ in XArray (S.fromVector [n] (VGC.fromListNChecked n l))
toList1 :: Storable a => XArray '[n] a -> [a]
toList1 (XArray arr) = S.toList arr