diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Data/Array/Mixed.hs | 2 | ||||
-rw-r--r-- | src/Data/Array/Nested.hs | 7 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 37 |
3 files changed, 42 insertions, 4 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index f62d781..b07f120 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -283,7 +283,7 @@ instance KnownShX sh => IsList (ShX sh Int) where type family Rank sh where Rank '[] = 0 - Rank (_ : sh) = 1 + Rank sh + Rank (_ : sh) = Rank sh + 1 type XArray :: [Maybe Nat] -> Type -> Type newtype XArray sh a = XArray (S.Array (Rank sh) a) diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index f712301..438f144 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -2,7 +2,7 @@ {-# LANGUAGE PatternSynonyms #-} module Data.Array.Nested ( -- * Ranked arrays - Ranked, + Ranked(Ranked), ListR(ZR, (:::)), IxR(.., ZIR, (:.:)), IIxR, ShR(.., ZSR, (:$:)), @@ -14,9 +14,10 @@ module Data.Array.Nested ( rlift, -- ** Conversions rasXArrayPrim, rfromXArrayPrim, + rcastToShaped, -- * Shaped arrays - Shaped, + Shaped(Shaped), ListS(ZS, (::$)), IxS(.., ZIS, (:.$)), IIxS, ShS(.., ZSS, (:$$)), KnownShS(..), @@ -28,6 +29,7 @@ module Data.Array.Nested ( slift, -- ** Conversions sasXArrayPrim, sfromXArrayPrim, + stoRanked, -- * Mixed arrays Mixed, @@ -37,6 +39,7 @@ module Data.Array.Nested ( mconstant, mfromList, mtoList, mslice, mrev1, mreshape, -- ** Conversions masXArrayPrim, mfromXArrayPrim, + mtoRanked, mcastToShaped, -- * Array elements Elt(mshape, mindex, mindexPartial, mscalar, mfromList1, mtoList1, mlift, mlift2), diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index ef2179c..3863556 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -28,7 +28,6 @@ {-| TODO: -* Write conversions between Mixed, Ranked and Shaped * Write `rerank` * Write `rconst :: OR.Array n a -> Ranked n a` @@ -163,6 +162,10 @@ lemTakeLenApp :: X.Rank l1 <= X.Rank l2 -> X.TakeLen l1 l2 :~: X.TakeLen l1 (l2 ++ rest) lemTakeLenApp _ _ _ = unsafeCoerce Refl +srankSh :: ShX sh f -> SNat (X.Rank sh) +srankSh ZSX = SNat +srankSh (_ :$% sh) | SNat <- srankSh sh = SNat + -- === NEW INDEX TYPES === -- @@ -953,6 +956,27 @@ instance (Storable a, Num a, PrimElt a) => Num (Mixed sh a) where signum = mliftPrim signum fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mconstant" +mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (X.Rank sh) a +mtoRanked arr + | Refl <- X.lemAppNil @sh + , Refl <- X.lemAppNil @(Replicate (X.Rank sh) (Nothing @Nat)) + , Refl <- lemRankReplicate (srankSh (mshape arr)) + = Ranked (mcast (X.staticShapeFrom (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr) + where + convSh :: IShX sh' -> IShX (Replicate (X.Rank sh') Nothing) + convSh ZSX = ZSX + convSh (smn :$% (sh :: IShX sh'T)) + | Refl <- X.lemReplicateSucc @(Nothing @Nat) @(X.Rank sh'T) + = SUnknown (fromSMayNat' smn) :$% convSh sh + +mcastToShaped :: forall sh sh' a. (Elt a, X.Rank sh ~ X.Rank sh') + => Mixed sh a -> ShS sh' -> Shaped sh' a +mcastToShaped arr targetsh + | Refl <- X.lemAppNil @sh + , Refl <- X.lemAppNil @(MapJust sh') + , Refl <- lemRankMapJust targetsh + = Shaped (mcast (X.staticShapeFrom (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr) + -- | A rank-typed array: the number of dimensions of the array (its /rank/) is -- represented on the type level as a 'Nat'. @@ -1408,6 +1432,12 @@ rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (X.staticShapeFrom (X.shape ( rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (X.staticShapeFrom (X.shape (ssxFromSNat sn) arr)) arr) +rcastToShaped :: Elt a => Ranked (X.Rank sh) a -> ShS sh -> Shaped sh a +rcastToShaped (Ranked arr) targetsh + | Refl <- lemRankReplicate (srankSh (shCvtSX targetsh)) + , Refl <- lemRankMapJust targetsh + = mcastToShaped arr targetsh + -- ====== API OF SHAPED ARRAYS ====== -- @@ -1617,3 +1647,8 @@ sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (X.staticShapeFrom (shCvtSX s sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (X.staticShapeFrom (shCvtSX sh)) arr) + +stoRanked :: Elt a => Shaped sh a -> Ranked (X.Rank sh) a +stoRanked sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + = mtoRanked arr |