diff options
Diffstat (limited to 'src/Data/Array/Nested/Convert.hs')
-rw-r--r-- | src/Data/Array/Nested/Convert.hs | 32 |
1 files changed, 26 insertions, 6 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index 01abae3..92bc3b4 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} @@ -16,6 +17,9 @@ module Data.Array.Nested.Convert ( rtoMixed, rcastToMixed, rcastToShaped, stoMixed, scastToMixed, stoRanked, mcast, mcastToShaped, mtoRanked, + + -- * Additional index/shape casting functions + ixrFromIxS, shrFromShS, ) where import Control.Category @@ -28,6 +32,7 @@ import Data.Array.Nested.Internal.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Ranked.Base +import Data.Array.Nested.Ranked.Shape import Data.Array.Nested.Shaped.Base import Data.Array.Nested.Shaped.Shape @@ -37,7 +42,7 @@ mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) mcast ssh2 arr | Refl <- lemAppNil @sh1 , Refl <- lemAppNil @sh2 - = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr + = mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a mtoRanked = castCastable (CastXR CastId) @@ -74,10 +79,25 @@ stoRanked sarr@(Shaped arr) rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a rcastToShaped (Ranked arr) targetsh - | Refl <- lemRankReplicate (shxRank (shCvtSX targetsh)) + | Refl <- lemRankReplicate (shxRank (shxFromShS targetsh)) , Refl <- lemRankMapJust targetsh = mcastToShaped targetsh arr +ixrFromIxS :: IIxS sh -> IIxR (Rank sh) +ixrFromIxS ZIS = ZIR +ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix + +-- ixsFromIxR :: IIxR (Rank sh) -> IIxS sh +-- ixsFromIxR = \ix -> go ix _ +-- where +-- go :: IIxR n -> (forall sh. KnownShS sh => IIxS sh -> r) -> r +-- go ZIR k = k ZIS +-- go (i :.: ix) k = go ix (i :.$) + +shrFromShS :: ShS sh -> IShR (Rank sh) +shrFromShS ZSS = ZSR +shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh + -- | The constructors that perform runtime shape checking are marked with a -- @'@: 'CastXS'' and 'CastXX''. For the other constructors, the types ensure -- that the shapes are already compatible. To convert between 'Ranked' and @@ -122,20 +142,20 @@ castCastable = \c x -> munScalar (go c (mscalar x)) go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh) = let x' = go c x - ssx' = ssxAppend (ssxFromShape esh) - (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x') (ssxFromShape esh)))) + ssx' = ssxAppend (ssxFromShX esh) + (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x') (ssxFromShX esh)))) in M_Ranked (M_Nest esh (mcast ssx' x')) go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x)) go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x) | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') - = M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh'))) + = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh'))) (go c x))) go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x) go (CastXX' @sh @sh' ssx c) (M_Nest @esh esh x) | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh') - = M_Nest esh $ mcast (ssxFromShape esh `ssxAppend` ssx) (go c x) + = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) (go c x) lemRankAppRankEq :: Rank sh ~ Rank sh' => Proxy esh -> Proxy sh -> Proxy sh' |