diff options
Diffstat (limited to 'src/Data/Array/Nested/Convert.hs')
-rw-r--r-- | src/Data/Array/Nested/Convert.hs | 149 |
1 files changed, 74 insertions, 75 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index d07bab9..723e965 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -14,12 +14,12 @@ module Data.Array.Nested.Convert ( ixxFromIxR, ixxFromIxS, shxFromShR, shxFromShS, -- * Array conversions - castCastable, - Castable(..), + convert, + Conversion(..), -- * Special cases of array conversions -- - -- | These functions can all be implemented using 'castCastable' in some way, + -- | These functions can all be implemented using 'convert' in some way, -- but some have fewer constraints. rtoMixed, rcastToMixed, rcastToShaped, stoMixed, scastToMixed, stoRanked, @@ -102,108 +102,107 @@ shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh -- * Array conversions -- | The constructors that perform runtime shape checking are marked with a --- @'@: 'CastXS'' and 'CastXX''. For the other constructors, the types ensure +-- @'@: 'ConvXS'' and 'ConvXX''. For the other constructors, the types ensure -- that the shapes are already compatible. To convert between 'Ranked' and -- 'Shaped', go via 'Mixed'. -- --- The guiding principle behind 'Castable' is that it should represent the +-- The guiding principle behind 'Conversion' is that it should represent the -- array restructurings, or perhaps re-presentations, that do not change the -- underlying 'XArray's. This leads to the inclusion of some operations that do --- not look like a cast at first glance, like 'CastZip'; with the underlying --- representation in mind, however, they are very much like a cast. -data Castable a b where - CastId :: Castable a a - CastCmp :: Castable b c -> Castable a b -> Castable a c - - CastRX :: Castable (Ranked n a) (Mixed (Replicate n Nothing) a) - CastSX :: Castable (Shaped sh a) (Mixed (MapJust sh) a) - - CastXR :: Elt a - => Castable (Mixed sh a) (Ranked (Rank sh) a) - CastXS :: Castable (Mixed (MapJust sh) a) (Shaped sh a) - CastXS' :: (Rank sh ~ Rank sh', Elt a) +-- not look like simple conversions (casts) at first glance, like 'ConvZip'. +data Conversion a b where + ConvId :: Conversion a a + ConvCmp :: Conversion b c -> Conversion a b -> Conversion a c + + ConvRX :: Conversion (Ranked n a) (Mixed (Replicate n Nothing) a) + ConvSX :: Conversion (Shaped sh a) (Mixed (MapJust sh) a) + + ConvXR :: Elt a + => Conversion (Mixed sh a) (Ranked (Rank sh) a) + ConvXS :: Conversion (Mixed (MapJust sh) a) (Shaped sh a) + ConvXS' :: (Rank sh ~ Rank sh', Elt a) => ShS sh' - -> Castable (Mixed sh a) (Shaped sh' a) + -> Conversion (Mixed sh a) (Shaped sh' a) - CastXX' :: (Rank sh ~ Rank sh', Elt a) + ConvXX' :: (Rank sh ~ Rank sh', Elt a) => StaticShX sh' - -> Castable (Mixed sh a) (Mixed sh' a) - - CastRR :: Castable a b - -> Castable (Ranked n a) (Ranked n b) - CastSS :: Castable a b - -> Castable (Shaped sh a) (Shaped sh b) - CastXX :: Castable a b - -> Castable (Mixed sh a) (Mixed sh b) - CastT2 :: Castable a a' - -> Castable b b' - -> Castable (a, b) (a', b') - - Cast0X :: Elt a - => Castable a (Mixed '[] a) - CastX0 :: Castable (Mixed '[] a) a - - CastNest :: Elt a => StaticShX sh - -> Castable (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a)) - CastUnnest :: Castable (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a) - - CastZip :: (Elt a, Elt b) - => Castable (Mixed sh a, Mixed sh b) (Mixed sh (a, b)) - CastUnzip :: (Elt a, Elt b) - => Castable (Mixed sh (a, b)) (Mixed sh a, Mixed sh b) -deriving instance Show (Castable a b) - -instance Category Castable where - id = CastId - (.) = CastCmp - -castCastable :: (Elt a, Elt b) => Castable a b -> a -> b -castCastable = \c x -> munScalar (go c (mscalar x)) + -> Conversion (Mixed sh a) (Mixed sh' a) + + ConvRR :: Conversion a b + -> Conversion (Ranked n a) (Ranked n b) + ConvSS :: Conversion a b + -> Conversion (Shaped sh a) (Shaped sh b) + ConvXX :: Conversion a b + -> Conversion (Mixed sh a) (Mixed sh b) + ConvT2 :: Conversion a a' + -> Conversion b b' + -> Conversion (a, b) (a', b') + + Conv0X :: Elt a + => Conversion a (Mixed '[] a) + ConvX0 :: Conversion (Mixed '[] a) a + + ConvNest :: Elt a => StaticShX sh + -> Conversion (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a)) + ConvUnnest :: Conversion (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a) + + ConvZip :: (Elt a, Elt b) + => Conversion (Mixed sh a, Mixed sh b) (Mixed sh (a, b)) + ConvUnzip :: (Elt a, Elt b) + => Conversion (Mixed sh (a, b)) (Mixed sh a, Mixed sh b) +deriving instance Show (Conversion a b) + +instance Category Conversion where + id = ConvId + (.) = ConvCmp + +convert :: (Elt a, Elt b) => Conversion a b -> a -> b +convert = \c x -> munScalar (go c (mscalar x)) where - -- The 'esh' is the extension shape: the casting happens under a whole + -- The 'esh' is the extension shape: the conversion happens under a whole -- bunch of additional dimensions that it does not touch. These dimensions -- are 'esh'. -- The strategy is to unwind step-by-step to a large Mixed array, and to - -- perform the required checks and castings when re-nesting back up. - go :: Castable a b -> Mixed esh a -> Mixed esh b - go CastId x = x - go (CastCmp c1 c2) x = go c1 (go c2 x) - go CastRX (M_Ranked x) = x - go CastSX (M_Shaped x) = x - go (CastXR @_ @sh) (M_Nest @esh esh x) + -- perform the required checks and conversions when re-nesting back up. + go :: Conversion a b -> Mixed esh a -> Mixed esh b + go ConvId x = x + go (ConvCmp c1 c2) x = go c1 (go c2 x) + go ConvRX (M_Ranked x) = x + go ConvSX (M_Shaped x) = x + go (ConvXR @_ @sh) (M_Nest @esh esh x) | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh) = let ssx' = ssxAppend (ssxFromShX esh) (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x) (ssxFromShX esh)))) in M_Ranked (M_Nest esh (mcast ssx' x)) - go CastXS (M_Nest esh x) = M_Shaped (M_Nest esh x) - go (CastXS' @sh @sh' sh') (M_Nest @esh esh x) + go ConvXS (M_Nest esh x) = M_Shaped (M_Nest esh x) + go (ConvXS' @sh @sh' sh') (M_Nest @esh esh x) | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh'))) x)) - go (CastXX' @sh @sh' ssx) (M_Nest @esh esh x) + go (ConvXX' @sh @sh' ssx) (M_Nest @esh esh x) | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh') = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) x - go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) - go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) - go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x) - go (CastT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2) - go Cast0X (x :: Mixed esh a) + go (ConvRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) + go (ConvSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) + go (ConvXX c) (M_Nest esh x) = M_Nest esh (go c x) + go (ConvT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2) + go Conv0X (x :: Mixed esh a) | Refl <- lemAppNil @esh = M_Nest (mshape x) x - go CastX0 (M_Nest @esh _ x) + go ConvX0 (M_Nest @esh _ x) | Refl <- lemAppNil @esh = x - go (CastNest @_ @sh @sh' ssh) (M_Nest @esh esh x) + go (ConvNest @_ @sh @sh' ssh) (M_Nest @esh esh x) | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') = M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (mshape x) (ssxFromShX esh `ssxAppend` ssh)) x) - go (CastUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x)) + go (ConvUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x)) | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') = M_Nest esh x - go CastZip x = + go ConvZip x = -- no need to check that the two esh's are equal because they were zipped previously let (M_Nest esh x1, M_Nest _ x2) = munzip x in M_Nest esh (mzip x1 x2) - go CastUnzip (M_Nest esh x) = + go ConvUnzip (M_Nest esh x) = let (x1, x2) = munzip x in mzip (M_Nest esh x1) (M_Nest esh x2) @@ -232,7 +231,7 @@ mcast ssh2 arr = mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a -mtoRanked = castCastable CastXR +mtoRanked = convert ConvXR rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a rtoMixed (Ranked arr) = arr @@ -246,7 +245,7 @@ rcastToMixed sshx rarr@(Ranked arr) mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') => ShS sh' -> Mixed sh a -> Shaped sh' a -mcastToShaped targetsh = castCastable (CastXS' targetsh) +mcastToShaped targetsh = convert (ConvXS' targetsh) stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a stoMixed (Shaped arr) = arr |