diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-05-17 10:17:30 +0200 | 
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-05-17 10:17:30 +0200 | 
| commit | 3361aa23c6a415adf50194d69680d7d2f519b512 (patch) | |
| tree | 1170229d6c4704d0a57512faaa0bbd34e05db1ff /src/Data/Array/Nested | |
| parent | 713d76559c42129afb24843af4386d18f1827727 (diff) | |
Move code around in Data.Array.Nested.Convert
Diffstat (limited to 'src/Data/Array/Nested')
| -rw-r--r-- | src/Data/Array/Nested/Convert.hs | 107 | 
1 files changed, 58 insertions, 49 deletions
| diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index dd26f16..73055db 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -7,10 +7,14 @@  {-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeOperators #-}  module Data.Array.Nested.Convert ( +  -- * Shape/index/list casting functions +  ixrFromIxS, shrFromShS, + +  -- * Array conversions    castCastable,    Castable(..), -  -- * Special cases +  -- * Special cases of array conversions    --    -- | These functions can all be implemented using 'castCastable' in some way,    -- but some have fewer constraints. @@ -18,8 +22,6 @@ module Data.Array.Nested.Convert (    stoMixed, scastToMixed, stoRanked,    mcast, mcastToShaped, mtoRanked, -  -- * Additional index/shape casting functions -  ixrFromIxS, shrFromShS,  ) where  import Control.Category @@ -35,52 +37,7 @@ import Data.Array.Nested.Shaped.Base  import Data.Array.Nested.Shaped.Shape  import Data.Array.Nested.Types - -mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) -      => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a -mcast ssh2 arr -  | Refl <- lemAppNil @sh1 -  , Refl <- lemAppNil @sh2 -  = mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr - -mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a -mtoRanked = castCastable (CastXR CastId) - -rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a -rtoMixed (Ranked arr) = arr - --- | A more weakly-typed version of 'rtoMixed' that does a runtime shape --- compatibility check. -rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a -rcastToMixed sshx rarr@(Ranked arr) -  | Refl <- lemRankReplicate (rrank rarr) -  = mcast sshx arr - -mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') -              => ShS sh' -> Mixed sh a -> Shaped sh' a -mcastToShaped targetsh = castCastable (CastXS' targetsh CastId) - -stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a -stoMixed (Shaped arr) = arr - --- | A more weakly-typed version of 'stoMixed' that does a runtime shape --- compatibility check. -scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') -             => StaticShX sh' -> Shaped sh a -> Mixed sh' a -scastToMixed sshx sarr@(Shaped arr) -  | Refl <- lemRankMapJust (sshape sarr) -  = mcast sshx arr - -stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a -stoRanked sarr@(Shaped arr) -  | Refl <- lemRankMapJust (sshape sarr) -  = mtoRanked arr - -rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a -rcastToShaped (Ranked arr) targetsh -  | Refl <- lemRankReplicate (shxRank (shxFromShS targetsh)) -  , Refl <- lemRankMapJust targetsh -  = mcastToShaped targetsh arr +-- * Shape/index/list casting functions  ixrFromIxS :: IxS sh i -> IxR (Rank sh) i  ixrFromIxS ZIS = ZIR @@ -97,6 +54,9 @@ shrFromShS :: ShS sh -> IShR (Rank sh)  shrFromShS ZSS = ZSR  shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh + +-- * Array conversions +  -- | The constructors that perform runtime shape checking are marked with a  -- @'@: 'CastXS'' and 'CastXX''. For the other constructors, the types ensure  -- that the shapes are already compatible. To convert between 'Ranked' and @@ -169,3 +129,52 @@ castCastable = \c x -> munScalar (go c (mscalar x))                              => Proxy esh -> Proxy sh -> Proxy sh'                              -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh')      lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl + + +-- * Special cases of array conversions + +mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) +      => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a +mcast ssh2 arr +  | Refl <- lemAppNil @sh1 +  , Refl <- lemAppNil @sh2 +  = mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr + +mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a +mtoRanked = castCastable (CastXR CastId) + +rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a +rtoMixed (Ranked arr) = arr + +-- | A more weakly-typed version of 'rtoMixed' that does a runtime shape +-- compatibility check. +rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a +rcastToMixed sshx rarr@(Ranked arr) +  | Refl <- lemRankReplicate (rrank rarr) +  = mcast sshx arr + +mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') +              => ShS sh' -> Mixed sh a -> Shaped sh' a +mcastToShaped targetsh = castCastable (CastXS' targetsh CastId) + +stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a +stoMixed (Shaped arr) = arr + +-- | A more weakly-typed version of 'stoMixed' that does a runtime shape +-- compatibility check. +scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') +             => StaticShX sh' -> Shaped sh a -> Mixed sh' a +scastToMixed sshx sarr@(Shaped arr) +  | Refl <- lemRankMapJust (sshape sarr) +  = mcast sshx arr + +stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a +stoRanked sarr@(Shaped arr) +  | Refl <- lemRankMapJust (sshape sarr) +  = mtoRanked arr + +rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a +rcastToShaped (Ranked arr) targetsh +  | Refl <- lemRankReplicate (shxRank (shxFromShS targetsh)) +  , Refl <- lemRankMapJust targetsh +  = mcastToShaped targetsh arr | 
