diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Data/Array/Nested/Convert.hs | 28 |
1 files changed, 14 insertions, 14 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index cdd2b6d..17ccc4d 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -5,12 +5,22 @@ {-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -module Data.Array.Nested.Convert where +module Data.Array.Nested.Convert ( + castCastable, + Castable(..), + + -- * Special cases + -- + -- | These functions can all be implemented using 'castCastable' in some way, + -- but some have fewer constraints. + rtoMixed, rcastToMixed, rcastToShaped, + stoMixed, scastToMixed, stoRanked, + mcast, mcastToShaped, mtoRanked, +) where import Control.Category import Data.Proxy import Data.Type.Equality -import GHC.TypeLits (Nat) import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Types @@ -30,15 +40,7 @@ mcast ssh2 arr = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a -mtoRanked arr - | Refl <- lemRankReplicate (shxRank (mshape arr)) - = Ranked (mcast (ssxFromShape (convSh (mshape arr))) arr) - where - convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing) - convSh ZSX = ZSX - convSh (smn :$% (sh :: IShX sh'T)) - | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T) - = SUnknown (fromSMayNat' smn) :$% convSh sh +mtoRanked = castCastable (CastXR CastId) rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a rtoMixed (Ranked arr) = arr @@ -52,9 +54,7 @@ rcastToMixed sshx rarr@(Ranked arr) mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') => Mixed sh a -> ShS sh' -> Shaped sh' a -mcastToShaped arr targetsh - | Refl <- lemRankMapJust targetsh - = Shaped (mcast (ssxFromShape (shCvtSX targetsh)) arr) +mcastToShaped arr targetsh = castCastable (CastXS' targetsh CastId) arr stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a stoMixed (Shaped arr) = arr |