From af06ef345d22df015ac8a0ab069438c180ab3e94 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Sun, 26 May 2024 11:54:31 +0200
Subject: Fix bug in rerank workaround

---
 src/Data/Array/Mixed.hs | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

(limited to 'src/Data/Array')

diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index cc74b90..080c458 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -626,7 +626,7 @@ rerank :: forall sh sh1 sh2 a b.
 rerank ssh ssh1 ssh2 f xarr@(XArray arr)
   | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)
   = let (sh, _) = shAppSplit (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr)
-    in if sh == completeShXzeros ssh
+    in if any (== 0) (shapeLshape sh)
          then XArray (S.fromList (shapeLshape (shAppend sh (completeShXzeros ssh2))) [])
          else case () of
            () | Dict <- lemKnownNatRankSSX ssh
@@ -653,7 +653,7 @@ rerank2 :: forall sh sh1 sh2 a b c.
 rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2)
   | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)
   = let (sh, _) = shAppSplit (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1)
-    in if sh == completeShXzeros ssh
+    in if any (== 0) (shapeLshape sh)
          then XArray (S.fromList (shapeLshape (shAppend sh (completeShXzeros ssh2))) [])
          else case () of
            () | Dict <- lemKnownNatRankSSX ssh
-- 
cgit v1.2.3-70-g09d2