1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Fancy where
import Control.Monad (forM_)
import Control.Monad.ST
import Data.Coerce (coerce)
import Data.Kind
import Data.Proxy
import Data.Type.Equality
import Data.Type.Ord
import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Unboxed.Mutable as VUM
import Array (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++))
import qualified Array as X
import Nats
type family Replicate n a where
Replicate Z a = '[]
Replicate (S n) a = a : Replicate n a
type family MapJust l where
MapJust '[] = '[]
MapJust (x : xs) = Just x : MapJust xs
lemCompareFalse1 :: (0 < n, 1 > n) => Proxy n -> a
lemCompareFalse1 = error "Incoherence"
lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)
lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n))
where
go :: SNat m -> StaticShapeX (Replicate m Nothing)
go SZ = SZX
go (SS n) = () :$? go n
-- Wrapper type used as a tag to attach instances on.
newtype Primitive a = Primitive a
type Mixed :: [Maybe Nat] -> Type -> Type
data family Mixed sh a
newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a)
newtype instance Mixed sh Int = M_Int (XArray sh Int)
newtype instance Mixed sh Double = M_Double (XArray sh Double)
newtype instance Mixed sh () = M_Nil (XArray sh ()) -- no content, orthotope optimises this (via Vector)
-- etc.
data instance Mixed sh (a, b) = M_Tup2 (Mixed sh a) (Mixed sh b)
data instance Mixed sh (a, b, c) = M_Tup3 (Mixed sh a) (Mixed sh b) (Mixed sh c)
data instance Mixed sh (a, b, c, d) = M_Tup4 (Mixed sh a) (Mixed sh b) (Mixed sh c) (Mixed sh d)
newtype instance Mixed sh1 (Mixed sh2 a) = M_Nest (Mixed (sh1 ++ sh2) a)
type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type
data family MixedVecs s sh a
newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VU.MVector s a)
newtype instance MixedVecs s sh Int = MV_Int (VU.MVector s Int)
newtype instance MixedVecs s sh Double = MV_Double (VU.MVector s Double)
newtype instance MixedVecs s sh () = MV_Nil (VU.MVector s ()) -- no content, MVector optimises this
-- etc.
data instance MixedVecs s sh (a, b) = MV_Tup2 (MixedVecs s sh a) (MixedVecs s sh b)
data instance MixedVecs s sh (a, b, c) = MV_Tup3 (MixedVecs s sh a) (MixedVecs s sh b) (MixedVecs s sh c)
data instance MixedVecs s sh (a, b, c, d) = MV_Tup4 (MixedVecs s sh a) (MixedVecs s sh b) (MixedVecs s sh c) (MixedVecs s sh d)
data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest (IxX sh2) (MixedVecs s (sh1 ++ sh2) a)
class GMixed a where
mshape :: KnownShapeX sh => Mixed sh a -> IxX sh
mindex :: Mixed sh a -> IxX sh -> a
mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a
mlift :: forall sh1 sh2. KnownShapeX sh2
=> (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
-> Mixed sh1 a -> Mixed sh2 a
-- | Create an empty array. The given shape must have size zero; this may or may not be checked.
memptyArray :: IxX sh -> Mixed sh a
-- | Return the size of the individual (SoA) arrays in this value. If @a@
-- does not contain tuples, this coincides with the total number of scalars
-- in the given value; if @a@ contains tuples, then it is some multiple of
-- this number of scalars.
mvecsNumElts :: a -> Int
-- | Create uninitialised vectors for this array type, given the shape of
-- this vector and an example for the contents. The shape must not have size
-- zero; an error may be thrown otherwise.
mvecsUnsafeNew :: IxX sh -> a -> ST s (MixedVecs s sh a)
-- | Given the shape of this array, an index and a value, write the value at
-- that index in the vectors.
mvecsWrite :: IxX sh -> IxX sh -> a -> MixedVecs s sh a -> ST s ()
-- | Given the shape of this array, an index and a value, write the value at
-- that index in the vectors.
mvecsWritePartial :: KnownShapeX sh' => IxX (sh ++ sh') -> IxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
-- | Given the shape of this array, finalise the vectors into 'XArray's.
mvecsFreeze :: IxX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
instance VU.Unbox a => GMixed (Primitive a) where
mshape (M_Primitive a) = X.shape a
mindex (M_Primitive a) i = Primitive (X.index a i)
mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i)
mlift :: forall sh1 sh2.
(Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)
-> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a)
mlift f (M_Primitive a)
| Refl <- X.lemAppNil @sh1
, Refl <- X.lemAppNil @sh2
= M_Primitive (f Proxy a)
memptyArray sh = M_Primitive (X.generate sh (error "memptyArray Int: shape was not empty"))
mvecsNumElts _ = 1
mvecsUnsafeNew sh _ = MV_Primitive <$> VUM.unsafeNew (X.shapeSize sh)
mvecsWrite sh i (Primitive x) (MV_Primitive v) = VUM.write v (X.toLinearIdx sh i) x
-- TODO: this use of toVector is suboptimal
mvecsWritePartial
:: forall sh' sh s. (KnownShapeX sh', VU.Unbox a)
=> IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do
let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr)))
VU.copy (VUM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr)
mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VU.freeze v
deriving via Primitive Int instance GMixed Int
deriving via Primitive Double instance GMixed Double
deriving via Primitive () instance GMixed ()
instance (GMixed a, GMixed b) => GMixed (a, b) where
mshape (M_Tup2 a _) = mshape a
mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b)
memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)
mvecsNumElts (x, y) = mvecsNumElts x * mvecsNumElts y
mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
mvecsWrite sh i x a
mvecsWrite sh i y b
mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do
mvecsWritePartial sh i x a
mvecsWritePartial sh i y b
mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where
mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh
mshape (M_Nest arr)
| Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh')
= ixAppPrefix (knownShapeX @sh) (mshape arr)
where
ixAppPrefix :: StaticShapeX sh1 -> IxX (sh1 ++ sh') -> IxX sh1
ixAppPrefix SZX _ = IZX
ixAppPrefix (_ :$@ ssh) (i ::@ idx) = i ::@ ixAppPrefix ssh idx
ixAppPrefix (_ :$? ssh) (i ::? idx) = i ::? ixAppPrefix ssh idx
mindex (M_Nest arr) i = mindexPartial arr i
mindexPartial :: forall sh1 sh2.
Mixed (sh1 ++ sh2) (Mixed sh' a) -> IxX sh1 -> Mixed sh2 (Mixed sh' a)
mindexPartial (M_Nest arr) i
| Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
= M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
mlift :: forall sh1 sh2. KnownShapeX sh2
=> (forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray (sh1 ++ sh3) b -> XArray (sh2 ++ sh3) b)
-> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
mlift f (M_Nest arr)
| Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
= M_Nest (mlift f' arr)
where
f' :: forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray ((sh1 ++ sh') ++ sh3) b -> XArray ((sh2 ++ sh') ++ sh3) b
f' _
| Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @sh3)
, Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @sh3)
, Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @sh3))
= f (Proxy @(sh' ++ sh3))
memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIdx (knownShapeX @sh'))))
mvecsNumElts arr =
let n = X.shapeSize (mshape arr)
in if n == 0 then 0 else n * mvecsNumElts (mindex arr (X.zeroIdx (knownShapeX @sh')))
mvecsUnsafeNew sh example
| X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example"
| otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example))
(mindex example (X.zeroIdx (knownShapeX @sh')))
where
sh' = mshape example
mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs
mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2
=> IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Mixed sh' a)
-> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
-> ST s ()
mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs)
| Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
, Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
= mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs
mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs
mgenerate :: forall sh a. (KnownShapeX sh, GMixed a) => IxX sh -> (IxX sh -> a) -> Mixed sh a
mgenerate sh f
| not (checkBounds sh (knownShapeX @sh)) =
error $ "mgenerate: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh)
-- We need to be very careful here to ensure that neither 'sh' nor
-- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty.
| X.shapeSize sh == 0 = memptyArray sh
| otherwise =
let firstidx = X.zeroIdx' sh
firstelem = f (X.zeroIdx' sh)
in if mvecsNumElts firstelem == 0
then memptyArray sh
else runST $ do
vecs <- mvecsUnsafeNew sh firstelem
mvecsWrite sh firstidx firstelem vecs
forM_ (tail (X.enumShape sh)) $ \idx ->
mvecsWrite sh idx (f idx) vecs
mvecsFreeze sh vecs
where
checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool
checkBounds IZX SZX = True
checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (unSNat n') && checkBounds sh' ssh'
checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'
type Ranked :: Nat -> Type -> Type
newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
type Shaped :: [Nat] -> Type -> Type
newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a))
newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a))
newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a))
newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a))
instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where
mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr
mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i)
mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IxX sh -> Mixed sh' (Ranked n a)
mindexPartial (M_Ranked arr) i
| Dict <- lemKnownReplicate (Proxy @n)
= coerce @(Mixed sh' (Mixed (Replicate n 'Nothing) a)) @(Mixed sh' (Ranked n a)) $
mindexPartial arr i
mlift :: forall sh1 sh2. KnownShapeX sh2
=> (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
-> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a)
mlift f (M_Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
= coerce @(Mixed sh2 (Mixed (Replicate n 'Nothing) a)) @(Mixed sh2 (Ranked n a)) $
mlift f arr
memptyArray :: forall sh. IxX sh -> Mixed sh (Ranked n a)
memptyArray i
| Dict <- lemKnownReplicate (Proxy @n)
= coerce @(Mixed sh (Mixed (Replicate n 'Nothing) a)) @(Mixed sh (Ranked n a)) $
memptyArray i
mvecsNumElts (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
= mvecsNumElts arr
mvecsUnsafeNew idx (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
= MV_Ranked <$> mvecsUnsafeNew idx arr
mvecsWrite :: forall sh s. IxX sh -> IxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
mvecsWrite sh idx (Ranked arr) vecs
| Dict <- lemKnownReplicate (Proxy @n)
= mvecsWrite sh idx arr
(coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
vecs)
mvecsWritePartial :: forall sh sh' s. KnownShapeX sh'
=> IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Ranked n a)
-> MixedVecs s (sh ++ sh') (Ranked n a)
-> ST s ()
mvecsWritePartial sh idx arr vecs
| Dict <- lemKnownReplicate (Proxy @n)
= mvecsWritePartial sh idx
(coerce @(Mixed sh' (Ranked n a))
@(Mixed sh' (Mixed (Replicate n Nothing) a))
arr)
(coerce @(MixedVecs s (sh ++ sh') (Ranked n a))
@(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
vecs)
mvecsFreeze :: forall sh s. IxX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
mvecsFreeze sh vecs
| Dict <- lemKnownReplicate (Proxy @n)
= coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
@(Mixed sh (Ranked n a))
<$> mvecsFreeze sh
(coerce @(MixedVecs s sh (Ranked n a))
@(MixedVecs s sh (Mixed (Replicate n Nothing) a))
vecs)
data SShape sh where
ShNil :: SShape '[]
ShCons :: SNat n -> SShape sh -> SShape (n : sh)
deriving instance Show (SShape sh)
class KnownShape sh where knownShape :: SShape sh
instance KnownShape '[] where knownShape = ShNil
instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons knownNat knownShape
-- instance (KnownShape sh, GMixed a) => GMixed (Shaped sh a) where
type IxR :: Nat -> Type
data IxR n where
IZR :: IxR Z
(:::) :: Int -> IxR n -> IxR (S n)
type IxS :: [Nat] -> Type
data IxS sh where
IZS :: IxS '[]
(::$) :: Int -> IxS sh -> IxS (n : sh)
ixCvtXR :: IxX sh -> IxR (X.Rank sh)
ixCvtXR IZX = IZR
ixCvtXR (n ::@ sh) = n ::: ixCvtXR sh
ixCvtXR (n ::? sh) = n ::: ixCvtXR sh
ixCvtRX :: IxR n -> IxX (Replicate n Nothing)
ixCvtRX IZR = IZX
ixCvtRX (n ::: sh) = n ::? ixCvtRX sh
lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n
lemRankReplicate _ = go (knownNat @n)
where
go :: SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m
go SZ = Refl
go (SS n) | Refl <- go n = Refl
lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a
-> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
lemReplicatePlusApp _ _ _ = go (knownNat @n)
where
go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
go SZ = Refl
go (SS n) | Refl <- go n = Refl
rshape :: forall n a. (KnownNat n, GMixed a) => Ranked n a -> IxR n
rshape (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
, Refl <- lemRankReplicate (Proxy @n)
= ixCvtXR (mshape arr)
rindex :: GMixed a => Ranked n a -> IxR n -> a
rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
rewriteMixed :: sh1 :~: sh2 -> Mixed sh1 a -> Mixed sh2 a
rewriteMixed Refl x = x
rindexPartial :: forall n m a. (KnownNat n, GMixed a) => Ranked (n + m) a -> IxR n -> Ranked m a
rindexPartial (Ranked arr) idx
| Refl <- lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)
= Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
(rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr)
(ixCvtRX idx))
rgenerate :: forall n a. (KnownNat n, GMixed a) => IxR n -> (IxR n -> a) -> Ranked n a
rgenerate sh f
| Dict <- lemKnownReplicate (Proxy @n)
, Refl <- lemRankReplicate (Proxy @n)
= Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR))
rlift :: forall n1 n2 a. (KnownNat n2, GMixed a)
=> (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
-> Ranked n1 a -> Ranked n2 a
rlift f (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n2)
= Ranked (mlift f arr)
|