diff options
Diffstat (limited to 'src/Data/Array/Nested/Convert.hs')
-rw-r--r-- | src/Data/Array/Nested/Convert.hs | 86 |
1 files changed, 86 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs new file mode 100644 index 0000000..639f5fd --- /dev/null +++ b/src/Data/Array/Nested/Convert.hs @@ -0,0 +1,86 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +module Data.Array.Nested.Convert where + +import Control.Category +import Data.Proxy +import Data.Type.Equality + +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Types +import Data.Array.Nested.Internal.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Ranked +import Data.Array.Nested.Shaped +import Data.Array.Nested.Shaped.Shape + + +stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a +stoRanked sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + = mtoRanked arr + +rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a +rcastToShaped (Ranked arr) targetsh + | Refl <- lemRankReplicate (shxRank (shCvtSX targetsh)) + , Refl <- lemRankMapJust targetsh + = mcastToShaped arr targetsh + +-- | The only constructor that performs runtime shape checking is 'CastXS''. +-- For the other construtors, the types ensure that the shapes are already +-- compatible. To convert between 'Ranked' and 'Shaped', go via 'Mixed'. +data Castable a b where + CastId :: Castable a a + CastCmp :: Castable b c -> Castable a b -> Castable a c + + CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b) + CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b) + + CastXR :: Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b) + CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b) + CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh' + -> Castable a b -> Castable (Mixed sh a) (Shaped sh' b) + + CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b) + CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b) + CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b) + +instance Category Castable where + id = CastId + (.) = CastCmp + +castCastable :: (Elt a, Elt b) => Castable a b -> a -> b +castCastable = \c x -> munScalar (go c (mscalar x)) + where + -- The 'esh' is the extension shape: the casting happens under a whole + -- bunch of additional dimensions that it does not touch. These dimensions + -- are 'esh'. + -- The strategy is to unwind step-by-step to a large Mixed array, and to + -- perform the required checks and castings when re-nesting back up. + go :: Castable a b -> Mixed esh a -> Mixed esh b + go CastId x = x + go (CastCmp c1 c2) x = go c1 (go c2 x) + go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x) + go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x) + go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) = + M_Ranked (M_Nest esh (mcastSafe @(MCastApp esh sh esh (Replicate (Rank sh) Nothing) MCastId MCastForget) Proxy + (go c x))) + go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x)) + go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x) + | Refl <- lemRankAppMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh'))) + (go c x))) + go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) + go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) + go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x) + + lemRankAppMapJust :: Rank sh ~ Rank sh' + => Proxy esh -> Proxy sh -> Proxy sh' + -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') + lemRankAppMapJust _ _ _ = unsafeCoerceRefl |