aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-20 22:55:59 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-20 22:55:59 +0200
commitbeec5f9ab3f37da78bb95e2631871b4e47b46c57 (patch)
tree8efeae83646526ec7bdeaeb1674fa62f3cea2676 /src/Data
parent8d495b7e6c21fc843f0538711c2203dfb213b7e1 (diff)
Better {from,to}List set
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/Mixed.hs28
-rw-r--r--src/Data/Array/Nested.hs8
-rw-r--r--src/Data/Array/Nested/Internal.hs123
3 files changed, 102 insertions, 57 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 2f23903..1dc6b58 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -728,18 +728,30 @@ sumOuter ssh ssh'
| Refl <- lemAppNil @sh
= sumInner ssh' ssh . transpose2 ssh ssh'
-fromList1 :: forall n sh a. Storable a
- => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
-fromList1 ssh l
+fromListOuter :: forall n sh a. Storable a
+ => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
+fromListOuter ssh l
| Dict <- lemKnownNatRankSSX ssh
= case ssh of
- SKnown m@SNat :!% _ | natVal m /= fromIntegral (length l) ->
- error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++
- "does not match the type (" ++ show (natVal m) ++ ")"
+ SKnown m :!% _ | fromSNat' m /= length l ->
+ error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++
+ "does not match the type (" ++ show (fromSNat' m) ++ ")"
_ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l)))
-toList1 :: Storable a => XArray (n : sh) a -> [XArray sh a]
-toList1 (XArray arr) = coerce (ORB.toList (S.unravel arr))
+toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a]
+toListOuter (XArray arr) = coerce (ORB.toList (S.unravel arr))
+
+fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a
+fromList1 ssh l =
+ let n = length l
+ in case ssh of
+ SKnown m :!% _ | fromSNat' m /= n ->
+ error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++
+ "does not match the type (" ++ show (fromSNat' m) ++ ")"
+ _ -> XArray (S.fromVector [n] (VS.fromListN n l))
+
+toList1 :: Storable a => XArray '[n] a -> [a]
+toList1 (XArray arr) = S.toList arr
-- | Throws if the given shape is not, in fact, empty.
empty :: forall sh a. Storable a => IShX sh -> XArray sh a
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 7298918..e804748 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -9,7 +9,7 @@ module Data.Array.Nested (
rshape, rindex, rindexPartial, rgenerate, rsumOuter1,
rtranspose, rappend, rscalar, rfromVector, rtoVector, runScalar,
rrerank,
- rreplicate, rfromList, rfromList1, rtoList, rtoList1,
+ rreplicate, rfromListOuter, rfromList1, rtoListOuter, rtoList1,
rslice, rrev1, rreshape,
-- ** Lifting orthotope operations to 'Ranked' arrays
rlift,
@@ -26,7 +26,7 @@ module Data.Array.Nested (
sshape, sindex, sindexPartial, sgenerate, ssumOuter1,
stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar,
srerank,
- sreplicate, sfromList, sfromList1, stoList, stoList1,
+ sreplicate, sfromListOuter, sfromList1, stoListOuter, stoList1,
sslice, srev1, sreshape,
-- ** Lifting orthotope operations to 'Shaped' arrays
slift,
@@ -40,13 +40,13 @@ module Data.Array.Nested (
KnownShX(..), StaticShX(..),
mgenerate, mtranspose, mappend, mfromVector, mtoVector, munScalar,
mrerank,
- mreplicate, mfromList, mtoList, mslice, mrev1, mreshape,
+ mreplicate, mfromList1, mtoList1, mslice, mrev1, mreshape,
-- ** Conversions
masXArrayPrim, mfromXArrayPrim,
mtoRanked, mcastToShaped,
-- * Array elements
- Elt(mshape, mindex, mindexPartial, mscalar, mfromList1, mtoList1, mlift, mlift2),
+ Elt(mshape, mindex, mindexPartial, mscalar, mfromListOuter, mtoListOuter, mlift, mlift2),
PrimElt,
Primitive(..),
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index feb0662..ab67dcc 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -529,11 +529,10 @@ class Elt a where
-- @Just m@ and @m@ does not equal the length of the list, a runtime error is
-- thrown.
--
- -- If you want a single-dimensional array from your list, map 'mscalar'
- -- first.
- mfromList1 :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
+ -- Consider also 'mfromListPrim', which can avoid intermediate arrays.
+ mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
- mtoList1 :: Mixed (n : sh) a -> [Mixed sh a]
+ mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a]
-- | Note: this library makes no particular guarantees about the shapes of
-- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the
@@ -603,10 +602,10 @@ instance Storable a => Elt (Primitive a) where
mindex (M_Primitive _ a) i = Primitive (X.index a i)
mindexPartial (M_Primitive sh a) i = M_Primitive (X.shDropIx sh i) (X.indexPartial a i)
mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
- mfromList1 l@(arr1 :| _) =
+ mfromListOuter l@(arr1 :| _) =
let sh = SUnknown (length l) :$% mshape arr1
- in M_Primitive sh (X.fromList1 (X.staticShapeFrom sh) (map (\(M_Primitive _ a) -> a) (toList l)))
- mtoList1 (M_Primitive sh arr) = map (M_Primitive (X.shTail sh)) (X.toList1 arr)
+ in M_Primitive sh (X.fromListOuter (X.staticShapeFrom sh) (map (\(M_Primitive _ a) -> a) (toList l)))
+ mtoListOuter (M_Primitive sh arr) = map (M_Primitive (X.shTail sh)) (X.toListOuter arr)
mlift :: forall sh1 sh2.
StaticShX sh2
@@ -679,10 +678,10 @@ instance (Elt a, Elt b) => Elt (a, b) where
mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)
- mfromList1 l =
- M_Tup2 (mfromList1 ((\(M_Tup2 x _) -> x) <$> l))
- (mfromList1 ((\(M_Tup2 _ y) -> y) <$> l))
- mtoList1 (M_Tup2 a b) = zipWith M_Tup2 (mtoList1 a) (mtoList1 b)
+ mfromListOuter l =
+ M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l))
+ (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l))
+ mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b)
mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b)
mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y)
@@ -728,12 +727,12 @@ instance Elt a => Elt (Mixed sh' a) where
mscalar = M_Nest ZSX
- mfromList1 :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
- mfromList1 l@(arr :| _) =
+ mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
+ mfromListOuter l@(arr :| _) =
M_Nest (SUnknown (length l) :$% mshape arr)
- (mfromList1 ((\(M_Nest _ a) -> a) <$> l))
+ (mfromListOuter ((\(M_Nest _ a) -> a) <$> l))
- mtoList1 (M_Nest sh arr) = map (M_Nest (X.shTail sh)) (mtoList1 arr)
+ mtoListOuter (M_Nest sh arr) = map (M_Nest (X.shTail sh)) (mtoListOuter arr)
mlift :: forall sh1 sh2.
StaticShX sh2
@@ -890,11 +889,22 @@ mtoVectorP (M_Primitive _ v) = X.toVector v
mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a
mtoVector arr = mtoVectorP (coerce toPrimitive arr)
-mfromList :: Elt a => NonEmpty a -> Mixed '[Nothing] a
-mfromList = mfromList1 . fmap mscalar
+mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a
+mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise?
-mtoList :: Elt a => Mixed '[n] a -> [a]
-mtoList = map munScalar . mtoList1
+mtoList1 :: Elt a => Mixed '[n] a -> [a]
+mtoList1 = map munScalar . mtoListOuter
+
+mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a
+mfromListPrim l =
+ let ssh = SUnknown () :!% ZKX
+ xarr = X.fromList1 ssh l
+ in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+
+mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a
+mfromListPrimLinear sh l =
+ let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
+ in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr)
munScalar :: Elt a => Mixed '[] a -> a
munScalar arr = mindex arr ZIX
@@ -1045,12 +1055,12 @@ instance Elt a => Elt (Ranked n a) where
mscalar (Ranked x) = M_Ranked (M_Nest ZSX x)
- mfromList1 :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a)
- mfromList1 l = M_Ranked (mfromList1 (coerce l))
+ mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a)
+ mfromListOuter l = M_Ranked (mfromListOuter (coerce l))
- mtoList1 :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)]
- mtoList1 (M_Ranked arr) =
- coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoList1 arr)
+ mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)]
+ mtoListOuter (M_Ranked arr) =
+ coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr)
mlift :: forall sh1 sh2.
StaticShX sh2
@@ -1143,12 +1153,12 @@ instance Elt a => Elt (Shaped sh a) where
mscalar (Shaped x) = M_Shaped (M_Nest ZSX x)
- mfromList1 :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a)
- mfromList1 l = M_Shaped (mfromList1 (coerce l))
+ mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a)
+ mfromListOuter l = M_Shaped (mfromListOuter (coerce l))
- mtoList1 :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)]
- mtoList1 (M_Shaped arr)
- = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoList1 arr)
+ mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)]
+ mtoListOuter (M_Shaped arr)
+ = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr)
mlift :: forall sh1 sh2.
StaticShX sh2
@@ -1387,21 +1397,32 @@ rtoVectorP = coerce mtoVectorP
rtoVector :: PrimElt a => Ranked n a -> VS.Vector a
rtoVector = coerce mtoVector
-rfromList1 :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
-rfromList1 l
+rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
+rfromListOuter l
| Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
- = Ranked (mfromList1 (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
+ = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
-rfromList :: Elt a => NonEmpty a -> Ranked 1 a
-rfromList l = Ranked (mfromList l)
+rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a
+rfromList1 l = Ranked (mfromList1 l)
-rtoList :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
-rtoList (Ranked arr)
+rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
+rtoListOuter (Ranked arr)
| Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
- = coerce (mtoList1 @a @Nothing @(Replicate n Nothing) arr)
+ = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr)
rtoList1 :: Elt a => Ranked 1 a -> [a]
-rtoList1 = map runScalar . rtoList
+rtoList1 = map runScalar . rtoListOuter
+
+rfromListPrim :: PrimElt a => [a] -> Ranked 1 a
+rfromListPrim l =
+ let ssh = SUnknown () :!% ZKX
+ xarr = X.fromList1 ssh l
+ in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+
+rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a
+rfromListPrimLinear sh l =
+ let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
+ in Ranked $ fromPrimitive $ M_Primitive (shCvtRX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtRX sh) xarr)
rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a
rfromOrthotope sn arr
@@ -1651,17 +1672,29 @@ stoVectorP = coerce mtoVectorP
stoVector :: PrimElt a => Shaped sh a -> VS.Vector a
stoVector = coerce mtoVector
-sfromList1 :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
-sfromList1 sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromList1 (coerce l))
+sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
+sfromListOuter sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromListOuter (coerce l))
-sfromList :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
-sfromList sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList
+sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
+sfromList1 sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1
-stoList :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
-stoList (Shaped arr) = coerce (mtoList1 arr)
+stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
+stoListOuter (Shaped arr) = coerce (mtoListOuter arr)
stoList1 :: Elt a => Shaped '[n] a -> [a]
-stoList1 = map sunScalar . stoList
+stoList1 = map sunScalar . stoListOuter
+
+sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a
+sfromListPrim sn l
+ | Refl <- X.lemAppNil @'[Just n]
+ = let ssh = SUnknown () :!% ZKX
+ xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l)
+ in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr
+
+sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a
+sfromListPrimLinear sh l =
+ let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
+ in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr)
sunScalar :: Elt a => Shaped '[] a -> a
sunScalar arr = sindex arr ZIS