aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Data/Array/Nested/Convert.hs12
1 files changed, 12 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index 659d13c..beb978e 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -140,6 +140,11 @@ data Castable a b where
-> Castable (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a))
CastUnnest :: Castable (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a)
+ CastZip :: (Elt a, Elt b)
+ => Castable (Mixed sh a, Mixed sh b) (Mixed sh (a, b))
+ CastUnzip :: (Elt a, Elt b)
+ => Castable (Mixed sh (a, b)) (Mixed sh a, Mixed sh b)
+
instance Category Castable where
id = CastId
(.) = CastCmp
@@ -186,6 +191,13 @@ castCastable = \c x -> munScalar (go c (mscalar x))
go (CastUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x))
| Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')
= M_Nest esh x
+ go CastZip x =
+ -- no need to check that the two esh's are equal because they were zipped previously
+ let (M_Nest esh x1, M_Nest _ x2) = munzip x
+ in M_Nest esh (mzip x1 x2)
+ go CastUnzip (M_Nest esh x) =
+ let (x1, x2) = munzip x
+ in mzip (M_Nest esh x1) (M_Nest esh x2)
lemRankAppRankEq :: Rank sh ~ Rank sh'
=> Proxy esh -> Proxy sh -> Proxy sh'