diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-02-01 22:17:34 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-02-01 22:21:23 +0100 |
commit | f578e36a8ed73268c3e1b91609baa76adfa0693a (patch) | |
tree | d0cd31b7f8f7fb8f1510e9ee843788cada2aebd3 /src/Data/Array/Nested/Internal/Mixed.hs | |
parent | 03af9faf39e8872b5577e6f32d55b692c9a90d0e (diff) |
mcastSafe, castCastable
Diffstat (limited to 'src/Data/Array/Nested/Internal/Mixed.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 24 |
1 files changed, 23 insertions, 1 deletions
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index d3e8088..3b3f196 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -12,6 +12,7 @@ {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} @@ -28,7 +29,7 @@ import Data.Bifunctor (bimap) import Data.Coerce import Data.Foldable (toList) import Data.Int -import Data.Kind (Type) +import Data.Kind (Type, Constraint) import Data.List.NonEmpty (NonEmpty(..)) import Data.List.NonEmpty qualified as NE import Data.Proxy @@ -40,6 +41,7 @@ import Foreign.Storable (Storable) import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp) import GHC.Generics (Generic) import GHC.TypeLits +import Unsafe.Coerce (unsafeCoerce) import Data.Array.Mixed.XArray (XArray(..)) import Data.Array.Mixed.XArray qualified as X @@ -923,3 +925,23 @@ mcast ssh2 arr | Refl <- lemAppNil @sh1 , Refl <- lemAppNil @sh2 = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr + +type data SafeMCastSpec + = MCastId + | MCastApp [Maybe Nat] [Maybe Nat] [Maybe Nat] [Maybe Nat] SafeMCastSpec SafeMCastSpec + | MCastForget + +type SafeMCast :: SafeMCastSpec -> [Maybe Nat] -> [Maybe Nat] -> Constraint +type family SafeMCast spec sh1 sh2 where + SafeMCast MCastId sh sh = () + SafeMCast (MCastApp sh1A sh1B sh2A sh2B specA specB) sh1 sh2 = (sh1 ~ sh1A ++ sh1B, sh2 ~ sh2A ++ sh2B, SafeMCast specA sh1A sh2A, SafeMCast specB sh1B sh2B) + SafeMCast MCastForget sh1 sh2 = sh2 ~ Replicate (Rank sh1) Nothing + +-- | This is an O(1) operation: the 'SafeMCast' constraint ensures that +-- type-level shape information can only be forgotten, not introduced, and thus +-- that no runtime shape checks are required. The @spec@ describes to +-- 'SafeMCast' how exactly you intend @sh2@ to be a weakening of @sh1@. +-- +-- To see how to construct the spec, read the equations of 'SafeMCast' closely. +mcastSafe :: forall spec sh1 sh2 a proxy. SafeMCast spec sh1 sh2 => proxy spec -> Mixed sh1 a -> Mixed sh2 a +mcastSafe _ = unsafeCoerce @(Mixed sh1 a) @(Mixed sh2 a) |