aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs42
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs11
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs9
3 files changed, 60 insertions, 2 deletions
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index b799190..4746f31 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -263,7 +263,7 @@ class Elt a where
mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a]
-- | Note: this library makes no particular guarantees about the shapes of
- -- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the
+ -- arrays "inside" an empty array. With 'mlift', 'mlift2' and 'mliftL' you can see the
-- full 'XArray' and as such you can distinguish different empty arrays by
-- the "shapes" of their elements. This information is meaningless, so you
-- should not use it.
@@ -278,6 +278,14 @@ class Elt a where
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
-> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a
+ -- TODO: mliftL is currently unused.
+ -- | All arrays in the input must have equal shapes, including subarrays
+ -- inside their elements.
+ mliftL :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
+ -> NonEmpty (Mixed sh1 a) -> NonEmpty (Mixed sh2 a)
+
mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
=> StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a
@@ -364,6 +372,16 @@ instance Storable a => Elt (Primitive a) where
, let result = f ZKX a b
= M_Primitive (X.shape ssh3 result) result
+ mliftL :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
+ -> NonEmpty (Mixed sh1 (Primitive a)) -> NonEmpty (Mixed sh2 (Primitive a))
+ mliftL ssh2 f l
+ | Refl <- lemAppNil @sh1
+ , Refl <- lemAppNil @sh2
+ = fmap (\arr -> M_Primitive (X.shape ssh2 arr) arr) $
+ f ZKX (fmap (\(M_Primitive _ arr) -> arr) l)
+
mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
=> StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a)
mcast ssh1 sh2 _ (M_Primitive sh1' arr) =
@@ -432,6 +450,11 @@ instance (Elt a, Elt b) => Elt (a, b) where
mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b)
mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b)
mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y)
+ mliftL ssh2 f =
+ let unzipT2l [] = ([], [])
+ unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2)
+ unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2)
+ in uncurry (NE.zipWith M_Tup2) . bimap (mliftL ssh2 f) (mliftL ssh2 f) . unzipT2
mcast ssh1 sh2 psh' (M_Tup2 a b) =
M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b)
@@ -523,6 +546,23 @@ instance Elt a => Elt (Mixed sh' a) where
, Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT)
= f (ssxAppend ssh' sshT)
+ mliftL :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray (sh1 ++ shT) b) -> NonEmpty (XArray (sh2 ++ shT) b))
+ -> NonEmpty (Mixed sh1 (Mixed sh' a)) -> NonEmpty (Mixed sh2 (Mixed sh' a))
+ mliftL ssh2 f l@(M_Nest sh1 arr1 :| _) =
+ let result = mliftL (ssxAppend ssh2 ssh') f' (fmap (\(M_Nest _ arr) -> arr) l)
+ (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result))
+ in fmap (M_Nest sh2) result
+ where
+ ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1)))
+
+ f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b)
+ f' sshT
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
+ , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
+ = f (ssxAppend ssh' sshT)
+
mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2
=> StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a)
mcast ssh1 sh2 _ (M_Nest sh1T arr)
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index d6e05e6..55ae59f 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -103,7 +103,16 @@ instance Elt a => Elt (Ranked n a) where
-> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a)
mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) =
coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $
- mlift2 ssh3 f arr1 arr2
+ mlift2 ssh3 f arr1 arr2
+
+ mliftL :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
+ -> NonEmpty (Mixed sh1 (Ranked n a)) -> NonEmpty (Mixed sh2 (Ranked n a))
+ mliftL ssh2 f l =
+ coerce @(NonEmpty (Mixed sh2 (Mixed (Replicate n Nothing) a)))
+ @(NonEmpty (Mixed sh2 (Ranked n a))) $
+ mliftL ssh2 f (coerce l)
mcast ssh1 sh2 psh' (M_Ranked arr) = M_Ranked (mcast ssh1 sh2 psh' arr)
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index d1881c1..544a2fa 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -100,6 +100,15 @@ instance Elt a => Elt (Shaped sh a) where
coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
mlift2 ssh3 f arr1 arr2
+ mliftL :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
+ -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a))
+ mliftL ssh2 f l =
+ coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a)))
+ @(NonEmpty (Mixed sh2 (Shaped sh a))) $
+ mliftL ssh2 f (coerce l)
+
mcast ssh1 sh2 psh' (M_Shaped arr) = M_Shaped (mcast ssh1 sh2 psh' arr)
mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr)