aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-06-09 21:06:13 +0200
committerTom Smeding <tom@tomsmeding.com>2024-06-09 21:06:13 +0200
commitf70a381a05ec86767365b7d16b674ceff318d07d (patch)
tree09cbcf7877ca24df087978cdfb50175a80be5080 /src/Data
parent5763bf70dc67c5437207ff8e9dd08585d2ea5384 (diff)
nest, unNest
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/Nested.hs3
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs6
-rw-r--r--src/Data/Array/Nested/Internal/Ranked.hs10
-rw-r--r--src/Data/Array/Nested/Internal/Shaped.hs10
4 files changed, 29 insertions, 0 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 7cc1de3..3e4e1a0 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -11,6 +11,7 @@ module Data.Array.Nested (
rrerank,
rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1,
rslice, rrev1, rreshape, riota,
+ rnest, runNest,
-- ** Lifting orthotope operations to 'Ranked' arrays
rlift, rlift2,
-- ** Conversions
@@ -29,6 +30,7 @@ module Data.Array.Nested (
srerank,
sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1,
sslice, srev1, sreshape, siota,
+ snest, sunNest,
-- ** Lifting orthotope operations to 'Shaped' arrays
slift, slift2,
-- ** Conversions
@@ -44,6 +46,7 @@ module Data.Array.Nested (
mrerank,
mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1,
mslice, mrev1, mreshape, miota,
+ mnest, munNest,
-- ** Lifting orthotope operations to 'Mixed' arrays
mlift, mlift2,
-- ** Conversions
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index 6d601b8..b799190 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -686,6 +686,12 @@ mfromListPrimLinear sh l =
munScalar :: Elt a => Mixed '[] a -> a
munScalar arr = mindex arr ZIX
+mnest :: forall sh sh' a. Elt a => StaticShX sh -> Mixed (sh ++ sh') a -> Mixed sh (Mixed sh' a)
+mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr
+
+munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a
+munNest (M_Nest _ arr) = arr
+
mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
=> StaticShX sh -> IShX sh2
-> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b))
diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs
index 3e911ac..d6e05e6 100644
--- a/src/Data/Array/Nested/Internal/Ranked.hs
+++ b/src/Data/Array/Nested/Internal/Ranked.hs
@@ -349,6 +349,16 @@ rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr)))
runScalar :: Elt a => Ranked 0 a -> a
runScalar arr = rindex arr ZIR
+rnest :: forall n m a. Elt a => SNat n -> Ranked (n + m) a -> Ranked n (Ranked m a)
+rnest n arr
+ | Refl <- lemReplicatePlusApp n (Proxy @m) (Proxy @(Nothing @Nat))
+ = coerce (mnest (ssxFromSNat n) (coerce arr))
+
+runNest :: forall n m a. Elt a => Ranked n (Ranked m a) -> Ranked (n + m) a
+runNest rarr@(Ranked (M_Ranked (M_Nest _ arr)))
+ | Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat))
+ = Ranked arr
+
rrerankP :: forall n1 n2 n a b. (Storable a, Storable b)
=> SNat n -> IShR n2
-> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b))
diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs
index 7d523b0..d1881c1 100644
--- a/src/Data/Array/Nested/Internal/Shaped.hs
+++ b/src/Data/Array/Nested/Internal/Shaped.hs
@@ -310,6 +310,16 @@ sfromListPrimLinear sh l =
sunScalar :: Elt a => Shaped '[] a -> a
sunScalar arr = sindex arr ZIS
+snest :: forall sh sh' a. Elt a => ShS sh -> Shaped (sh ++ sh') a -> Shaped sh (Shaped sh' a)
+snest sh arr
+ | Refl <- lemMapJustApp sh (Proxy @sh')
+ = coerce (mnest (ssxFromShape (shCvtSX sh)) (coerce arr))
+
+sunNest :: forall sh sh' a. Elt a => Shaped sh (Shaped sh' a) -> Shaped (sh ++ sh') a
+sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr)))
+ | Refl <- lemMapJustApp (sshape sarr) (Proxy @sh')
+ = Shaped arr
+
srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
=> ShS sh -> ShS sh2
-> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b))