aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Convert.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Convert.hs')
-rw-r--r--src/Data/Array/Nested/Convert.hs50
1 files changed, 48 insertions, 2 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index e9bc20e..cdd2b6d 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -10,17 +10,63 @@ module Data.Array.Nested.Convert where
import Control.Category
import Data.Proxy
import Data.Type.Equality
+import GHC.TypeLits (Nat)
import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Types
import Data.Array.Nested.Internal.Lemmas
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
-import Data.Array.Nested.Ranked
-import Data.Array.Nested.Shaped
+import Data.Array.Nested.Ranked.Base
+import Data.Array.Nested.Shaped.Base
import Data.Array.Nested.Shaped.Shape
+mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a)
+ => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a
+mcast ssh2 arr
+ | Refl <- lemAppNil @sh1
+ , Refl <- lemAppNil @sh2
+ = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr
+
+mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a
+mtoRanked arr
+ | Refl <- lemRankReplicate (shxRank (mshape arr))
+ = Ranked (mcast (ssxFromShape (convSh (mshape arr))) arr)
+ where
+ convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing)
+ convSh ZSX = ZSX
+ convSh (smn :$% (sh :: IShX sh'T))
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T)
+ = SUnknown (fromSMayNat' smn) :$% convSh sh
+
+rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a
+rtoMixed (Ranked arr) = arr
+
+-- | A more weakly-typed version of 'rtoMixed' that does a runtime shape
+-- compatibility check.
+rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a
+rcastToMixed sshx rarr@(Ranked arr)
+ | Refl <- lemRankReplicate (rrank rarr)
+ = mcast sshx arr
+
+mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
+ => Mixed sh a -> ShS sh' -> Shaped sh' a
+mcastToShaped arr targetsh
+ | Refl <- lemRankMapJust targetsh
+ = Shaped (mcast (ssxFromShape (shCvtSX targetsh)) arr)
+
+stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a
+stoMixed (Shaped arr) = arr
+
+-- | A more weakly-typed version of 'stoMixed' that does a runtime shape
+-- compatibility check.
+scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
+ => StaticShX sh' -> Shaped sh a -> Mixed sh' a
+scastToMixed sshx sarr@(Shaped arr)
+ | Refl <- lemRankMapJust (sshape sarr)
+ = mcast sshx arr
+
stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a
stoRanked sarr@(Shaped arr)
| Refl <- lemRankMapJust (sshape sarr)