1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Nested.Convert (
castCastable,
Castable(..),
-- * Special cases
--
-- | These functions can all be implemented using 'castCastable' in some way,
-- but some have fewer constraints.
rtoMixed, rcastToMixed, rcastToShaped,
stoMixed, scastToMixed, stoRanked,
mcast, mcastToShaped, mtoRanked,
) where
import Control.Category
import Data.Proxy
import Data.Type.Equality
import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Types
import Data.Array.Nested.Internal.Lemmas
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Ranked.Base
import Data.Array.Nested.Shaped.Base
import Data.Array.Nested.Shaped.Shape
mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a)
=> StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a
mcast ssh2 arr
| Refl <- lemAppNil @sh1
, Refl <- lemAppNil @sh2
= mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr
mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a
mtoRanked = castCastable (CastXR CastId)
rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a
rtoMixed (Ranked arr) = arr
-- | A more weakly-typed version of 'rtoMixed' that does a runtime shape
-- compatibility check.
rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a
rcastToMixed sshx rarr@(Ranked arr)
| Refl <- lemRankReplicate (rrank rarr)
= mcast sshx arr
mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
=> Mixed sh a -> ShS sh' -> Shaped sh' a
mcastToShaped arr targetsh = castCastable (CastXS' targetsh CastId) arr
stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a
stoMixed (Shaped arr) = arr
-- | A more weakly-typed version of 'stoMixed' that does a runtime shape
-- compatibility check.
scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
=> StaticShX sh' -> Shaped sh a -> Mixed sh' a
scastToMixed sshx sarr@(Shaped arr)
| Refl <- lemRankMapJust (sshape sarr)
= mcast sshx arr
stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a
stoRanked sarr@(Shaped arr)
| Refl <- lemRankMapJust (sshape sarr)
= mtoRanked arr
rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a
rcastToShaped (Ranked arr) targetsh
| Refl <- lemRankReplicate (shxRank (shCvtSX targetsh))
, Refl <- lemRankMapJust targetsh
= mcastToShaped arr targetsh
-- | The constructors that perform runtime shape checking are marked with a
-- @'@: 'CastXS'' and 'CastXX''. For the other constructors, the types ensure
-- that the shapes are already compatible. To convert between 'Ranked' and
-- 'Shaped', go via 'Mixed'.
data Castable a b where
CastId :: Castable a a
CastCmp :: Castable b c -> Castable a b -> Castable a c
CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b)
CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b)
CastXR :: Elt b
=> Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b)
CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b)
CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh'
-> Castable a b -> Castable (Mixed sh a) (Shaped sh' b)
CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b)
CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b)
CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b)
CastXX' :: (Rank sh ~ Rank sh', Elt b) => StaticShX sh'
-> Castable a b -> Castable (Mixed sh a) (Mixed sh' b)
instance Category Castable where
id = CastId
(.) = CastCmp
castCastable :: (Elt a, Elt b) => Castable a b -> a -> b
castCastable = \c x -> munScalar (go c (mscalar x))
where
-- The 'esh' is the extension shape: the casting happens under a whole
-- bunch of additional dimensions that it does not touch. These dimensions
-- are 'esh'.
-- The strategy is to unwind step-by-step to a large Mixed array, and to
-- perform the required checks and castings when re-nesting back up.
go :: Castable a b -> Mixed esh a -> Mixed esh b
go CastId x = x
go (CastCmp c1 c2) x = go c1 (go c2 x)
go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x)
go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x)
go (CastXR @_ @_ @sh c) (M_Nest @esh esh x)
| Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh)
= let x' = go c x
ssx' = ssxAppend (ssxFromShape esh)
(ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x') (ssxFromShape esh))))
in M_Ranked (M_Nest esh (mcast ssx' x'))
go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x))
go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x)
| Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')
= M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh')))
(go c x)))
go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x)
go (CastXX' @sh @sh' ssx c) (M_Nest @esh esh x)
| Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh')
= M_Nest esh $ mcast (ssxFromShape esh `ssxAppend` ssx) (go c x)
lemRankAppRankEq :: Rank sh ~ Rank sh'
=> Proxy esh -> Proxy sh -> Proxy sh'
-> Rank (esh ++ sh) :~: Rank (esh ++ sh')
lemRankAppRankEq _ _ _ = unsafeCoerceRefl
lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh
-> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing)
lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl
lemRankAppRankEqMapJust :: Rank sh ~ Rank sh'
=> Proxy esh -> Proxy sh -> Proxy sh'
-> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh')
lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl
|