aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-06-03 19:56:57 +0200
committerTom Smeding <tom@tomsmeding.com>2024-06-03 19:56:57 +0200
commitac061cf450b1c8e153de06f7b12256914c496788 (patch)
tree6774c5752674d749518986d575f64ce95728568f /src
parenta25d4061e219cec153f066fddbf710abd63b5e48 (diff)
rrank, rtoOrthotope
Diffstat (limited to 'src')
-rw-r--r--src/Data/Array/Mixed/Shape.hs3
-rw-r--r--src/Data/Array/Nested.hs4
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs8
3 files changed, 13 insertions, 2 deletions
diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs
index 4ab3c26..4343574 100644
--- a/src/Data/Array/Mixed/Shape.hs
+++ b/src/Data/Array/Mixed/Shape.hs
@@ -252,6 +252,9 @@ instance NFData i => NFData (ShX sh i) where
shxLength :: ShX sh i -> Int
shxLength (ShX l) = listxLength l
+shxLengthSNat :: ShX sh f -> SNat (Rank sh)
+shxLengthSNat (ShX list) = listxLengthSNat list
+
-- | This is more than @geq@: it also checks that the integers (the unknown
-- dimensions) are the same.
shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh')
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 1c9cebc..370cfc8 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -6,7 +6,7 @@ module Data.Array.Nested (
ListR(ZR, (:::)),
IxR(.., ZIR, (:.:)), IIxR,
ShR(.., ZSR, (:$:)), IShR,
- rshape, rindex, rindexPartial, rgenerate, rsumOuter1,
+ rshape, rrank, rindex, rindexPartial, rgenerate, rsumOuter1,
rtranspose, rappend, rscalar, rfromVector, rtoVector, runScalar,
rrerank,
rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1,
@@ -16,7 +16,7 @@ module Data.Array.Nested (
-- ** Conversions
rtoXArrayPrim, rfromXArrayPrim,
rcastToShaped,
- rfromOrthotope,
+ rfromOrthotope, rtoOrthotope,
-- * Shaped arrays
Shaped(Shaped),
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index d6eff31..894ed0d 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -215,6 +215,9 @@ instance (FloatElt a, NumElt a, PrimElt a) => Floating (Ranked n a) where
rshape :: forall n a. Elt a => Ranked n a -> IShR n
rshape (Ranked arr) = shCvtXR' (mshape arr)
+rrank :: Elt a => Ranked n a -> SNat n
+rrank = shrToSNat . rshape
+
rindex :: Elt a => Ranked n a -> IIxR n -> a
rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
@@ -331,6 +334,11 @@ rfromOrthotope sn arr
= let xarr = XArray arr
in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr))
+rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a
+rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr)))
+ | Refl <- lemRankReplicate (shrToSNat $ shCvtXR' sh)
+ = arr
+
runScalar :: Elt a => Ranked 0 a -> a
runScalar arr = rindex arr ZIR