aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-16 21:52:58 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-16 21:52:58 +0200
commit2ebf15f4085f633fc2f22c05684391aa9d1c4cd9 (patch)
tree079af7d9af282694882244fe8b0823f874e0a3f3
parentbad1902b2b3d8835cfe65700893c8ed8b560c893 (diff)
Convert arrays <-> XArray
-rw-r--r--src/Data/Array/Nested.hs6
-rw-r--r--src/Data/Array/Nested/Internal.hs36
2 files changed, 42 insertions, 0 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index f451920..4b455da 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -12,6 +12,8 @@ module Data.Array.Nested (
rslice, rrev1, rreshape,
-- ** Lifting orthotope operations to 'Ranked' arrays
rlift,
+ -- ** Conversions
+ rasXArrayPrim, rfromXArrayPrim,
-- * Shaped arrays
Shaped,
@@ -24,6 +26,8 @@ module Data.Array.Nested (
sslice, srev1, sreshape,
-- ** Lifting orthotope operations to 'Shaped' arrays
slift,
+ -- ** Conversions
+ sasXArrayPrim, sfromXArrayPrim,
-- * Mixed arrays
Mixed,
@@ -31,6 +35,8 @@ module Data.Array.Nested (
KnownShapeX(..), StaticShX(..),
mgenerate, mtranspose, mappend, mfromVector, mtoVector, munScalar,
mconstant, mfromList, mtoList, mslice, mrev1, mreshape,
+ -- ** Conversions
+ masXArrayPrim, mfromXArrayPrim,
-- * Array elements
Elt(mshape, mindex, mindexPartial, mscalar, mfromList1, mtoList1, mlift, mlift2),
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index de27336..65c5419 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -605,6 +605,18 @@ mreshape :: forall sh sh' a. (KnownShapeX sh, KnownShapeX sh', Elt a)
mreshape sh' = mlift $ \(_ :: Proxy shIn) ->
X.reshapePartial (knownShapeX @sh) (knownShapeX @shIn) sh'
+masXArrayPrimP :: Mixed sh (Primitive a) -> XArray sh a
+masXArrayPrimP (M_Primitive arr) = arr
+
+masXArrayPrim :: PrimElt a => Mixed sh a -> XArray sh a
+masXArrayPrim = masXArrayPrimP . toPrimitive
+
+mfromXArrayPrimP :: XArray sh a -> Mixed sh (Primitive a)
+mfromXArrayPrimP = M_Primitive
+
+mfromXArrayPrim :: PrimElt a => XArray sh a -> Mixed sh a
+mfromXArrayPrim = fromPrimitive . mfromXArrayPrimP
+
mliftPrim :: (KnownShapeX sh, Storable a)
=> (a -> a)
-> Mixed sh (Primitive a) -> Mixed sh (Primitive a)
@@ -1161,6 +1173,18 @@ rreshape sh' (Ranked arr)
, Dict <- lemKnownReplicate (Proxy @n')
= Ranked (mreshape (shCvtRX sh') arr)
+rasXArrayPrimP :: Ranked n (Primitive a) -> XArray (Replicate n Nothing) a
+rasXArrayPrimP (Ranked arr) = masXArrayPrimP arr
+
+rasXArrayPrim :: PrimElt a => Ranked n a -> XArray (Replicate n Nothing) a
+rasXArrayPrim (Ranked arr) = masXArrayPrim arr
+
+rfromXArrayPrimP :: XArray (Replicate n Nothing) a -> Ranked n (Primitive a)
+rfromXArrayPrimP = Ranked . mfromXArrayPrimP
+
+rfromXArrayPrim :: PrimElt a => XArray (Replicate n Nothing) a -> Ranked n a
+rfromXArrayPrim = Ranked . mfromXArrayPrim
+
-- ====== API OF SHAPED ARRAYS ====== --
@@ -1424,3 +1448,15 @@ sreshape sh' (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
, Dict <- lemKnownMapJust (Proxy @sh')
= Shaped (mreshape (shCvtSX sh') arr)
+
+sasXArrayPrimP :: Shaped sh (Primitive a) -> XArray (MapJust sh) a
+sasXArrayPrimP (Shaped arr) = masXArrayPrimP arr
+
+sasXArrayPrim :: PrimElt a => Shaped sh a -> XArray (MapJust sh) a
+sasXArrayPrim (Shaped arr) = masXArrayPrim arr
+
+sfromXArrayPrimP :: XArray (MapJust sh) a -> Shaped sh (Primitive a)
+sfromXArrayPrimP = Shaped . mfromXArrayPrimP
+
+sfromXArrayPrim :: PrimElt a => XArray (MapJust sh) a -> Shaped sh a
+sfromXArrayPrim = Shaped . mfromXArrayPrim