{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} module Data.Array.Nested.Convert ( castCastable, Castable(..), -- * Special cases -- -- | These functions can all be implemented using 'castCastable' in some way, -- but some have fewer constraints. rtoMixed, rcastToMixed, rcastToShaped, stoMixed, scastToMixed, stoRanked, mcast, mcastToShaped, mtoRanked, ) where import Control.Category import Data.Proxy import Data.Type.Equality import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Types import Data.Array.Nested.Internal.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Ranked.Base import Data.Array.Nested.Shaped.Base import Data.Array.Nested.Shaped.Shape mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a mcast ssh2 arr | Refl <- lemAppNil @sh1 , Refl <- lemAppNil @sh2 = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a mtoRanked = castCastable (CastXR CastId) rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a rtoMixed (Ranked arr) = arr -- | A more weakly-typed version of 'rtoMixed' that does a runtime shape -- compatibility check. rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a rcastToMixed sshx rarr@(Ranked arr) | Refl <- lemRankReplicate (rrank rarr) = mcast sshx arr mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') => ShS sh' -> Mixed sh a -> Shaped sh' a mcastToShaped targetsh = castCastable (CastXS' targetsh CastId) stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a stoMixed (Shaped arr) = arr -- | A more weakly-typed version of 'stoMixed' that does a runtime shape -- compatibility check. scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') => StaticShX sh' -> Shaped sh a -> Mixed sh' a scastToMixed sshx sarr@(Shaped arr) | Refl <- lemRankMapJust (sshape sarr) = mcast sshx arr stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a stoRanked sarr@(Shaped arr) | Refl <- lemRankMapJust (sshape sarr) = mtoRanked arr rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a rcastToShaped (Ranked arr) targetsh | Refl <- lemRankReplicate (shxRank (shCvtSX targetsh)) , Refl <- lemRankMapJust targetsh = mcastToShaped targetsh arr -- | The constructors that perform runtime shape checking are marked with a -- @'@: 'CastXS'' and 'CastXX''. For the other constructors, the types ensure -- that the shapes are already compatible. To convert between 'Ranked' and -- 'Shaped', go via 'Mixed'. data Castable a b where CastId :: Castable a a CastCmp :: Castable b c -> Castable a b -> Castable a c CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b) CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b) CastXR :: Elt b => Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b) CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b) CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh' -> Castable a b -> Castable (Mixed sh a) (Shaped sh' b) CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b) CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b) CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b) CastXX' :: (Rank sh ~ Rank sh', Elt b) => StaticShX sh' -> Castable a b -> Castable (Mixed sh a) (Mixed sh' b) instance Category Castable where id = CastId (.) = CastCmp castCastable :: (Elt a, Elt b) => Castable a b -> a -> b castCastable = \c x -> munScalar (go c (mscalar x)) where -- The 'esh' is the extension shape: the casting happens under a whole -- bunch of additional dimensions that it does not touch. These dimensions -- are 'esh'. -- The strategy is to unwind step-by-step to a large Mixed array, and to -- perform the required checks and castings when re-nesting back up. go :: Castable a b -> Mixed esh a -> Mixed esh b go CastId x = x go (CastCmp c1 c2) x = go c1 (go c2 x) go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x) go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x) go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh) = let x' = go c x ssx' = ssxAppend (ssxFromShape esh) (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x') (ssxFromShape esh)))) in M_Ranked (M_Nest esh (mcast ssx' x')) go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x)) go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x) | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') = M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh'))) (go c x))) go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x) go (CastXX' @sh @sh' ssx c) (M_Nest @esh esh x) | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh') = M_Nest esh $ mcast (ssxFromShape esh `ssxAppend` ssx) (go c x) lemRankAppRankEq :: Rank sh ~ Rank sh' => Proxy esh -> Proxy sh -> Proxy sh' -> Rank (esh ++ sh) :~: Rank (esh ++ sh') lemRankAppRankEq _ _ _ = unsafeCoerceRefl lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh -> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing) lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl lemRankAppRankEqMapJust :: Rank sh ~ Rank sh' => Proxy esh -> Proxy sh -> Proxy sh' -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl