aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/Data/Array/Nested/Convert.hs18
1 files changed, 15 insertions, 3 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index 813155f..e9bc20e 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -32,9 +32,10 @@ rcastToShaped (Ranked arr) targetsh
, Refl <- lemRankMapJust targetsh
= mcastToShaped arr targetsh
--- | The only constructor that performs runtime shape checking is 'CastXS''.
--- For the other construtors, the types ensure that the shapes are already
--- compatible. To convert between 'Ranked' and 'Shaped', go via 'Mixed'.
+-- | The constructors that perform runtime shape checking are marked with a
+-- @'@: 'CastXS'' and 'CastXX''. For the other constructors, the types ensure
+-- that the shapes are already compatible. To convert between 'Ranked' and
+-- 'Shaped', go via 'Mixed'.
data Castable a b where
CastId :: Castable a a
CastCmp :: Castable b c -> Castable a b -> Castable a c
@@ -52,6 +53,9 @@ data Castable a b where
CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b)
CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b)
+ CastXX' :: (Rank sh ~ Rank sh', Elt b) => StaticShX sh'
+ -> Castable a b -> Castable (Mixed sh a) (Mixed sh' b)
+
instance Category Castable where
id = CastId
(.) = CastCmp
@@ -83,6 +87,14 @@ castCastable = \c x -> munScalar (go c (mscalar x))
go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x)
+ go (CastXX' @sh @sh' ssx c) (M_Nest @esh esh x)
+ | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Nest esh $ mcast (ssxFromShape esh `ssxAppend` ssx) (go c x)
+
+ lemRankAppRankEq :: Rank sh ~ Rank sh'
+ => Proxy esh -> Proxy sh -> Proxy sh'
+ -> Rank (esh ++ sh) :~: Rank (esh ++ sh')
+ lemRankAppRankEq _ _ _ = unsafeCoerceRefl
lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh
-> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing)