aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Data/Array/Nested.hs1
-rw-r--r--src/Data/Array/Nested/Convert.hs26
-rw-r--r--src/Data/Array/Nested/Mixed.hs24
3 files changed, 18 insertions, 33 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 9faf6d7..114fdc8 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -73,7 +73,6 @@ module Data.Array.Nested (
-- ** Conversions
mtoXArrayPrim, mfromXArrayPrim,
mcast,
- mcastSafe, SafeMCast, SafeMCastSpec(..),
mtoRanked, mcastToShaped,
castCastable, Castable(..),
-- ** Additional arithmetic operations
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index 639f5fd..813155f 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -42,7 +42,8 @@ data Castable a b where
CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b)
CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b)
- CastXR :: Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b)
+ CastXR :: Elt b
+ => Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b)
CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b)
CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh'
-> Castable a b -> Castable (Mixed sh a) (Shaped sh' b)
@@ -68,19 +69,26 @@ castCastable = \c x -> munScalar (go c (mscalar x))
go (CastCmp c1 c2) x = go c1 (go c2 x)
go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x)
go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x)
- go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) =
- M_Ranked (M_Nest esh (mcastSafe @(MCastApp esh sh esh (Replicate (Rank sh) Nothing) MCastId MCastForget) Proxy
- (go c x)))
+ go (CastXR @_ @_ @sh c) (M_Nest @esh esh x)
+ | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh)
+ = let x' = go c x
+ ssx' = ssxAppend (ssxFromShape esh)
+ (ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x') (ssxFromShape esh))))
+ in M_Ranked (M_Nest esh (mcast ssx' x'))
go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x))
go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x)
- | Refl <- lemRankAppMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')
= M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh')))
(go c x)))
go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x)
- lemRankAppMapJust :: Rank sh ~ Rank sh'
- => Proxy esh -> Proxy sh -> Proxy sh'
- -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh')
- lemRankAppMapJust _ _ _ = unsafeCoerceRefl
+ lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh
+ -> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing)
+ lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl
+
+ lemRankAppRankEqMapJust :: Rank sh ~ Rank sh'
+ => Proxy esh -> Proxy sh -> Proxy sh'
+ -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh')
+ lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 0a7eaba..c18db63 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -28,7 +28,7 @@ import Data.Bifunctor (bimap)
import Data.Coerce
import Data.Foldable (toList)
import Data.Int
-import Data.Kind (Constraint, Type)
+import Data.Kind (Type)
import Data.List.NonEmpty (NonEmpty(..))
import Data.List.NonEmpty qualified as NE
import Data.Proxy
@@ -40,7 +40,6 @@ import Foreign.Storable (Storable)
import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
import GHC.TypeLits
-import Unsafe.Coerce (unsafeCoerce)
import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Permutation
@@ -932,24 +931,3 @@ mcast ssh2 arr
| Refl <- lemAppNil @sh1
, Refl <- lemAppNil @sh2
= mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr
-
--- TODO: This should be `type data` but a bug in GHC 9.10 means that that throws linker errors
-data SafeMCastSpec
- = MCastId
- | MCastApp [Maybe Nat] [Maybe Nat] [Maybe Nat] [Maybe Nat] SafeMCastSpec SafeMCastSpec
- | MCastForget
-
-type SafeMCast :: SafeMCastSpec -> [Maybe Nat] -> [Maybe Nat] -> Constraint
-type family SafeMCast spec sh1 sh2 where
- SafeMCast MCastId sh sh = ()
- SafeMCast (MCastApp sh1A sh1B sh2A sh2B specA specB) sh1 sh2 = (sh1 ~ sh1A ++ sh1B, sh2 ~ sh2A ++ sh2B, SafeMCast specA sh1A sh2A, SafeMCast specB sh1B sh2B)
- SafeMCast MCastForget sh1 sh2 = sh2 ~ Replicate (Rank sh1) Nothing
-
--- | This is an O(1) operation: the 'SafeMCast' constraint ensures that
--- type-level shape information can only be forgotten, not introduced, and thus
--- that no runtime shape checks are required. The @spec@ describes to
--- 'SafeMCast' how exactly you intend @sh2@ to be a weakening of @sh1@.
---
--- To see how to construct the spec, read the equations of 'SafeMCast' closely.
-mcastSafe :: forall spec sh1 sh2 a proxy. SafeMCast spec sh1 sh2 => proxy spec -> Mixed sh1 a -> Mixed sh2 a
-mcastSafe _ = unsafeCoerce @(Mixed sh1 a) @(Mixed sh2 a)