aboutsummaryrefslogtreecommitdiff
path: root/src/Fancy.hs
blob: 6b6d8d4ff4f4e4178bd38b59fbb804d9a6916c6d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Fancy where

import Control.Monad (forM_)
import Control.Monad.ST
import Data.Coerce (coerce)
import Data.Kind
import Data.Proxy
import Data.Type.Equality
import Data.Type.Ord
import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Unboxed.Mutable as VUM

import Array (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++))
import qualified Array as X
import Nats


type family Replicate n a where
  Replicate Z a = '[]
  Replicate (S n) a = a : Replicate n a

type family MapJust l where
  MapJust '[] = '[]
  MapJust (x : xs) = Just x : MapJust xs

lemCompareFalse1 :: (0 < n, 1 > n) => Proxy n -> a
lemCompareFalse1 = error "Incoherence"

lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)
lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n))
  where
    go :: SNat m -> StaticShapeX (Replicate m Nothing)
    go SZ = SZX
    go (SS n) = () :$? go n


type Mixed :: [Maybe Nat] -> Type -> Type
data family Mixed sh a

newtype instance Mixed sh Int = M_Int (XArray sh Int)
newtype instance Mixed sh Double = M_Double (XArray sh Double)
-- etc.

newtype instance Mixed sh () = M_Nil (IxX sh)  -- store the shape
data instance Mixed sh (a, b) = M_Tup2 (Mixed sh a) (Mixed sh b)
data instance Mixed sh (a, b, c) = M_Tup3 (Mixed sh a) (Mixed sh b) (Mixed sh c)
data instance Mixed sh (a, b, c, d) = M_Tup4 (Mixed sh a) (Mixed sh b) (Mixed sh c) (Mixed sh d)

newtype instance Mixed sh1 (Mixed sh2 a) = M_Nest (Mixed (sh1 ++ sh2) a)


type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type
data family MixedVecs s sh a

newtype instance MixedVecs s sh Int = MV_Int (VU.MVector s Int)
newtype instance MixedVecs s sh Double = MV_Double (VU.MVector s Double)
-- etc.

data instance MixedVecs s sh () = MV_Nil
data instance MixedVecs s sh (a, b) = MV_Tup2 (MixedVecs s sh a) (MixedVecs s sh b)
data instance MixedVecs s sh (a, b, c) = MV_Tup3 (MixedVecs s sh a) (MixedVecs s sh b) (MixedVecs s sh c)
data instance MixedVecs s sh (a, b, c, d) = MV_Tup4 (MixedVecs s sh a) (MixedVecs s sh b) (MixedVecs s sh c) (MixedVecs s sh d)

data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest (IxX sh2) (MixedVecs s (sh1 ++ sh2) a)


class GMixed a where
  mshape :: KnownShapeX sh => Mixed sh a -> IxX sh
  mindex :: Mixed sh a -> IxX sh -> a
  mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a

  -- | Create an empty array. The given shape must have size zero; this may or may not be checked.
  memptyArray :: IxX sh -> Mixed sh a

  -- | Return the size of the individual (SoA) arrays in this value. If @a@
  -- does not contain tuples, this coincides with the total number of scalars
  -- in the given value; if @a@ contains tuples, then it is some multiple of
  -- this number of scalars.
  mvecsNumElts :: a -> Int

  -- | Create uninitialised vectors for this array type, given the shape of
  -- this vector and an example for the contents. The shape must not have size
  -- zero; an error may be thrown otherwise.
  mvecsUnsafeNew :: IxX sh -> a -> ST s (MixedVecs s sh a)

  -- | Given the shape of this array, an index and a value, write the value at
  -- that index in the vectors.
  mvecsWrite :: IxX sh -> IxX sh -> a -> MixedVecs s sh a -> ST s ()

  -- | Given the shape of this array, an index and a value, write the value at
  -- that index in the vectors.
  mvecsWritePartial :: KnownShapeX sh' => IxX (sh ++ sh') -> IxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()

  -- | Given the shape of this array, finalise the vectors into 'XArray's.
  mvecsFreeze :: IxX sh -> MixedVecs s sh a -> ST s (Mixed sh a)

-- TODO: this use of toVector is suboptimal
mvecsWritePartialPrimitive
  :: forall sh' sh a s. (KnownShapeX sh', VU.Unbox a)
  => IxX (sh ++ sh') -> IxX sh -> XArray sh' a -> VU.MVector s a -> ST s ()
mvecsWritePartialPrimitive sh i arr v = do
  let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr)))
  VU.copy (VUM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr)

instance GMixed Int where
  mshape (M_Int a) = X.shape a
  mindex (M_Int a) i = X.index a i
  mindexPartial (M_Int a) i = M_Int (X.indexPartial a i)
  memptyArray sh = M_Int (X.generate sh (error "memptyArray Int: shape was not empty"))

  mvecsNumElts _ = 1
  mvecsUnsafeNew sh _ = MV_Int <$> VUM.unsafeNew (X.shapeSize sh)
  mvecsWrite sh i x (MV_Int v) = VUM.write v (X.toLinearIdx sh i) x
  mvecsWritePartial sh i (M_Int @sh' arr) (MV_Int v) = mvecsWritePartialPrimitive @sh' sh i arr v
  mvecsFreeze sh (MV_Int v) = M_Int . X.fromVector sh <$> VU.freeze v

instance GMixed Double where
  mshape (M_Double a) = X.shape a
  mindex (M_Double a) i = X.index a i
  mindexPartial (M_Double a) i = M_Double (X.indexPartial a i)
  memptyArray sh = M_Double (X.generate sh (error "memptyArray Double: shape was not empty"))

  mvecsNumElts _ = 1
  mvecsUnsafeNew sh _ = MV_Double <$> VUM.unsafeNew (X.shapeSize sh)
  mvecsWrite sh i x (MV_Double v) = VUM.write v (X.toLinearIdx sh i) x
  mvecsWritePartial sh i (M_Double @sh' arr) (MV_Double v) = mvecsWritePartialPrimitive @sh' sh i arr v
  mvecsFreeze sh (MV_Double v) = M_Double . X.fromVector sh <$> VU.freeze v

instance GMixed () where
  mshape (M_Nil sh) = sh
  mindex _ _ = ()
  mindexPartial = \(M_Nil sh) i -> M_Nil (X.ixDrop sh i)
  memptyArray sh = M_Nil sh

  mvecsNumElts _ = 1
  mvecsUnsafeNew _ _ = return MV_Nil
  mvecsWrite _ _ _ _ = return ()
  mvecsWritePartial _ _ _ _ = return ()
  mvecsFreeze sh _ = return (M_Nil sh)

instance (GMixed a, GMixed b) => GMixed (a, b) where
  mshape (M_Tup2 a _) = mshape a
  mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
  mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
  memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)

  mvecsNumElts (x, y) = mvecsNumElts x * mvecsNumElts y
  mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
  mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
    mvecsWrite sh i x a
    mvecsWrite sh i y b
  mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do
    mvecsWritePartial sh i x a
    mvecsWritePartial sh i y b
  mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b

instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where
  -- TODO: this is quadratic in the nesting level
  mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh
  mshape (M_Nest arr)
    | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh')
    = ixAppPrefix (knownShapeX @sh) (mshape arr)
    where
      ixAppPrefix :: StaticShapeX sh1 -> IxX (sh1 ++ sh') -> IxX sh1
      ixAppPrefix SZX _ = IZX
      ixAppPrefix (_ :$@ ssh) (i ::@ idx) = i ::@ ixAppPrefix ssh idx
      ixAppPrefix (_ :$? ssh) (i ::? idx) = i ::? ixAppPrefix ssh idx

  mindex (M_Nest arr) i = mindexPartial arr i

  mindexPartial :: forall sh1 sh2.
                   Mixed (sh1 ++ sh2) (Mixed sh' a) -> IxX sh1 -> Mixed sh2 (Mixed sh' a)
  mindexPartial (M_Nest arr) i
    | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
    = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)

  memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIdx (knownShapeX @sh'))))

  mvecsNumElts arr =
    let n = X.shapeSize (mshape arr)
    in if n == 0 then 0 else n * mvecsNumElts (mindex arr (X.zeroIdx (knownShapeX @sh')))

  mvecsUnsafeNew sh example
    | X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example"
    | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example))
                                                 (mindex example (X.zeroIdx (knownShapeX @sh')))
    where
      sh' = mshape example

  mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs

  mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2
                    => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Mixed sh' a)
                    -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
                    -> ST s ()
  mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs)
    | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
    , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
    = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs

  mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs

mgenerate :: forall sh a. (KnownShapeX sh, GMixed a) => IxX sh -> (IxX sh -> a) -> Mixed sh a
mgenerate sh f
  | not (checkBounds sh (knownShapeX @sh)) =
      error $ "mgenerate: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh)
  -- We need to be very careful here to ensure that neither 'sh' nor
  -- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty.
  | X.shapeSize sh == 0 = memptyArray sh
  | otherwise =
      let firstidx = X.zeroIdx' sh
          firstelem = f (X.zeroIdx' sh)
      in if mvecsNumElts firstelem == 0
           then memptyArray sh
           else runST $ do
                  vecs <- mvecsUnsafeNew sh firstelem
                  mvecsWrite sh firstidx firstelem vecs
                  forM_ (tail (X.enumShape sh)) $ \idx ->
                    mvecsWrite sh idx (f idx) vecs
                  mvecsFreeze sh vecs
  where
    checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool
    checkBounds IZX SZX = True
    checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (unSNat n') && checkBounds sh' ssh'
    checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'


type Ranked :: Nat -> Type -> Type
newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)

type Shaped :: [Nat] -> Type -> Type
newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)

newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a))
newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a))

newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a))
newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a))


instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where
  mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr
  mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i)

  mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IxX sh -> Mixed sh' (Ranked n a)
  mindexPartial (M_Ranked arr) i
    | Dict <- lemKnownReplicate (Proxy @n)
    = coerce @(Mixed sh' (Mixed (Replicate n 'Nothing) a)) @(Mixed sh' (Ranked n a)) $
        mindexPartial arr i

  memptyArray :: forall sh. IxX sh -> Mixed sh (Ranked n a)
  memptyArray i
    | Dict <- lemKnownReplicate (Proxy @n)
    = coerce @(Mixed sh (Mixed (Replicate n 'Nothing) a)) @(Mixed sh (Ranked n a)) $
        memptyArray i

  mvecsNumElts (Ranked arr)
    | Dict <- lemKnownReplicate (Proxy @n)
    = mvecsNumElts arr

  mvecsUnsafeNew idx (Ranked arr)
    | Dict <- lemKnownReplicate (Proxy @n)
    = MV_Ranked <$> mvecsUnsafeNew idx arr

  mvecsWrite :: forall sh s. IxX sh -> IxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
  mvecsWrite sh idx (Ranked arr) vecs
    | Dict <- lemKnownReplicate (Proxy @n)
    = mvecsWrite sh idx arr
        (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
           vecs)

  mvecsWritePartial :: forall sh sh' s. KnownShapeX sh'
                    => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Ranked n a)
                    -> MixedVecs s (sh ++ sh') (Ranked n a)
                    -> ST s ()
  mvecsWritePartial sh idx arr vecs
    | Dict <- lemKnownReplicate (Proxy @n)
    = mvecsWritePartial sh idx
        (coerce @(Mixed sh' (Ranked n a))
                @(Mixed sh' (Mixed (Replicate n Nothing) a))
           arr)
        (coerce @(MixedVecs s (sh ++ sh') (Ranked n a))
                @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
           vecs)

  mvecsFreeze :: forall sh s. IxX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
  mvecsFreeze sh vecs
    | Dict <- lemKnownReplicate (Proxy @n)
    = coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
             @(Mixed sh (Ranked n a))
        <$> mvecsFreeze sh
              (coerce @(MixedVecs s sh (Ranked n a))
                      @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
                      vecs)


data SShape sh where
  ShNil :: SShape '[]
  ShCons :: SNat n -> SShape sh -> SShape (n : sh)
deriving instance Show (SShape sh)

class KnownShape sh where knownShape :: SShape sh
instance KnownShape '[] where knownShape = ShNil
instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons knownNat knownShape

-- instance (KnownShape sh, GMixed a) => GMixed (Shaped sh a) where


type IxR :: Nat -> Type
data IxR n where
  IZR :: IxR Z
  (:::) :: Int -> IxR n -> IxR (S n)

type IxS :: [Nat] -> Type
data IxS sh where
  IZS :: IxS '[]
  (::$) :: Int -> IxS sh -> IxS (n : sh)

ixCvtXR :: IxX sh -> IxR (X.Rank sh)
ixCvtXR IZX = IZR
ixCvtXR (n ::@ sh) = n ::: ixCvtXR sh
ixCvtXR (n ::? sh) = n ::: ixCvtXR sh

ixCvtRX :: IxR n -> IxX (Replicate n Nothing)
ixCvtRX IZR = IZX
ixCvtRX (n ::: sh) = n ::? ixCvtRX sh

lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n
lemRankReplicate _ = go (knownNat @n)
  where
    go :: SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m
    go SZ = Refl
    go (SS n) | Refl <- go n = Refl

lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a
                    -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
lemReplicatePlusApp _ _ _ = go (knownNat @n)
  where
    go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
    go SZ = Refl
    go (SS n) | Refl <- go n = Refl


rshape :: forall n a. (KnownNat n, GMixed a) => Ranked n a -> IxR n
rshape (Ranked arr)
  | Dict <- lemKnownReplicate (Proxy @n)
  , Refl <- lemRankReplicate (Proxy @n)
  = ixCvtXR (mshape arr)

rindex :: GMixed a => Ranked n a -> IxR n -> a
rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)

rewriteMixed :: sh1 :~: sh2 -> Mixed sh1 a -> Mixed sh2 a
rewriteMixed Refl x = x

rindexPartial :: forall n m a. (KnownNat n, GMixed a) => Ranked (n + m) a -> IxR n -> Ranked m a
rindexPartial (Ranked arr) idx
  | Refl <- lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)
  = Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
              (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr)
              (ixCvtRX idx))

rgenerate :: forall n a. (KnownNat n, GMixed a) => IxR n -> (IxR n -> a) -> Ranked n a
rgenerate sh f
  | Dict <- lemKnownReplicate (Proxy @n)
  , Refl <- lemRankReplicate (Proxy @n)
  = Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR))