aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal/Mixed.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-12-11 19:56:28 +0100
committerTom Smeding <tom@tomsmeding.com>2024-12-11 19:57:12 +0100
commita3299c09e0fd12cf73c4a0a9a2ae37b8f69f9b10 (patch)
tree8f28f2cb8034530f20fc56265c64af1164b35776 /src/Data/Array/Nested/Internal/Mixed.hs
parent9570a94d331facc8961be204d7a3010d33146f97 (diff)
Simpler API to mcast
Diffstat (limited to 'src/Data/Array/Nested/Internal/Mixed.hs')
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs35
1 files changed, 22 insertions, 13 deletions
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index 619d9bc..8813ab3 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -321,8 +321,8 @@ class Elt a where
-> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b))
-> NonEmpty (Mixed sh1 a) -> NonEmpty (Mixed sh2 a)
- mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
- => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a
+ mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
+ => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a
mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
=> Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a
@@ -417,10 +417,11 @@ instance Storable a => Elt (Primitive a) where
= fmap (\arr -> M_Primitive (X.shape ssh2 arr) arr) $
f ZKX (fmap (\(M_Primitive _ arr) -> arr) l)
- mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
- => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a)
- mcast ssh1 sh2 _ (M_Primitive sh1' arr) =
- let (_, sh') = shxSplitApp (Proxy @sh') ssh1 sh1'
+ mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
+ => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a)
+ mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) =
+ let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1'
+ sh2 = shxCast' sh1 ssh2
in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr)
mtranspose perm (M_Primitive sh arr) =
@@ -493,8 +494,8 @@ instance (Elt a, Elt b) => Elt (a, b) where
unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2)
in uncurry (NE.zipWith M_Tup2) . bimap (mliftL ssh2 f) (mliftL ssh2 f) . unzipT2
- mcast ssh1 sh2 psh' (M_Tup2 a b) =
- M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b)
+ mcastPartial ssh1 sh2 psh' (M_Tup2 a b) =
+ M_Tup2 (mcastPartial ssh1 sh2 psh' a) (mcastPartial ssh1 sh2 psh' b)
mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b)
mconcat =
@@ -600,13 +601,14 @@ instance Elt a => Elt (Mixed sh' a) where
, Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
= f (ssxAppend ssh' sshT)
- mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2
- => StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a)
- mcast ssh1 sh2 _ (M_Nest sh1T arr)
+ mcastPartial :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2
+ => StaticShX sh1 -> StaticShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a)
+ mcastPartial ssh1 ssh2 _ (M_Nest sh1T arr)
| Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh')
, Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh')
- = let (_, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T
- in M_Nest (shxAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr)
+ = let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T
+ sh2 = shxCast' sh1 ssh2
+ in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr)
mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
=> Perm is -> Mixed sh (Mixed sh' a)
@@ -908,3 +910,10 @@ mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c)
-> Mixed sh a -> Mixed sh b -> Mixed sh c
mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) =
fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2))
+
+mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a)
+ => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a
+mcast ssh2 arr
+ | Refl <- lemAppNil @sh1
+ , Refl <- lemAppNil @sh2
+ = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr