blob: 183d62ca4d3b5eb880da26e99df7eb012b565ab5 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Nested.Internal.Convert where
import Control.Category
import Data.Proxy
import Data.Type.Equality
import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
import Data.Array.Nested.Internal.Lemmas
import Data.Array.Nested.Internal.Mixed
import Data.Array.Nested.Internal.Ranked
import Data.Array.Nested.Internal.Shape
import Data.Array.Nested.Internal.Shaped
stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a
stoRanked sarr@(Shaped arr)
| Refl <- lemRankMapJust (sshape sarr)
= mtoRanked arr
rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a
rcastToShaped (Ranked arr) targetsh
| Refl <- lemRankReplicate (shxRank (shCvtSX targetsh))
, Refl <- lemRankMapJust targetsh
= mcastToShaped arr targetsh
-- | The only constructor that performs runtime shape checking is 'CastXS''.
-- For the other construtors, the types ensure that the shapes are already
-- compatible. To convert between 'Ranked' and 'Shaped', go via 'Mixed'.
data Castable a b where
CastId :: Castable a a
CastCmp :: Castable b c -> Castable a b -> Castable a c
CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b)
CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b)
CastXR :: Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b)
CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b)
CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh'
-> Castable a b -> Castable (Mixed sh a) (Shaped sh' b)
CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b)
CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b)
CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b)
instance Category Castable where
id = CastId
(.) = CastCmp
castCastable :: (Elt a, Elt b) => Castable a b -> a -> b
castCastable = \c x -> munScalar (go c (mscalar x))
where
-- The 'esh' is the extension shape: the casting happens under a whole
-- bunch of additional dimensions that it does not touch. These dimensions
-- are 'esh'.
-- The strategy is to unwind step-by-step to a large Mixed array, and to
-- perform the required checks and castings when re-nesting back up.
go :: Castable a b -> Mixed esh a -> Mixed esh b
go CastId x = x
go (CastCmp c1 c2) x = go c1 (go c2 x)
go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x)
go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x)
go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) =
M_Ranked (M_Nest esh (mcastSafe @(MCastApp esh sh esh (Replicate (Rank sh) Nothing) MCastId MCastForget) Proxy
(go c x)))
go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x))
go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x)
| Refl <- lemRankAppMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')
= M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh')))
(go c x)))
go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x)
lemRankAppMapJust :: Rank sh ~ Rank sh'
=> Proxy esh -> Proxy sh -> Proxy sh'
-> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh')
lemRankAppMapJust _ _ _ = unsafeCoerceRefl
|