aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r--src/Data/Array/Mixed.hs82
1 files changed, 59 insertions, 23 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 1e8cee2..cc74b90 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -590,22 +590,52 @@ append ssh (XArray a) (XArray b)
| Dict <- lemKnownNatRankSSX ssh
= XArray (S.append a b)
+-- | If the prefix of the shape of the input array (@sh@) is empty (i.e.
+-- contains a zero), then there is no way to deduce the full shape of the output
+-- array (more precisely, the @sh2@ part): that could only come from calling
+-- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in
+-- this case; we choose to fill the shape with zeros wherever we cannot deduce
+-- what it should be.
+--
+-- For example, if:
+--
+-- @
+-- arr :: XArray '[Just 3, Just 0, Just 4, Just 2, Nothing] Int -- of shape [3, 0, 4, 2, 21]
+-- f :: XArray '[Just 2, Nothing] Int -> XArray '[Just 5, Nothing, Just 17] Float
+-- @
+--
+-- then:
+--
+-- @
+-- rerank _ _ _ f arr :: XArray '[Just 3, Just 0, Just 4, Just 5, Nothing, Just 17] Float
+-- @
+--
+-- and this result will have shape @[3, 0, 4, 5, 0, 17]@. Note the second @0@
+-- in this shape: we don't know if @f@ intended to return an array with shape 0
+-- here (it probably didn't), but there is no better number to put here absent
+-- a subarray of the input to pass to @f@.
+--
+-- In this particular case the fact that @sh@ is empty was evident from the
+-- type-level information, but the same situation occurs when @sh@ consists of
+-- @Nothing@s, and some of those happen to be zero at runtime.
rerank :: forall sh sh1 sh2 a b.
(Storable a, Storable b)
=> StaticShX sh -> StaticShX sh1 -> StaticShX sh2
-> (XArray sh1 a -> XArray sh2 b)
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b
-rerank ssh ssh1 ssh2 f (XArray arr)
- | Dict <- lemKnownNatRankSSX ssh
- , Dict <- lemKnownNatRankSSX ssh2
- , Refl <- lemRankApp ssh ssh1
- , Refl <- lemRankApp ssh ssh2
- , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- should be redundant but the solver is not clever enough
- = XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2)
- (\a -> unXArray (f (XArray a)))
- arr)
- where
- unXArray (XArray a) = a
+rerank ssh ssh1 ssh2 f xarr@(XArray arr)
+ | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)
+ = let (sh, _) = shAppSplit (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr)
+ in if sh == completeShXzeros ssh
+ then XArray (S.fromList (shapeLshape (shAppend sh (completeShXzeros ssh2))) [])
+ else case () of
+ () | Dict <- lemKnownNatRankSSX ssh
+ , Dict <- lemKnownNatRankSSX ssh2
+ , Refl <- lemRankApp ssh ssh1
+ , Refl <- lemRankApp ssh ssh2
+ -> XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2)
+ (\a -> let XArray r = f (XArray a) in r)
+ arr)
rerankTop :: forall sh1 sh2 sh a b.
(Storable a, Storable b)
@@ -614,22 +644,25 @@ rerankTop :: forall sh1 sh2 sh a b.
-> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b
rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh
+-- | The caveat about empty arrays at @rerank@ applies here too.
rerank2 :: forall sh sh1 sh2 a b c.
(Storable a, Storable b, Storable c)
=> StaticShX sh -> StaticShX sh1 -> StaticShX sh2
-> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c)
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c
-rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2)
- | Dict <- lemKnownNatRankSSX ssh
- , Dict <- lemKnownNatRankSSX ssh2
- , Refl <- lemRankApp ssh ssh1
- , Refl <- lemRankApp ssh ssh2
- , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- should be redundant but the solver is not clever enough
- = XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2)
- (\a b -> unXArray (f (XArray a) (XArray b)))
- arr1 arr2)
- where
- unXArray (XArray a) = a
+rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2)
+ | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)
+ = let (sh, _) = shAppSplit (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1)
+ in if sh == completeShXzeros ssh
+ then XArray (S.fromList (shapeLshape (shAppend sh (completeShXzeros ssh2))) [])
+ else case () of
+ () | Dict <- lemKnownNatRankSSX ssh
+ , Dict <- lemKnownNatRankSSX ssh2
+ , Refl <- lemRankApp ssh ssh1
+ , Refl <- lemRankApp ssh ssh2
+ -> XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2)
+ (\a b -> let XArray r = f (XArray a) (XArray b) in r)
+ arr1 arr2)
type family Elem x l where
Elem x '[] = 'False
@@ -829,7 +862,10 @@ fromListOuter ssh l
_ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l)))
toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a]
-toListOuter (XArray arr) = coerce (ORB.toList (S.unravel arr))
+toListOuter (XArray arr) =
+ case S.shapeL arr of
+ 0 : _ -> []
+ _ -> coerce (ORB.toList (S.unravel arr))
fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a
fromList1 ssh l =