diff options
Diffstat (limited to 'src/Data/Array/Nested/Internal')
-rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 42 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 11 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 9 |
3 files changed, 60 insertions, 2 deletions
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index b799190..4746f31 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -263,7 +263,7 @@ class Elt a where mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a] -- | Note: this library makes no particular guarantees about the shapes of - -- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the + -- arrays "inside" an empty array. With 'mlift', 'mlift2' and 'mliftL' you can see the -- full 'XArray' and as such you can distinguish different empty arrays by -- the "shapes" of their elements. This information is meaningless, so you -- should not use it. @@ -278,6 +278,14 @@ class Elt a where -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a + -- TODO: mliftL is currently unused. + -- | All arrays in the input must have equal shapes, including subarrays + -- inside their elements. + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) + -> NonEmpty (Mixed sh1 a) -> NonEmpty (Mixed sh2 a) + mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a @@ -364,6 +372,16 @@ instance Storable a => Elt (Primitive a) where , let result = f ZKX a b = M_Primitive (X.shape ssh3 result) result + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) + -> NonEmpty (Mixed sh1 (Primitive a)) -> NonEmpty (Mixed sh2 (Primitive a)) + mliftL ssh2 f l + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + = fmap (\arr -> M_Primitive (X.shape ssh2 arr) arr) $ + f ZKX (fmap (\(M_Primitive _ arr) -> arr) l) + mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) mcast ssh1 sh2 _ (M_Primitive sh1' arr) = @@ -432,6 +450,11 @@ instance (Elt a, Elt b) => Elt (a, b) where mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b) mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b) mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) + mliftL ssh2 f = + let unzipT2l [] = ([], []) + unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2) + unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2) + in uncurry (NE.zipWith M_Tup2) . bimap (mliftL ssh2 f) (mliftL ssh2 f) . unzipT2 mcast ssh1 sh2 psh' (M_Tup2 a b) = M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b) @@ -523,6 +546,23 @@ instance Elt a => Elt (Mixed sh' a) where , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) = f (ssxAppend ssh' sshT) + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray (sh1 ++ shT) b) -> NonEmpty (XArray (sh2 ++ shT) b)) + -> NonEmpty (Mixed sh1 (Mixed sh' a)) -> NonEmpty (Mixed sh2 (Mixed sh' a)) + mliftL ssh2 f l@(M_Nest sh1 arr1 :| _) = + let result = mliftL (ssxAppend ssh2 ssh') f' (fmap (\(M_Nest _ arr) -> arr) l) + (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result)) + in fmap (M_Nest sh2) result + where + ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) + + f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b) + f' sshT + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) + = f (ssxAppend ssh' sshT) + mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2 => StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a) mcast ssh1 sh2 _ (M_Nest sh1T arr) diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index d6e05e6..55ae59f 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -103,7 +103,16 @@ instance Elt a => Elt (Ranked n a) where -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a) mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) = coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ - mlift2 ssh3 f arr1 arr2 + mlift2 ssh3 f arr1 arr2 + + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) + -> NonEmpty (Mixed sh1 (Ranked n a)) -> NonEmpty (Mixed sh2 (Ranked n a)) + mliftL ssh2 f l = + coerce @(NonEmpty (Mixed sh2 (Mixed (Replicate n Nothing) a))) + @(NonEmpty (Mixed sh2 (Ranked n a))) $ + mliftL ssh2 f (coerce l) mcast ssh1 sh2 psh' (M_Ranked arr) = M_Ranked (mcast ssh1 sh2 psh' arr) diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index d1881c1..544a2fa 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -100,6 +100,15 @@ instance Elt a => Elt (Shaped sh a) where coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ mlift2 ssh3 f arr1 arr2 + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) + -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a)) + mliftL ssh2 f l = + coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a))) + @(NonEmpty (Mixed sh2 (Shaped sh a))) $ + mliftL ssh2 f (coerce l) + mcast ssh1 sh2 psh' (M_Shaped arr) = M_Shaped (mcast ssh1 sh2 psh' arr) mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) |