summaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r--src/Data/Array/Nested/Internal.hs39
1 files changed, 34 insertions, 5 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 6f0cfc8..c8a0670 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -206,6 +206,8 @@ class Elt a where
-- first.
mfromList :: forall n sh. KnownShapeX (n : sh) => NonEmpty (Mixed sh a) -> Mixed (n : sh) a
+ mtoList :: Mixed (n : sh) a -> [Mixed sh a]
+
-- | Note: this library makes no particular guarantees about the shapes of
-- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the
-- full 'XArray' and as such you can distinguish different empty arrays by
@@ -260,7 +262,8 @@ instance Storable a => Elt (Primitive a) where
mindex (M_Primitive a) i = Primitive (X.index a i)
mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i)
mscalar (Primitive x) = M_Primitive (X.scalar x)
- mfromList l = M_Primitive (X.fromList knownShapeX [x | M_Primitive x <- toList l])
+ mfromList l = M_Primitive (X.fromList knownShapeX (coerce (toList l)))
+ mtoList (M_Primitive arr) = coerce (X.toList arr)
mlift :: forall sh1 sh2.
(Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)
@@ -311,6 +314,7 @@ instance (Elt a, Elt b) => Elt (a, b) where
mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)
mfromList l = M_Tup2 (mfromList ((\(M_Tup2 x _) -> x) <$> l))
(mfromList ((\(M_Tup2 _ y) -> y) <$> l))
+ mtoList (M_Tup2 a b) = zipWith M_Tup2 (mtoList a) (mtoList b)
mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b)
mlift2 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 f a x) (mlift2 f b y)
@@ -348,13 +352,15 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
| Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
= M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
- mscalar x = M_Nest x
+ mscalar = M_Nest
mfromList :: forall n sh. KnownShapeX (n : sh)
=> NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (n : sh) (Mixed sh' a)
mfromList l
| Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @(n : sh)) (knownShapeX @sh'))
- = M_Nest (mfromList ((\(M_Nest x) -> x) <$> l))
+ = M_Nest (mfromList (coerce l))
+
+ mtoList (M_Nest arr) = coerce (mtoList arr)
mlift :: forall sh1 sh2. KnownShapeX sh2
=> (forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)
@@ -492,6 +498,9 @@ mfromVector sh v
mfromList1 :: (KnownShapeX '[n], Elt a) => NonEmpty a -> Mixed '[n] a
mfromList1 = mfromList . fmap mscalar
+mtoList1 :: Elt a => Mixed '[n] a -> [a]
+mtoList1 = map munScalar . mtoList
+
munScalar :: Elt a => Mixed '[] a -> a
munScalar arr = mindex arr IZX
@@ -594,7 +603,11 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where
=> NonEmpty (Mixed sh (Ranked n a)) -> Mixed (m : sh) (Ranked n a)
mfromList l
| Dict <- lemKnownReplicate (Proxy @n)
- = M_Ranked (mfromList ((\(M_Ranked x) -> x) <$> l))
+ = M_Ranked (mfromList (coerce l))
+
+ mtoList (M_Ranked arr)
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = coerce (mtoList arr)
mlift :: forall sh1 sh2. KnownShapeX sh2
=> (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
@@ -718,7 +731,11 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
=> NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (n : sh') (Shaped sh a)
mfromList l
| Dict <- lemKnownMapJust (Proxy @sh)
- = M_Shaped (mfromList ((\(M_Shaped x) -> x) <$> l))
+ = M_Shaped (mfromList (coerce l))
+
+ mtoList (M_Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = coerce (mtoList arr)
mlift :: forall sh1 sh2. KnownShapeX sh2
=> (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
@@ -924,6 +941,12 @@ rfromVector sh v
rfromList1 :: Elt a => NonEmpty a -> Ranked I1 a
rfromList1 = Ranked . mfromList . fmap mscalar
+rtoList :: Elt a => Ranked (S n) a -> [Ranked n a]
+rtoList (Ranked arr) = coerce (mtoList arr)
+
+rtoList1 :: Elt a => Ranked I1 a -> [a]
+rtoList1 = map runScalar . rtoList
+
runScalar :: Elt a => Ranked I0 a -> a
runScalar arr = rindex arr IZR
@@ -1060,6 +1083,12 @@ sfromVector v
sfromList1 :: (KnownNat n, Elt a) => NonEmpty a -> Shaped '[n] a
sfromList1 = Shaped . mfromList . fmap mscalar
+stoList :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
+stoList (Shaped arr) = coerce (mtoList arr)
+
+stoList1 :: Elt a => Shaped '[n] a -> [a]
+stoList1 = map sunScalar . stoList
+
sunScalar :: Elt a => Shaped '[] a -> a
sunScalar arr = sindex arr IZS