diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-20 22:55:59 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-20 22:55:59 +0200 |
commit | beec5f9ab3f37da78bb95e2631871b4e47b46c57 (patch) | |
tree | 8efeae83646526ec7bdeaeb1674fa62f3cea2676 | |
parent | 8d495b7e6c21fc843f0538711c2203dfb213b7e1 (diff) |
Better {from,to}List set
-rw-r--r-- | src/Data/Array/Mixed.hs | 28 | ||||
-rw-r--r-- | src/Data/Array/Nested.hs | 8 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 123 |
3 files changed, 102 insertions, 57 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 2f23903..1dc6b58 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -728,18 +728,30 @@ sumOuter ssh ssh' | Refl <- lemAppNil @sh = sumInner ssh' ssh . transpose2 ssh ssh' -fromList1 :: forall n sh a. Storable a - => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a -fromList1 ssh l +fromListOuter :: forall n sh a. Storable a + => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a +fromListOuter ssh l | Dict <- lemKnownNatRankSSX ssh = case ssh of - SKnown m@SNat :!% _ | natVal m /= fromIntegral (length l) -> - error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++ - "does not match the type (" ++ show (natVal m) ++ ")" + SKnown m :!% _ | fromSNat' m /= length l -> + error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++ + "does not match the type (" ++ show (fromSNat' m) ++ ")" _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) -toList1 :: Storable a => XArray (n : sh) a -> [XArray sh a] -toList1 (XArray arr) = coerce (ORB.toList (S.unravel arr)) +toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a] +toListOuter (XArray arr) = coerce (ORB.toList (S.unravel arr)) + +fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a +fromList1 ssh l = + let n = length l + in case ssh of + SKnown m :!% _ | fromSNat' m /= n -> + error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++ + "does not match the type (" ++ show (fromSNat' m) ++ ")" + _ -> XArray (S.fromVector [n] (VS.fromListN n l)) + +toList1 :: Storable a => XArray '[n] a -> [a] +toList1 (XArray arr) = S.toList arr -- | Throws if the given shape is not, in fact, empty. empty :: forall sh a. Storable a => IShX sh -> XArray sh a diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 7298918..e804748 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -9,7 +9,7 @@ module Data.Array.Nested ( rshape, rindex, rindexPartial, rgenerate, rsumOuter1, rtranspose, rappend, rscalar, rfromVector, rtoVector, runScalar, rrerank, - rreplicate, rfromList, rfromList1, rtoList, rtoList1, + rreplicate, rfromListOuter, rfromList1, rtoListOuter, rtoList1, rslice, rrev1, rreshape, -- ** Lifting orthotope operations to 'Ranked' arrays rlift, @@ -26,7 +26,7 @@ module Data.Array.Nested ( sshape, sindex, sindexPartial, sgenerate, ssumOuter1, stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar, srerank, - sreplicate, sfromList, sfromList1, stoList, stoList1, + sreplicate, sfromListOuter, sfromList1, stoListOuter, stoList1, sslice, srev1, sreshape, -- ** Lifting orthotope operations to 'Shaped' arrays slift, @@ -40,13 +40,13 @@ module Data.Array.Nested ( KnownShX(..), StaticShX(..), mgenerate, mtranspose, mappend, mfromVector, mtoVector, munScalar, mrerank, - mreplicate, mfromList, mtoList, mslice, mrev1, mreshape, + mreplicate, mfromList1, mtoList1, mslice, mrev1, mreshape, -- ** Conversions masXArrayPrim, mfromXArrayPrim, mtoRanked, mcastToShaped, -- * Array elements - Elt(mshape, mindex, mindexPartial, mscalar, mfromList1, mtoList1, mlift, mlift2), + Elt(mshape, mindex, mindexPartial, mscalar, mfromListOuter, mtoListOuter, mlift, mlift2), PrimElt, Primitive(..), diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index feb0662..ab67dcc 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -529,11 +529,10 @@ class Elt a where -- @Just m@ and @m@ does not equal the length of the list, a runtime error is -- thrown. -- - -- If you want a single-dimensional array from your list, map 'mscalar' - -- first. - mfromList1 :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a + -- Consider also 'mfromListPrim', which can avoid intermediate arrays. + mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a - mtoList1 :: Mixed (n : sh) a -> [Mixed sh a] + mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a] -- | Note: this library makes no particular guarantees about the shapes of -- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the @@ -603,10 +602,10 @@ instance Storable a => Elt (Primitive a) where mindex (M_Primitive _ a) i = Primitive (X.index a i) mindexPartial (M_Primitive sh a) i = M_Primitive (X.shDropIx sh i) (X.indexPartial a i) mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) - mfromList1 l@(arr1 :| _) = + mfromListOuter l@(arr1 :| _) = let sh = SUnknown (length l) :$% mshape arr1 - in M_Primitive sh (X.fromList1 (X.staticShapeFrom sh) (map (\(M_Primitive _ a) -> a) (toList l))) - mtoList1 (M_Primitive sh arr) = map (M_Primitive (X.shTail sh)) (X.toList1 arr) + in M_Primitive sh (X.fromListOuter (X.staticShapeFrom sh) (map (\(M_Primitive _ a) -> a) (toList l))) + mtoListOuter (M_Primitive sh arr) = map (M_Primitive (X.shTail sh)) (X.toListOuter arr) mlift :: forall sh1 sh2. StaticShX sh2 @@ -679,10 +678,10 @@ instance (Elt a, Elt b) => Elt (a, b) where mindex (M_Tup2 a b) i = (mindex a i, mindex b i) mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) - mfromList1 l = - M_Tup2 (mfromList1 ((\(M_Tup2 x _) -> x) <$> l)) - (mfromList1 ((\(M_Tup2 _ y) -> y) <$> l)) - mtoList1 (M_Tup2 a b) = zipWith M_Tup2 (mtoList1 a) (mtoList1 b) + mfromListOuter l = + M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l)) + (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l)) + mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b) mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b) mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) @@ -728,12 +727,12 @@ instance Elt a => Elt (Mixed sh' a) where mscalar = M_Nest ZSX - mfromList1 :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) - mfromList1 l@(arr :| _) = + mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) + mfromListOuter l@(arr :| _) = M_Nest (SUnknown (length l) :$% mshape arr) - (mfromList1 ((\(M_Nest _ a) -> a) <$> l)) + (mfromListOuter ((\(M_Nest _ a) -> a) <$> l)) - mtoList1 (M_Nest sh arr) = map (M_Nest (X.shTail sh)) (mtoList1 arr) + mtoListOuter (M_Nest sh arr) = map (M_Nest (X.shTail sh)) (mtoListOuter arr) mlift :: forall sh1 sh2. StaticShX sh2 @@ -890,11 +889,22 @@ mtoVectorP (M_Primitive _ v) = X.toVector v mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a mtoVector arr = mtoVectorP (coerce toPrimitive arr) -mfromList :: Elt a => NonEmpty a -> Mixed '[Nothing] a -mfromList = mfromList1 . fmap mscalar +mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a +mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise? -mtoList :: Elt a => Mixed '[n] a -> [a] -mtoList = map munScalar . mtoList1 +mtoList1 :: Elt a => Mixed '[n] a -> [a] +mtoList1 = map munScalar . mtoListOuter + +mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a +mfromListPrim l = + let ssh = SUnknown () :!% ZKX + xarr = X.fromList1 ssh l + in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr + +mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a +mfromListPrimLinear sh l = + let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr) munScalar :: Elt a => Mixed '[] a -> a munScalar arr = mindex arr ZIX @@ -1045,12 +1055,12 @@ instance Elt a => Elt (Ranked n a) where mscalar (Ranked x) = M_Ranked (M_Nest ZSX x) - mfromList1 :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a) - mfromList1 l = M_Ranked (mfromList1 (coerce l)) + mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a) + mfromListOuter l = M_Ranked (mfromListOuter (coerce l)) - mtoList1 :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] - mtoList1 (M_Ranked arr) = - coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoList1 arr) + mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] + mtoListOuter (M_Ranked arr) = + coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr) mlift :: forall sh1 sh2. StaticShX sh2 @@ -1143,12 +1153,12 @@ instance Elt a => Elt (Shaped sh a) where mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) - mfromList1 :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) - mfromList1 l = M_Shaped (mfromList1 (coerce l)) + mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) + mfromListOuter l = M_Shaped (mfromListOuter (coerce l)) - mtoList1 :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] - mtoList1 (M_Shaped arr) - = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoList1 arr) + mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] + mtoListOuter (M_Shaped arr) + = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) mlift :: forall sh1 sh2. StaticShX sh2 @@ -1387,21 +1397,32 @@ rtoVectorP = coerce mtoVectorP rtoVector :: PrimElt a => Ranked n a -> VS.Vector a rtoVector = coerce mtoVector -rfromList1 :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a -rfromList1 l +rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a +rfromListOuter l | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n - = Ranked (mfromList1 (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) + = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) -rfromList :: Elt a => NonEmpty a -> Ranked 1 a -rfromList l = Ranked (mfromList l) +rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a +rfromList1 l = Ranked (mfromList1 l) -rtoList :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] -rtoList (Ranked arr) +rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] +rtoListOuter (Ranked arr) | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n - = coerce (mtoList1 @a @Nothing @(Replicate n Nothing) arr) + = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) rtoList1 :: Elt a => Ranked 1 a -> [a] -rtoList1 = map runScalar . rtoList +rtoList1 = map runScalar . rtoListOuter + +rfromListPrim :: PrimElt a => [a] -> Ranked 1 a +rfromListPrim l = + let ssh = SUnknown () :!% ZKX + xarr = X.fromList1 ssh l + in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr + +rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a +rfromListPrimLinear sh l = + let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + in Ranked $ fromPrimitive $ M_Primitive (shCvtRX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtRX sh) xarr) rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a rfromOrthotope sn arr @@ -1651,17 +1672,29 @@ stoVectorP = coerce mtoVectorP stoVector :: PrimElt a => Shaped sh a -> VS.Vector a stoVector = coerce mtoVector -sfromList1 :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a -sfromList1 sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromList1 (coerce l)) +sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a +sfromListOuter sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromListOuter (coerce l)) -sfromList :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a -sfromList sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList +sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a +sfromList1 sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1 -stoList :: Elt a => Shaped (n : sh) a -> [Shaped sh a] -stoList (Shaped arr) = coerce (mtoList1 arr) +stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] +stoListOuter (Shaped arr) = coerce (mtoListOuter arr) stoList1 :: Elt a => Shaped '[n] a -> [a] -stoList1 = map sunScalar . stoList +stoList1 = map sunScalar . stoListOuter + +sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a +sfromListPrim sn l + | Refl <- X.lemAppNil @'[Just n] + = let ssh = SUnknown () :!% ZKX + xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l) + in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr + +sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a +sfromListPrimLinear sh l = + let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr) sunScalar :: Elt a => Shaped '[] a -> a sunScalar arr = sindex arr ZIS |