diff options
Diffstat (limited to 'src/Data/Array')
| -rw-r--r-- | src/Data/Array/Mixed.hs | 82 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 24 | 
2 files changed, 83 insertions, 23 deletions
| diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 1e8cee2..cc74b90 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -590,22 +590,52 @@ append ssh (XArray a) (XArray b)    | Dict <- lemKnownNatRankSSX ssh    = XArray (S.append a b) +-- | If the prefix of the shape of the input array (@sh@) is empty (i.e. +-- contains a zero), then there is no way to deduce the full shape of the output +-- array (more precisely, the @sh2@ part): that could only come from calling +-- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in +-- this case; we choose to fill the shape with zeros wherever we cannot deduce +-- what it should be. +-- +-- For example, if: +-- +-- @ +-- arr :: XArray '[Just 3, Just 0, Just 4, Just 2, Nothing] Int   -- of shape [3, 0, 4, 2, 21] +-- f :: XArray '[Just 2, Nothing] Int -> XArray '[Just 5, Nothing, Just 17] Float +-- @ +-- +-- then: +-- +-- @ +-- rerank _ _ _ f arr :: XArray '[Just 3, Just 0, Just 4, Just 5, Nothing, Just 17] Float +-- @ +-- +-- and this result will have shape @[3, 0, 4, 5, 0, 17]@. Note the second @0@ +-- in this shape: we don't know if @f@ intended to return an array with shape 0 +-- here (it probably didn't), but there is no better number to put here absent +-- a subarray of the input to pass to @f@. +-- +-- In this particular case the fact that @sh@ is empty was evident from the +-- type-level information, but the same situation occurs when @sh@ consists of +-- @Nothing@s, and some of those happen to be zero at runtime.  rerank :: forall sh sh1 sh2 a b.            (Storable a, Storable b)         => StaticShX sh -> StaticShX sh1 -> StaticShX sh2         -> (XArray sh1 a -> XArray sh2 b)         -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b -rerank ssh ssh1 ssh2 f (XArray arr) -  | Dict <- lemKnownNatRankSSX ssh -  , Dict <- lemKnownNatRankSSX ssh2 -  , Refl <- lemRankApp ssh ssh1 -  , Refl <- lemRankApp ssh ssh2 -  , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)      -- should be redundant but the solver is not clever enough -  = XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2) -                (\a -> unXArray (f (XArray a))) -                arr) -  where -    unXArray (XArray a) = a +rerank ssh ssh1 ssh2 f xarr@(XArray arr) +  | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) +  = let (sh, _) = shAppSplit (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) +    in if sh == completeShXzeros ssh +         then XArray (S.fromList (shapeLshape (shAppend sh (completeShXzeros ssh2))) []) +         else case () of +           () | Dict <- lemKnownNatRankSSX ssh +              , Dict <- lemKnownNatRankSSX ssh2 +              , Refl <- lemRankApp ssh ssh1 +              , Refl <- lemRankApp ssh ssh2 +              -> XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2) +                          (\a -> let XArray r = f (XArray a) in r) +                          arr)  rerankTop :: forall sh1 sh2 sh a b.               (Storable a, Storable b) @@ -614,22 +644,25 @@ rerankTop :: forall sh1 sh2 sh a b.            -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b  rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh +-- | The caveat about empty arrays at @rerank@ applies here too.  rerank2 :: forall sh sh1 sh2 a b c.             (Storable a, Storable b, Storable c)          => StaticShX sh -> StaticShX sh1 -> StaticShX sh2          -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c)          -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c -rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) -  | Dict <- lemKnownNatRankSSX ssh -  , Dict <- lemKnownNatRankSSX ssh2 -  , Refl <- lemRankApp ssh ssh1 -  , Refl <- lemRankApp ssh ssh2 -  , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)  -- should be redundant but the solver is not clever enough -  = XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2) -                (\a b -> unXArray (f (XArray a) (XArray b))) -                arr1 arr2) -  where -    unXArray (XArray a) = a +rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2) +  | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) +  = let (sh, _) = shAppSplit (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) +    in if sh == completeShXzeros ssh +         then XArray (S.fromList (shapeLshape (shAppend sh (completeShXzeros ssh2))) []) +         else case () of +           () | Dict <- lemKnownNatRankSSX ssh +              , Dict <- lemKnownNatRankSSX ssh2 +              , Refl <- lemRankApp ssh ssh1 +              , Refl <- lemRankApp ssh ssh2 +              -> XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2) +                          (\a b -> let XArray r = f (XArray a) (XArray b) in r) +                          arr1 arr2)  type family Elem x l where    Elem x '[] = 'False @@ -829,7 +862,10 @@ fromListOuter ssh l        _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l)))  toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a] -toListOuter (XArray arr) = coerce (ORB.toList (S.unravel arr)) +toListOuter (XArray arr) = +  case S.shapeL arr of +    0 : _ -> [] +    _ -> coerce (ORB.toList (S.unravel arr))  fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a  fromList1 ssh l = diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index a440ccc..fb2ba0b 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -961,6 +961,7 @@ mrerankP ssh sh2 f (M_Primitive sh arr) =                             (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)                             arr) +-- | See the caveats at @X.rerank@.  mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)          => StaticShX sh -> IShX sh2          -> (Mixed sh1 a -> Mixed sh2 b) @@ -1582,6 +1583,29 @@ rrerankP sn sh2 f (Ranked arr)                       (\a -> let Ranked r = f (Ranked a) in r)                       arr) +-- | If there is a zero-sized dimension in the @n@-prefix of the shape of the +-- input array, then there is no way to deduce the full shape of the output +-- array (more precisely, the @n2@ part): that could only come from calling +-- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in +-- this case; we choose to fill the @n2@ part of the output shape with zeros. +-- +-- For example, if: +-- +-- @ +-- arr :: Ranked 5 Int   -- of shape [3, 0, 4, 2, 21] +-- f :: Ranked 2 Int -> Ranked 3 Float +-- @ +-- +-- then: +-- +-- @ +-- rrerank _ _ _ f arr :: Ranked 5 Float +-- @ +-- +-- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the +-- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended +-- to return an array with shape all-0 here (it probably didn't), but there is +-- no better number to put here absent a subarray of the input to pass to @f@.  rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b)           => SNat n -> IShR n2           -> (Ranked n1 a -> Ranked n2 b) | 
