summaryrefslogtreecommitdiff
path: root/src/Interpreter.hs
blob: f58cefbc12dfa5095207576df2b66675e6f7519a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TupleSections #-}
module Interpreter (
  interpret,
  interpret',
  Value,
) where

import Control.Monad (foldM, join)
import Data.Int (Int64)
import Data.IORef
import System.IO.Unsafe (unsafePerformIO)

import Array
import AST
import CHAD.Types
import Data
import Interpreter.Rep
import Data.Bifunctor (first)


newtype AcM s a = AcM { unAcM :: IO a }
  deriving newtype (Functor, Applicative, Monad)

runAcM :: (forall s. AcM s a) -> a
runAcM (AcM m) = unsafePerformIO m

interpret :: Ex '[] t -> Rep t
interpret e = runAcM (interpret' SNil e)

newtype Value t = Value (Rep t)

interpret' :: forall env t s. SList Value env -> Ex env t -> AcM s (Rep t)
interpret' env = \case
  EVar _ _ i -> case slistIdx env i of Value x -> return x
  ELet _ a b -> do
    x <- interpret' env a
    interpret' (Value x `SCons` env) b
  EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b
  EFst _ e -> fst <$> interpret' env e
  ESnd _ e -> snd <$> interpret' env e
  ENil _ -> return ()
  EInl _ _ e -> Left <$> interpret' env e
  EInr _ _ e -> Right <$> interpret' env e
  ECase _ e a b -> interpret' env e >>= \case
                     Left x -> interpret' (Value x `SCons` env) a
                     Right y -> interpret' (Value y `SCons` env) b
  ENothing _ _ -> return Nothing
  EJust _ e -> Just <$> interpret' env e
  EMaybe _ a b e -> maybe (interpret' env a) (\x -> interpret' (Value x `SCons` env) b) =<< interpret' env e
  EConstArr _ _ _ v -> return v
  EBuild1 _ a b -> do
    n <- fromIntegral @Int64 @Int <$> interpret' env a
    arrayGenerateLinM (ShNil `ShCons` n)
                      (\i -> interpret' (Value (fromIntegral @Int @Int64 i) `SCons` env) b)
  EBuild _ dim a b -> do
    sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a
    arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx ixUncons dim idx) `SCons` env) b)
  EFold1Inner _ a b -> do
    let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a
    arr <- interpret' env b
    let sh `ShCons` n = arrayShape arr
    arrayGenerateM sh $ \idx -> foldl1M f [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
  ESum1Inner _ e -> do
    arr <- interpret' env e
    let STArr _ (STScal t) = typeOf e
        sh `ShCons` n = arrayShape arr
    numericIsNum t $ arrayGenerateM sh $ \idx -> return $ sum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
  EUnit _ e -> arrayGenerateLinM ShNil (\_ -> interpret' env e)
  EReplicate1Inner _ a b -> do
    n <- fromIntegral @Int64 @Int <$> interpret' env a
    arr <- interpret' env b
    let sh = arrayShape arr
    arrayGenerateM (sh `ShCons` n) (\(idx `IxCons` _) -> return (arrayIndex arr idx))
  EConst _ _ v -> return v
  EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e
  EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b)
  EIdx _ n a b -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b)
  EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e
  EOp _ op e -> interpretOp op <$> interpret' env e
  EWith e1 e2 -> do
    initval <- interpret' env e1
    withAccum (typeOf e1) (typeOf e2) initval $ \accum ->
      interpret' (Value accum `SCons` env) e2
  EAccum i e1 e2 e3 -> do
    let STAccum t = typeOf e3
    idx <- interpret' env e1
    val <- interpret' env e2
    accum <- interpret' env e3
    accumAddSparse t i accum idx val
  EZero t -> do
    return $ zeroD2 t
  EPlus t a b -> do
    a' <- interpret' env a
    b' <- interpret' env b
    return $ addD2s t a' b'
  EError _ s -> error $ "Interpreter: Program threw error: " ++ s

interpretOp :: SOp a t -> Rep a -> Rep t
interpretOp op arg = case op of
  OAdd st -> numericIsNum st $ uncurry (+) arg
  OMul st -> numericIsNum st $ uncurry (*) arg
  ONeg st -> numericIsNum st $ negate arg
  OLt st -> numericIsNum st $ uncurry (<) arg
  OLe st -> numericIsNum st $ uncurry (<=) arg
  OEq st -> numericIsNum st $ uncurry (==) arg
  ONot -> not arg
  OIf -> if arg then Left () else Right ()

zeroD2 :: STy t -> Rep (D2 t)
zeroD2 typ = case typ of
  STNil -> ()
  STPair _ _ -> Left ()
  STEither _ _ -> Left ()
  STMaybe _ -> Nothing
  STArr n _ -> emptyArray n
  STScal sty -> case sty of
                  STI32 -> ()
                  STI64 -> ()
                  STF32 -> 0.0
                  STF64 -> 0.0
                  STBool -> ()
  STAccum{} -> error "Zero of Accum"

addD2s :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t)
addD2s typ a b = case typ of
  STNil -> ()
  STPair t1 t2 -> case (a, b) of
    (Left (), _) -> b
    (_, Left ()) -> a
    (Right (x1, x2), Right (y1, y2)) -> Right (addD2s t1 x1 y1, addD2s t2 x2 y2)
  STEither t1 t2 -> case (a, b) of
    (Left (), _) -> b
    (_, Left ()) -> a
    (Right (Left x), Right (Left y)) -> Right (Left (addD2s t1 x y))
    (Right (Right x), Right (Right y)) -> Right (Right (addD2s t2 x y))
    _ -> error "Plus of inconsistent Eithers"
  STMaybe t -> case (a, b) of
    (Nothing, _) -> b
    (_, Nothing) -> a
    (Just x, Just y) -> Just (addD2s t x y)
  STArr _ t ->
    let sh1 = arrayShape a
        sh2 = arrayShape b
    in if | shapeSize sh1 == 0 -> b
          | shapeSize sh2 == 0 -> a
          | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addD2s t (arrayIndexLinear a i) (arrayIndexLinear b i))
          | otherwise -> error "Plus of inconsistently shaped arrays"
  STScal sty -> case sty of
    STI32 -> ()
    STI64 -> ()
    STF32 -> a + b
    STF64 -> a + b
    STBool -> ()
  STAccum{} -> error "Plus of Accum"

withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t)
withAccum t _ initval f = AcM $ do
  accum <- newAcSparse t initval
  out <- case f accum of AcM m -> m
  val <- readAcSparse t accum
  return (out, val)

newAcSparse :: STy t -> Rep t -> IO (RepAcSparse t)
newAcSparse typ val = case typ of
  STNil -> return ()
  STPair{} -> newIORef =<<newAcDense typ val
  STMaybe t -> newIORef =<< traverse (newAcDense t) val
  STArr _ t -> newIORef =<< traverse (newAcSparse t) val
  STScal{} -> newIORef val
  STAccum{} -> error "Nested accumulators"
  STEither{} -> error "Bare Either in accumulator"

newAcDense :: STy t -> Rep t -> IO (RepAcDense t)
newAcDense typ val = case typ of
  STNil -> return ()
  STPair t1 t2 -> (,) <$> newAcSparse t1 (fst val) <*> newAcSparse t2 (snd val)
  STEither t1 t2 -> case val of
    Left x -> Left <$> newAcSparse t1 x
    Right y -> Right <$> newAcSparse t2 y
  STMaybe t -> traverse (newAcSparse t) val
  STArr _ t -> traverse (newAcSparse t) val
  STScal{} -> return val
  STAccum{} -> error "Nested accumulators"

readAcSparse :: STy t -> RepAcSparse t -> IO (Rep t)
readAcSparse typ val = case typ of
  STNil -> return ()
  STPair t1 t2 -> do
    (a, b) <- readIORef val
    (,) <$> readAcSparse t1 a <*> readAcSparse t2 b
  STMaybe t -> traverse (readAcDense t) =<< readIORef val
  STArr _ t -> traverse (readAcSparse t) =<< readIORef val
  STScal{} -> readIORef val
  STAccum{} -> error "Nested accumulators"
  STEither{} -> error "Bare Either in accumulator"

readAcDense :: STy t -> RepAcDense t -> IO (Rep t)
readAcDense typ val = case typ of
  STNil -> return ()
  STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val)
  STEither t1 t2 -> case val of
    Left x -> Left <$> readAcSparse t1 x
    Right y -> Right <$> readAcSparse t2 y
  STMaybe t -> traverse (readAcSparse t) val
  STArr _ t -> traverse (readAcSparse t) val
  STScal{} -> return val
  STAccum{} -> error "Nested accumulators"

accumAddSparse :: STy t -> SNat i -> RepAcSparse t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s ()
accumAddSparse typ SZ ref () val = case typ of
  STNil -> return ()
  STPair t1 t2 -> AcM $ do
    (r1, r2) <- readIORef ref
    unAcM $ accumAddSparse t1 SZ r1 () (fst val)
    unAcM $ accumAddSparse t2 SZ r2 () (snd val)
  STMaybe t ->
    join $ AcM $ atomicModifyIORef' ref $ \ac -> case (ac, val) of
                   (Nothing, _) -> (ac, _)
                   (Just{}, Nothing) -> (ac, return ())
                   (Just ac', Just val') -> first Just (accumAddDense t SZ ac' () val')
  STArr _ t -> _ ref val
  STScal{} -> _ ref val
  STAccum{} -> error "Nested accumulators"
  STEither{} -> error "Bare Either in accumulator"
accumAddSparse typ (SS dep) ref idx val = case typ of
  STNil -> return ()
  STPair t1 t2 -> _ ref idx val
  STMaybe t -> _ ref idx val
  STArr _ t -> _ ref idx val
  STScal{} -> _ ref idx val
  STAccum{} -> error "Nested accumulators"
  STEither{} -> error "Bare Either in accumulator"

accumAddDense :: forall t i s. STy t -> SNat i -> RepAcDense t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAcDense t, AcM s ())
accumAddDense = _

-- accumAddVal :: forall t i s. STy t -> SNat i -> RepAc t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAc t, AcM s ())
-- accumAddVal typ SZ ac () val = case typ of
--   STNil -> ((), return ())
--   STPair t1 t2 ->
--     let (ac1', m1) = accumAddVal t1 SZ (fst ac) () (fst val)
--         (ac2', m2) = accumAddVal t2 SZ (snd ac) () (snd val)
--     in ((ac1', ac2'), m1 >> m2)
--   STMaybe t -> case t of
--     STEither t1 t2 -> (ac, accumAddValME t1 t2 ac val)
--     STNil -> def ; STPair{} -> def ; STMaybe{} -> def ; STArr{} -> def ; STScal{} -> def ; STAccum{} -> def
--     where def :: (t ~ TMaybe a, RepAc (TMaybe a) ~ IORef (Maybe (RepAc a))) => (RepAc t, AcM s ())
--           def = (ac, accumAddValM t ac val)
--   STArr n t
--     | shapeSize (arrayShape ac) == 0 -> makeRepAc (STArr n t) val
--   STEither{} -> error "Bare Either in accumulator"
--   _ -> _
-- accumAddVal typ (SS dep) ac idx val = case typ of
--   STNil -> ((), return ())
--   STPair t1 t2 ->
--     case (idx, val) of
--       (Left idx', Left val') -> first (,snd ac) $ accumAddVal t1 dep (fst ac) idx' val'
--       (Right idx', Right val') -> first (fst ac,) $ accumAddVal t2 dep (snd ac) idx' val'
--       _ -> error "Inconsistent idx and val in accumulator add operation"
--   _ -> _

-- accumAddValME :: STy a -> STy b
--               -> IORef (Maybe (Either (RepAc a) (RepAc b)))
--               -> Maybe (Either (Rep a) (Rep b))
--               -> AcM s ()
-- accumAddValME t1 t2 ac val =
--   case val of
--     Nothing -> return ()
--     Just val' ->
--       join $ AcM $ atomicModifyIORef' ac $ \ac' -> case (ac', val') of
--                      (Nothing, _) -> (Nothing, AcM $ initAccumOrTryAgainME t1 t2 ac val' (unAcM $ accumAddValME t1 t2 ac val))
--                      (Just (Left x), Left val'1) -> first (Just . Left) $ accumAddVal t1 SZ x () val'1
--                      (Just (Right y), Right val'2) -> first (Just . Right) $ accumAddVal t2 SZ y () val'2
--                      _ -> error "Inconsistent accumulator and value in add operation on Maybe Either"

-- initAccumOrTryAgainME :: STy a -> STy b
--                       -> IORef (Maybe (Either (RepAc a) (RepAc b)))
--                       -> Either (Rep a) (Rep b)
--                       -> IO ()
--                       -> IO ()
-- initAccumOrTryAgainME t1 t2 ac val onRace = do
--   newContents <- case val of Left x -> Left <$> makeRepAc t1 x
--                              Right y -> Right <$> makeRepAc t2 y
--   join $ atomicModifyIORef' ac (\case Nothing -> (Just newContents, return ())
--                                       value@Just{} -> (value, onRace))

-- accumAddValM :: STy t
--              -> IORef (Maybe (RepAc t))
--              -> Maybe (Rep t)
--              -> AcM s ()
-- accumAddValM t ac val =
--   case val of
--     Nothing -> return ()
--     Just val' ->
--       join $ AcM $ atomicModifyIORef' ac $ \ac' -> case ac' of
--                      Nothing -> (Nothing, AcM $ initAccumOrTryAgainM t ac val' (unAcM $ accumAddValM t ac val))
--                      Just x -> first Just $ accumAddVal t SZ x () val'

-- initAccumOrTryAgainM :: STy t
--                      -> IORef (Maybe (RepAc t))
--                      -> Rep t
--                      -> IO ()
--                      -> IO ()
-- initAccumOrTryAgainM t ac val onRace = do
--   newContents <- makeRepAc t val
--   join $ atomicModifyIORef' ac (\case Nothing -> (Just newContents, return ())
--                                       value@Just{} -> (value, onRace))

numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r
numericIsNum STI32 = id
numericIsNum STI64 = id
numericIsNum STF32 = id
numericIsNum STF64 = id

unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m))
            -> SNat n -> Rep (Tup (Replicate n TIx)) -> f n
unTupRepIdx nil _    SZ _ = nil
unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx nil cons n idx `cons` fromIntegral @Int64 @Int i

tupRepIdx :: (forall m. f (S m) -> (f m, Int))
          -> SNat n -> f n -> Rep (Tup (Replicate n TIx))
tupRepIdx _      SZ _ = ()
tupRepIdx uncons (SS n) tup =
  let (tup', i) = uncons tup
  in (tupRepIdx uncons n tup', fromIntegral @Int @Int64 i)

ixUncons :: Index (S n) -> (Index n, Int)
ixUncons (IxCons idx i) = (idx, i)

shUncons :: Shape (S n) -> (Shape n, Int)
shUncons (ShCons idx i) = (idx, i)

foldl1M :: Monad m => (a -> a -> m a) -> [a] -> m a
foldl1M _ [] = error "foldl1M: empty list"
foldl1M f (tophead : toptail) = foldM f tophead toptail