aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r--src/Data/Array/Nested/Internal.hs37
1 files changed, 36 insertions, 1 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index ef2179c..3863556 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -28,7 +28,6 @@
{-|
TODO:
-* Write conversions between Mixed, Ranked and Shaped
* Write `rerank`
* Write `rconst :: OR.Array n a -> Ranked n a`
@@ -163,6 +162,10 @@ lemTakeLenApp :: X.Rank l1 <= X.Rank l2
-> X.TakeLen l1 l2 :~: X.TakeLen l1 (l2 ++ rest)
lemTakeLenApp _ _ _ = unsafeCoerce Refl
+srankSh :: ShX sh f -> SNat (X.Rank sh)
+srankSh ZSX = SNat
+srankSh (_ :$% sh) | SNat <- srankSh sh = SNat
+
-- === NEW INDEX TYPES === --
@@ -953,6 +956,27 @@ instance (Storable a, Num a, PrimElt a) => Num (Mixed sh a) where
signum = mliftPrim signum
fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mconstant"
+mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (X.Rank sh) a
+mtoRanked arr
+ | Refl <- X.lemAppNil @sh
+ , Refl <- X.lemAppNil @(Replicate (X.Rank sh) (Nothing @Nat))
+ , Refl <- lemRankReplicate (srankSh (mshape arr))
+ = Ranked (mcast (X.staticShapeFrom (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr)
+ where
+ convSh :: IShX sh' -> IShX (Replicate (X.Rank sh') Nothing)
+ convSh ZSX = ZSX
+ convSh (smn :$% (sh :: IShX sh'T))
+ | Refl <- X.lemReplicateSucc @(Nothing @Nat) @(X.Rank sh'T)
+ = SUnknown (fromSMayNat' smn) :$% convSh sh
+
+mcastToShaped :: forall sh sh' a. (Elt a, X.Rank sh ~ X.Rank sh')
+ => Mixed sh a -> ShS sh' -> Shaped sh' a
+mcastToShaped arr targetsh
+ | Refl <- X.lemAppNil @sh
+ , Refl <- X.lemAppNil @(MapJust sh')
+ , Refl <- lemRankMapJust targetsh
+ = Shaped (mcast (X.staticShapeFrom (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr)
+
-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
-- represented on the type level as a 'Nat'.
@@ -1408,6 +1432,12 @@ rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (X.staticShapeFrom (X.shape (
rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a
rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (X.staticShapeFrom (X.shape (ssxFromSNat sn) arr)) arr)
+rcastToShaped :: Elt a => Ranked (X.Rank sh) a -> ShS sh -> Shaped sh a
+rcastToShaped (Ranked arr) targetsh
+ | Refl <- lemRankReplicate (srankSh (shCvtSX targetsh))
+ , Refl <- lemRankMapJust targetsh
+ = mcastToShaped arr targetsh
+
-- ====== API OF SHAPED ARRAYS ====== --
@@ -1617,3 +1647,8 @@ sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (X.staticShapeFrom (shCvtSX s
sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a
sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (X.staticShapeFrom (shCvtSX sh)) arr)
+
+stoRanked :: Elt a => Shaped sh a -> Ranked (X.Rank sh) a
+stoRanked sarr@(Shaped arr)
+ | Refl <- lemRankMapJust (sshape sarr)
+ = mtoRanked arr