From beec5f9ab3f37da78bb95e2631871b4e47b46c57 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 20 May 2024 22:55:59 +0200 Subject: Better {from,to}List set --- src/Data/Array/Nested/Internal.hs | 123 ++++++++++++++++++++++++-------------- 1 file changed, 78 insertions(+), 45 deletions(-) (limited to 'src/Data/Array/Nested') diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index feb0662..ab67dcc 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -529,11 +529,10 @@ class Elt a where -- @Just m@ and @m@ does not equal the length of the list, a runtime error is -- thrown. -- - -- If you want a single-dimensional array from your list, map 'mscalar' - -- first. - mfromList1 :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a + -- Consider also 'mfromListPrim', which can avoid intermediate arrays. + mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a - mtoList1 :: Mixed (n : sh) a -> [Mixed sh a] + mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a] -- | Note: this library makes no particular guarantees about the shapes of -- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the @@ -603,10 +602,10 @@ instance Storable a => Elt (Primitive a) where mindex (M_Primitive _ a) i = Primitive (X.index a i) mindexPartial (M_Primitive sh a) i = M_Primitive (X.shDropIx sh i) (X.indexPartial a i) mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) - mfromList1 l@(arr1 :| _) = + mfromListOuter l@(arr1 :| _) = let sh = SUnknown (length l) :$% mshape arr1 - in M_Primitive sh (X.fromList1 (X.staticShapeFrom sh) (map (\(M_Primitive _ a) -> a) (toList l))) - mtoList1 (M_Primitive sh arr) = map (M_Primitive (X.shTail sh)) (X.toList1 arr) + in M_Primitive sh (X.fromListOuter (X.staticShapeFrom sh) (map (\(M_Primitive _ a) -> a) (toList l))) + mtoListOuter (M_Primitive sh arr) = map (M_Primitive (X.shTail sh)) (X.toListOuter arr) mlift :: forall sh1 sh2. StaticShX sh2 @@ -679,10 +678,10 @@ instance (Elt a, Elt b) => Elt (a, b) where mindex (M_Tup2 a b) i = (mindex a i, mindex b i) mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) - mfromList1 l = - M_Tup2 (mfromList1 ((\(M_Tup2 x _) -> x) <$> l)) - (mfromList1 ((\(M_Tup2 _ y) -> y) <$> l)) - mtoList1 (M_Tup2 a b) = zipWith M_Tup2 (mtoList1 a) (mtoList1 b) + mfromListOuter l = + M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l)) + (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l)) + mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b) mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b) mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) @@ -728,12 +727,12 @@ instance Elt a => Elt (Mixed sh' a) where mscalar = M_Nest ZSX - mfromList1 :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) - mfromList1 l@(arr :| _) = + mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) + mfromListOuter l@(arr :| _) = M_Nest (SUnknown (length l) :$% mshape arr) - (mfromList1 ((\(M_Nest _ a) -> a) <$> l)) + (mfromListOuter ((\(M_Nest _ a) -> a) <$> l)) - mtoList1 (M_Nest sh arr) = map (M_Nest (X.shTail sh)) (mtoList1 arr) + mtoListOuter (M_Nest sh arr) = map (M_Nest (X.shTail sh)) (mtoListOuter arr) mlift :: forall sh1 sh2. StaticShX sh2 @@ -890,11 +889,22 @@ mtoVectorP (M_Primitive _ v) = X.toVector v mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a mtoVector arr = mtoVectorP (coerce toPrimitive arr) -mfromList :: Elt a => NonEmpty a -> Mixed '[Nothing] a -mfromList = mfromList1 . fmap mscalar +mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a +mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise? -mtoList :: Elt a => Mixed '[n] a -> [a] -mtoList = map munScalar . mtoList1 +mtoList1 :: Elt a => Mixed '[n] a -> [a] +mtoList1 = map munScalar . mtoListOuter + +mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a +mfromListPrim l = + let ssh = SUnknown () :!% ZKX + xarr = X.fromList1 ssh l + in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr + +mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a +mfromListPrimLinear sh l = + let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr) munScalar :: Elt a => Mixed '[] a -> a munScalar arr = mindex arr ZIX @@ -1045,12 +1055,12 @@ instance Elt a => Elt (Ranked n a) where mscalar (Ranked x) = M_Ranked (M_Nest ZSX x) - mfromList1 :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a) - mfromList1 l = M_Ranked (mfromList1 (coerce l)) + mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a) + mfromListOuter l = M_Ranked (mfromListOuter (coerce l)) - mtoList1 :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] - mtoList1 (M_Ranked arr) = - coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoList1 arr) + mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] + mtoListOuter (M_Ranked arr) = + coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr) mlift :: forall sh1 sh2. StaticShX sh2 @@ -1143,12 +1153,12 @@ instance Elt a => Elt (Shaped sh a) where mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) - mfromList1 :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) - mfromList1 l = M_Shaped (mfromList1 (coerce l)) + mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) + mfromListOuter l = M_Shaped (mfromListOuter (coerce l)) - mtoList1 :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] - mtoList1 (M_Shaped arr) - = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoList1 arr) + mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] + mtoListOuter (M_Shaped arr) + = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) mlift :: forall sh1 sh2. StaticShX sh2 @@ -1387,21 +1397,32 @@ rtoVectorP = coerce mtoVectorP rtoVector :: PrimElt a => Ranked n a -> VS.Vector a rtoVector = coerce mtoVector -rfromList1 :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a -rfromList1 l +rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a +rfromListOuter l | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n - = Ranked (mfromList1 (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) + = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) -rfromList :: Elt a => NonEmpty a -> Ranked 1 a -rfromList l = Ranked (mfromList l) +rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a +rfromList1 l = Ranked (mfromList1 l) -rtoList :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] -rtoList (Ranked arr) +rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] +rtoListOuter (Ranked arr) | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n - = coerce (mtoList1 @a @Nothing @(Replicate n Nothing) arr) + = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) rtoList1 :: Elt a => Ranked 1 a -> [a] -rtoList1 = map runScalar . rtoList +rtoList1 = map runScalar . rtoListOuter + +rfromListPrim :: PrimElt a => [a] -> Ranked 1 a +rfromListPrim l = + let ssh = SUnknown () :!% ZKX + xarr = X.fromList1 ssh l + in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr + +rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a +rfromListPrimLinear sh l = + let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + in Ranked $ fromPrimitive $ M_Primitive (shCvtRX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtRX sh) xarr) rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a rfromOrthotope sn arr @@ -1651,17 +1672,29 @@ stoVectorP = coerce mtoVectorP stoVector :: PrimElt a => Shaped sh a -> VS.Vector a stoVector = coerce mtoVector -sfromList1 :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a -sfromList1 sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromList1 (coerce l)) +sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a +sfromListOuter sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromListOuter (coerce l)) -sfromList :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a -sfromList sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList +sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a +sfromList1 sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1 -stoList :: Elt a => Shaped (n : sh) a -> [Shaped sh a] -stoList (Shaped arr) = coerce (mtoList1 arr) +stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] +stoListOuter (Shaped arr) = coerce (mtoListOuter arr) stoList1 :: Elt a => Shaped '[n] a -> [a] -stoList1 = map sunScalar . stoList +stoList1 = map sunScalar . stoListOuter + +sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a +sfromListPrim sn l + | Refl <- X.lemAppNil @'[Just n] + = let ssh = SUnknown () :!% ZKX + xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l) + in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr + +sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a +sfromListPrimLinear sh l = + let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr) sunScalar :: Elt a => Shaped '[] a -> a sunScalar arr = sindex arr ZIS -- cgit v1.2.3-70-g09d2