diff options
Diffstat (limited to 'src/Data/Array/Nested/Convert.hs')
-rw-r--r-- | src/Data/Array/Nested/Convert.hs | 106 |
1 files changed, 76 insertions, 30 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index b3a2c63..cea2489 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -104,25 +104,52 @@ shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh -- @'@: 'CastXS'' and 'CastXX''. For the other constructors, the types ensure -- that the shapes are already compatible. To convert between 'Ranked' and -- 'Shaped', go via 'Mixed'. +-- +-- The guiding principle behind 'Castable' is that it should represent the +-- array restructurings, or perhaps re-presentations, that do not change the +-- underlying 'XArray's. This leads to the inclusion of some operations that do +-- not look like a cast at first glance, like 'CastZip'; with the underlying +-- representation in mind, however, they are very much like a cast. data Castable a b where CastId :: Castable a a CastCmp :: Castable b c -> Castable a b -> Castable a c - CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b) - CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b) - - CastXR :: Elt b - => Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b) - CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b) - CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh' - -> Castable a b -> Castable (Mixed sh a) (Shaped sh' b) - - CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b) - CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b) - CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b) - - CastXX' :: (Rank sh ~ Rank sh', Elt b) => StaticShX sh' - -> Castable a b -> Castable (Mixed sh a) (Mixed sh' b) + CastRX :: Castable (Ranked n a) (Mixed (Replicate n Nothing) a) + CastSX :: Castable (Shaped sh a) (Mixed (MapJust sh) a) + + CastXR :: Elt a + => Castable (Mixed sh a) (Ranked (Rank sh) a) + CastXS :: Castable (Mixed (MapJust sh) a) (Shaped sh a) + CastXS' :: (Rank sh ~ Rank sh', Elt a) + => ShS sh' + -> Castable (Mixed sh a) (Shaped sh' a) + + CastXX' :: (Rank sh ~ Rank sh', Elt a) + => StaticShX sh' + -> Castable (Mixed sh a) (Mixed sh' a) + + CastRR :: Castable a b + -> Castable (Ranked n a) (Ranked n b) + CastSS :: Castable a b + -> Castable (Shaped sh a) (Shaped sh b) + CastXX :: Castable a b + -> Castable (Mixed sh a) (Mixed sh b) + CastT2 :: Castable a a' + -> Castable b b' + -> Castable (a, b) (a', b') + + Cast0X :: Elt a + => Castable a (Mixed '[] a) + CastX0 :: Castable (Mixed '[] a) a + + CastNest :: Elt a => StaticShX sh + -> Castable (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a)) + CastUnnest :: Castable (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a) + + CastZip :: (Elt a, Elt b) + => Castable (Mixed sh a, Mixed sh b) (Mixed sh (a, b)) + CastUnzip :: (Elt a, Elt b) + => Castable (Mixed sh (a, b)) (Mixed sh a, Mixed sh b) instance Category Castable where id = CastId @@ -139,25 +166,44 @@ castCastable = \c x -> munScalar (go c (mscalar x)) go :: Castable a b -> Mixed esh a -> Mixed esh b go CastId x = x go (CastCmp c1 c2) x = go c1 (go c2 x) - go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x) - go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x) - go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) + go CastRX (M_Ranked x) = x + go CastSX (M_Shaped x) = x + go (CastXR @_ @sh) (M_Nest @esh esh x) | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh) - = let x' = go c x - ssx' = ssxAppend (ssxFromShX esh) - (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x') (ssxFromShX esh)))) - in M_Ranked (M_Nest esh (mcast ssx' x')) - go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x)) - go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x) + = let ssx' = ssxAppend (ssxFromShX esh) + (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x) (ssxFromShX esh)))) + in M_Ranked (M_Nest esh (mcast ssx' x)) + go CastXS (M_Nest esh x) = M_Shaped (M_Nest esh x) + go (CastXS' @sh @sh' sh') (M_Nest @esh esh x) | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh'))) - (go c x))) + x)) + go (CastXX' @sh @sh' ssx) (M_Nest @esh esh x) + | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) x go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x) - go (CastXX' @sh @sh' ssx c) (M_Nest @esh esh x) - | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh') - = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) (go c x) + go (CastT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2) + go Cast0X (x :: Mixed esh a) + | Refl <- lemAppNil @esh + = M_Nest (mshape x) x + go CastX0 (M_Nest @esh _ x) + | Refl <- lemAppNil @esh + = x + go (CastNest @_ @sh @sh' ssh) (M_Nest @esh esh x) + | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (mshape x) (ssxFromShX esh `ssxAppend` ssh)) x) + go (CastUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x)) + | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh x + go CastZip x = + -- no need to check that the two esh's are equal because they were zipped previously + let (M_Nest esh x1, M_Nest _ x2) = munzip x + in M_Nest esh (mzip x1 x2) + go CastUnzip (M_Nest esh x) = + let (x1, x2) = munzip x + in mzip (M_Nest esh x1) (M_Nest esh x2) lemRankAppRankEq :: Rank sh ~ Rank sh' => Proxy esh -> Proxy sh -> Proxy sh' @@ -184,7 +230,7 @@ mcast ssh2 arr = mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a -mtoRanked = castCastable (CastXR CastId) +mtoRanked = castCastable CastXR rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a rtoMixed (Ranked arr) = arr @@ -198,7 +244,7 @@ rcastToMixed sshx rarr@(Ranked arr) mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') => ShS sh' -> Mixed sh a -> Shaped sh' a -mcastToShaped targetsh = castCastable (CastXS' targetsh CastId) +mcastToShaped targetsh = castCastable (CastXS' targetsh) stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a stoMixed (Shaped arr) = arr |