summaryrefslogtreecommitdiff
path: root/src/Interpreter.hs
blob: 8728ec0213bc91e9476cdbf6f63721289b301275 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Interpreter (
  interpret,
  interpret',
  Value,
) where

import Control.Monad (foldM)
import Data.Int (Int64)
import Data.Proxy
import System.IO.Unsafe (unsafePerformIO)

import Array
import AST
import CHAD.Types
import Data
import Interpreter.Rep


newtype AcM s a = AcM (IO a)
  deriving newtype (Functor, Applicative, Monad)

runAcM :: (forall s. AcM s a) -> a
runAcM (AcM m) = unsafePerformIO m

interpret :: Ex '[] t -> Rep t
interpret e = runAcM (interpret' SNil e)

newtype Value t = Value (Rep t)

interpret' :: forall env t s. SList Value env -> Ex env t -> AcM s (Rep t)
interpret' env = \case
  EVar _ _ i -> case slistIdx env i of Value x -> return x
  ELet _ a b -> do
    x <- interpret' env a
    interpret' (Value x `SCons` env) b
  EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b
  EFst _ e -> fst <$> interpret' env e
  ESnd _ e -> snd <$> interpret' env e
  ENil _ -> return ()
  EInl _ _ e -> Left <$> interpret' env e
  EInr _ _ e -> Right <$> interpret' env e
  ECase _ e a b -> interpret' env e >>= \case
                     Left x -> interpret' (Value x `SCons` env) a
                     Right y -> interpret' (Value y `SCons` env) b
  ENothing _ _ -> _
  EJust _ _ -> _
  EMaybe _ _ _ _ -> _
  EConstArr _ _ _ v -> return v
  EBuild1 _ a b -> do
    n <- fromIntegral @Int64 @Int <$> interpret' env a
    arrayGenerateLinM (ShNil `ShCons` n)
                      (\i -> interpret' (Value (fromIntegral @Int @Int64 i) `SCons` env) b)
  EBuild _ dim a b -> do
    sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a
    arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx ixUncons dim idx) `SCons` env) b)
  EFold1Inner _ a b -> do
    let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a
    arr <- interpret' env b
    let sh `ShCons` n = arrayShape arr
    arrayGenerateM sh $ \idx -> foldl1M f [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
  ESum1Inner _ e -> do
    arr <- interpret' env e
    let STArr _ (STScal t) = typeOf e
        sh `ShCons` n = arrayShape arr
    numericIsNum t $ arrayGenerateM sh $ \idx -> return $ sum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
  EUnit _ e -> arrayGenerateLinM ShNil (\_ -> interpret' env e)
  EReplicate1Inner _ a b -> do
    n <- fromIntegral @Int64 @Int <$> interpret' env a
    arr <- interpret' env b
    let sh = arrayShape arr
    arrayGenerateM (sh `ShCons` n) (\(idx `IxCons` _) -> return (arrayIndex arr idx))
  EConst _ _ v -> return v
  EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e
  EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b)
  EIdx _ n a b -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b)
  EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e
  EOp _ op e -> interpretOp op <$> interpret' env e
  EWith e1 e2 -> do
    initval <- interpret' env e1
    withAccum (typeOf e1) initval $ \accum ->
      interpret' (Value accum `SCons` env) e2
  EAccum i e1 e2 e3 -> do
    idx <- interpret' env e1
    val <- interpret' env e2
    accum <- interpret' env e3
    accumAdd accum i idx val
  EZero t -> do
    return $ makeZero t
  EPlus t a b -> do
    a' <- interpret' env a
    b' <- interpret' env b
    return $ makePlus t a' b'
  EError _ s -> error $ "Interpreter: Program threw error: " ++ s

interpretOp :: SOp a t -> Rep a -> Rep t
interpretOp op arg = case op of
  OAdd st -> numericIsNum st $ uncurry (+) arg
  OMul st -> numericIsNum st $ uncurry (*) arg
  ONeg st -> numericIsNum st $ negate arg
  OLt st -> numericIsNum st $ uncurry (<) arg
  OLe st -> numericIsNum st $ uncurry (<=) arg
  OEq st -> numericIsNum st $ uncurry (==) arg
  ONot -> not arg
  OIf -> if arg then Left () else Right ()

makeZero :: STy t -> Rep (D2 t)
makeZero typ = case typ of
  STNil -> ()
  STPair _ _ -> Left ()
  STEither _ _ -> Left ()
  STMaybe _ -> Nothing
  STArr n _ -> emptyArray n
  STScal sty -> case sty of
                  STI32 -> ()
                  STI64 -> ()
                  STF32 -> 0.0
                  STF64 -> 0.0
                  STBool -> ()
  STAccum{} -> error "Zero of Accum"

makePlus :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t)
makePlus typ a b = case typ of
  STNil -> ()
  STPair t1 t2 -> case (a, b) of
    (Left (), _) -> b
    (_, Left ()) -> a
    (Right (x1, x2), Right (y1, y2)) -> Right (makePlus t1 x1 y1, makePlus t2 x2 y2)
  STEither t1 t2 -> case (a, b) of
    (Left (), _) -> b
    (_, Left ()) -> a
    (Right (Left x), Right (Left y)) -> Right (Left (makePlus t1 x y))
    (Right (Right x), Right (Right y)) -> Right (Right (makePlus t2 x y))
    _ -> error "Plus of inconsistent Eithers"
  STArr _ t ->
    let sh1 = arrayShape a
        sh2 = arrayShape b
    in if | shapeSize sh1 == 0 -> b
          | shapeSize sh2 == 0 -> a
          | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> makePlus t (arrayIndexLinear a i) (arrayIndexLinear b i))
          | otherwise -> error "Plus of inconsistently shaped arrays"
  STScal sty -> case sty of
    STI32 -> ()
    STI64 -> ()
    STF32 -> a + b
    STF64 -> a + b
    STBool -> ()
  STAccum{} -> error "Plus of Accum"

numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r
numericIsNum STI32 = id
numericIsNum STI64 = id
numericIsNum STF32 = id
numericIsNum STF64 = id

unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m))
            -> SNat n -> Rep (Tup (Replicate n TIx)) -> f n
unTupRepIdx nil _    SZ _ = nil
unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i

tupRepIdx :: (forall m. f (S m) -> (f m, Int))
          -> SNat n -> f n -> Rep (Tup (Replicate n TIx))
tupRepIdx _      SZ _ = ()
tupRepIdx uncons (SS n) tup =
  let (tup', i) = uncons tup
  in (tupRepIdx uncons n tup', fromIntegral @Int @Int64 i)

ixUncons :: Index (S n) -> (Index n, Int)
ixUncons (IxCons idx i) = (idx, i)

shUncons :: Shape (S n) -> (Shape n, Int)
shUncons (ShCons idx i) = (idx, i)

foldl1M :: Monad m => (a -> a -> m a) -> [a] -> m a
foldl1M _ [] = error "foldl1M: empty list"
foldl1M f (tophead : toptail) = foldM f tophead toptail