diff options
Diffstat (limited to 'src/Data/Array/Nested/Internal/Mixed.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 35 |
1 files changed, 22 insertions, 13 deletions
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index 619d9bc..8813ab3 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -321,8 +321,8 @@ class Elt a where -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) -> NonEmpty (Mixed sh1 a) -> NonEmpty (Mixed sh2 a) - mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 - => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a + mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a @@ -417,10 +417,11 @@ instance Storable a => Elt (Primitive a) where = fmap (\arr -> M_Primitive (X.shape ssh2 arr) arr) $ f ZKX (fmap (\(M_Primitive _ arr) -> arr) l) - mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 - => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) - mcast ssh1 sh2 _ (M_Primitive sh1' arr) = - let (_, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' + mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) + mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) = + let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' + sh2 = shxCast' sh1 ssh2 in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr) mtranspose perm (M_Primitive sh arr) = @@ -493,8 +494,8 @@ instance (Elt a, Elt b) => Elt (a, b) where unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2) in uncurry (NE.zipWith M_Tup2) . bimap (mliftL ssh2 f) (mliftL ssh2 f) . unzipT2 - mcast ssh1 sh2 psh' (M_Tup2 a b) = - M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b) + mcastPartial ssh1 sh2 psh' (M_Tup2 a b) = + M_Tup2 (mcastPartial ssh1 sh2 psh' a) (mcastPartial ssh1 sh2 psh' b) mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b) mconcat = @@ -600,13 +601,14 @@ instance Elt a => Elt (Mixed sh' a) where , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) = f (ssxAppend ssh' sshT) - mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2 - => StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a) - mcast ssh1 sh2 _ (M_Nest sh1T arr) + mcastPartial :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> StaticShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a) + mcastPartial ssh1 ssh2 _ (M_Nest sh1T arr) | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') - = let (_, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T - in M_Nest (shxAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr) + = let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T + sh2 = shxCast' sh1 ssh2 + in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr) mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) => Perm is -> Mixed sh (Mixed sh' a) @@ -908,3 +910,10 @@ mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c) -> Mixed sh a -> Mixed sh b -> Mixed sh c mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) = fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2)) + +mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) + => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a +mcast ssh2 arr + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr |