diff options
| -rw-r--r-- | src/Data/Array/Nested/Convert.hs | 72 | 
1 files changed, 45 insertions, 27 deletions
| diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index b3a2c63..fe590d1 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -108,21 +108,33 @@ data Castable a b where    CastId  :: Castable a a    CastCmp :: Castable b c -> Castable a b -> Castable a c -  CastRX  :: Castable a b -> Castable (Ranked n a)           (Mixed (Replicate n Nothing) b) -  CastSX  :: Castable a b -> Castable (Shaped sh a)          (Mixed (MapJust sh) b) +  CastRX  :: Castable (Ranked n a) (Mixed (Replicate n Nothing) a) +  CastSX  :: Castable (Shaped sh a) (Mixed (MapJust sh) a) -  CastXR  :: Elt b -          => Castable a b -> Castable (Mixed sh a)           (Ranked (Rank sh) b) -  CastXS  :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b) -  CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh' -          -> Castable a b -> Castable (Mixed sh a)           (Shaped sh' b) +  CastXR  :: Elt a +          => Castable (Mixed sh a) (Ranked (Rank sh) a) +  CastXS  :: Castable (Mixed (MapJust sh) a) (Shaped sh a) +  CastXS' :: (Rank sh ~ Rank sh', Elt a) +          => ShS sh' +          -> Castable (Mixed sh a) (Shaped sh' a) -  CastRR  :: Castable a b -> Castable (Ranked n a)           (Ranked n b) -  CastSS  :: Castable a b -> Castable (Shaped sh a)          (Shaped sh b) -  CastXX  :: Castable a b -> Castable (Mixed sh a)           (Mixed sh b) +  CastXX' :: (Rank sh ~ Rank sh', Elt a) +          => StaticShX sh' +          -> Castable (Mixed sh a) (Mixed sh' a) -  CastXX' :: (Rank sh ~ Rank sh', Elt b) => StaticShX sh' -          -> Castable a b -> Castable (Mixed sh a)           (Mixed sh' b) +  CastRR  :: Castable a b +          -> Castable (Ranked n a) (Ranked n b) +  CastSS  :: Castable a b +          -> Castable (Shaped sh a) (Shaped sh b) +  CastXX  :: Castable a b +          -> Castable (Mixed sh a) (Mixed sh b) +  CastT2  :: Castable a a' +          -> Castable b b' +          -> Castable (a, b) (a', b') + +  Cast0X  :: Elt a +          => Castable a (Mixed '[] a) +  CastX0  :: Castable (Mixed '[] a) a  instance Category Castable where    id = CastId @@ -139,25 +151,31 @@ castCastable = \c x -> munScalar (go c (mscalar x))      go :: Castable a b -> Mixed esh a -> Mixed esh b      go CastId x = x      go (CastCmp c1 c2) x = go c1 (go c2 x) -    go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x) -    go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x) -    go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) +    go CastRX (M_Ranked x) = x +    go CastSX (M_Shaped x) = x +    go (CastXR @_ @sh) (M_Nest @esh esh x)        | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh) -      = let x' = go c x -            ssx' = ssxAppend (ssxFromShX esh) -                             (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x') (ssxFromShX esh)))) -        in M_Ranked (M_Nest esh (mcast ssx' x')) -    go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x)) -    go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x) +      = let ssx' = ssxAppend (ssxFromShX esh) +                             (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x) (ssxFromShX esh)))) +        in M_Ranked (M_Nest esh (mcast ssx' x)) +    go CastXS (M_Nest esh x) = M_Shaped (M_Nest esh x) +    go (CastXS' @sh @sh' sh') (M_Nest @esh esh x)        | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')        = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh'))) -                                    (go c x))) +                                    x)) +    go (CastXX' @sh @sh' ssx) (M_Nest @esh esh x) +      | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh') +      = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) x      go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))      go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))      go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x) -    go (CastXX' @sh @sh' ssx c) (M_Nest @esh esh x) -      | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh') -      = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) (go c x) +    go (CastT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2) +    go Cast0X (x :: Mixed esh a) +      | Refl <- lemAppNil @esh +      = M_Nest (mshape x) x +    go CastX0 (M_Nest @esh _ x) +      | Refl <- lemAppNil @esh +      = x      lemRankAppRankEq :: Rank sh ~ Rank sh'                       => Proxy esh -> Proxy sh -> Proxy sh' @@ -184,7 +202,7 @@ mcast ssh2 arr    = mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr  mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a -mtoRanked = castCastable (CastXR CastId) +mtoRanked = castCastable CastXR  rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a  rtoMixed (Ranked arr) = arr @@ -198,7 +216,7 @@ rcastToMixed sshx rarr@(Ranked arr)  mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')                => ShS sh' -> Mixed sh a -> Shaped sh' a -mcastToShaped targetsh = castCastable (CastXS' targetsh CastId) +mcastToShaped targetsh = castCastable (CastXS' targetsh)  stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a  stoMixed (Shaped arr) = arr | 
