summaryrefslogtreecommitdiff
path: root/src/Fancy.hs
blob: 8019393e7202081355254d102a71109211aa4d7e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Fancy where

import Control.Monad (forM_)
import Control.Monad.ST
import Data.Kind
import Data.Proxy
import Data.Type.Equality
import Data.Type.Ord
import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Unboxed.Mutable as VUM

import Array (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++))
import qualified Array as X
import Nats


type family Replicate n a where
  Replicate Z a = '[]
  Replicate (S n) a = a : Replicate n a

type family MapJust l where
  MapJust '[] = '[]
  MapJust (x : xs) = Just x : MapJust xs

lemCompareFalse1 :: (0 < n, 1 > n) => Proxy n -> a
lemCompareFalse1 = error "Incoherence"

lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)
lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n))
  where
    go :: SNat m -> StaticShapeX (Replicate m Nothing)
    go SZ = SZX
    go (SS n) = () :$? go n


type Mixed :: [Maybe Nat] -> Type -> Type
data family Mixed sh a

newtype instance Mixed sh Int = M_Int (XArray sh Int)
newtype instance Mixed sh Double = M_Double (XArray sh Double)
-- etc.

newtype instance Mixed sh () = M_Nil (IxX sh)  -- store the shape
data instance Mixed sh (a, b) = M_Tup2 (Mixed sh a) (Mixed sh b)
data instance Mixed sh (a, b, c) = M_Tup3 (Mixed sh a) (Mixed sh b) (Mixed sh c)
data instance Mixed sh (a, b, c, d) = M_Tup4 (Mixed sh a) (Mixed sh b) (Mixed sh c) (Mixed sh d)

newtype instance Mixed sh1 (Mixed sh2 a) = M_Nest (Mixed (sh1 ++ sh2) a)


type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type
data family MixedVecs s sh a

newtype instance MixedVecs s sh Int = MV_Int (VU.MVector s Int)
newtype instance MixedVecs s sh Double = MV_Double (VU.MVector s Double)
-- etc.

data instance MixedVecs s sh () = MV_Nil
data instance MixedVecs s sh (a, b) = MV_Tup2 (MixedVecs s sh a) (MixedVecs s sh b)
data instance MixedVecs s sh (a, b, c) = MV_Tup3 (MixedVecs s sh a) (MixedVecs s sh b) (MixedVecs s sh c)
data instance MixedVecs s sh (a, b, c, d) = MV_Tup4 (MixedVecs s sh a) (MixedVecs s sh b) (MixedVecs s sh c) (MixedVecs s sh d)

data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest (IxX sh2) (MixedVecs s (sh1 ++ sh2) a)


class GMixed a where
  mshape :: KnownShapeX sh => Mixed sh a -> IxX sh
  mindex :: Mixed sh a -> IxX sh -> a
  mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a

  -- | Create an empty array. The given shape must have size zero; this may or may not be checked.
  memptyArray :: IxX sh -> Mixed sh a

  -- | Return the size of the individual (SoA) arrays in this value. If @a@
  -- does not contain tuples, this coincides with the total number of scalars
  -- in the given value; if @a@ contains tuples, then it is some multiple of
  -- this number of scalars.
  mvecsNumElts :: a -> Int

  -- | Create uninitialised vectors for this array type, given the shape of
  -- this vector and an example for the contents. The shape must not have size
  -- zero; an error may be thrown otherwise.
  mvecsUnsafeNew :: IxX sh -> a -> ST s (MixedVecs s sh a)

  -- | Given the shape of this array, an index and a value, write the value at
  -- that index in the vectors.
  mvecsWrite :: IxX sh -> IxX sh -> a -> MixedVecs s sh a -> ST s ()

  -- | Given the shape of this array, an index and a value, write the value at
  -- that index in the vectors.
  mvecsWritePartial :: KnownShapeX sh' => IxX (sh ++ sh') -> IxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()

  -- | Given the shape of this array, finalise the vectors into 'XArray's.
  mvecsFreeze :: IxX sh -> MixedVecs s sh a -> ST s (Mixed sh a)

-- TODO: this use of toVector is suboptimal
mvecsWritePartialPrimitive
  :: forall sh' sh a s. (KnownShapeX sh', VU.Unbox a)
  => IxX (sh ++ sh') -> IxX sh -> XArray sh' a -> VU.MVector s a -> ST s ()
mvecsWritePartialPrimitive sh i arr v = do
  let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr)))
  VU.copy (VUM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr)

instance GMixed Int where
  mshape (M_Int a) = X.shape a
  mindex (M_Int a) i = X.index a i
  mindexPartial (M_Int a) i = M_Int (X.indexPartial a i)
  memptyArray sh = M_Int (X.generate sh (error "memptyArray Int: shape was not empty"))

  mvecsNumElts _ = 1
  mvecsUnsafeNew sh _ = MV_Int <$> VUM.unsafeNew (X.shapeSize sh)
  mvecsWrite sh i x (MV_Int v) = VUM.write v (X.toLinearIdx sh i) x
  mvecsWritePartial sh i (M_Int @sh' arr) (MV_Int v) = mvecsWritePartialPrimitive @sh' sh i arr v
  mvecsFreeze sh (MV_Int v) = M_Int . X.fromVector sh <$> VU.freeze v

instance GMixed Double where
  mshape (M_Double a) = X.shape a
  mindex (M_Double a) i = X.index a i
  mindexPartial (M_Double a) i = M_Double (X.indexPartial a i)
  memptyArray sh = M_Double (X.generate sh (error "memptyArray Double: shape was not empty"))

  mvecsNumElts _ = 1
  mvecsUnsafeNew sh _ = MV_Double <$> VUM.unsafeNew (X.shapeSize sh)
  mvecsWrite sh i x (MV_Double v) = VUM.write v (X.toLinearIdx sh i) x
  mvecsWritePartial sh i (M_Double @sh' arr) (MV_Double v) = mvecsWritePartialPrimitive @sh' sh i arr v
  mvecsFreeze sh (MV_Double v) = M_Double . X.fromVector sh <$> VU.freeze v

instance GMixed () where
  mshape (M_Nil sh) = sh
  mindex _ _ = ()
  mindexPartial = \(M_Nil sh) i -> M_Nil (X.ixDrop sh i)
  memptyArray sh = M_Nil sh

  mvecsNumElts _ = 1
  mvecsUnsafeNew _ _ = return MV_Nil
  mvecsWrite _ _ _ _ = return ()
  mvecsWritePartial _ _ _ _ = return ()
  mvecsFreeze sh _ = return (M_Nil sh)

instance (GMixed a, GMixed b) => GMixed (a, b) where
  mshape (M_Tup2 a _) = mshape a
  mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
  mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
  memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)

  mvecsNumElts (x, y) = mvecsNumElts x * mvecsNumElts y
  mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
  mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
    mvecsWrite sh i x a
    mvecsWrite sh i y b
  mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do
    mvecsWritePartial sh i x a
    mvecsWritePartial sh i y b
  mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b

instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where
  -- TODO: this is quadratic in the nesting level
  mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh
  mshape (M_Nest arr)
    | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh')
    = ixAppPrefix (knownShapeX @sh) (mshape arr)
    where
      ixAppPrefix :: StaticShapeX sh1 -> IxX (sh1 ++ sh') -> IxX sh1
      ixAppPrefix SZX _ = IZX
      ixAppPrefix (_ :$@ ssh) (i ::@ idx) = i ::@ ixAppPrefix ssh idx
      ixAppPrefix (_ :$? ssh) (i ::? idx) = i ::? ixAppPrefix ssh idx

  mindex (M_Nest arr) i = mindexPartial arr i

  mindexPartial :: forall sh1 sh2.
                   Mixed (sh1 ++ sh2) (Mixed sh' a) -> IxX sh1 -> Mixed sh2 (Mixed sh' a)
  mindexPartial (M_Nest arr) i
    | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
    = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)

  memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIdx (knownShapeX @sh'))))

  mvecsNumElts arr =
    let n = X.shapeSize (mshape arr)
    in if n == 0 then 0 else n * mvecsNumElts (mindex arr (X.zeroIdx (knownShapeX @sh')))

  mvecsUnsafeNew sh example
    | X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example"
    | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example))
                                                 (mindex example (X.zeroIdx (knownShapeX @sh')))
    where
      sh' = mshape example

  mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs

  mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2
                    => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Mixed sh' a)
                    -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
                    -> ST s ()
  mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs)
    | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
    , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
    = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs

  mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs

mgenerate :: GMixed a => IxX sh -> (IxX sh -> a) -> Mixed sh a
mgenerate sh f
  -- We need to be very careful here to ensure that neither 'sh' nor
  -- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty.
  | X.shapeSize sh == 0 = memptyArray sh
  | otherwise =
      let firstidx = X.zeroIdx' sh
          firstelem = f (X.zeroIdx' sh)
      in if mvecsNumElts firstelem == 0
           then memptyArray sh
           else runST $ do
                  vecs <- mvecsUnsafeNew sh firstelem
                  mvecsWrite sh firstidx firstelem vecs
                  forM_ (tail (X.enumShape sh)) $ \idx ->
                    mvecsWrite sh idx (f idx) vecs
                  mvecsFreeze sh vecs


type Ranked :: Nat -> Type -> Type
newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)

type Shaped :: [Nat] -> Type -> Type
newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)

newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a))
newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a))

newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a))
newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a))


instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where
  mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr


type IxR :: Nat -> Type
data IxR n where
  IZR :: IxR Z
  (:::) :: Int -> IxR n -> IxR (S n)

type IxS :: [Nat] -> Type
data IxS sh where
  IZS :: IxS '[]
  (::$) :: Int -> IxS sh -> IxS (n : sh)