diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-06-03 20:08:32 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-06-03 20:08:32 +0200 |
commit | 7feaa162828e1ba1e6b73db7833ab94440eeb06b (patch) | |
tree | 8ff5b99a10bdf1ef308420ad8e6da98712063e75 /src/Data/Array | |
parent | 75ee1572b75b45dcdc50e3af82ed50259ca77df0 (diff) |
Add CastNest and CastUnnest
Diffstat (limited to 'src/Data/Array')
-rw-r--r-- | src/Data/Array/Nested/Convert.hs | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index fe590d1..659d13c 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -136,6 +136,10 @@ data Castable a b where => Castable a (Mixed '[] a) CastX0 :: Castable (Mixed '[] a) a + CastNest :: Elt a => StaticShX sh + -> Castable (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a)) + CastUnnest :: Castable (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a) + instance Category Castable where id = CastId (.) = CastCmp @@ -176,6 +180,12 @@ castCastable = \c x -> munScalar (go c (mscalar x)) go CastX0 (M_Nest @esh _ x) | Refl <- lemAppNil @esh = x + go (CastNest @_ @sh @sh' ssh) (M_Nest @esh esh x) + | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (mshape x) (ssxFromShX esh `ssxAppend` ssh)) x) + go (CastUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x)) + | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh x lemRankAppRankEq :: Rank sh ~ Rank sh' => Proxy esh -> Proxy sh -> Proxy sh' |