1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE FlexibleInstances #-}
module Fancy where
import Control.Monad (forM_)
import Control.Monad.ST
import Data.Kind
import Data.Proxy
import Data.Type.Equality
import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Unboxed.Mutable as VUM
import GHC.TypeLits
import Array (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++))
import qualified Array as X
type family Replicate n a where
Replicate 0 a = '[]
Replicate n a = a : Replicate (n - 1) a
type family MapJust l where
MapJust '[] = '[]
MapJust (x : xs) = Just x : MapJust xs
type Mixed :: [Maybe Nat] -> Type -> Type
data family Mixed sh a
newtype instance Mixed sh Int = M_Int (XArray sh Int)
newtype instance Mixed sh Double = M_Double (XArray sh Double)
-- etc.
newtype instance Mixed sh () = M_Nil (IxX sh) -- store the shape
data instance Mixed sh (a, b) = M_Tup2 (Mixed sh a) (Mixed sh b)
data instance Mixed sh (a, b, c) = M_Tup3 (Mixed sh a) (Mixed sh b) (Mixed sh c)
data instance Mixed sh (a, b, c, d) = M_Tup4 (Mixed sh a) (Mixed sh b) (Mixed sh c) (Mixed sh d)
newtype instance Mixed sh1 (Mixed sh2 a) = M_Nest (Mixed (sh1 ++ sh2) a)
type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type
data family MixedVecs s sh a
newtype instance MixedVecs s sh Int = MV_Int (VU.MVector s Int)
newtype instance MixedVecs s sh Double = MV_Double (VU.MVector s Double)
-- etc.
data instance MixedVecs s sh () = MV_Nil
data instance MixedVecs s sh (a, b) = MV_Tup2 (MixedVecs s sh a) (MixedVecs s sh b)
data instance MixedVecs s sh (a, b, c) = MV_Tup3 (MixedVecs s sh a) (MixedVecs s sh b) (MixedVecs s sh c)
data instance MixedVecs s sh (a, b, c, d) = MV_Tup4 (MixedVecs s sh a) (MixedVecs s sh b) (MixedVecs s sh c) (MixedVecs s sh d)
data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest (IxX sh2) (MixedVecs s (sh1 ++ sh2) a)
class GMixed a where
mshape :: KnownShapeX sh => Mixed sh a -> IxX sh
mindex :: Mixed sh a -> IxX sh -> a
mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a
-- | Create an empty array. The given shape must have size zero; this may or may not be checked.
memptyArray :: IxX sh -> Mixed sh a
-- | Return the size of the individual (SoA) arrays in this value. If @a@
-- does not contain tuples, this coincides with the total number of scalars
-- in the given value; if @a@ contains tuples, then it is some multiple of
-- this number of scalars.
mvecsNumElts :: a -> Int
-- | Create uninitialised vectors for this array type, given the shape of
-- this vector and an example for the contents. The shape must not have size
-- zero; an error may be thrown otherwise.
mvecsUnsafeNew :: IxX sh -> a -> ST s (MixedVecs s sh a)
-- | Given the shape of this array, an index and a value, write the value at
-- that index in the vectors.
mvecsWrite :: IxX sh -> IxX sh -> a -> MixedVecs s sh a -> ST s ()
-- | Given the shape of this array, an index and a value, write the value at
-- that index in the vectors.
mvecsWritePartial :: KnownShapeX sh' => IxX (sh ++ sh') -> IxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
-- | Given the shape of this array, finalise the vectors into 'XArray's.
mvecsFreeze :: IxX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
-- TODO: this use of toVector is suboptimal
mvecsWritePartialPrimitive
:: forall sh' sh a s. (KnownShapeX sh', VU.Unbox a)
=> IxX (sh ++ sh') -> IxX sh -> XArray sh' a -> VU.MVector s a -> ST s ()
mvecsWritePartialPrimitive sh i arr v = do
let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr)))
VU.copy (VUM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr)
instance GMixed Int where
mshape (M_Int a) = X.shape a
mindex (M_Int a) i = X.index a i
mindexPartial (M_Int a) i = M_Int (X.indexPartial a i)
memptyArray sh = M_Int (X.generate sh (error "memptyArray Int: shape was not empty"))
mvecsNumElts _ = 1
mvecsUnsafeNew sh _ = MV_Int <$> VUM.unsafeNew (X.shapeSize sh)
mvecsWrite sh i x (MV_Int v) = VUM.write v (X.toLinearIdx sh i) x
mvecsWritePartial sh i (M_Int @sh' arr) (MV_Int v) = mvecsWritePartialPrimitive @sh' sh i arr v
mvecsFreeze sh (MV_Int v) = M_Int . X.fromVector sh <$> VU.freeze v
instance GMixed Double where
mshape (M_Double a) = X.shape a
mindex (M_Double a) i = X.index a i
mindexPartial (M_Double a) i = M_Double (X.indexPartial a i)
memptyArray sh = M_Double (X.generate sh (error "memptyArray Double: shape was not empty"))
mvecsNumElts _ = 1
mvecsUnsafeNew sh _ = MV_Double <$> VUM.unsafeNew (X.shapeSize sh)
mvecsWrite sh i x (MV_Double v) = VUM.write v (X.toLinearIdx sh i) x
mvecsWritePartial sh i (M_Double @sh' arr) (MV_Double v) = mvecsWritePartialPrimitive @sh' sh i arr v
mvecsFreeze sh (MV_Double v) = M_Double . X.fromVector sh <$> VU.freeze v
instance GMixed () where
mshape (M_Nil sh) = sh
mindex _ _ = ()
mindexPartial = \(M_Nil sh) i -> M_Nil (X.ixDrop sh i)
memptyArray sh = M_Nil sh
mvecsNumElts _ = 1
mvecsUnsafeNew _ _ = return MV_Nil
mvecsWrite _ _ _ _ = return ()
mvecsWritePartial _ _ _ _ = return ()
mvecsFreeze sh _ = return (M_Nil sh)
instance (GMixed a, GMixed b) => GMixed (a, b) where
mshape (M_Tup2 a _) = mshape a
mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)
mvecsNumElts (x, y) = mvecsNumElts x * mvecsNumElts y
mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
mvecsWrite sh i x a
mvecsWrite sh i y b
mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do
mvecsWritePartial sh i x a
mvecsWritePartial sh i y b
mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where
-- TODO: this is quadratic in the nesting level
mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh
mshape (M_Nest arr)
| X.Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh')
= ixAppPrefix (knownShapeX @sh) (mshape arr)
where
ixAppPrefix :: StaticShapeX sh1 -> IxX (sh1 ++ sh') -> IxX sh1
ixAppPrefix SZX _ = IZX
ixAppPrefix (_ :$@ ssh) (i ::@ idx) = i ::@ ixAppPrefix ssh idx
ixAppPrefix (_ :$? ssh) (i ::? idx) = i ::? ixAppPrefix ssh idx
mindex (M_Nest arr) i = mindexPartial arr i
mindexPartial :: forall sh1 sh2.
Mixed (sh1 ++ sh2) (Mixed sh' a) -> IxX sh1 -> Mixed sh2 (Mixed sh' a)
mindexPartial (M_Nest arr) i
| Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
= M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIdx (knownShapeX @sh'))))
mvecsNumElts arr =
let n = X.shapeSize (mshape arr)
in if n == 0 then 0 else n * mvecsNumElts (mindex arr (X.zeroIdx (knownShapeX @sh')))
mvecsUnsafeNew sh example
| X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example"
| otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example))
(mindex example (X.zeroIdx (knownShapeX @sh')))
where
sh' = mshape example
mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs
mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2
=> IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Mixed sh' a)
-> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
-> ST s ()
mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs)
| X.Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
, Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
= mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs
mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs
mgenerate :: GMixed a => IxX sh -> (IxX sh -> a) -> Mixed sh a
mgenerate sh f
-- We need to be very careful here to ensure that neither 'sh' nor
-- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty.
| X.shapeSize sh == 0 = memptyArray sh
| otherwise =
let firstidx = X.zeroIdx' sh
firstelem = f (X.zeroIdx' sh)
in if mvecsNumElts firstelem == 0
then memptyArray sh
else runST $ do
vecs <- mvecsUnsafeNew sh firstelem
mvecsWrite sh firstidx firstelem vecs
forM_ (tail (X.enumShape sh)) $ \idx ->
mvecsWrite sh idx (f idx) vecs
mvecsFreeze sh vecs
type Ranked :: Nat -> Type -> Type
newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
type Shaped :: [Nat] -> Type -> Type
newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
type IxR :: Nat -> Type
data IxR n where
IZR :: IxR 0
(:::) :: Int -> IxR n -> IxR (n + 1)
type IxS :: [Nat] -> Type
data IxS sh where
IZS :: IxS '[]
(::$) :: Int -> IxS sh -> IxS (n : sh)
|