aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Convert.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Convert.hs')
-rw-r--r--src/Data/Array/Nested/Convert.hs26
1 files changed, 17 insertions, 9 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index 639f5fd..813155f 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -42,7 +42,8 @@ data Castable a b where
CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b)
CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b)
- CastXR :: Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b)
+ CastXR :: Elt b
+ => Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b)
CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b)
CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh'
-> Castable a b -> Castable (Mixed sh a) (Shaped sh' b)
@@ -68,19 +69,26 @@ castCastable = \c x -> munScalar (go c (mscalar x))
go (CastCmp c1 c2) x = go c1 (go c2 x)
go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x)
go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x)
- go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) =
- M_Ranked (M_Nest esh (mcastSafe @(MCastApp esh sh esh (Replicate (Rank sh) Nothing) MCastId MCastForget) Proxy
- (go c x)))
+ go (CastXR @_ @_ @sh c) (M_Nest @esh esh x)
+ | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh)
+ = let x' = go c x
+ ssx' = ssxAppend (ssxFromShape esh)
+ (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x') (ssxFromShape esh))))
+ in M_Ranked (M_Nest esh (mcast ssx' x'))
go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x))
go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x)
- | Refl <- lemRankAppMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')
= M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh')))
(go c x)))
go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x)
- lemRankAppMapJust :: Rank sh ~ Rank sh'
- => Proxy esh -> Proxy sh -> Proxy sh'
- -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh')
- lemRankAppMapJust _ _ _ = unsafeCoerceRefl
+ lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh
+ -> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing)
+ lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl
+
+ lemRankAppRankEqMapJust :: Rank sh ~ Rank sh'
+ => Proxy esh -> Proxy sh -> Proxy sh'
+ -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh')
+ lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl