1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
|
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UnboxedTuples #-}
module Interpreter.Accum (
AcM,
runAcM,
Rep',
Accum,
withAccum,
accumAdd,
inParallel,
) where
import Control.Concurrent
import Control.Monad (when, forM_)
import Data.Bifunctor (second)
import Data.Proxy
import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr)
import Foreign.Storable (sizeOf)
import GHC.Exts
import GHC.Float
import GHC.Int
import GHC.IO (IO(..))
import GHC.Word
import System.IO.Unsafe (unsafePerformIO)
import Array
import AST
import Data
newtype AcM s a = AcM (IO a)
deriving newtype (Functor, Applicative, Monad)
runAcM :: (forall s. AcM s a) -> a
runAcM (AcM m) = unsafePerformIO m
-- | Equal to Interpreter.Rep.Rep, except that the TAccum case is defined.
type family Rep' s t where
Rep' s TNil = ()
Rep' s (TPair a b) = (Rep' s a, Rep' s b)
Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b)
Rep' s (TMaybe t) = Maybe (Rep' s t)
Rep' s (TArr n t) = Array n (Rep' s t)
Rep' s (TScal sty) = ScalRep sty
Rep' s (TAccum t) = Accum s t
-- | Floats and integers are accumulated; booleans are left as-is.
data Accum s t = Accum (STy t) (ForeignPtr ())
tSize :: Proxy s -> STy t -> Rep' s t -> Int
tSize p ty x = tSize' p ty (Just x)
tSize' :: Proxy s -> STy t -> Int
tSize' p typ = case typ of
STNil -> 0
STPair a b -> tSize' p a + tSize' p b
STEither a b -> 1 + max (tSize' p a) (tSize' p b)
-- Representation of Maybe t is the same as Either () t; the add operation is different, however.
STMaybe t -> tSize' p (STEither STNil t)
STArr ndim t ->
case val of
Nothing -> error "Nested arrays not supported in this implementation"
Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' p t Nothing
STScal sty -> goScal sty
STAccum{} -> error "Nested accumulators unsupported"
where
goScal :: SScalTy t -> Int
goScal STI32 = 4
goScal STI64 = 8
goScal STF32 = 4
goScal STF64 = 8
goScal STBool = 1
-- | This operation does not commute with 'accumAdd', so it must be used with
-- care. Furthermore it must be used on exactly the same value as tSize was
-- called on. Hence it lives in IO, not in AcM.
accumWrite :: forall s t. Accum s t -> Rep' s t -> IO ()
accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) ->
let
go :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int
go inarr ty val off = case ty of
STNil -> return off
STPair a b -> do
off1 <- go inarr a (fst val) off
go inarr b (snd val) off1
STEither a b -> do
let !(I# off#) = off
off1 <- case val of
Left x -> do
let !(I8# tag#) = 0
writeInt8# addr# off# tag#
go inarr a x (off + 1)
Right y -> do
let !(I8# tag#) = 1
writeInt8# addr# off# tag#
go inarr b y (off + 1)
if inarr
then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing))
else return off1
-- Representation is the same, but add operation is different
STMaybe t -> go inarr (STEither STNil t) (maybe (Left ()) Right val) off
STArr _ t
| inarr -> error "Nested arrays not supported in this implementation"
| otherwise -> do
off1 <- goShape (arrayShape val) off
let eltsize = tSize' (Proxy @s) t Nothing
n = arraySize val
traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val
return (off1 + eltsize * n)
STScal sty -> goScal sty val off
STAccum{} -> error "Nested accumulators unsupported"
goShape :: Shape n -> Int -> IO Int
goShape ShNil off = return off
goShape (ShCons sh n) off = do
off1@(I# off1#) <- goShape sh off
let !(I64# n'#) = fromIntegral n
writeInt64# addr# off1# n'#
return (off1 + 8)
goScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int
goScal STI32 (I32# x) off@(I# off#) = off + 4 <$ writeInt32# addr# off# x
goScal STI64 (I64# x) off@(I# off#) = off + 8 <$ writeInt64# addr# off# x
goScal STF32 (F# x) off@(I# off#) = off + 4 <$ writeFloat# addr# off# x
goScal STF64 (D# x) off@(I# off#) = off + 8 <$ writeDouble# addr# off# x
goScal STBool b off@(I# off#) = do
let !(I8# i) = fromIntegral (fromEnum b)
off + 1 <$ writeInt8# addr# off# i
in () <$ go False topty top_value 0
accumRead :: forall s t. Accum s t -> AcM s (Rep' s t)
accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->
let
go :: Bool -> STy t' -> Int -> IO (Int, Rep' s t')
go inarr ty off = case ty of
STNil -> return (off, ())
STPair a b -> do
(off1, x) <- go inarr a off
(off2, y) <- go inarr b off1
return (off1 + off2, (x, y))
STEither a b -> do
let !(I# off#) = off
tag <- readInt8 addr# off#
(off1, val) <- case tag of
0 -> fmap Left <$> go inarr a (off + 1)
1 -> fmap Right <$> go inarr b (off + 1)
_ -> error "Invalid tag in accum memory"
if inarr
then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val)
else return (off1, val)
-- Representation is the same, but add operation is different
STMaybe t -> second (either (const Nothing) Just) <$> go inarr (STEither STNil t) off
STArr ndim t
| inarr -> error "Nested arrays not supported in this implementation"
| otherwise -> do
(off1, sh) <- readShape addr# ndim off
let eltsize = tSize' (Proxy @s) t Nothing
n = shapeSize sh
arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini))
return (off1 + eltsize * n, arr)
STScal sty -> goScal sty off
STAccum{} -> error "Nested accumulators unsupported"
goScal :: SScalTy t' -> Int -> IO (Int, ScalRep t')
goScal STI32 off@(I# off#) = (off + 4,) <$> readInt32 addr# off#
goScal STI64 off@(I# off#) = (off + 8,) <$> readInt64 addr# off#
goScal STF32 off@(I# off#) = (off + 4,) <$> readFloat addr# off#
goScal STF64 off@(I# off#) = (off + 8,) <$> readDouble addr# off#
goScal STBool off@(I# off#) = do
i8 <- readInt8 addr# off#
return (off + 1, toEnum (fromIntegral i8))
in snd <$> go False topty 0
readShape :: Addr# -> SNat n -> Int -> IO (Int, Shape n)
readShape _ SZ off = return (off, ShNil)
readShape mbarr (SS ndim) off = do
(off1@(I# off1#), sh) <- readShape mbarr ndim off
n' <- readInt64 mbarr off1#
return (off1 + 8, ShCons sh (fromIntegral n'))
-- | @reverse@ of 'Shape'. The /outer/ dimension is on the left, at the head of
-- the list.
data InvShape n where
IShNil :: InvShape Z
IShCons :: Int -- ^ How many subarrays are there?
-> Int -- ^ What is the size of all subarrays together?
-> InvShape n -- ^ Sub array inverted shape
-> InvShape (S n)
ishSize :: InvShape n -> Int
ishSize IShNil = 1
ishSize (IShCons _ sz _) = sz
invertShape :: forall n. Shape n -> InvShape n
invertShape | Refl <- lemPlusZero @n = flip go IShNil
where
go :: forall n' m. Shape n' -> InvShape m -> InvShape (n' + m)
go ShNil ish = ish
go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish)
accumAdd :: forall s t i. Accum s t -> SNat i -> Rep' s (AcIdx t i) -> Rep' s (AcVal t i) -> AcM s ()
accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->
let
go :: Bool -> STy t' -> SNat i' -> Rep' s (AcIdx t' i') -> Rep' s (AcVal t' i') -> Int -> IO ()
go inarr ty SZ () val off = () <$ performAdd inarr ty val off
go inarr ty (SS dep) idx val off = case (ty, idx, val) of
(STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off
(STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off
(STPair{}, _, _) -> error "Mismatching idx/val for Pair in accumAdd"
(STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off
(STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off
(STEither{}, _, _) -> error "Mismatching idx/val for Either in accumAdd"
(STMaybe t, _, _) -> _ idx val
(STArr rank eltty, _, _)
| inarr -> error "Nested arrays"
| otherwise -> do
(off1, ish) <- second invertShape <$> readShape addr# rank off
goArr (SS dep) ish eltty idx val off1
(STScal{}, _, _) -> error "accumAdd: Scal impossible with nonzero depth"
(STNil, _, _) -> error "accumAdd: Nil impossible with nonzero depth"
(STAccum{}, _, _) -> error "Nested accumulators unsupported"
goArr :: SNat i' -> InvShape n -> STy t'
-> Rep' s (AcIdx (TArr n t') i') -> Rep' s (AcVal (TArr n t') i') -> Int -> IO ()
goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off
goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off
goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do
let i' = fromIntegral @(Rep' s TIx) @Int i
when (i' < 0 || i' >= n) $
error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")"
goArr depm1 ish t1 idx val (off + i' * ishSize ish)
performAddArr :: Int -> STy t' -> Array n (Rep' s t') -> Int -> IO Int
performAddArr arraySz eltty val off = do
let eltsize = tSize' (Proxy @s) eltty Nothing
forM_ [0 .. arraySz - 1] $ \lini ->
performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize)
return (off + arraySz * eltsize)
performAdd :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int
performAdd inarr ty val off = case ty of
STNil -> return off
STPair t1 t2 -> do
off1 <- performAdd inarr t1 (fst val) off
performAdd inarr t2 (snd val) off1
STEither t1 t2 -> do
let !(I# off#) = off
tag <- readInt8 addr# off#
off1 <- case (val, tag) of
(Left val1, 0) -> performAdd inarr t1 val1 (off + 1)
(Right val2, 1) -> performAdd inarr t2 val2 (off + 1)
_ -> error "accumAdd: Tag mismatch for Either"
if inarr
then return (off + 1 + max (tSize' (Proxy @s) t1 Nothing) (tSize' (Proxy @s) t2 Nothing))
else return off1
STArr n ty'
| inarr -> error "Nested array"
| otherwise -> do
(off1, sh) <- readShape addr# n off
performAddArr (shapeSize sh) ty' val off1
STScal ty' -> performAddScal ty' val off
STAccum{} -> error "Nested accumulators unsupported"
performAddScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int
performAddScal STI32 (I32# x#) off@(I# off#)
| sizeOf (undefined :: Int) == 4
= off + 4 <$ fetchAddWord# addr# off# (word32ToWord# (int32ToWord32# x#))
| otherwise
= off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\(W32# w#) -> W32# (int32ToWord32# x# `plusWord32#` w#))
performAddScal STI64 (I64# x#) off@(I# off#)
| sizeOf (undefined :: Int) == 8
= off + 8 <$ fetchAddWord# addr# off# (word64ToWord# (int64ToWord64# x#))
| otherwise
= off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\(W64# w#) -> W64# (int64ToWord64# x# `plusWord64#` w#))
performAddScal STF32 x off@(I# off#) =
off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\w -> castFloatToWord32 (x + castWord32ToFloat w))
performAddScal STF64 x off@(I# off#) =
off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\w -> castDoubleToWord64 (x + castWord64ToDouble w))
performAddScal STBool _ off = return (off + 1) -- don't do anything with booleans
casLoop :: Eq w
=> (Addr# -> Int# -> IO w) -- ^ read value (from a given byte offset; will get 0#)
-> (Addr# -> w -> w -> IO w) -- ^ CAS value at address (expected -> desired -> IO observed)
-> Addr# -- ^ Address to attempt to modify
-> (w -> w) -- ^ Operation to apply to the value
-> IO ()
casLoop readOp casOp addr modify = readOp addr 0# >>= loop
where
loop value = do
value' <- casOp addr value (modify value)
if value == value'
then return ()
else loop value'
in () <$ go False topty top_depth top_index top_value 0
withAccum :: forall t s b. STy t -> Rep' s t -> (Accum s t -> AcM s b) -> AcM s (b, Rep' s t)
withAccum ty start fun = do
-- The initial write must happen before any of the adds or reads, so it makes
-- sense to put it in IO together with the allocation, instead of in AcM.
accum <- AcM $ do buffer <- mallocBytes (tSize (Proxy @s) ty start)
ptr <- newForeignPtr finalizerFree buffer
let accum = Accum ty ptr
accumWrite accum start
return accum
b <- fun accum
out <- accumRead accum
return (b, out)
inParallel :: [AcM s t] -> AcM s [t]
inParallel actions = AcM $ do
mvars <- mapM (\_ -> newEmptyMVar) actions
forM_ (zip actions mvars) $ \(AcM action, var) ->
forkIO $ action >>= putMVar var
mapM takeMVar mvars
-- | Offset is in bytes.
readInt8 :: Addr# -> Int# -> IO Int8
readInt32 :: Addr# -> Int# -> IO Int32
readInt64 :: Addr# -> Int# -> IO Int64
readWord32 :: Addr# -> Int# -> IO Word32
readWord64 :: Addr# -> Int# -> IO Word64
readFloat :: Addr# -> Int# -> IO Float
readDouble :: Addr# -> Int# -> IO Double
readInt8 addr off# = IO $ \s -> case readInt8OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I8# val #)
readInt32 addr off# = IO $ \s -> case readInt32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I32# val #)
readInt64 addr off# = IO $ \s -> case readInt64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I64# val #)
readWord32 addr off# = IO $ \s -> case readWord32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W32# val #)
readWord64 addr off# = IO $ \s -> case readWord64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W64# val #)
readFloat addr off# = IO $ \s -> case readFloatOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', F# val #)
readDouble addr off# = IO $ \s -> case readDoubleOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', D# val #)
writeInt8# :: Addr# -> Int# -> Int8# -> IO ()
writeInt32# :: Addr# -> Int# -> Int32# -> IO ()
writeInt64# :: Addr# -> Int# -> Int64# -> IO ()
writeFloat# :: Addr# -> Int# -> Float# -> IO ()
writeDouble# :: Addr# -> Int# -> Double# -> IO ()
writeInt8# addr off# val = IO $ \s -> (# writeInt8OffAddr# (addr `plusAddr#` off#) 0# val s, () #)
writeInt32# addr off# val = IO $ \s -> (# writeInt32OffAddr# (addr `plusAddr#` off#) 0# val s, () #)
writeInt64# addr off# val = IO $ \s -> (# writeInt64OffAddr# (addr `plusAddr#` off#) 0# val s, () #)
writeFloat# addr off# val = IO $ \s -> (# writeFloatOffAddr# (addr `plusAddr#` off#) 0# val s, () #)
writeDouble# addr off# val = IO $ \s -> (# writeDoubleOffAddr# (addr `plusAddr#` off#) 0# val s, () #)
fetchAddWord# :: Addr# -> Int# -> Word# -> IO ()
fetchAddWord# addr off# val = IO $ \s -> case fetchAddWordAddr# (addr `plusAddr#` off#) val s of (# s', _ #) -> (# s', () #)
atomicCasWord32Addr :: Addr# -> Word32 -> Word32 -> IO Word32
atomicCasWord64Addr :: Addr# -> Word64 -> Word64 -> IO Word64
atomicCasWord32Addr addr (W32# expected) (W32# desired) =
IO $ \s -> case atomicCasWord32Addr# addr expected desired s of (# s', old #) -> (# s', W32# old #)
atomicCasWord64Addr addr (W64# expected) (W64# desired) =
IO $ \s -> case atomicCasWord64Addr# addr expected desired s of (# s', old #) -> (# s', W64# old #)
|