diff options
Diffstat (limited to 'src/Data/Array/Nested/Convert.hs')
-rw-r--r-- | src/Data/Array/Nested/Convert.hs | 26 |
1 files changed, 17 insertions, 9 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index 639f5fd..813155f 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -42,7 +42,8 @@ data Castable a b where CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b) CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b) - CastXR :: Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b) + CastXR :: Elt b + => Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b) CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b) CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh' -> Castable a b -> Castable (Mixed sh a) (Shaped sh' b) @@ -68,19 +69,26 @@ castCastable = \c x -> munScalar (go c (mscalar x)) go (CastCmp c1 c2) x = go c1 (go c2 x) go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x) go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x) - go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) = - M_Ranked (M_Nest esh (mcastSafe @(MCastApp esh sh esh (Replicate (Rank sh) Nothing) MCastId MCastForget) Proxy - (go c x))) + go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) + | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh) + = let x' = go c x + ssx' = ssxAppend (ssxFromShape esh) + (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x') (ssxFromShape esh)))) + in M_Ranked (M_Nest esh (mcast ssx' x')) go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x)) go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x) - | Refl <- lemRankAppMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') + | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') = M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh'))) (go c x))) go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x) - lemRankAppMapJust :: Rank sh ~ Rank sh' - => Proxy esh -> Proxy sh -> Proxy sh' - -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') - lemRankAppMapJust _ _ _ = unsafeCoerceRefl + lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh + -> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing) + lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl + + lemRankAppRankEqMapJust :: Rank sh ~ Rank sh' + => Proxy esh -> Proxy sh -> Proxy sh' + -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') + lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl |