aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Data/Array/Nested/Internal/Convert.hs93
1 files changed, 68 insertions, 25 deletions
diff --git a/src/Data/Array/Nested/Internal/Convert.hs b/src/Data/Array/Nested/Internal/Convert.hs
index 183d62c..8458efe 100644
--- a/src/Data/Array/Nested/Internal/Convert.hs
+++ b/src/Data/Array/Nested/Internal/Convert.hs
@@ -1,15 +1,24 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeOperators #-}
-module Data.Array.Nested.Internal.Convert where
+{-# LANGUAGE ViewPatterns #-}
+module Data.Array.Nested.Internal.Convert (
+ stoRanked,
+ rcastToShaped,
+ castCastable,
+ Castable(.., CastXR, CastXS, CastRS),
+ castRR', castSS',
+) where
import Control.Category
import Data.Proxy
import Data.Type.Equality
+import GHC.TypeNats
import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Shape
@@ -32,24 +41,41 @@ rcastToShaped (Ranked arr) targetsh
, Refl <- lemRankMapJust targetsh
= mcastToShaped arr targetsh
--- | The only constructor that performs runtime shape checking is 'CastXS''.
--- For the other construtors, the types ensure that the shapes are already
--- compatible. To convert between 'Ranked' and 'Shaped', go via 'Mixed'.
data Castable a b where
CastId :: Castable a a
CastCmp :: Castable b c -> Castable a b -> Castable a c
+ CastInv :: Castable b a -> Castable a b
- CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b)
- CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b)
+ CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b)
+ CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b)
+ CastSR :: Elt a => ShS sh -- ^ The singleton is required in case this constructor appears under 'CastInv'.
+ -> Castable a b -> Castable (Shaped sh a) (Ranked (Rank sh) b)
- CastXR :: Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b)
- CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b)
- CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh'
- -> Castable a b -> Castable (Mixed sh a) (Shaped sh' b)
+ CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b)
+ CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b)
+ CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b)
- CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b)
- CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b)
- CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b)
+ CastXX' :: (Rank sh ~ Rank sh', Elt a, Elt b) => IShX sh -> IShX sh'
+ -> Castable a b -> Castable (Mixed sh a) (Mixed sh' b)
+
+pattern CastXR :: Castable a b -> Castable (Mixed (Replicate n Nothing) a) (Ranked n b)
+pattern CastXR c = CastInv (CastRX (CastInv c))
+
+pattern CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b)
+pattern CastXS c = CastInv (CastSX (CastInv c))
+
+pattern CastRS :: Elt b => ShS sh -> Castable a b -> Castable (Ranked (Rank sh) a) (Shaped sh b)
+pattern CastRS sh c = CastInv (CastSR sh (CastInv c))
+
+castRR' :: SNat n -> SNat n' -> Castable a b -> Castable (Ranked n a) (Ranked n' b)
+castRR' n@SNat n'@SNat c
+ | Just Refl <- sameNat n n' = CastRR c
+ | otherwise = error "castRR': Ranks unequal"
+
+castSS' :: ShS sh -> ShS sh' -> Castable a b -> Castable (Shaped sh a) (Shaped sh' b)
+castSS' sh sh' c
+ | Just Refl <- testEquality sh sh' = CastSS c
+ | otherwise = error "castSS': Shapes unequal"
instance Category Castable where
id = CastId
@@ -66,21 +92,38 @@ castCastable = \c x -> munScalar (go c (mscalar x))
go :: Castable a b -> Mixed esh a -> Mixed esh b
go CastId x = x
go (CastCmp c1 c2) x = go c1 (go c2 x)
+ go (CastInv c) x = goInv c x
go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x)
go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x)
- go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) =
- M_Ranked (M_Nest esh (mcastSafe @(MCastApp esh sh esh (Replicate (Rank sh) Nothing) MCastId MCastForget) Proxy
- (go c x)))
- go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x))
- go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x)
- | Refl <- lemRankAppMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')
- = M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh')))
- (go c x)))
+ go (CastSR @_ @sh sh c) (M_Shaped (M_Nest @esh esh x))
+ | Refl <- lemRankMapJust sh
+ = M_Ranked (M_Nest esh (mcastSafe @(MCastApp esh (MapJust sh) esh (Replicate (Rank sh) Nothing) MCastId MCastForget) Proxy
+ (go c x)))
go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x)
+ go (CastXX' sh sh' c) (M_Nest esh x)
+ | Refl <- lemRankApp (ssxFromShape esh) (ssxFromShape sh)
+ , Refl <- lemRankApp (ssxFromShape esh) (ssxFromShape sh')
+ = M_Nest esh (mcast (ssxFromShape (shxAppend esh sh')) (go c x))
- lemRankAppMapJust :: Rank sh ~ Rank sh'
- => Proxy esh -> Proxy sh -> Proxy sh'
- -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh')
- lemRankAppMapJust _ _ _ = unsafeCoerceRefl
+ goInv :: Castable b a -> Mixed esh a -> Mixed esh b
+ goInv CastId x = x
+ goInv (CastCmp c1 c2) x = goInv c2 (goInv c1 x)
+ goInv (CastInv c) x = go c x
+ goInv (CastRX c) (M_Nest esh x) = M_Ranked (M_Nest esh (goInv c x))
+ goInv (CastSX c) (M_Nest esh x) = M_Shaped (M_Nest esh (goInv c x))
+ goInv (CastSR @sh sh c) (M_Ranked (M_Nest esh x))
+ | Refl <- lemRankApp (ssxFromShape esh) (ssxFromSNat (shsRank sh))
+ , Refl <- lemRankApp (ssxFromShape esh) (ssxFromShape (shCvtSX sh))
+ , Refl <- lemRankReplicate (shsRank sh)
+ , Refl <- lemRankMapJust sh
+ = M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh)))
+ (goInv c x)))
+ goInv (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (goInv c x))
+ goInv (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (goInv c x))
+ goInv (CastXX c) (M_Nest esh x) = M_Nest esh (goInv c x)
+ goInv (CastXX' sh sh' c) (M_Nest esh x)
+ | Refl <- lemRankApp (ssxFromShape esh) (ssxFromShape sh)
+ , Refl <- lemRankApp (ssxFromShape esh) (ssxFromShape sh')
+ = M_Nest esh (mcast (ssxFromShape (shxAppend esh sh)) (goInv c x))