From 3c8f13c8310de646b15c6f2745cfe190db7610db Mon Sep 17 00:00:00 2001 From: Mikolaj Konarski Date: Wed, 14 May 2025 19:43:21 +0200 Subject: Move Arith, XArray and Convert --- src/Data/Array/Nested/Convert.hs | 86 +++++++++++++++++++++++++++++++ src/Data/Array/Nested/Internal/Convert.hs | 86 ------------------------------- src/Data/Array/Nested/Mixed.hs | 6 +-- src/Data/Array/Nested/Ranked.hs | 4 +- src/Data/Array/Nested/Shaped.hs | 4 +- 5 files changed, 93 insertions(+), 93 deletions(-) create mode 100644 src/Data/Array/Nested/Convert.hs delete mode 100644 src/Data/Array/Nested/Internal/Convert.hs (limited to 'src/Data/Array/Nested') diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs new file mode 100644 index 0000000..639f5fd --- /dev/null +++ b/src/Data/Array/Nested/Convert.hs @@ -0,0 +1,86 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +module Data.Array.Nested.Convert where + +import Control.Category +import Data.Proxy +import Data.Type.Equality + +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Types +import Data.Array.Nested.Internal.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Ranked +import Data.Array.Nested.Shaped +import Data.Array.Nested.Shaped.Shape + + +stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a +stoRanked sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + = mtoRanked arr + +rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a +rcastToShaped (Ranked arr) targetsh + | Refl <- lemRankReplicate (shxRank (shCvtSX targetsh)) + , Refl <- lemRankMapJust targetsh + = mcastToShaped arr targetsh + +-- | The only constructor that performs runtime shape checking is 'CastXS''. +-- For the other construtors, the types ensure that the shapes are already +-- compatible. To convert between 'Ranked' and 'Shaped', go via 'Mixed'. +data Castable a b where + CastId :: Castable a a + CastCmp :: Castable b c -> Castable a b -> Castable a c + + CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b) + CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b) + + CastXR :: Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b) + CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b) + CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh' + -> Castable a b -> Castable (Mixed sh a) (Shaped sh' b) + + CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b) + CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b) + CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b) + +instance Category Castable where + id = CastId + (.) = CastCmp + +castCastable :: (Elt a, Elt b) => Castable a b -> a -> b +castCastable = \c x -> munScalar (go c (mscalar x)) + where + -- The 'esh' is the extension shape: the casting happens under a whole + -- bunch of additional dimensions that it does not touch. These dimensions + -- are 'esh'. + -- The strategy is to unwind step-by-step to a large Mixed array, and to + -- perform the required checks and castings when re-nesting back up. + go :: Castable a b -> Mixed esh a -> Mixed esh b + go CastId x = x + go (CastCmp c1 c2) x = go c1 (go c2 x) + go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x) + go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x) + go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) = + M_Ranked (M_Nest esh (mcastSafe @(MCastApp esh sh esh (Replicate (Rank sh) Nothing) MCastId MCastForget) Proxy + (go c x))) + go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x)) + go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x) + | Refl <- lemRankAppMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh'))) + (go c x))) + go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) + go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) + go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x) + + lemRankAppMapJust :: Rank sh ~ Rank sh' + => Proxy esh -> Proxy sh -> Proxy sh' + -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') + lemRankAppMapJust _ _ _ = unsafeCoerceRefl diff --git a/src/Data/Array/Nested/Internal/Convert.hs b/src/Data/Array/Nested/Internal/Convert.hs deleted file mode 100644 index 611b45e..0000000 --- a/src/Data/Array/Nested/Internal/Convert.hs +++ /dev/null @@ -1,86 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeAbstractions #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -module Data.Array.Nested.Internal.Convert where - -import Control.Category -import Data.Proxy -import Data.Type.Equality - -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Types -import Data.Array.Nested.Internal.Lemmas -import Data.Array.Nested.Mixed -import Data.Array.Nested.Ranked -import Data.Array.Nested.Shaped -import Data.Array.Nested.Mixed.Shape -import Data.Array.Nested.Shaped.Shape - - -stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a -stoRanked sarr@(Shaped arr) - | Refl <- lemRankMapJust (sshape sarr) - = mtoRanked arr - -rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a -rcastToShaped (Ranked arr) targetsh - | Refl <- lemRankReplicate (shxRank (shCvtSX targetsh)) - , Refl <- lemRankMapJust targetsh - = mcastToShaped arr targetsh - --- | The only constructor that performs runtime shape checking is 'CastXS''. --- For the other construtors, the types ensure that the shapes are already --- compatible. To convert between 'Ranked' and 'Shaped', go via 'Mixed'. -data Castable a b where - CastId :: Castable a a - CastCmp :: Castable b c -> Castable a b -> Castable a c - - CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b) - CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b) - - CastXR :: Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b) - CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b) - CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh' - -> Castable a b -> Castable (Mixed sh a) (Shaped sh' b) - - CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b) - CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b) - CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b) - -instance Category Castable where - id = CastId - (.) = CastCmp - -castCastable :: (Elt a, Elt b) => Castable a b -> a -> b -castCastable = \c x -> munScalar (go c (mscalar x)) - where - -- The 'esh' is the extension shape: the casting happens under a whole - -- bunch of additional dimensions that it does not touch. These dimensions - -- are 'esh'. - -- The strategy is to unwind step-by-step to a large Mixed array, and to - -- perform the required checks and castings when re-nesting back up. - go :: Castable a b -> Mixed esh a -> Mixed esh b - go CastId x = x - go (CastCmp c1 c2) x = go c1 (go c2 x) - go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x) - go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x) - go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) = - M_Ranked (M_Nest esh (mcastSafe @(MCastApp esh sh esh (Replicate (Rank sh) Nothing) MCastId MCastForget) Proxy - (go c x))) - go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x)) - go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x) - | Refl <- lemRankAppMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') - = M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh'))) - (go c x))) - go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) - go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) - go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x) - - lemRankAppMapJust :: Rank sh ~ Rank sh' - => Proxy esh -> Proxy sh -> Proxy sh' - -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') - lemRankAppMapJust _ _ _ = unsafeCoerceRefl diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 50a1b71..ec19c21 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -42,12 +42,12 @@ import GHC.Generics (Generic) import GHC.TypeLits import Unsafe.Coerce (unsafeCoerce) -import Data.Array.Mixed.Internal.Arith +import Data.Array.Arith import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation import Data.Array.Mixed.Types -import Data.Array.Mixed.XArray (XArray(..)) -import Data.Array.Mixed.XArray qualified as X +import Data.Array.XArray (XArray(..)) +import Data.Array.XArray qualified as X import Data.Array.Nested.Mixed.Shape import Data.Bag diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs index fb5caa9..e2074ac 100644 --- a/src/Data/Array/Nested/Ranked.hs +++ b/src/Data/Array/Nested/Ranked.hs @@ -41,8 +41,8 @@ import GHC.TypeNats qualified as TN import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation import Data.Array.Mixed.Types -import Data.Array.Mixed.XArray (XArray(..)) -import Data.Array.Mixed.XArray qualified as X +import Data.Array.XArray (XArray(..)) +import Data.Array.XArray qualified as X import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Ranked.Shape diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index ba767cd..4bccbc4 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -41,8 +41,8 @@ import GHC.TypeLits import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation import Data.Array.Mixed.Types -import Data.Array.Mixed.XArray (XArray) -import Data.Array.Mixed.XArray qualified as X +import Data.Array.XArray (XArray) +import Data.Array.XArray qualified as X import Data.Array.Nested.Internal.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape -- cgit v1.2.3-70-g09d2