From 09ec518c63424364e7521697f5c2a1b8f2d82d01 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 4 Jun 2025 15:07:04 +0200 Subject: Add CastZip and CastUnzip --- src/Data/Array/Nested/Convert.hs | 12 ++++++++++++ 1 file changed, 12 insertions(+) (limited to 'src/Data') diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index 659d13c..beb978e 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -140,6 +140,11 @@ data Castable a b where -> Castable (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a)) CastUnnest :: Castable (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a) + CastZip :: (Elt a, Elt b) + => Castable (Mixed sh a, Mixed sh b) (Mixed sh (a, b)) + CastUnzip :: (Elt a, Elt b) + => Castable (Mixed sh (a, b)) (Mixed sh a, Mixed sh b) + instance Category Castable where id = CastId (.) = CastCmp @@ -186,6 +191,13 @@ castCastable = \c x -> munScalar (go c (mscalar x)) go (CastUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x)) | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') = M_Nest esh x + go CastZip x = + -- no need to check that the two esh's are equal because they were zipped previously + let (M_Nest esh x1, M_Nest _ x2) = munzip x + in M_Nest esh (mzip x1 x2) + go CastUnzip (M_Nest esh x) = + let (x1, x2) = munzip x + in mzip (M_Nest esh x1) (M_Nest esh x2) lemRankAppRankEq :: Rank sh ~ Rank sh' => Proxy esh -> Proxy sh -> Proxy sh' -- cgit v1.2.3-70-g09d2