summaryrefslogtreecommitdiff
path: root/src/Interpreter/Accum.hs
blob: b0deaef72b6b9c685587dfa09afc81a24ead62cd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE BangPatterns #-}
module Interpreter.Accum where

import Control.Monad.ST
import Control.Monad.ST.Unsafe
import GHC.Exts
import GHC.Int
import GHC.ST (ST(..))

import AST
import Data
import Interpreter.Array


newtype AcM s a = AcM (ST s a)
  deriving newtype (Functor, Applicative, Monad)

runAcM :: (forall s. AcM s a) -> a
runAcM m = runST (case m of AcM m' -> m')

type family Rep t where
  Rep TNil = ()
  Rep (TPair a b) = (Rep a, Rep b)
  Rep (TEither a b) = Either (Rep a) (Rep b)
  Rep (TArr n t) = Array n (Rep t)
  Rep (TScal sty) = ScalRep sty
  -- Rep (TAccum t) = _

data Accum s t = Accum (STy t) (MutableByteArray# s)

tSize :: STy t -> Rep t -> Int
tSize ty x = tSize' ty (Just x)

-- | Passing Nothing as the value means "this is (inside) an array element".
tSize' :: STy t -> Maybe (Rep t) -> Int
tSize' typ val = case typ of
  STNil -> 0
  STPair a b -> tSize' a (fst <$> val) + tSize' b (snd <$> val)
  STEither a b ->
    case val of
      Nothing -> 1 + max (tSize' a Nothing) (tSize' b Nothing)
      Just (Left x) -> 1 + tSize a x   -- '1 +' is for runtime sanity checking
      Just (Right y) -> 1 + tSize b y  -- idem
  STArr ndim t ->
    case val of
      Nothing -> error "Nested arrays not supported in this implementation"
      Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' t Nothing
  STScal sty -> goScal sty
  STAccum{} -> error "Nested accumulators unsupported"
  where
    goScal :: SScalTy t -> Int
    goScal STI32 = 4
    goScal STI64 = 8
    goScal STF32 = 4
    goScal STF64 = 8
    goScal STBool = 1

-- | This operation does not commute with 'accumAdd', so it must be used with
-- care. Furthermore it must be used on exactly the same value as tSize was
-- called on. Hence it lives in ST, not in AcM.
accumWrite :: forall s t. Accum s t -> Rep t -> ST s ()
accumWrite (Accum topty mbarr) = \val -> () <$ go False topty val 0#
  where
    go :: Bool -> STy t' -> Rep t' -> Int# -> ST s Int
    go inarr ty val off# = case ty of
      STNil -> return (I# off#)
      STPair a b -> do
        I# off1# <- go inarr a (fst val) off#
        go inarr b (snd val) off1#
      STEither a b ->
        case val of
          Left x -> do
            let !(I8# tag#) = 0
            ST $ \s -> (# writeInt8Array# mbarr off# tag# s, () #)
            go inarr a x (off# +# 1#)
          Right y -> do
            let !(I8# tag#) = 1
            ST $ \s -> (# writeInt8Array# mbarr off# tag# s, () #)
            go inarr b y (off# +# 1#)
      STArr _ t
        | inarr -> error "Nested arrays not supported in this implementation"
        | otherwise -> do
            I# off1# <- goShape (arrayShape val) off#
            let !(I# eltsize#) = tSize' t Nothing
                !(I# n#) = arraySize val
            traverseArray_ (\(I# lini#) x -> () <$ go True t x (off1# +# eltsize# *# lini#)) val
            return (I# (off1# +# eltsize# *# n#))
      STScal sty -> goScal sty val off#
      STAccum{} -> error "Nested accumulators unsupported"

    goShape :: Shape n -> Int# -> ST s Int
    goShape ShNil off# = return (I# off#)
    goShape (ShCons sh n) off# = do
      I# off1# <- goShape sh off#
      let !(I64# n') = fromIntegral n
      ST $ \s -> (# writeInt64Array# mbarr off1# n' s, I# (off1# +# 8#) #)

    goScal :: SScalTy t' -> ScalRep t' -> Int# -> ST s Int
    goScal STI32 (I32# x) off# = ST $ \s -> (# writeInt32Array# mbarr off# x s, I# (off# +# 4#) #)
    goScal STI64 (I64# x) off# = ST $ \s -> (# writeInt64Array# mbarr off# x s, I# (off# +# 8#) #)
    goScal STF32 (F# x) off# = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #)
    goScal STF64 (D# x) off# = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #)
    goScal STBool b off# =
      let !(I8# i) = fromIntegral (fromEnum b)
      in ST $ \s -> (# writeInt8Array# mbarr off# i s, I# (off# +# 1#) #)

accumRead :: forall s t. Accum s t -> AcM s (Rep t)
accumRead (Accum topty mbarr) = AcM $ snd <$> go False topty 0#
  where
    go :: Bool -> STy t' -> Int# -> ST s (Int, Rep t')
    go inarr ty off# = case ty of
      STNil -> return (I# off#, ())
      STPair a b -> do
        (I# off1#, x) <- go inarr a off#
        (I# off2#, y) <- go inarr b off1#
        return (I# (off1# +# off2#), (x, y))
      STEither a b -> do
        tag <- ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', I8# i #)
        if inarr
          then do val <- case tag of
                           0 -> Left . snd <$> go inarr a (off# +# 1#)
                           1 -> Right . snd <$> go inarr b (off# +# 1#)
                           _ -> error "Invalid tag in accum memory"
                  return (I# off# + max (tSize' a Nothing) (tSize' b Nothing), val)
          else case tag of
                 0 -> fmap Left <$> go inarr a (off# +# 1#)
                 1 -> fmap Right <$> go inarr b (off# +# 1#)
                 _ -> error "Invalid tag in accum memory"
      STArr ndim t
        | inarr -> error "Nested arrays not supported in this implementation"
        | otherwise -> do
            (I# off1#, sh) <- readShape mbarr ndim off#
            let !(I# eltsize#) = tSize' t Nothing
                !(I# n#) = shapeSize sh
            arr <- arrayGenerateLinM sh (\(I# lini#) -> snd <$> go True t (off1# +# eltsize# *# lini#))
            return (I# (off1# +# eltsize# *# n#), arr)
      STScal sty -> goScal sty off#
      STAccum{} -> error "Nested accumulators unsupported"

    goScal :: SScalTy t' -> Int# -> ST s (Int, ScalRep t')
    goScal STI32 off# = ST $ \s -> case readInt32Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 4#), I32# i) #)
    goScal STI64 off# = ST $ \s -> case readInt64Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 8#), I64# i) #)
    goScal STF32 off# = ST $ \s -> case readFloatArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 4#), F# f) #)
    goScal STF64 off# = ST $ \s -> case readDoubleArray# mbarr off# s of (# s', f #) -> (# s', (I# (off# +# 8#), D# f) #)
    goScal STBool off# = ST $ \s -> case readInt8Array# mbarr off# s of (# s', i #) -> (# s', (I# (off# +# 1#), toEnum (fromIntegral (I8# i))) #)

readShape :: MutableByteArray# s -> SNat n -> Int# -> ST s (Int, Shape n)
readShape _ SZ off# = return (I# off#, ShNil)
readShape mbarr (SS ndim) off# = do
  (I# off1#, sh) <- readShape mbarr ndim off#
  n' <- ST $ \s -> case readInt64Array# mbarr off1# s of (# s', i #) -> (# s', I64# i #)
  return (I# (off1# +# 8#), ShCons sh (fromIntegral n'))

data InvShape full yet where
  IShFull :: InvShape full full
  IShCons :: InvShape full (S yet) -> Int -> InvShape full yet

invertShape :: Shape n -> InvShape n Z
invertShape


accumAdd :: forall s t i. Accum s t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s ()
accumAdd (Accum topty mbarr) = \depth index value -> AcM $ () <$ go False topty depth index value 0#
  where
    go :: Bool -> STy t' -> SNat i' -> Rep (AcIdx t' i') -> Rep (AcVal t' i') -> Int# -> ST s Int
    go inarr ty SZ idx val off# = performAdd inarr ty val off#
    go inarr ty (SS dep) idx val off# = case (ty, idx, val) of
      (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off#
      (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off#
      (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off#
      (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off#
      (STArr rank eltty, _, _)
        | inarr -> error "Nested arrays"
        | otherwise -> goArr (SS dep) rank eltty idx val off#
      -- (STArr SZ t1, _, _) -> go inarr t1 dep idx val off#
      -- (STArr (SS rankm1) t1, _, _) -> go inarr t1 dep idx val off#
      _ -> _

    goArr :: SNat i' -> SNat n -> STy t'
          -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') -> Int# -> ST s Int
    goArr dep n t1 idx val off# = do
      (I# off1#, sh) <- readShape mbarr n off#
      _

    collectArrIndex :: SNat i' -> STy t' -> Shape n
                    -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i')
                    -> (forall i2. Index n -> SNat i2 -> Rep (AcIdx t' i2) -> Rep (AcVal t' i2) -> ST s r)
                    -> (forall i2. Index n -> SNat i2 -> Rep (AcIdx t' i2) -> Rep (AcVal t' i2) -> ST s r)
                    -> ST s r
    collectArrIndex (SS dep) eltty ShNil idx val k = k IxNil dep idx val
    collectArrIndex (SS dep) eltty (ShCons sh size) (i, idx) val k =
      collectArrIndex dep eltty sh idx val $ \index dep' idx' val' ->
        k (IxCons index (fromIntegral i)) dep' idx' val'
    -- collectArrIndex SZ eltty

    goShape :: Shape n -> Int# -> ST s Int
    goShape ShNil off# = return (I# off#)
    goShape (ShCons sh n) off# = do
      I# off1# <- goShape sh off#
      let !(I64# n') = fromIntegral n
      ST $ \s -> (# writeInt64Array# mbarr off1# n' s, I# (off1# +# 8#) #)

    goScal :: SScalTy t' -> ScalRep t' -> Int# -> ST s Int
    goScal STI32 (I32# x) off# = ST $ \s -> (# writeInt32Array# mbarr off# x s, I# (off# +# 4#) #)
    goScal STI64 (I64# x) off# = ST $ \s -> (# writeInt64Array# mbarr off# x s, I# (off# +# 8#) #)
    goScal STF32 (F# x) off# = ST $ \s -> (# writeFloatArray# mbarr off# x s, I# (off# +# 4#) #)
    goScal STF64 (D# x) off# = ST $ \s -> (# writeDoubleArray# mbarr off# x s, I# (off# +# 8#) #)
    goScal STBool b off# =
      let !(I8# i) = fromIntegral (fromEnum b)
      in ST $ \s -> (# writeInt8Array# mbarr off# i s, I# (off# +# 1#) #)

withAccum :: STy t -> Rep t -> (Accum s t -> AcM s b) -> AcM s (Rep t, b)
withAccum ty start fun = do
  let !(I# size) = tSize ty start
  accum <- AcM . ST $ \s -> case newByteArray# size s of
                              (# s', mbarr #) -> (# s', Accum ty mbarr #)
  AcM $ accumWrite accum start
  b <- fun accum
  out <- accumRead accum
  return (out, b)