aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-06-04 15:07:04 +0200
committerTom Smeding <tom@tomsmeding.com>2025-06-04 15:07:04 +0200
commit09ec518c63424364e7521697f5c2a1b8f2d82d01 (patch)
treee066f9f89c537b791f771a28208411b0a79524b3 /src
parent5d769178ee804c3804c9d7bf155ac2e46407eb3a (diff)
Add CastZip and CastUnzip
Diffstat (limited to 'src')
-rw-r--r--src/Data/Array/Nested/Convert.hs12
1 files changed, 12 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index 659d13c..beb978e 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -140,6 +140,11 @@ data Castable a b where
-> Castable (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a))
CastUnnest :: Castable (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a)
+ CastZip :: (Elt a, Elt b)
+ => Castable (Mixed sh a, Mixed sh b) (Mixed sh (a, b))
+ CastUnzip :: (Elt a, Elt b)
+ => Castable (Mixed sh (a, b)) (Mixed sh a, Mixed sh b)
+
instance Category Castable where
id = CastId
(.) = CastCmp
@@ -186,6 +191,13 @@ castCastable = \c x -> munScalar (go c (mscalar x))
go (CastUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x))
| Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')
= M_Nest esh x
+ go CastZip x =
+ -- no need to check that the two esh's are equal because they were zipped previously
+ let (M_Nest esh x1, M_Nest _ x2) = munzip x
+ in M_Nest esh (mzip x1 x2)
+ go CastUnzip (M_Nest esh x) =
+ let (x1, x2) = munzip x
+ in mzip (M_Nest esh x1) (M_Nest esh x2)
lemRankAppRankEq :: Rank sh ~ Rank sh'
=> Proxy esh -> Proxy sh -> Proxy sh'