diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Data/Array/Nested/Convert.hs | 107 |
1 files changed, 58 insertions, 49 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index dd26f16..73055db 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -7,10 +7,14 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} module Data.Array.Nested.Convert ( + -- * Shape/index/list casting functions + ixrFromIxS, shrFromShS, + + -- * Array conversions castCastable, Castable(..), - -- * Special cases + -- * Special cases of array conversions -- -- | These functions can all be implemented using 'castCastable' in some way, -- but some have fewer constraints. @@ -18,8 +22,6 @@ module Data.Array.Nested.Convert ( stoMixed, scastToMixed, stoRanked, mcast, mcastToShaped, mtoRanked, - -- * Additional index/shape casting functions - ixrFromIxS, shrFromShS, ) where import Control.Category @@ -35,52 +37,7 @@ import Data.Array.Nested.Shaped.Base import Data.Array.Nested.Shaped.Shape import Data.Array.Nested.Types - -mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) - => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a -mcast ssh2 arr - | Refl <- lemAppNil @sh1 - , Refl <- lemAppNil @sh2 - = mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr - -mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a -mtoRanked = castCastable (CastXR CastId) - -rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a -rtoMixed (Ranked arr) = arr - --- | A more weakly-typed version of 'rtoMixed' that does a runtime shape --- compatibility check. -rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a -rcastToMixed sshx rarr@(Ranked arr) - | Refl <- lemRankReplicate (rrank rarr) - = mcast sshx arr - -mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') - => ShS sh' -> Mixed sh a -> Shaped sh' a -mcastToShaped targetsh = castCastable (CastXS' targetsh CastId) - -stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a -stoMixed (Shaped arr) = arr - --- | A more weakly-typed version of 'stoMixed' that does a runtime shape --- compatibility check. -scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') - => StaticShX sh' -> Shaped sh a -> Mixed sh' a -scastToMixed sshx sarr@(Shaped arr) - | Refl <- lemRankMapJust (sshape sarr) - = mcast sshx arr - -stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a -stoRanked sarr@(Shaped arr) - | Refl <- lemRankMapJust (sshape sarr) - = mtoRanked arr - -rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a -rcastToShaped (Ranked arr) targetsh - | Refl <- lemRankReplicate (shxRank (shxFromShS targetsh)) - , Refl <- lemRankMapJust targetsh - = mcastToShaped targetsh arr +-- * Shape/index/list casting functions ixrFromIxS :: IxS sh i -> IxR (Rank sh) i ixrFromIxS ZIS = ZIR @@ -97,6 +54,9 @@ shrFromShS :: ShS sh -> IShR (Rank sh) shrFromShS ZSS = ZSR shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh + +-- * Array conversions + -- | The constructors that perform runtime shape checking are marked with a -- @'@: 'CastXS'' and 'CastXX''. For the other constructors, the types ensure -- that the shapes are already compatible. To convert between 'Ranked' and @@ -169,3 +129,52 @@ castCastable = \c x -> munScalar (go c (mscalar x)) => Proxy esh -> Proxy sh -> Proxy sh' -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl + + +-- * Special cases of array conversions + +mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) + => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a +mcast ssh2 arr + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + = mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr + +mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a +mtoRanked = castCastable (CastXR CastId) + +rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a +rtoMixed (Ranked arr) = arr + +-- | A more weakly-typed version of 'rtoMixed' that does a runtime shape +-- compatibility check. +rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a +rcastToMixed sshx rarr@(Ranked arr) + | Refl <- lemRankReplicate (rrank rarr) + = mcast sshx arr + +mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') + => ShS sh' -> Mixed sh a -> Shaped sh' a +mcastToShaped targetsh = castCastable (CastXS' targetsh CastId) + +stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a +stoMixed (Shaped arr) = arr + +-- | A more weakly-typed version of 'stoMixed' that does a runtime shape +-- compatibility check. +scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') + => StaticShX sh' -> Shaped sh a -> Mixed sh' a +scastToMixed sshx sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + = mcast sshx arr + +stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a +stoRanked sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + = mtoRanked arr + +rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a +rcastToShaped (Ranked arr) targetsh + | Refl <- lemRankReplicate (shxRank (shxFromShS targetsh)) + , Refl <- lemRankMapJust targetsh + = mcastToShaped targetsh arr |