1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Nested.Convert (
-- * Shape/index/list casting functions
ixrFromIxS, ixrFromIxX, shrFromShS, shrFromShX, shrFromShX2,
ixsFromIxX, shsFromShX,
ixxFromIxR, ixxFromIxS, shxFromShR, shxFromShS,
-- * Array conversions
castCastable,
Castable(..),
-- * Special cases of array conversions
--
-- | These functions can all be implemented using 'castCastable' in some way,
-- but some have fewer constraints.
rtoMixed, rcastToMixed, rcastToShaped,
stoMixed, scastToMixed, stoRanked,
mcast, mcastToShaped, mtoRanked,
) where
import Control.Category
import Data.Proxy
import Data.Type.Equality
import GHC.TypeLits
import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Ranked.Base
import Data.Array.Nested.Ranked.Shape
import Data.Array.Nested.Shaped.Base
import Data.Array.Nested.Shaped.Shape
import Data.Array.Nested.Types
-- * Shape/index/list casting functions
-- * To ranked
ixrFromIxS :: IxS sh i -> IxR (Rank sh) i
ixrFromIxS ZIS = ZIR
ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix
ixrFromIxX :: IxX sh i -> IxR (Rank sh) i
ixrFromIxX ZIX = ZIR
ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx
shrFromShS :: ShS sh -> IShR (Rank sh)
shrFromShS ZSS = ZSR
shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh
-- shrFromShX re-exported
-- shrFromShX2 re-exported
-- * To shaped
-- ixsFromIxR :: IIxR (Rank sh) -> IIxS sh
-- ixsFromIxR = \ix -> go ix _
-- where
-- go :: IIxR n -> (forall sh. KnownShS sh => IIxS sh -> r) -> r
-- go ZIR k = k ZIS
-- go (i :.: ix) k = go ix (i :.$)
ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i
ixsFromIxX ZSS ZIX = ZIS
ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx
-- shsFromShX re-exported
-- * To mixed
ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i
ixxFromIxR ZIR = ZIX
ixxFromIxR (n :.: (idx :: IxR m i)) =
castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) @m))
(n :.% ixxFromIxR idx)
ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i
ixxFromIxS ZIS = ZIX
ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh
shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i
shxFromShR ZSR = ZSX
shxFromShR (n :$: (idx :: ShR m i)) =
castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) @m))
(SUnknown n :$% shxFromShR idx)
shxFromShS :: ShS sh -> IShX (MapJust sh)
shxFromShS ZSS = ZSX
shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh
-- * Array conversions
-- | The constructors that perform runtime shape checking are marked with a
-- @'@: 'CastXS'' and 'CastXX''. For the other constructors, the types ensure
-- that the shapes are already compatible. To convert between 'Ranked' and
-- 'Shaped', go via 'Mixed'.
data Castable a b where
CastId :: Castable a a
CastCmp :: Castable b c -> Castable a b -> Castable a c
CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b)
CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b)
CastXR :: Elt b
=> Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b)
CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b)
CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh'
-> Castable a b -> Castable (Mixed sh a) (Shaped sh' b)
CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b)
CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b)
CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b)
CastXX' :: (Rank sh ~ Rank sh', Elt b) => StaticShX sh'
-> Castable a b -> Castable (Mixed sh a) (Mixed sh' b)
instance Category Castable where
id = CastId
(.) = CastCmp
castCastable :: (Elt a, Elt b) => Castable a b -> a -> b
castCastable = \c x -> munScalar (go c (mscalar x))
where
-- The 'esh' is the extension shape: the casting happens under a whole
-- bunch of additional dimensions that it does not touch. These dimensions
-- are 'esh'.
-- The strategy is to unwind step-by-step to a large Mixed array, and to
-- perform the required checks and castings when re-nesting back up.
go :: Castable a b -> Mixed esh a -> Mixed esh b
go CastId x = x
go (CastCmp c1 c2) x = go c1 (go c2 x)
go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x)
go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x)
go (CastXR @_ @_ @sh c) (M_Nest @esh esh x)
| Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh)
= let x' = go c x
ssx' = ssxAppend (ssxFromShX esh)
(ssxReplicate (shxRank (shxDropSSX @esh @sh (mshape x') (ssxFromShX esh))))
in M_Ranked (M_Nest esh (mcast ssx' x'))
go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x))
go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x)
| Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh')
= M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh')))
(go c x)))
go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x))
go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x))
go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x)
go (CastXX' @sh @sh' ssx c) (M_Nest @esh esh x)
| Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh')
= M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) (go c x)
lemRankAppRankEq :: Rank sh ~ Rank sh'
=> Proxy esh -> Proxy sh -> Proxy sh'
-> Rank (esh ++ sh) :~: Rank (esh ++ sh')
lemRankAppRankEq _ _ _ = unsafeCoerceRefl
lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh
-> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing)
lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl
lemRankAppRankEqMapJust :: Rank sh ~ Rank sh'
=> Proxy esh -> Proxy sh -> Proxy sh'
-> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh')
lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl
-- * Special cases of array conversions
mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a)
=> StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a
mcast ssh2 arr
| Refl <- lemAppNil @sh1
, Refl <- lemAppNil @sh2
= mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr
mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a
mtoRanked = castCastable (CastXR CastId)
rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a
rtoMixed (Ranked arr) = arr
-- | A more weakly-typed version of 'rtoMixed' that does a runtime shape
-- compatibility check.
rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a
rcastToMixed sshx rarr@(Ranked arr)
| Refl <- lemRankReplicate (rrank rarr)
= mcast sshx arr
mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
=> ShS sh' -> Mixed sh a -> Shaped sh' a
mcastToShaped targetsh = castCastable (CastXS' targetsh CastId)
stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a
stoMixed (Shaped arr) = arr
-- | A more weakly-typed version of 'stoMixed' that does a runtime shape
-- compatibility check.
scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
=> StaticShX sh' -> Shaped sh a -> Mixed sh' a
scastToMixed sshx sarr@(Shaped arr)
| Refl <- lemRankMapJust (sshape sarr)
= mcast sshx arr
stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a
stoRanked sarr@(Shaped arr)
| Refl <- lemRankMapJust (sshape sarr)
= mtoRanked arr
rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a
rcastToShaped (Ranked arr) targetsh
| Refl <- lemRankReplicate (shxRank (shxFromShS targetsh))
, Refl <- lemRankMapJust targetsh
= mcastToShaped targetsh arr
|