diff options
-rw-r--r-- | src/Data/Array/Nested.hs | 1 | ||||
-rw-r--r-- | src/Data/Array/Nested/Convert.hs | 26 | ||||
-rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 24 |
3 files changed, 18 insertions, 33 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 9faf6d7..114fdc8 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -73,7 +73,6 @@ module Data.Array.Nested ( -- ** Conversions mtoXArrayPrim, mfromXArrayPrim, mcast, - mcastSafe, SafeMCast, SafeMCastSpec(..), mtoRanked, mcastToShaped, castCastable, Castable(..), -- ** Additional arithmetic operations diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index 639f5fd..813155f 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -42,7 +42,8 @@ data Castable a b where CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b) CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b) - CastXR :: Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b) + CastXR :: Elt b + => Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b) CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b) CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh' -> Castable a b -> Castable (Mixed sh a) (Shaped sh' b) @@ -68,19 +69,26 @@ castCastable = \c x -> munScalar (go c (mscalar x)) go (CastCmp c1 c2) x = go c1 (go c2 x) go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x) go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x) - go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) = - M_Ranked (M_Nest esh (mcastSafe @(MCastApp esh sh esh (Replicate (Rank sh) Nothing) MCastId MCastForget) Proxy - (go c x))) + go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) + | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh) + = let x' = go c x + ssx' = ssxAppend (ssxFromShape esh) + (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x') (ssxFromShape esh)))) + in M_Ranked (M_Nest esh (mcast ssx' x')) go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x)) go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x) - | Refl <- lemRankAppMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') + | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') = M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh'))) (go c x))) go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x) - lemRankAppMapJust :: Rank sh ~ Rank sh' - => Proxy esh -> Proxy sh -> Proxy sh' - -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') - lemRankAppMapJust _ _ _ = unsafeCoerceRefl + lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh + -> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing) + lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl + + lemRankAppRankEqMapJust :: Rank sh ~ Rank sh' + => Proxy esh -> Proxy sh -> Proxy sh' + -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') + lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 0a7eaba..c18db63 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -28,7 +28,7 @@ import Data.Bifunctor (bimap) import Data.Coerce import Data.Foldable (toList) import Data.Int -import Data.Kind (Constraint, Type) +import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty(..)) import Data.List.NonEmpty qualified as NE import Data.Proxy @@ -40,7 +40,6 @@ import Foreign.Storable (Storable) import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) import GHC.Generics (Generic) import GHC.TypeLits -import Unsafe.Coerce (unsafeCoerce) import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation @@ -932,24 +931,3 @@ mcast ssh2 arr | Refl <- lemAppNil @sh1 , Refl <- lemAppNil @sh2 = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr - --- TODO: This should be `type data` but a bug in GHC 9.10 means that that throws linker errors -data SafeMCastSpec - = MCastId - | MCastApp [Maybe Nat] [Maybe Nat] [Maybe Nat] [Maybe Nat] SafeMCastSpec SafeMCastSpec - | MCastForget - -type SafeMCast :: SafeMCastSpec -> [Maybe Nat] -> [Maybe Nat] -> Constraint -type family SafeMCast spec sh1 sh2 where - SafeMCast MCastId sh sh = () - SafeMCast (MCastApp sh1A sh1B sh2A sh2B specA specB) sh1 sh2 = (sh1 ~ sh1A ++ sh1B, sh2 ~ sh2A ++ sh2B, SafeMCast specA sh1A sh2A, SafeMCast specB sh1B sh2B) - SafeMCast MCastForget sh1 sh2 = sh2 ~ Replicate (Rank sh1) Nothing - --- | This is an O(1) operation: the 'SafeMCast' constraint ensures that --- type-level shape information can only be forgotten, not introduced, and thus --- that no runtime shape checks are required. The @spec@ describes to --- 'SafeMCast' how exactly you intend @sh2@ to be a weakening of @sh1@. --- --- To see how to construct the spec, read the equations of 'SafeMCast' closely. -mcastSafe :: forall spec sh1 sh2 a proxy. SafeMCast spec sh1 sh2 => proxy spec -> Mixed sh1 a -> Mixed sh2 a -mcastSafe _ = unsafeCoerce @(Mixed sh1 a) @(Mixed sh2 a) |