aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-20 22:55:59 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-20 22:55:59 +0200
commitbeec5f9ab3f37da78bb95e2631871b4e47b46c57 (patch)
tree8efeae83646526ec7bdeaeb1674fa62f3cea2676 /src/Data/Array/Nested
parent8d495b7e6c21fc843f0538711c2203dfb213b7e1 (diff)
Better {from,to}List set
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r--src/Data/Array/Nested/Internal.hs123
1 files changed, 78 insertions, 45 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index feb0662..ab67dcc 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -529,11 +529,10 @@ class Elt a where
-- @Just m@ and @m@ does not equal the length of the list, a runtime error is
-- thrown.
--
- -- If you want a single-dimensional array from your list, map 'mscalar'
- -- first.
- mfromList1 :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
+ -- Consider also 'mfromListPrim', which can avoid intermediate arrays.
+ mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
- mtoList1 :: Mixed (n : sh) a -> [Mixed sh a]
+ mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a]
-- | Note: this library makes no particular guarantees about the shapes of
-- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the
@@ -603,10 +602,10 @@ instance Storable a => Elt (Primitive a) where
mindex (M_Primitive _ a) i = Primitive (X.index a i)
mindexPartial (M_Primitive sh a) i = M_Primitive (X.shDropIx sh i) (X.indexPartial a i)
mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
- mfromList1 l@(arr1 :| _) =
+ mfromListOuter l@(arr1 :| _) =
let sh = SUnknown (length l) :$% mshape arr1
- in M_Primitive sh (X.fromList1 (X.staticShapeFrom sh) (map (\(M_Primitive _ a) -> a) (toList l)))
- mtoList1 (M_Primitive sh arr) = map (M_Primitive (X.shTail sh)) (X.toList1 arr)
+ in M_Primitive sh (X.fromListOuter (X.staticShapeFrom sh) (map (\(M_Primitive _ a) -> a) (toList l)))
+ mtoListOuter (M_Primitive sh arr) = map (M_Primitive (X.shTail sh)) (X.toListOuter arr)
mlift :: forall sh1 sh2.
StaticShX sh2
@@ -679,10 +678,10 @@ instance (Elt a, Elt b) => Elt (a, b) where
mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)
- mfromList1 l =
- M_Tup2 (mfromList1 ((\(M_Tup2 x _) -> x) <$> l))
- (mfromList1 ((\(M_Tup2 _ y) -> y) <$> l))
- mtoList1 (M_Tup2 a b) = zipWith M_Tup2 (mtoList1 a) (mtoList1 b)
+ mfromListOuter l =
+ M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l))
+ (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l))
+ mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b)
mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b)
mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y)
@@ -728,12 +727,12 @@ instance Elt a => Elt (Mixed sh' a) where
mscalar = M_Nest ZSX
- mfromList1 :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
- mfromList1 l@(arr :| _) =
+ mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
+ mfromListOuter l@(arr :| _) =
M_Nest (SUnknown (length l) :$% mshape arr)
- (mfromList1 ((\(M_Nest _ a) -> a) <$> l))
+ (mfromListOuter ((\(M_Nest _ a) -> a) <$> l))
- mtoList1 (M_Nest sh arr) = map (M_Nest (X.shTail sh)) (mtoList1 arr)
+ mtoListOuter (M_Nest sh arr) = map (M_Nest (X.shTail sh)) (mtoListOuter arr)
mlift :: forall sh1 sh2.
StaticShX sh2
@@ -890,11 +889,22 @@ mtoVectorP (M_Primitive _ v) = X.toVector v
mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a
mtoVector arr = mtoVectorP (coerce toPrimitive arr)
-mfromList :: Elt a => NonEmpty a -> Mixed '[Nothing] a
-mfromList = mfromList1 . fmap mscalar
+mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a
+mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise?
-mtoList :: Elt a => Mixed '[n] a -> [a]
-mtoList = map munScalar . mtoList1
+mtoList1 :: Elt a => Mixed '[n] a -> [a]
+mtoList1 = map munScalar . mtoListOuter
+
+mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a
+mfromListPrim l =
+ let ssh = SUnknown () :!% ZKX
+ xarr = X.fromList1 ssh l
+ in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+
+mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a
+mfromListPrimLinear sh l =
+ let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
+ in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr)
munScalar :: Elt a => Mixed '[] a -> a
munScalar arr = mindex arr ZIX
@@ -1045,12 +1055,12 @@ instance Elt a => Elt (Ranked n a) where
mscalar (Ranked x) = M_Ranked (M_Nest ZSX x)
- mfromList1 :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a)
- mfromList1 l = M_Ranked (mfromList1 (coerce l))
+ mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a)
+ mfromListOuter l = M_Ranked (mfromListOuter (coerce l))
- mtoList1 :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)]
- mtoList1 (M_Ranked arr) =
- coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoList1 arr)
+ mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)]
+ mtoListOuter (M_Ranked arr) =
+ coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr)
mlift :: forall sh1 sh2.
StaticShX sh2
@@ -1143,12 +1153,12 @@ instance Elt a => Elt (Shaped sh a) where
mscalar (Shaped x) = M_Shaped (M_Nest ZSX x)
- mfromList1 :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a)
- mfromList1 l = M_Shaped (mfromList1 (coerce l))
+ mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a)
+ mfromListOuter l = M_Shaped (mfromListOuter (coerce l))
- mtoList1 :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)]
- mtoList1 (M_Shaped arr)
- = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoList1 arr)
+ mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)]
+ mtoListOuter (M_Shaped arr)
+ = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr)
mlift :: forall sh1 sh2.
StaticShX sh2
@@ -1387,21 +1397,32 @@ rtoVectorP = coerce mtoVectorP
rtoVector :: PrimElt a => Ranked n a -> VS.Vector a
rtoVector = coerce mtoVector
-rfromList1 :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
-rfromList1 l
+rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
+rfromListOuter l
| Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
- = Ranked (mfromList1 (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
+ = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
-rfromList :: Elt a => NonEmpty a -> Ranked 1 a
-rfromList l = Ranked (mfromList l)
+rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a
+rfromList1 l = Ranked (mfromList1 l)
-rtoList :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
-rtoList (Ranked arr)
+rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
+rtoListOuter (Ranked arr)
| Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
- = coerce (mtoList1 @a @Nothing @(Replicate n Nothing) arr)
+ = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr)
rtoList1 :: Elt a => Ranked 1 a -> [a]
-rtoList1 = map runScalar . rtoList
+rtoList1 = map runScalar . rtoListOuter
+
+rfromListPrim :: PrimElt a => [a] -> Ranked 1 a
+rfromListPrim l =
+ let ssh = SUnknown () :!% ZKX
+ xarr = X.fromList1 ssh l
+ in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+
+rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a
+rfromListPrimLinear sh l =
+ let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
+ in Ranked $ fromPrimitive $ M_Primitive (shCvtRX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtRX sh) xarr)
rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a
rfromOrthotope sn arr
@@ -1651,17 +1672,29 @@ stoVectorP = coerce mtoVectorP
stoVector :: PrimElt a => Shaped sh a -> VS.Vector a
stoVector = coerce mtoVector
-sfromList1 :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
-sfromList1 sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromList1 (coerce l))
+sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
+sfromListOuter sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromListOuter (coerce l))
-sfromList :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
-sfromList sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList
+sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
+sfromList1 sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1
-stoList :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
-stoList (Shaped arr) = coerce (mtoList1 arr)
+stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
+stoListOuter (Shaped arr) = coerce (mtoListOuter arr)
stoList1 :: Elt a => Shaped '[n] a -> [a]
-stoList1 = map sunScalar . stoList
+stoList1 = map sunScalar . stoListOuter
+
+sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a
+sfromListPrim sn l
+ | Refl <- X.lemAppNil @'[Just n]
+ = let ssh = SUnknown () :!% ZKX
+ xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l)
+ in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr
+
+sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a
+sfromListPrimLinear sh l =
+ let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
+ in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr)
sunScalar :: Elt a => Shaped '[] a -> a
sunScalar arr = sindex arr ZIS