diff options
Diffstat (limited to 'src/Data/Array/Nested')
| -rw-r--r-- | src/Data/Array/Nested/Internal/Convert.hs | 60 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 24 | 
2 files changed, 82 insertions, 2 deletions
| diff --git a/src/Data/Array/Nested/Internal/Convert.hs b/src/Data/Array/Nested/Internal/Convert.hs index e101981..183d62c 100644 --- a/src/Data/Array/Nested/Internal/Convert.hs +++ b/src/Data/Array/Nested/Internal/Convert.hs @@ -1,14 +1,19 @@  {-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE TypeApplications #-} -{-# LANGUAGE GADTs #-} +{-# LANGUAGE TypeAbstractions #-}  {-# LANGUAGE TypeOperators #-}  module Data.Array.Nested.Internal.Convert where +import Control.Category +import Data.Proxy  import Data.Type.Equality  import Data.Array.Mixed.Lemmas  import Data.Array.Mixed.Shape +import Data.Array.Mixed.Types  import Data.Array.Nested.Internal.Lemmas  import Data.Array.Nested.Internal.Mixed  import Data.Array.Nested.Internal.Ranked @@ -26,3 +31,56 @@ rcastToShaped (Ranked arr) targetsh    | Refl <- lemRankReplicate (shxRank (shCvtSX targetsh))    , Refl <- lemRankMapJust targetsh    = mcastToShaped arr targetsh + +-- | The only constructor that performs runtime shape checking is 'CastXS''. +-- For the other construtors, the types ensure that the shapes are already +-- compatible. To convert between 'Ranked' and 'Shaped', go via 'Mixed'. +data Castable a b where +  CastId  :: Castable a a +  CastCmp :: Castable b c -> Castable a b -> Castable a c + +  CastRX  :: Castable a b -> Castable (Ranked n a)           (Mixed (Replicate n Nothing) b) +  CastSX  :: Castable a b -> Castable (Shaped sh a)          (Mixed (MapJust sh) b) + +  CastXR  :: Castable a b -> Castable (Mixed sh a)           (Ranked (Rank sh) b) +  CastXS  :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b) +  CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh' +          -> Castable a b -> Castable (Mixed sh a)           (Shaped sh' b) + +  CastRR  :: Castable a b -> Castable (Ranked n a)           (Ranked n b) +  CastSS  :: Castable a b -> Castable (Shaped sh a)          (Shaped sh b) +  CastXX  :: Castable a b -> Castable (Mixed sh a)           (Mixed sh b) + +instance Category Castable where +  id = CastId +  (.) = CastCmp + +castCastable :: (Elt a, Elt b) => Castable a b -> a -> b +castCastable = \c x -> munScalar (go c (mscalar x)) +  where +    -- The 'esh' is the extension shape: the casting happens under a whole +    -- bunch of additional dimensions that it does not touch. These dimensions +    -- are 'esh'. +    -- The strategy is to unwind step-by-step to a large Mixed array, and to +    -- perform the required checks and castings when re-nesting back up. +    go :: Castable a b -> Mixed esh a -> Mixed esh b +    go CastId x = x +    go (CastCmp c1 c2) x = go c1 (go c2 x) +    go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x) +    go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x) +    go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) = +      M_Ranked (M_Nest esh (mcastSafe @(MCastApp esh sh esh (Replicate (Rank sh) Nothing) MCastId MCastForget) Proxy +                                      (go c x))) +    go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x)) +    go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x) +      | Refl <- lemRankAppMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') +      = M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh'))) +                                    (go c x))) +    go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) +    go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) +    go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x) + +    lemRankAppMapJust :: Rank sh ~ Rank sh' +                      => Proxy esh -> Proxy sh -> Proxy sh' +                      -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') +    lemRankAppMapJust _ _ _ = unsafeCoerceRefl diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index d3e8088..3b3f196 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -12,6 +12,7 @@  {-# LANGUAGE StandaloneKindSignatures #-}  {-# LANGUAGE StrictData #-}  {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeData #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-}  {-# LANGUAGE UndecidableInstances #-} @@ -28,7 +29,7 @@ import Data.Bifunctor (bimap)  import Data.Coerce  import Data.Foldable (toList)  import Data.Int -import Data.Kind (Type) +import Data.Kind (Type, Constraint)  import Data.List.NonEmpty (NonEmpty(..))  import Data.List.NonEmpty qualified as NE  import Data.Proxy @@ -40,6 +41,7 @@ import Foreign.Storable (Storable)  import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp)  import GHC.Generics (Generic)  import GHC.TypeLits +import Unsafe.Coerce (unsafeCoerce)  import Data.Array.Mixed.XArray (XArray(..))  import Data.Array.Mixed.XArray qualified as X @@ -923,3 +925,23 @@ mcast ssh2 arr    | Refl <- lemAppNil @sh1    , Refl <- lemAppNil @sh2    = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr + +type data SafeMCastSpec +  = MCastId +  | MCastApp [Maybe Nat] [Maybe Nat] [Maybe Nat] [Maybe Nat] SafeMCastSpec SafeMCastSpec +  | MCastForget + +type SafeMCast :: SafeMCastSpec -> [Maybe Nat] -> [Maybe Nat] -> Constraint +type family SafeMCast spec sh1 sh2 where +  SafeMCast MCastId sh sh = () +  SafeMCast (MCastApp sh1A sh1B sh2A sh2B specA specB) sh1 sh2 = (sh1 ~ sh1A ++ sh1B, sh2 ~ sh2A ++ sh2B, SafeMCast specA sh1A sh2A, SafeMCast specB sh1B sh2B) +  SafeMCast MCastForget sh1 sh2 = sh2 ~ Replicate (Rank sh1) Nothing + +-- | This is an O(1) operation: the 'SafeMCast' constraint ensures that +-- type-level shape information can only be forgotten, not introduced, and thus +-- that no runtime shape checks are required. The @spec@ describes to +-- 'SafeMCast' how exactly you intend @sh2@ to be a weakening of @sh1@. +-- +-- To see how to construct the spec, read the equations of 'SafeMCast' closely. +mcastSafe :: forall spec sh1 sh2 a proxy. SafeMCast spec sh1 sh2 => proxy spec -> Mixed sh1 a -> Mixed sh2 a +mcastSafe _ = unsafeCoerce @(Mixed sh1 a) @(Mixed sh2 a) | 
