diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-06-09 21:06:13 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-06-09 21:06:13 +0200 |
commit | f70a381a05ec86767365b7d16b674ceff318d07d (patch) | |
tree | 09cbcf7877ca24df087978cdfb50175a80be5080 /src/Data/Array | |
parent | 5763bf70dc67c5437207ff8e9dd08585d2ea5384 (diff) |
nest, unNest
Diffstat (limited to 'src/Data/Array')
-rw-r--r-- | src/Data/Array/Nested.hs | 3 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 6 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 10 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 10 |
4 files changed, 29 insertions, 0 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 7cc1de3..3e4e1a0 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -11,6 +11,7 @@ module Data.Array.Nested ( rrerank, rreplicate, rreplicateScal, rfromListOuter, rfromList1, rfromList1Prim, rtoListOuter, rtoList1, rslice, rrev1, rreshape, riota, + rnest, runNest, -- ** Lifting orthotope operations to 'Ranked' arrays rlift, rlift2, -- ** Conversions @@ -29,6 +30,7 @@ module Data.Array.Nested ( srerank, sreplicate, sreplicateScal, sfromListOuter, sfromList1, sfromList1Prim, stoListOuter, stoList1, sslice, srev1, sreshape, siota, + snest, sunNest, -- ** Lifting orthotope operations to 'Shaped' arrays slift, slift2, -- ** Conversions @@ -44,6 +46,7 @@ module Data.Array.Nested ( mrerank, mreplicate, mreplicateScal, mfromListOuter, mfromList1, mfromList1Prim, mtoListOuter, mtoList1, mslice, mrev1, mreshape, miota, + mnest, munNest, -- ** Lifting orthotope operations to 'Mixed' arrays mlift, mlift2, -- ** Conversions diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index 6d601b8..b799190 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -686,6 +686,12 @@ mfromListPrimLinear sh l = munScalar :: Elt a => Mixed '[] a -> a munScalar arr = mindex arr ZIX +mnest :: forall sh sh' a. Elt a => StaticShX sh -> Mixed (sh ++ sh') a -> Mixed sh (Mixed sh' a) +mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr + +munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a +munNest (M_Nest _ arr) = arr + mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) => StaticShX sh -> IShX sh2 -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index 3e911ac..d6e05e6 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -349,6 +349,16 @@ rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr))) runScalar :: Elt a => Ranked 0 a -> a runScalar arr = rindex arr ZIR +rnest :: forall n m a. Elt a => SNat n -> Ranked (n + m) a -> Ranked n (Ranked m a) +rnest n arr + | Refl <- lemReplicatePlusApp n (Proxy @m) (Proxy @(Nothing @Nat)) + = coerce (mnest (ssxFromSNat n) (coerce arr)) + +runNest :: forall n m a. Elt a => Ranked n (Ranked m a) -> Ranked (n + m) a +runNest rarr@(Ranked (M_Ranked (M_Nest _ arr))) + | Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat)) + = Ranked arr + rrerankP :: forall n1 n2 n a b. (Storable a, Storable b) => SNat n -> IShR n2 -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b)) diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index 7d523b0..d1881c1 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -310,6 +310,16 @@ sfromListPrimLinear sh l = sunScalar :: Elt a => Shaped '[] a -> a sunScalar arr = sindex arr ZIS +snest :: forall sh sh' a. Elt a => ShS sh -> Shaped (sh ++ sh') a -> Shaped sh (Shaped sh' a) +snest sh arr + | Refl <- lemMapJustApp sh (Proxy @sh') + = coerce (mnest (ssxFromShape (shCvtSX sh)) (coerce arr)) + +sunNest :: forall sh sh' a. Elt a => Shaped sh (Shaped sh' a) -> Shaped (sh ++ sh') a +sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr))) + | Refl <- lemMapJustApp (sshape sarr) (Proxy @sh') + = Shaped arr + srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) => ShS sh -> ShS sh2 -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b)) |