diff options
| -rw-r--r-- | src/Data/Array/Nested.hs | 1 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Convert.hs | 26 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 24 | 
3 files changed, 18 insertions, 33 deletions
| diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 9faf6d7..114fdc8 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -73,7 +73,6 @@ module Data.Array.Nested (    -- ** Conversions    mtoXArrayPrim, mfromXArrayPrim,    mcast, -  mcastSafe, SafeMCast, SafeMCastSpec(..),    mtoRanked, mcastToShaped,    castCastable, Castable(..),    -- ** Additional arithmetic operations diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index 639f5fd..813155f 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -42,7 +42,8 @@ data Castable a b where    CastRX  :: Castable a b -> Castable (Ranked n a)           (Mixed (Replicate n Nothing) b)    CastSX  :: Castable a b -> Castable (Shaped sh a)          (Mixed (MapJust sh) b) -  CastXR  :: Castable a b -> Castable (Mixed sh a)           (Ranked (Rank sh) b) +  CastXR  :: Elt b +          => Castable a b -> Castable (Mixed sh a)           (Ranked (Rank sh) b)    CastXS  :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b)    CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh'            -> Castable a b -> Castable (Mixed sh a)           (Shaped sh' b) @@ -68,19 +69,26 @@ castCastable = \c x -> munScalar (go c (mscalar x))      go (CastCmp c1 c2) x = go c1 (go c2 x)      go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x)      go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x) -    go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) = -      M_Ranked (M_Nest esh (mcastSafe @(MCastApp esh sh esh (Replicate (Rank sh) Nothing) MCastId MCastForget) Proxy -                                      (go c x))) +    go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) +      | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh) +      = let x' = go c x +            ssx' = ssxAppend (ssxFromShape esh) +                             (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x') (ssxFromShape esh)))) +        in M_Ranked (M_Nest esh (mcast ssx' x'))      go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x))      go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x) -      | Refl <- lemRankAppMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') +      | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')        = M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh')))                                      (go c x)))      go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))      go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))      go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x) -    lemRankAppMapJust :: Rank sh ~ Rank sh' -                      => Proxy esh -> Proxy sh -> Proxy sh' -                      -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') -    lemRankAppMapJust _ _ _ = unsafeCoerceRefl +    lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh +                          -> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing) +    lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl + +    lemRankAppRankEqMapJust :: Rank sh ~ Rank sh' +                            => Proxy esh -> Proxy sh -> Proxy sh' +                            -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') +    lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 0a7eaba..c18db63 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -28,7 +28,7 @@ import Data.Bifunctor (bimap)  import Data.Coerce  import Data.Foldable (toList)  import Data.Int -import Data.Kind (Constraint, Type) +import Data.Kind (Type)  import Data.List.NonEmpty (NonEmpty(..))  import Data.List.NonEmpty qualified as NE  import Data.Proxy @@ -40,7 +40,6 @@ import Foreign.Storable (Storable)  import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)  import GHC.Generics (Generic)  import GHC.TypeLits -import Unsafe.Coerce (unsafeCoerce)  import Data.Array.Mixed.Lemmas  import Data.Array.Mixed.Permutation @@ -932,24 +931,3 @@ mcast ssh2 arr    | Refl <- lemAppNil @sh1    , Refl <- lemAppNil @sh2    = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr - --- TODO: This should be `type data` but a bug in GHC 9.10 means that that throws linker errors -data SafeMCastSpec -  = MCastId -  | MCastApp [Maybe Nat] [Maybe Nat] [Maybe Nat] [Maybe Nat] SafeMCastSpec SafeMCastSpec -  | MCastForget - -type SafeMCast :: SafeMCastSpec -> [Maybe Nat] -> [Maybe Nat] -> Constraint -type family SafeMCast spec sh1 sh2 where -  SafeMCast MCastId sh sh = () -  SafeMCast (MCastApp sh1A sh1B sh2A sh2B specA specB) sh1 sh2 = (sh1 ~ sh1A ++ sh1B, sh2 ~ sh2A ++ sh2B, SafeMCast specA sh1A sh2A, SafeMCast specB sh1B sh2B) -  SafeMCast MCastForget sh1 sh2 = sh2 ~ Replicate (Rank sh1) Nothing - --- | This is an O(1) operation: the 'SafeMCast' constraint ensures that --- type-level shape information can only be forgotten, not introduced, and thus --- that no runtime shape checks are required. The @spec@ describes to --- 'SafeMCast' how exactly you intend @sh2@ to be a weakening of @sh1@. --- --- To see how to construct the spec, read the equations of 'SafeMCast' closely. -mcastSafe :: forall spec sh1 sh2 a proxy. SafeMCast spec sh1 sh2 => proxy spec -> Mixed sh1 a -> Mixed sh2 a -mcastSafe _ = unsafeCoerce @(Mixed sh1 a) @(Mixed sh2 a) | 
