aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-02-01 22:17:34 +0100
committerTom Smeding <tom@tomsmeding.com>2025-02-01 22:21:23 +0100
commitf578e36a8ed73268c3e1b91609baa76adfa0693a (patch)
treed0cd31b7f8f7fb8f1510e9ee843788cada2aebd3
parent03af9faf39e8872b5577e6f32d55b692c9a90d0e (diff)
mcastSafe, castCastable
-rw-r--r--src/Data/Array/Nested.hs2
-rw-r--r--src/Data/Array/Nested/Internal/Convert.hs60
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs24
3 files changed, 84 insertions, 2 deletions
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 77252dc..bef83d1 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -65,7 +65,9 @@ module Data.Array.Nested (
-- ** Conversions
mtoXArrayPrim, mfromXArrayPrim,
mcast,
+ mcastSafe, SafeMCast, SafeMCastSpec(..),
mtoRanked, mcastToShaped,
+ castCastable, Castable(..),
-- * Array elements
Elt,
diff --git a/src/Data/Array/Nested/Internal/Convert.hs b/src/Data/Array/Nested/Internal/Convert.hs
index e101981..183d62c 100644
--- a/src/Data/Array/Nested/Internal/Convert.hs
+++ b/src/Data/Array/Nested/Internal/Convert.hs
@@ -1,14 +1,19 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE GADTs #-}
+{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Nested.Internal.Convert where
+import Control.Category
+import Data.Proxy
import Data.Type.Equality
import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Shape
+import Data.Array.Mixed.Types
import Data.Array.Nested.Internal.Lemmas
import Data.Array.Nested.Internal.Mixed
import Data.Array.Nested.Internal.Ranked
@@ -26,3 +31,56 @@ rcastToShaped (Ranked arr) targetsh
| Refl <- lemRankReplicate (shxRank (shCvtSX targetsh))
, Refl <- lemRankMapJust targetsh
= mcastToShaped arr targetsh
+
+-- | The only constructor that performs runtime shape checking is 'CastXS''.
+-- For the other construtors, the types ensure that the shapes are already
+-- compatible. To convert between 'Ranked' and 'Shaped', go via 'Mixed'.
+data Castable a b where
+ CastId :: Castable a a
+ CastCmp :: Castable b c -> Castable a b -> Castable a c
+
+ CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b)
+ CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b)
+
+ CastXR :: Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b)
+ CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b)
+ CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh'
+ -> Castable a b -> Castable (Mixed sh a) (Shaped sh' b)
+
+ CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b)
+ CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b)
+ CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b)
+
+instance Category Castable where
+ id = CastId
+ (.) = CastCmp
+
+castCastable :: (Elt a, Elt b) => Castable a b -> a -> b
+castCastable = \c x -> munScalar (go c (mscalar x))
+ where
+ -- The 'esh' is the extension shape: the casting happens under a whole
+ -- bunch of additional dimensions that it does not touch. These dimensions
+ -- are 'esh'.
+ -- The strategy is to unwind step-by-step to a large Mixed array, and to
+ -- perform the required checks and castings when re-nesting back up.
+ go :: Castable a b -> Mixed esh a -> Mixed esh b
+ go CastId x = x
+ go (CastCmp c1 c2) x = go c1 (go c2 x)
+ go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x)
+ go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x)
+ go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) =
+ M_Ranked (M_Nest esh (mcastSafe @(MCastApp esh sh esh (Replicate (Rank sh) Nothing) MCastId MCastForget) Proxy
+ (go c x)))
+ go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x))
+ go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x)
+ | Refl <- lemRankAppMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')
+ = M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh')))
+ (go c x)))
+ go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
+ go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
+ go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x)
+
+ lemRankAppMapJust :: Rank sh ~ Rank sh'
+ => Proxy esh -> Proxy sh -> Proxy sh'
+ -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh')
+ lemRankAppMapJust _ _ _ = unsafeCoerceRefl
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index d3e8088..3b3f196 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -12,6 +12,7 @@
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
@@ -28,7 +29,7 @@ import Data.Bifunctor (bimap)
import Data.Coerce
import Data.Foldable (toList)
import Data.Int
-import Data.Kind (Type)
+import Data.Kind (Type, Constraint)
import Data.List.NonEmpty (NonEmpty(..))
import Data.List.NonEmpty qualified as NE
import Data.Proxy
@@ -40,6 +41,7 @@ import Foreign.Storable (Storable)
import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp)
import GHC.Generics (Generic)
import GHC.TypeLits
+import Unsafe.Coerce (unsafeCoerce)
import Data.Array.Mixed.XArray (XArray(..))
import Data.Array.Mixed.XArray qualified as X
@@ -923,3 +925,23 @@ mcast ssh2 arr
| Refl <- lemAppNil @sh1
, Refl <- lemAppNil @sh2
= mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr
+
+type data SafeMCastSpec
+ = MCastId
+ | MCastApp [Maybe Nat] [Maybe Nat] [Maybe Nat] [Maybe Nat] SafeMCastSpec SafeMCastSpec
+ | MCastForget
+
+type SafeMCast :: SafeMCastSpec -> [Maybe Nat] -> [Maybe Nat] -> Constraint
+type family SafeMCast spec sh1 sh2 where
+ SafeMCast MCastId sh sh = ()
+ SafeMCast (MCastApp sh1A sh1B sh2A sh2B specA specB) sh1 sh2 = (sh1 ~ sh1A ++ sh1B, sh2 ~ sh2A ++ sh2B, SafeMCast specA sh1A sh2A, SafeMCast specB sh1B sh2B)
+ SafeMCast MCastForget sh1 sh2 = sh2 ~ Replicate (Rank sh1) Nothing
+
+-- | This is an O(1) operation: the 'SafeMCast' constraint ensures that
+-- type-level shape information can only be forgotten, not introduced, and thus
+-- that no runtime shape checks are required. The @spec@ describes to
+-- 'SafeMCast' how exactly you intend @sh2@ to be a weakening of @sh1@.
+--
+-- To see how to construct the spec, read the equations of 'SafeMCast' closely.
+mcastSafe :: forall spec sh1 sh2 a proxy. SafeMCast spec sh1 sh2 => proxy spec -> Mixed sh1 a -> Mixed sh2 a
+mcastSafe _ = unsafeCoerce @(Mixed sh1 a) @(Mixed sh2 a)