aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-26 11:54:31 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-26 11:54:55 +0200
commitaf06ef345d22df015ac8a0ab069438c180ab3e94 (patch)
treefc49ba2866ac22a00039cda7e038c8806b24edec /src/Data/Array
parent34a9ac8e4497e776c3ca499c41ef749f4edf8383 (diff)
Fix bug in rerank workaround
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Mixed.hs4
1 files changed, 2 insertions, 2 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index cc74b90..080c458 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -626,7 +626,7 @@ rerank :: forall sh sh1 sh2 a b.
rerank ssh ssh1 ssh2 f xarr@(XArray arr)
| Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)
= let (sh, _) = shAppSplit (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr)
- in if sh == completeShXzeros ssh
+ in if any (== 0) (shapeLshape sh)
then XArray (S.fromList (shapeLshape (shAppend sh (completeShXzeros ssh2))) [])
else case () of
() | Dict <- lemKnownNatRankSSX ssh
@@ -653,7 +653,7 @@ rerank2 :: forall sh sh1 sh2 a b c.
rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2)
| Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)
= let (sh, _) = shAppSplit (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1)
- in if sh == completeShXzeros ssh
+ in if any (== 0) (shapeLshape sh)
then XArray (S.fromList (shapeLshape (shAppend sh (completeShXzeros ssh2))) [])
else case () of
() | Dict <- lemKnownNatRankSSX ssh