aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-06-04 14:59:40 +0200
committerTom Smeding <tom@tomsmeding.com>2025-06-04 14:59:40 +0200
commit5d769178ee804c3804c9d7bf155ac2e46407eb3a (patch)
tree0866687a78a5ecdc411a938f288e43eae774fe70 /src/Data/Array/Nested
parent4261cb045081188a48bc8306f173166a79fcb1df (diff)
Add shape checking to [rms]zip
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r--src/Data/Array/Nested/Mixed.hs7
-rw-r--r--src/Data/Array/Nested/Ranked.hs2
-rw-r--r--src/Data/Array/Nested/Shaped.hs2
3 files changed, 7 insertions, 4 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 9ec8d9d..54f8fe6 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -815,8 +815,11 @@ mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr
munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a
munNest (M_Nest _ arr) = arr
-mzip :: Mixed sh a -> Mixed sh b -> Mixed sh (a, b)
-mzip = M_Tup2
+-- | The arguments must have equal shapes. If they do not, an error is raised.
+mzip :: (Elt a, Elt b) => Mixed sh a -> Mixed sh b -> Mixed sh (a, b)
+mzip a b
+ | Just Refl <- shxEqual (mshape a) (mshape b) = M_Tup2 a b
+ | otherwise = error "mzip: unequal shapes"
munzip :: Mixed sh (a, b) -> (Mixed sh a, Mixed sh b)
munzip (M_Tup2 a b) = (a, b)
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs
index 8591af7..97b4c7c 100644
--- a/src/Data/Array/Nested/Ranked.hs
+++ b/src/Data/Array/Nested/Ranked.hs
@@ -191,7 +191,7 @@ runNest rarr@(Ranked (M_Ranked (M_Nest _ arr)))
| Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat))
= Ranked arr
-rzip :: Ranked n a -> Ranked n b -> Ranked n (a, b)
+rzip :: (Elt a, Elt b) => Ranked n a -> Ranked n b -> Ranked n (a, b)
rzip = coerce mzip
runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b)
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
index aaba367..0275aad 100644
--- a/src/Data/Array/Nested/Shaped.hs
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -173,7 +173,7 @@ sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr)))
| Refl <- lemMapJustApp (sshape sarr) (Proxy @sh')
= Shaped arr
-szip :: Shaped sh a -> Shaped sh b -> Shaped sh (a, b)
+szip :: (Elt a, Elt b) => Shaped sh a -> Shaped sh b -> Shaped sh (a, b)
szip = coerce mzip
sunzip :: Shaped sh (a, b) -> (Shaped sh a, Shaped sh b)