diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-25 21:37:38 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-25 21:37:38 +0200 |
commit | 85593969debadbf11ad3c159de71e7b480ca367c (patch) | |
tree | aa0f05401409d5a6241d61ea76297444a81fd3d9 /src/Data | |
parent | 13433346340e4376d8bc286f2e883f57e3962314 (diff) |
Choose behaviour for rerank of empty array
This works around an undocumented runtime error in orthotope.
Diffstat (limited to 'src/Data')
-rw-r--r-- | src/Data/Array/Mixed.hs | 82 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 24 |
2 files changed, 83 insertions, 23 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 1e8cee2..cc74b90 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -590,22 +590,52 @@ append ssh (XArray a) (XArray b) | Dict <- lemKnownNatRankSSX ssh = XArray (S.append a b) +-- | If the prefix of the shape of the input array (@sh@) is empty (i.e. +-- contains a zero), then there is no way to deduce the full shape of the output +-- array (more precisely, the @sh2@ part): that could only come from calling +-- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in +-- this case; we choose to fill the shape with zeros wherever we cannot deduce +-- what it should be. +-- +-- For example, if: +-- +-- @ +-- arr :: XArray '[Just 3, Just 0, Just 4, Just 2, Nothing] Int -- of shape [3, 0, 4, 2, 21] +-- f :: XArray '[Just 2, Nothing] Int -> XArray '[Just 5, Nothing, Just 17] Float +-- @ +-- +-- then: +-- +-- @ +-- rerank _ _ _ f arr :: XArray '[Just 3, Just 0, Just 4, Just 5, Nothing, Just 17] Float +-- @ +-- +-- and this result will have shape @[3, 0, 4, 5, 0, 17]@. Note the second @0@ +-- in this shape: we don't know if @f@ intended to return an array with shape 0 +-- here (it probably didn't), but there is no better number to put here absent +-- a subarray of the input to pass to @f@. +-- +-- In this particular case the fact that @sh@ is empty was evident from the +-- type-level information, but the same situation occurs when @sh@ consists of +-- @Nothing@s, and some of those happen to be zero at runtime. rerank :: forall sh sh1 sh2 a b. (Storable a, Storable b) => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 -> (XArray sh1 a -> XArray sh2 b) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b -rerank ssh ssh1 ssh2 f (XArray arr) - | Dict <- lemKnownNatRankSSX ssh - , Dict <- lemKnownNatRankSSX ssh2 - , Refl <- lemRankApp ssh ssh1 - , Refl <- lemRankApp ssh ssh2 - , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- should be redundant but the solver is not clever enough - = XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2) - (\a -> unXArray (f (XArray a))) - arr) - where - unXArray (XArray a) = a +rerank ssh ssh1 ssh2 f xarr@(XArray arr) + | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) + = let (sh, _) = shAppSplit (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) + in if sh == completeShXzeros ssh + then XArray (S.fromList (shapeLshape (shAppend sh (completeShXzeros ssh2))) []) + else case () of + () | Dict <- lemKnownNatRankSSX ssh + , Dict <- lemKnownNatRankSSX ssh2 + , Refl <- lemRankApp ssh ssh1 + , Refl <- lemRankApp ssh ssh2 + -> XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2) + (\a -> let XArray r = f (XArray a) in r) + arr) rerankTop :: forall sh1 sh2 sh a b. (Storable a, Storable b) @@ -614,22 +644,25 @@ rerankTop :: forall sh1 sh2 sh a b. -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh +-- | The caveat about empty arrays at @rerank@ applies here too. rerank2 :: forall sh sh1 sh2 a b c. (Storable a, Storable b, Storable c) => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c -rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) - | Dict <- lemKnownNatRankSSX ssh - , Dict <- lemKnownNatRankSSX ssh2 - , Refl <- lemRankApp ssh ssh1 - , Refl <- lemRankApp ssh ssh2 - , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- should be redundant but the solver is not clever enough - = XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2) - (\a b -> unXArray (f (XArray a) (XArray b))) - arr1 arr2) - where - unXArray (XArray a) = a +rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2) + | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) + = let (sh, _) = shAppSplit (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) + in if sh == completeShXzeros ssh + then XArray (S.fromList (shapeLshape (shAppend sh (completeShXzeros ssh2))) []) + else case () of + () | Dict <- lemKnownNatRankSSX ssh + , Dict <- lemKnownNatRankSSX ssh2 + , Refl <- lemRankApp ssh ssh1 + , Refl <- lemRankApp ssh ssh2 + -> XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2) + (\a b -> let XArray r = f (XArray a) (XArray b) in r) + arr1 arr2) type family Elem x l where Elem x '[] = 'False @@ -829,7 +862,10 @@ fromListOuter ssh l _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a] -toListOuter (XArray arr) = coerce (ORB.toList (S.unravel arr)) +toListOuter (XArray arr) = + case S.shapeL arr of + 0 : _ -> [] + _ -> coerce (ORB.toList (S.unravel arr)) fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a fromList1 ssh l = diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index a440ccc..fb2ba0b 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -961,6 +961,7 @@ mrerankP ssh sh2 f (M_Primitive sh arr) = (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) arr) +-- | See the caveats at @X.rerank@. mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) => StaticShX sh -> IShX sh2 -> (Mixed sh1 a -> Mixed sh2 b) @@ -1582,6 +1583,29 @@ rrerankP sn sh2 f (Ranked arr) (\a -> let Ranked r = f (Ranked a) in r) arr) +-- | If there is a zero-sized dimension in the @n@-prefix of the shape of the +-- input array, then there is no way to deduce the full shape of the output +-- array (more precisely, the @n2@ part): that could only come from calling +-- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in +-- this case; we choose to fill the @n2@ part of the output shape with zeros. +-- +-- For example, if: +-- +-- @ +-- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21] +-- f :: Ranked 2 Int -> Ranked 3 Float +-- @ +-- +-- then: +-- +-- @ +-- rrerank _ _ _ f arr :: Ranked 5 Float +-- @ +-- +-- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the +-- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended +-- to return an array with shape all-0 here (it probably didn't), but there is +-- no better number to put here absent a subarray of the input to pass to @f@. rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b) => SNat n -> IShR n2 -> (Ranked n1 a -> Ranked n2 b) |