aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Nested/Convert.hs10
1 files changed, 10 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index fe590d1..659d13c 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -136,6 +136,10 @@ data Castable a b where
=> Castable a (Mixed '[] a)
CastX0 :: Castable (Mixed '[] a) a
+ CastNest :: Elt a => StaticShX sh
+ -> Castable (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a))
+ CastUnnest :: Castable (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a)
+
instance Category Castable where
id = CastId
(.) = CastCmp
@@ -176,6 +180,12 @@ castCastable = \c x -> munScalar (go c (mscalar x))
go CastX0 (M_Nest @esh _ x)
| Refl <- lemAppNil @esh
= x
+ go (CastNest @_ @sh @sh' ssh) (M_Nest @esh esh x)
+ | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (mshape x) (ssxFromShX esh `ssxAppend` ssh)) x)
+ go (CastUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x))
+ | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh x
lemRankAppRankEq :: Rank sh ~ Rank sh'
=> Proxy esh -> Proxy sh -> Proxy sh'