{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} module Data.Array.Nested.Internal.Convert ( stoRanked, rcastToShaped, castCastable, Castable(.., CastXR, CastXS, CastRS), castRR', castSS', ) where import Control.Category import Data.Proxy import Data.Type.Equality import GHC.TypeNats import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Shape import Data.Array.Mixed.Types import Data.Array.Nested.Internal.Lemmas import Data.Array.Nested.Internal.Mixed import Data.Array.Nested.Internal.Ranked import Data.Array.Nested.Internal.Shape import Data.Array.Nested.Internal.Shaped stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a stoRanked sarr@(Shaped arr) | Refl <- lemRankMapJust (sshape sarr) = mtoRanked arr rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a rcastToShaped (Ranked arr) targetsh | Refl <- lemRankReplicate (shxRank (shCvtSX targetsh)) , Refl <- lemRankMapJust targetsh = mcastToShaped arr targetsh data Castable a b where CastId :: Castable a a CastCmp :: Castable b c -> Castable a b -> Castable a c CastInv :: Castable b a -> Castable a b CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b) CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b) CastSR :: Elt a => ShS sh -- ^ The singleton is required in case this constructor appears under 'CastInv'. -> Castable a b -> Castable (Shaped sh a) (Ranked (Rank sh) b) CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b) CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b) CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b) CastXX' :: (Rank sh ~ Rank sh', Elt a, Elt b) => IShX sh -> IShX sh' -> Castable a b -> Castable (Mixed sh a) (Mixed sh' b) pattern CastXR :: Castable a b -> Castable (Mixed (Replicate n Nothing) a) (Ranked n b) pattern CastXR c = CastInv (CastRX (CastInv c)) pattern CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b) pattern CastXS c = CastInv (CastSX (CastInv c)) pattern CastRS :: Elt b => ShS sh -> Castable a b -> Castable (Ranked (Rank sh) a) (Shaped sh b) pattern CastRS sh c = CastInv (CastSR sh (CastInv c)) castRR' :: SNat n -> SNat n' -> Castable a b -> Castable (Ranked n a) (Ranked n' b) castRR' n@SNat n'@SNat c | Just Refl <- sameNat n n' = CastRR c | otherwise = error "castRR': Ranks unequal" castSS' :: ShS sh -> ShS sh' -> Castable a b -> Castable (Shaped sh a) (Shaped sh' b) castSS' sh sh' c | Just Refl <- testEquality sh sh' = CastSS c | otherwise = error "castSS': Shapes unequal" instance Category Castable where id = CastId (.) = CastCmp castCastable :: (Elt a, Elt b) => Castable a b -> a -> b castCastable = \c x -> munScalar (go c (mscalar x)) where -- The 'esh' is the extension shape: the casting happens under a whole -- bunch of additional dimensions that it does not touch. These dimensions -- are 'esh'. -- The strategy is to unwind step-by-step to a large Mixed array, and to -- perform the required checks and castings when re-nesting back up. go :: Castable a b -> Mixed esh a -> Mixed esh b go CastId x = x go (CastCmp c1 c2) x = go c1 (go c2 x) go (CastInv c) x = goInv c x go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x) go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x) go (CastSR @_ @sh sh c) (M_Shaped (M_Nest @esh esh x)) | Refl <- lemRankMapJust sh = M_Ranked (M_Nest esh (mcastSafe @(MCastApp esh (MapJust sh) esh (Replicate (Rank sh) Nothing) MCastId MCastForget) Proxy (go c x))) go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x) go (CastXX' sh sh' c) (M_Nest esh x) | Refl <- lemRankApp (ssxFromShape esh) (ssxFromShape sh) , Refl <- lemRankApp (ssxFromShape esh) (ssxFromShape sh') = M_Nest esh (mcast (ssxFromShape (shxAppend esh sh')) (go c x)) goInv :: Castable b a -> Mixed esh a -> Mixed esh b goInv CastId x = x goInv (CastCmp c1 c2) x = goInv c2 (goInv c1 x) goInv (CastInv c) x = go c x goInv (CastRX c) (M_Nest esh x) = M_Ranked (M_Nest esh (goInv c x)) goInv (CastSX c) (M_Nest esh x) = M_Shaped (M_Nest esh (goInv c x)) goInv (CastSR @sh sh c) (M_Ranked (M_Nest esh x)) | Refl <- lemRankApp (ssxFromShape esh) (ssxFromSNat (shsRank sh)) , Refl <- lemRankApp (ssxFromShape esh) (ssxFromShape (shCvtSX sh)) , Refl <- lemRankReplicate (shsRank sh) , Refl <- lemRankMapJust sh = M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh))) (goInv c x))) goInv (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (goInv c x)) goInv (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (goInv c x)) goInv (CastXX c) (M_Nest esh x) = M_Nest esh (goInv c x) goInv (CastXX' sh sh' c) (M_Nest esh x) | Refl <- lemRankApp (ssxFromShape esh) (ssxFromShape sh) , Refl <- lemRankApp (ssxFromShape esh) (ssxFromShape sh') = M_Nest esh (mcast (ssxFromShape (shxAppend esh sh)) (goInv c x))