aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-06-03 20:08:32 +0200
committerTom Smeding <tom@tomsmeding.com>2025-06-03 20:08:32 +0200
commit7feaa162828e1ba1e6b73db7833ab94440eeb06b (patch)
tree8ff5b99a10bdf1ef308420ad8e6da98712063e75 /src/Data/Array
parent75ee1572b75b45dcdc50e3af82ed50259ca77df0 (diff)
Add CastNest and CastUnnest
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Nested/Convert.hs10
1 files changed, 10 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index fe590d1..659d13c 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -136,6 +136,10 @@ data Castable a b where
=> Castable a (Mixed '[] a)
CastX0 :: Castable (Mixed '[] a) a
+ CastNest :: Elt a => StaticShX sh
+ -> Castable (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a))
+ CastUnnest :: Castable (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a)
+
instance Category Castable where
id = CastId
(.) = CastCmp
@@ -176,6 +180,12 @@ castCastable = \c x -> munScalar (go c (mscalar x))
go CastX0 (M_Nest @esh _ x)
| Refl <- lemAppNil @esh
= x
+ go (CastNest @_ @sh @sh' ssh) (M_Nest @esh esh x)
+ | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (mshape x) (ssxFromShX esh `ssxAppend` ssh)) x)
+ go (CastUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x))
+ | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh x
lemRankAppRankEq :: Rank sh ~ Rank sh'
=> Proxy esh -> Proxy sh -> Proxy sh'