summaryrefslogtreecommitdiff
path: root/src/Interpreter/Accum.hs
blob: 45f507b61e17c876d789d16b326b161925a1eea2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UnboxedTuples #-}
module Interpreter.Accum where

import Control.Monad.ST
import Control.Monad.ST.Unsafe
import GHC.Exts
import GHC.Int
import GHC.ST (ST(..))

import AST
import Data
import Interpreter.Array
import Data.Bifunctor (second)
import Control.Monad (when, forM_)
import Foreign.Storable (sizeOf)


newtype AcM s a = AcM (ST s a)
  deriving newtype (Functor, Applicative, Monad)

runAcM :: (forall s. AcM s a) -> a
runAcM m = runST (case m of AcM m' -> m')

type family Rep t where
  Rep TNil = ()
  Rep (TPair a b) = (Rep a, Rep b)
  Rep (TEither a b) = Either (Rep a) (Rep b)
  Rep (TArr n t) = Array n (Rep t)
  Rep (TScal sty) = ScalRep sty
  -- Rep (TAccum t) = _

data Accum s t = Accum (STy t) (MutableByteArray# s)

tSize :: STy t -> Rep t -> Int
tSize ty x = tSize' ty (Just x)

-- | Passing Nothing as the value means "this is (inside) an array element".
tSize' :: STy t -> Maybe (Rep t) -> Int
tSize' typ val = case typ of
  STNil -> 0
  STPair a b -> tSize' a (fst <$> val) + tSize' b (snd <$> val)
  STEither a b ->
    case val of
      Nothing -> 1 + max (tSize' a Nothing) (tSize' b Nothing)
      Just (Left x) -> 1 + tSize a x   -- '1 +' is for runtime sanity checking
      Just (Right y) -> 1 + tSize b y  -- idem
  STArr ndim t ->
    case val of
      Nothing -> error "Nested arrays not supported in this implementation"
      Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' t Nothing
  STScal sty -> goScal sty
  STAccum{} -> error "Nested accumulators unsupported"
  where
    goScal :: SScalTy t -> Int
    goScal STI32 = 4
    goScal STI64 = 8
    goScal STF32 = 4
    goScal STF64 = 8
    goScal STBool = 1

-- | This operation does not commute with 'accumAdd', so it must be used with
-- care. Furthermore it must be used on exactly the same value as tSize was
-- called on. Hence it lives in ST, not in AcM.
accumWrite :: forall s t. Accum s t -> Rep t -> ST s ()
accumWrite (Accum topty mbarr) = \val -> () <$ go False topty val 0
  where
    go :: Bool -> STy t' -> Rep t' -> Int -> ST s Int
    go inarr ty val off = case ty of
      STNil -> return off
      STPair a b -> do
        off1 <- go inarr a (fst val) off
        go inarr b (snd val) off1
      STEither a b -> do
        let !(I# off#) = off
        case val of
          Left x -> do
            let !(I8# tag#) = 0
            ST $ \s -> (# writeInt8Array# mbarr off# tag# s, () #)
            go inarr a x (off + 1)
          Right y -> do
            let !(I8# tag#) = 1
            ST $ \s -> (# writeInt8Array# mbarr off# tag# s, () #)
            go inarr b y (off + 1)
      STArr _ t
        | inarr -> error "Nested arrays not supported in this implementation"
        | otherwise -> do
            off1 <- goShape (arrayShape val) off
            let eltsize = tSize' t Nothing
                n = arraySize val
            traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val
            return (off1 + eltsize * n)
      STScal sty -> goScal sty val off
      STAccum{} -> error "Nested accumulators unsupported"

    goShape :: Shape n -> Int -> ST s Int
    goShape ShNil off = return off
    goShape (ShCons sh n) off = do
      off1@(I# off1#) <- goShape sh off
      let !(I64# n'#) = fromIntegral n
      ST $ \s -> (# writeInt64Array# mbarr off1# n'# s, off1 + 8 #)

    goScal :: SScalTy t' -> ScalRep t' -> Int -> ST s Int
    goScal STI32 (I32# x) (I# off#) = ST $ \s -> (# writeInt32Array# mbarr off# x s, I# (off# +# 4#) #)
    goScal STI64 (I64# x) (I# off#) = ST $ \s -> (# writeInt64Array# mbarr off# x s, I# (off# +# 8#) #)
    goScal STF32 (F# x) (I# off#) = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #)
    goScal STF64 (D# x) (I# off#) = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #)
    goScal STBool b (I# off#) =
      let !(I8# i) = fromIntegral (fromEnum b)
      in ST $ \s -> (# writeInt8Array# mbarr off# i s, I# (off# +# 1#) #)

accumRead :: forall s t. Accum s t -> AcM s (Rep t)
accumRead (Accum topty mbarr) = AcM $ snd <$> go False topty 0
  where
    go :: Bool -> STy t' -> Int -> ST s (Int, Rep t')
    go inarr ty off = case ty of
      STNil -> return (off, ())
      STPair a b -> do
        (off1, x) <- go inarr a off
        (off2, y) <- go inarr b off1
        return (off1 + off2, (x, y))
      STEither a b -> do
        let !(I# off#) = off
        tag <- ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', I8# i #)
        (off1, val) <- case tag of
                         0 -> fmap Left <$> go inarr a (off + 1)
                         1 -> fmap Right <$> go inarr b (off + 1)
                         _ -> error "Invalid tag in accum memory"
        if inarr
          then return (off + 1 + max (tSize' a Nothing) (tSize' b Nothing), val)
          else return (off1, val)
      STArr ndim t
        | inarr -> error "Nested arrays not supported in this implementation"
        | otherwise -> do
            (off1, sh) <- readShape mbarr ndim off
            let eltsize = tSize' t Nothing
                n = shapeSize sh
            arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini))
            return (off1 + eltsize * n, arr)
      STScal sty -> goScal sty off
      STAccum{} -> error "Nested accumulators unsupported"

    goScal :: SScalTy t' -> Int -> ST s (Int, ScalRep t')
    goScal STI32 (I# off#) = ST $ \s -> case readInt32Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 4#), I32# i) #)
    goScal STI64 (I# off#) = ST $ \s -> case readInt64Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 8#), I64# i) #)
    goScal STF32 (I# off#) = ST $ \s -> case readFloatArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 4#), F# f) #)
    goScal STF64 (I# off#) = ST $ \s -> case readDoubleArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 8#), D# f) #)
    goScal STBool (I# off#) = ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 1#), toEnum (fromIntegral (I8# i))) #)

readShape :: MutableByteArray# s -> SNat n -> Int -> ST s (Int, Shape n)
readShape _ SZ off = return (off, ShNil)
readShape mbarr (SS ndim) off = do
  (off1@(I# off1#), sh) <- readShape mbarr ndim off
  n' <- ST $ \s -> case readInt64Array# mbarr off1# s of (# s', i #) -> (# s', I64# i #)
  return (off1 + 8, ShCons sh (fromIntegral n'))

-- | @reverse@ of 'Shape'. The /outer/ dimension is on the left, at the head of
-- the list.
data InvShape n where
  IShNil :: InvShape Z
  IShCons :: Int  -- ^ How many subarrays are there?
          -> Int  -- ^ What is the size of all subarrays together?
          -> InvShape n  -- ^ Sub array inverted shape
          -> InvShape (S n)

ishSize :: InvShape n -> Int
ishSize IShNil = 0
ishSize (IShCons _ sz _) = sz

invertShape :: forall n. Shape n -> InvShape n
invertShape | Refl <- lemPlusZero @n = flip go IShNil
  where
    go :: forall n' m. Shape n' -> InvShape m -> InvShape (n' + m)
    go ShNil ish = ish
    go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish)

accumAdd :: forall s t i. Accum s t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s ()
accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty depth index value 0
  where
    go :: Bool -> STy t' -> SNat i' -> Rep (AcIdx t' i') -> Rep (AcVal t' i') -> Int -> ST s ()
    go inarr ty SZ () val off = () <$ performAdd inarr ty val off
    go inarr ty (SS dep) idx val off = case (ty, idx, val) of
      (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off
      (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off
      (STPair{}, _, _) -> error "Mismatching idx/val for Pair in accumAdd"
      (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off
      (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off
      (STEither{}, _, _) -> error "Mismatching idx/val for Either in accumAdd"
      (STArr rank eltty, _, _)
        | inarr -> error "Nested arrays"
        | otherwise -> do
            (off1, ish) <- second invertShape <$> readShape mbarr rank off
            goArr (SS dep) ish eltty idx val off1
      (STScal{}, _, _) -> error "accumAdd: Scal impossible with nonzero depth"
      (STNil, _, _) -> error "accumAdd: Nil impossible with nonzero depth"
      (STAccum{}, _, _) -> error "Nested accumulators unsupported"

    goArr :: SNat i' -> InvShape n -> STy t'
          -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') -> Int -> ST s ()
    goArr SZ ish t1 () val off = performAddArr ish t1 val off
    goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off
    goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do
      let i' = fromIntegral @(Rep TIx) @Int i
      when (i' < 0 || i' >= n) $
        error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")"
      goArr depm1 ish t1 idx val (off + i' * ishSize ish)

    performAdd :: Bool -> STy t' -> Rep t' -> Int -> ST s Int
    performAdd inarr ty val off = case ty of
      STNil -> return off
      STPair t1 t2 -> do
        off1 <- performAdd inarr t1 (fst val) off
        performAdd inarr t2 (snd val) off1
      STEither t1 t2 -> do
        let !(I# off#) = off
        tag <- ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', I8# i #)
        off1 <- case (val, tag) of
                  (Left val1, 0) -> performAdd inarr t1 val1 (off + 1)
                  (Right val2, 1) -> performAdd inarr t2 val2 (off + 1)
                  _ -> error "accumAdd: Tag mismatch for Either"
        if inarr
          then return (off + 1 + max (tSize' t1 Nothing) (tSize' t2 Nothing))
          else return off1
      STArr n ty'
        | inarr -> error "Nested array"
        | otherwise -> do
            (off1, sh) <- readShape mbarr n off
            let sz = shapeSize sh
                eltsize = tSize' ty' Nothing
            forM_ [0 .. sz - 1] $ \lini ->
              performAdd True ty' (arrayIndexLinear val lini) (off1 + lini * eltsize)
            return (off1 + sz * eltsize)
      STScal ty' -> performAddScal ty' val off
      STAccum{} -> error "Nested accumulators unsupported"

    performAddScal :: SScalTy t' -> ScalRep t' -> Int -> ST s Int
    performAddScal STI32 (I32# x#) (I# off#)
      | sizeOf (undefined :: Int) == 4
      = ST $ \s -> case fetchAddIntArray# mbarr off# (int32ToInt# x#) s of
                     (# s', _ #) -> (# s', I# (off# +# 4#) #)
      | otherwise
      = _
    performAddScal STI64 (I64# x#) (I# off#)
      | sizeOf (undefined :: Int) == 8
      = ST $ \s -> case fetchAddIntArray# mbarr off# (int64ToInt# x#) s of
                     (# s', _ #) -> (# s', I# (off# +# 8#) #)
      | otherwise
      = _
    performAddScal STF32 (F# x) (I# off#) = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #)
    performAddScal STF64 (D# x) (I# off#) = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #)
    performAddScal STBool b (I# off#) =
      let !(I8# i) = fromIntegral (fromEnum b)
      in ST $ \s -> (# writeInt8Array# mbarr off# i s, I# (off# +# 1#) #)

    casLoop :: (Addr# -> w -> w -> State# d -> (# State# d, w #))
            -> (w -> w)
            -> r
            -> State# d -> (# State# d, r #)
    casLoop casOp add ret = _

withAccum :: STy t -> Rep t -> (Accum s t -> AcM s b) -> AcM s (Rep t, b)
withAccum ty start fun = do
  let !(I# size) = tSize ty start
  accum <- AcM . ST $ \s -> case newByteArray# size s of
                              (# s', mbarr #) -> (# s', Accum ty mbarr #)
  AcM $ accumWrite accum start
  b <- fun accum
  out <- accumRead accum
  return (out, b)