1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Nested.Ranked (
Ranked(Ranked),
rquotArray, rremArray, ratan2Array,
rshape, rrank,
module Data.Array.Nested.Ranked,
liftRanked1, liftRanked2,
) where
import Prelude hiding (mappend, mconcat)
import Data.Array.RankedS qualified as S
import Data.Bifunctor (first)
import Data.Coerce (coerce)
import Data.List.NonEmpty (NonEmpty)
import Data.Proxy
import Data.Type.Equality
import Data.Vector.Storable qualified as VS
import Foreign.Storable (Storable)
import GHC.TypeLits
import GHC.TypeNats qualified as TN
import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Permutation
import Data.Array.Mixed.Types
import Data.Array.XArray (XArray(..))
import Data.Array.XArray qualified as X
import Data.Array.Nested.Convert
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Ranked.Base
import Data.Array.Nested.Ranked.Shape
import Data.Array.Strided.Arith
remptyArray :: KnownElt a => Ranked 1 a
remptyArray = mtoRanked (memptyArray ZSX)
-- | The total number of elements in the array.
rsize :: Elt a => Ranked n a -> Int
rsize = shrSize . rshape
rindex :: Elt a => Ranked n a -> IIxR n -> a
rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a
rindexPartial (Ranked arr) idx =
Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
(castWith (subst2 (lemReplicatePlusApp (ixrRank idx) (Proxy @m) (Proxy @Nothing))) arr)
(ixCvtRX idx))
-- | __WARNING__: All values returned from the function must have equal shape.
-- See the documentation of 'mgenerate' for more details.
rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a
rgenerate sh f
| sn@SNat <- shrRank sh
, Dict <- lemKnownReplicate sn
, Refl <- lemRankReplicate sn
= Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR))
-- | See the documentation of 'mlift'.
rlift :: forall n1 n2 a. Elt a
=> SNat n2
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
-> Ranked n1 a -> Ranked n2 a
rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr)
-- | See the documentation of 'mlift2'.
rlift2 :: forall n1 n2 n3 a. Elt a
=> SNat n3
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b)
-> Ranked n1 a -> Ranked n2 a -> Ranked n3 a
rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2)
rsumOuter1P :: forall n a.
(Storable a, NumElt a)
=> Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
rsumOuter1P (Ranked arr)
| Refl <- lemReplicateSucc @(Nothing @Nat) @n
= Ranked (msumOuter1P arr)
rsumOuter1 :: forall n a. (NumElt a, PrimElt a)
=> Ranked (n + 1) a -> Ranked n a
rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive
rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a
rsumAllPrim (Ranked arr) = msumAllPrim arr
rtranspose :: forall n a. Elt a => PermR -> Ranked n a -> Ranked n a
rtranspose perm arr
| sn@SNat <- rrank arr
, Dict <- lemKnownReplicate sn
, length perm <= fromIntegral (natVal (Proxy @n))
= rlift sn
(\ssh' -> X.transposeUntyped (natSing @n) ssh' perm)
arr
| otherwise
= error "Data.Array.Nested.rtranspose: Permutation longer than rank of array"
rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a
rconcat
| Refl <- lemReplicateSucc @(Nothing @Nat) @n
= coerce mconcat
rappend :: forall n a. Elt a
=> Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a
rappend arr1 arr2
| sn@SNat <- rrank arr1
, Dict <- lemKnownReplicate sn
, Refl <- lemReplicateSucc @(Nothing @Nat) @n
= coerce (mappend @Nothing @Nothing @(Replicate n Nothing))
arr1 arr2
rscalar :: Elt a => a -> Ranked 0 a
rscalar x = Ranked (mscalar x)
rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a)
rfromVectorP sh v
| Dict <- lemKnownReplicate (shrRank sh)
= Ranked (mfromVectorP (shCvtRX sh) v)
rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a
rfromVector sh v = rfromPrimitive (rfromVectorP sh v)
rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a
rtoVectorP = coerce mtoVectorP
rtoVector :: PrimElt a => Ranked n a -> VS.Vector a
rtoVector = coerce mtoVector
rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
rfromListOuter l
| Refl <- lemReplicateSucc @(Nothing @Nat) @n
= Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a
rfromList1 l = Ranked (mfromList1 l)
rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a
rfromList1Prim l = Ranked (mfromList1Prim l)
rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
rtoListOuter (Ranked arr)
| Refl <- lemReplicateSucc @(Nothing @Nat) @n
= coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr)
rtoList1 :: Elt a => Ranked 1 a -> [a]
rtoList1 = map runScalar . rtoListOuter
rfromListPrim :: PrimElt a => [a] -> Ranked 1 a
rfromListPrim l =
let ssh = SUnknown () :!% ZKX
xarr = X.fromList1 ssh l
in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a
rfromListPrimLinear sh l =
let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
in Ranked $ fromPrimitive $ M_Primitive (shCvtRX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtRX sh) xarr)
rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a
rfromListLinear sh l = rreshape sh (rfromList1 l)
rtoListLinear :: Elt a => Ranked n a -> [a]
rtoListLinear (Ranked arr) = mtoListLinear arr
rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a
rfromOrthotope sn arr
| Refl <- lemRankReplicate sn
= let xarr = XArray arr
in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr))
rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a
rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr)))
| Refl <- lemRankReplicate (shrRank $ shCvtXR' sh)
= arr
runScalar :: Elt a => Ranked 0 a -> a
runScalar arr = rindex arr ZIR
rnest :: forall n m a. Elt a => SNat n -> Ranked (n + m) a -> Ranked n (Ranked m a)
rnest n arr
| Refl <- lemReplicatePlusApp n (Proxy @m) (Proxy @(Nothing @Nat))
= coerce (mnest (ssxFromSNat n) (coerce arr))
runNest :: forall n m a. Elt a => Ranked n (Ranked m a) -> Ranked (n + m) a
runNest rarr@(Ranked (M_Ranked (M_Nest _ arr)))
| Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat))
= Ranked arr
rzip :: Ranked n a -> Ranked n b -> Ranked n (a, b)
rzip = coerce mzip
runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b)
runzip = coerce munzip
rrerankP :: forall n1 n2 n a b. (Storable a, Storable b)
=> SNat n -> IShR n2
-> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b))
-> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b)
rrerankP sn sh2 f (Ranked arr)
| Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat))
, Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat))
= Ranked (mrerankP (ssxFromSNat sn) (shCvtRX sh2)
(\a -> let Ranked r = f (Ranked a) in r)
arr)
-- | If there is a zero-sized dimension in the @n@-prefix of the shape of the
-- input array, then there is no way to deduce the full shape of the output
-- array (more precisely, the @n2@ part): that could only come from calling
-- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in
-- this case; we choose to fill the @n2@ part of the output shape with zeros.
--
-- For example, if:
--
-- @
-- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21]
-- f :: Ranked 2 Int -> Ranked 3 Float
-- @
--
-- then:
--
-- @
-- rrerank _ _ _ f arr :: Ranked 5 Float
-- @
--
-- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the
-- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended
-- to return an array with shape all-0 here (it probably didn't), but there is
-- no better number to put here absent a subarray of the input to pass to @f@.
rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b)
=> SNat n -> IShR n2
-> (Ranked n1 a -> Ranked n2 b)
-> Ranked (n + n1) a -> Ranked (n + n2) b
rrerank sn sh2 f (rtoPrimitive -> arr) =
rfromPrimitive $ rrerankP sn sh2 (rtoPrimitive . f . rfromPrimitive) arr
rreplicate :: forall n m a. Elt a
=> IShR n -> Ranked m a -> Ranked (n + m) a
rreplicate sh (Ranked arr)
| Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat))
= Ranked (mreplicate (shCvtRX sh) arr)
rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
rreplicateScalP sh x
| Dict <- lemKnownReplicate (shrRank sh)
= Ranked (mreplicateScalP (shCvtRX sh) x)
rreplicateScal :: forall n a. PrimElt a
=> IShR n -> a -> Ranked n a
rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x)
rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
rslice i n arr
| Refl <- lemReplicateSucc @(Nothing @Nat) @n
= rlift (rrank arr)
(\_ -> X.sliceU i n)
arr
rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a
rrev1 arr =
rlift (rrank arr)
(\(_ :: StaticShX sh') ->
case lemReplicateSucc @(Nothing @Nat) @n of
Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh'))
arr
rreshape :: forall n n' a. Elt a
=> IShR n' -> Ranked n a -> Ranked n' a
rreshape sh' rarr@(Ranked arr)
| Dict <- lemKnownReplicate (rrank rarr)
, Dict <- lemKnownReplicate (shrRank sh')
= Ranked (mreshape (shCvtRX sh') arr)
rflatten :: Elt a => Ranked n a -> Ranked 1 a
rflatten (Ranked arr) = mtoRanked (mflatten arr)
riota :: (Enum a, PrimElt a) => Int -> Ranked 1 a
riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota
-- | Throws if the array is empty.
rminIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n
rminIndexPrim rarr@(Ranked arr)
| Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
= ixCvtXR (mminIndexPrim arr)
-- | Throws if the array is empty.
rmaxIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n
rmaxIndexPrim rarr@(Ranked arr)
| Refl <- lemRankReplicate (rrank (rtoPrimitive rarr))
= ixCvtXR (mmaxIndexPrim arr)
rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a
rdot1Inner arr1 arr2
| SNat <- rrank arr1
, Refl <- lemReplicatePlusApp (SNat @n) (Proxy @1) (Proxy @(Nothing @Nat))
= coerce (mdot1Inner (Proxy @(Nothing @Nat))) arr1 arr2
-- | This has a temporary, suboptimal implementation in terms of 'mflatten'.
-- Prefer 'rdot1Inner' if applicable.
rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a
rdot = coerce mdot
rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a)
rtoXArrayPrimP (Ranked arr) = first shCvtXR' (mtoXArrayPrimP arr)
rtoXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a)
rtoXArrayPrim (Ranked arr) = first shCvtXR' (mtoXArrayPrim arr)
rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a)
rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr)
rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a
rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr)
rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a
rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr)
rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a)
rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr)
|