aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-20 16:16:19 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-20 16:16:19 +0200
commit6d624d24871ee30e36c73c890c8f4a7cdae54c1c (patch)
treec4a222df5e336a28fa1af9daf655b62fc4cde49c /src
parent60c3927c4694f7c212f73498aee96a663e17c88c (diff)
Some conversions
Diffstat (limited to 'src')
-rw-r--r--src/Data/Array/Mixed.hs2
-rw-r--r--src/Data/Array/Nested.hs7
-rw-r--r--src/Data/Array/Nested/Internal.hs37
3 files changed, 42 insertions, 4 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index f62d781..b07f120 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -283,7 +283,7 @@ instance KnownShX sh => IsList (ShX sh Int) where
type family Rank sh where
Rank '[] = 0
- Rank (_ : sh) = 1 + Rank sh
+ Rank (_ : sh) = Rank sh + 1
type XArray :: [Maybe Nat] -> Type -> Type
newtype XArray sh a = XArray (S.Array (Rank sh) a)
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index f712301..438f144 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -2,7 +2,7 @@
{-# LANGUAGE PatternSynonyms #-}
module Data.Array.Nested (
-- * Ranked arrays
- Ranked,
+ Ranked(Ranked),
ListR(ZR, (:::)),
IxR(.., ZIR, (:.:)), IIxR,
ShR(.., ZSR, (:$:)),
@@ -14,9 +14,10 @@ module Data.Array.Nested (
rlift,
-- ** Conversions
rasXArrayPrim, rfromXArrayPrim,
+ rcastToShaped,
-- * Shaped arrays
- Shaped,
+ Shaped(Shaped),
ListS(ZS, (::$)),
IxS(.., ZIS, (:.$)), IIxS,
ShS(.., ZSS, (:$$)), KnownShS(..),
@@ -28,6 +29,7 @@ module Data.Array.Nested (
slift,
-- ** Conversions
sasXArrayPrim, sfromXArrayPrim,
+ stoRanked,
-- * Mixed arrays
Mixed,
@@ -37,6 +39,7 @@ module Data.Array.Nested (
mconstant, mfromList, mtoList, mslice, mrev1, mreshape,
-- ** Conversions
masXArrayPrim, mfromXArrayPrim,
+ mtoRanked, mcastToShaped,
-- * Array elements
Elt(mshape, mindex, mindexPartial, mscalar, mfromList1, mtoList1, mlift, mlift2),
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index ef2179c..3863556 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -28,7 +28,6 @@
{-|
TODO:
-* Write conversions between Mixed, Ranked and Shaped
* Write `rerank`
* Write `rconst :: OR.Array n a -> Ranked n a`
@@ -163,6 +162,10 @@ lemTakeLenApp :: X.Rank l1 <= X.Rank l2
-> X.TakeLen l1 l2 :~: X.TakeLen l1 (l2 ++ rest)
lemTakeLenApp _ _ _ = unsafeCoerce Refl
+srankSh :: ShX sh f -> SNat (X.Rank sh)
+srankSh ZSX = SNat
+srankSh (_ :$% sh) | SNat <- srankSh sh = SNat
+
-- === NEW INDEX TYPES === --
@@ -953,6 +956,27 @@ instance (Storable a, Num a, PrimElt a) => Num (Mixed sh a) where
signum = mliftPrim signum
fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mconstant"
+mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (X.Rank sh) a
+mtoRanked arr
+ | Refl <- X.lemAppNil @sh
+ , Refl <- X.lemAppNil @(Replicate (X.Rank sh) (Nothing @Nat))
+ , Refl <- lemRankReplicate (srankSh (mshape arr))
+ = Ranked (mcast (X.staticShapeFrom (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr)
+ where
+ convSh :: IShX sh' -> IShX (Replicate (X.Rank sh') Nothing)
+ convSh ZSX = ZSX
+ convSh (smn :$% (sh :: IShX sh'T))
+ | Refl <- X.lemReplicateSucc @(Nothing @Nat) @(X.Rank sh'T)
+ = SUnknown (fromSMayNat' smn) :$% convSh sh
+
+mcastToShaped :: forall sh sh' a. (Elt a, X.Rank sh ~ X.Rank sh')
+ => Mixed sh a -> ShS sh' -> Shaped sh' a
+mcastToShaped arr targetsh
+ | Refl <- X.lemAppNil @sh
+ , Refl <- X.lemAppNil @(MapJust sh')
+ , Refl <- lemRankMapJust targetsh
+ = Shaped (mcast (X.staticShapeFrom (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr)
+
-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
-- represented on the type level as a 'Nat'.
@@ -1408,6 +1432,12 @@ rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (X.staticShapeFrom (X.shape (
rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a
rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (X.staticShapeFrom (X.shape (ssxFromSNat sn) arr)) arr)
+rcastToShaped :: Elt a => Ranked (X.Rank sh) a -> ShS sh -> Shaped sh a
+rcastToShaped (Ranked arr) targetsh
+ | Refl <- lemRankReplicate (srankSh (shCvtSX targetsh))
+ , Refl <- lemRankMapJust targetsh
+ = mcastToShaped arr targetsh
+
-- ====== API OF SHAPED ARRAYS ====== --
@@ -1617,3 +1647,8 @@ sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (X.staticShapeFrom (shCvtSX s
sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a
sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (X.staticShapeFrom (shCvtSX sh)) arr)
+
+stoRanked :: Elt a => Shaped sh a -> Ranked (X.Rank sh) a
+stoRanked sarr@(Shaped arr)
+ | Refl <- lemRankMapJust (sshape sarr)
+ = mtoRanked arr