diff options
Diffstat (limited to 'src/Data/Array/Nested/Internal.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 37 |
1 files changed, 36 insertions, 1 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index ef2179c..3863556 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -28,7 +28,6 @@ {-| TODO: -* Write conversions between Mixed, Ranked and Shaped * Write `rerank` * Write `rconst :: OR.Array n a -> Ranked n a` @@ -163,6 +162,10 @@ lemTakeLenApp :: X.Rank l1 <= X.Rank l2 -> X.TakeLen l1 l2 :~: X.TakeLen l1 (l2 ++ rest) lemTakeLenApp _ _ _ = unsafeCoerce Refl +srankSh :: ShX sh f -> SNat (X.Rank sh) +srankSh ZSX = SNat +srankSh (_ :$% sh) | SNat <- srankSh sh = SNat + -- === NEW INDEX TYPES === -- @@ -953,6 +956,27 @@ instance (Storable a, Num a, PrimElt a) => Num (Mixed sh a) where signum = mliftPrim signum fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mconstant" +mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (X.Rank sh) a +mtoRanked arr + | Refl <- X.lemAppNil @sh + , Refl <- X.lemAppNil @(Replicate (X.Rank sh) (Nothing @Nat)) + , Refl <- lemRankReplicate (srankSh (mshape arr)) + = Ranked (mcast (X.staticShapeFrom (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr) + where + convSh :: IShX sh' -> IShX (Replicate (X.Rank sh') Nothing) + convSh ZSX = ZSX + convSh (smn :$% (sh :: IShX sh'T)) + | Refl <- X.lemReplicateSucc @(Nothing @Nat) @(X.Rank sh'T) + = SUnknown (fromSMayNat' smn) :$% convSh sh + +mcastToShaped :: forall sh sh' a. (Elt a, X.Rank sh ~ X.Rank sh') + => Mixed sh a -> ShS sh' -> Shaped sh' a +mcastToShaped arr targetsh + | Refl <- X.lemAppNil @sh + , Refl <- X.lemAppNil @(MapJust sh') + , Refl <- lemRankMapJust targetsh + = Shaped (mcast (X.staticShapeFrom (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr) + -- | A rank-typed array: the number of dimensions of the array (its /rank/) is -- represented on the type level as a 'Nat'. @@ -1408,6 +1432,12 @@ rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (X.staticShapeFrom (X.shape ( rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (X.staticShapeFrom (X.shape (ssxFromSNat sn) arr)) arr) +rcastToShaped :: Elt a => Ranked (X.Rank sh) a -> ShS sh -> Shaped sh a +rcastToShaped (Ranked arr) targetsh + | Refl <- lemRankReplicate (srankSh (shCvtSX targetsh)) + , Refl <- lemRankMapJust targetsh + = mcastToShaped arr targetsh + -- ====== API OF SHAPED ARRAYS ====== -- @@ -1617,3 +1647,8 @@ sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (X.staticShapeFrom (shCvtSX s sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (X.staticShapeFrom (shCvtSX sh)) arr) + +stoRanked :: Elt a => Shaped sh a -> Ranked (X.Rank sh) a +stoRanked sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + = mtoRanked arr |