diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-02-05 21:22:16 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-02-05 21:22:16 +0100 |
commit | c78b827f530ff4d2dbe22632392251de300f5615 (patch) | |
tree | a3c00200059cc7c46f0dc168aecceba71856a138 /src/Data/Array | |
parent | f578e36a8ed73268c3e1b91609baa76adfa0693a (diff) |
WIP invertable Castableinvert-castable
Diffstat (limited to 'src/Data/Array')
-rw-r--r-- | src/Data/Array/Nested/Internal/Convert.hs | 93 |
1 files changed, 68 insertions, 25 deletions
diff --git a/src/Data/Array/Nested/Internal/Convert.hs b/src/Data/Array/Nested/Internal/Convert.hs index 183d62c..8458efe 100644 --- a/src/Data/Array/Nested/Internal/Convert.hs +++ b/src/Data/Array/Nested/Internal/Convert.hs @@ -1,15 +1,24 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeOperators #-} -module Data.Array.Nested.Internal.Convert where +{-# LANGUAGE ViewPatterns #-} +module Data.Array.Nested.Internal.Convert ( + stoRanked, + rcastToShaped, + castCastable, + Castable(.., CastXR, CastXS, CastRS), + castRR', castSS', +) where import Control.Category import Data.Proxy import Data.Type.Equality +import GHC.TypeNats import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Shape @@ -32,24 +41,41 @@ rcastToShaped (Ranked arr) targetsh , Refl <- lemRankMapJust targetsh = mcastToShaped arr targetsh --- | The only constructor that performs runtime shape checking is 'CastXS''. --- For the other construtors, the types ensure that the shapes are already --- compatible. To convert between 'Ranked' and 'Shaped', go via 'Mixed'. data Castable a b where CastId :: Castable a a CastCmp :: Castable b c -> Castable a b -> Castable a c + CastInv :: Castable b a -> Castable a b - CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b) - CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b) + CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b) + CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b) + CastSR :: Elt a => ShS sh -- ^ The singleton is required in case this constructor appears under 'CastInv'. + -> Castable a b -> Castable (Shaped sh a) (Ranked (Rank sh) b) - CastXR :: Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b) - CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b) - CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh' - -> Castable a b -> Castable (Mixed sh a) (Shaped sh' b) + CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b) + CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b) + CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b) - CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b) - CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b) - CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b) + CastXX' :: (Rank sh ~ Rank sh', Elt a, Elt b) => IShX sh -> IShX sh' + -> Castable a b -> Castable (Mixed sh a) (Mixed sh' b) + +pattern CastXR :: Castable a b -> Castable (Mixed (Replicate n Nothing) a) (Ranked n b) +pattern CastXR c = CastInv (CastRX (CastInv c)) + +pattern CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b) +pattern CastXS c = CastInv (CastSX (CastInv c)) + +pattern CastRS :: Elt b => ShS sh -> Castable a b -> Castable (Ranked (Rank sh) a) (Shaped sh b) +pattern CastRS sh c = CastInv (CastSR sh (CastInv c)) + +castRR' :: SNat n -> SNat n' -> Castable a b -> Castable (Ranked n a) (Ranked n' b) +castRR' n@SNat n'@SNat c + | Just Refl <- sameNat n n' = CastRR c + | otherwise = error "castRR': Ranks unequal" + +castSS' :: ShS sh -> ShS sh' -> Castable a b -> Castable (Shaped sh a) (Shaped sh' b) +castSS' sh sh' c + | Just Refl <- testEquality sh sh' = CastSS c + | otherwise = error "castSS': Shapes unequal" instance Category Castable where id = CastId @@ -66,21 +92,38 @@ castCastable = \c x -> munScalar (go c (mscalar x)) go :: Castable a b -> Mixed esh a -> Mixed esh b go CastId x = x go (CastCmp c1 c2) x = go c1 (go c2 x) + go (CastInv c) x = goInv c x go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x) go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x) - go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) = - M_Ranked (M_Nest esh (mcastSafe @(MCastApp esh sh esh (Replicate (Rank sh) Nothing) MCastId MCastForget) Proxy - (go c x))) - go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x)) - go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x) - | Refl <- lemRankAppMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') - = M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh'))) - (go c x))) + go (CastSR @_ @sh sh c) (M_Shaped (M_Nest @esh esh x)) + | Refl <- lemRankMapJust sh + = M_Ranked (M_Nest esh (mcastSafe @(MCastApp esh (MapJust sh) esh (Replicate (Rank sh) Nothing) MCastId MCastForget) Proxy + (go c x))) go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x) + go (CastXX' sh sh' c) (M_Nest esh x) + | Refl <- lemRankApp (ssxFromShape esh) (ssxFromShape sh) + , Refl <- lemRankApp (ssxFromShape esh) (ssxFromShape sh') + = M_Nest esh (mcast (ssxFromShape (shxAppend esh sh')) (go c x)) - lemRankAppMapJust :: Rank sh ~ Rank sh' - => Proxy esh -> Proxy sh -> Proxy sh' - -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') - lemRankAppMapJust _ _ _ = unsafeCoerceRefl + goInv :: Castable b a -> Mixed esh a -> Mixed esh b + goInv CastId x = x + goInv (CastCmp c1 c2) x = goInv c2 (goInv c1 x) + goInv (CastInv c) x = go c x + goInv (CastRX c) (M_Nest esh x) = M_Ranked (M_Nest esh (goInv c x)) + goInv (CastSX c) (M_Nest esh x) = M_Shaped (M_Nest esh (goInv c x)) + goInv (CastSR @sh sh c) (M_Ranked (M_Nest esh x)) + | Refl <- lemRankApp (ssxFromShape esh) (ssxFromSNat (shsRank sh)) + , Refl <- lemRankApp (ssxFromShape esh) (ssxFromShape (shCvtSX sh)) + , Refl <- lemRankReplicate (shsRank sh) + , Refl <- lemRankMapJust sh + = M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh))) + (goInv c x))) + goInv (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (goInv c x)) + goInv (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (goInv c x)) + goInv (CastXX c) (M_Nest esh x) = M_Nest esh (goInv c x) + goInv (CastXX' sh sh' c) (M_Nest esh x) + | Refl <- lemRankApp (ssxFromShape esh) (ssxFromShape sh) + , Refl <- lemRankApp (ssxFromShape esh) (ssxFromShape sh') + = M_Nest esh (mcast (ssxFromShape (shxAppend esh sh)) (goInv c x)) |