diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-05-15 23:22:37 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-05-15 23:24:11 +0200 | 
| commit | 5f1213fc9e464ec361e6543884968980dd28457d (patch) | |
| tree | 09d1684fc8ee64b3679c923142bf3184ed51056b | |
| parent | 707d4d015d8ecc0bda7f162e7f39b26556c54751 (diff) | |
Make mcast available in Castable
| -rw-r--r-- | src/Data/Array/Nested/Convert.hs | 18 | 
1 files changed, 15 insertions, 3 deletions
| diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index 813155f..e9bc20e 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -32,9 +32,10 @@ rcastToShaped (Ranked arr) targetsh    , Refl <- lemRankMapJust targetsh    = mcastToShaped arr targetsh --- | The only constructor that performs runtime shape checking is 'CastXS''. --- For the other construtors, the types ensure that the shapes are already --- compatible. To convert between 'Ranked' and 'Shaped', go via 'Mixed'. +-- | The constructors that perform runtime shape checking are marked with a +-- @'@: 'CastXS'' and 'CastXX''. For the other constructors, the types ensure +-- that the shapes are already compatible. To convert between 'Ranked' and +-- 'Shaped', go via 'Mixed'.  data Castable a b where    CastId  :: Castable a a    CastCmp :: Castable b c -> Castable a b -> Castable a c @@ -52,6 +53,9 @@ data Castable a b where    CastSS  :: Castable a b -> Castable (Shaped sh a)          (Shaped sh b)    CastXX  :: Castable a b -> Castable (Mixed sh a)           (Mixed sh b) +  CastXX' :: (Rank sh ~ Rank sh', Elt b) => StaticShX sh' +          -> Castable a b -> Castable (Mixed sh a)           (Mixed sh' b) +  instance Category Castable where    id = CastId    (.) = CastCmp @@ -83,6 +87,14 @@ castCastable = \c x -> munScalar (go c (mscalar x))      go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))      go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))      go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x) +    go (CastXX' @sh @sh' ssx c) (M_Nest @esh esh x) +      | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh') +      = M_Nest esh $ mcast (ssxFromShape esh `ssxAppend` ssx) (go c x) + +    lemRankAppRankEq :: Rank sh ~ Rank sh' +                     => Proxy esh -> Proxy sh -> Proxy sh' +                     -> Rank (esh ++ sh) :~: Rank (esh ++ sh') +    lemRankAppRankEq _ _ _ = unsafeCoerceRefl      lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh                            -> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing) | 
