1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Interpreter (
interpret,
interpret',
Value,
) where
import Control.Monad (foldM)
import Data.Int (Int64)
import Data.Proxy
import System.IO.Unsafe (unsafePerformIO)
import Array
import AST
import CHAD.Types
import Data
import Interpreter.Rep
newtype AcM s a = AcM (IO a)
deriving newtype (Functor, Applicative, Monad)
runAcM :: (forall s. AcM s a) -> a
runAcM (AcM m) = unsafePerformIO m
interpret :: Ex '[] t -> Rep t
interpret e = runAcM (interpret' SNil e)
newtype Value t = Value (Rep t)
interpret' :: forall env t s. SList Value env -> Ex env t -> AcM s (Rep t)
interpret' env = \case
EVar _ _ i -> case slistIdx env i of Value x -> return x
ELet _ a b -> do
x <- interpret' env a
interpret' (Value x `SCons` env) b
EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b
EFst _ e -> fst <$> interpret' env e
ESnd _ e -> snd <$> interpret' env e
ENil _ -> return ()
EInl _ _ e -> Left <$> interpret' env e
EInr _ _ e -> Right <$> interpret' env e
ECase _ e a b -> interpret' env e >>= \case
Left x -> interpret' (Value x `SCons` env) a
Right y -> interpret' (Value y `SCons` env) b
ENothing _ _ -> _
EJust _ _ -> _
EMaybe _ _ _ _ -> _
EConstArr _ _ _ v -> return v
EBuild1 _ a b -> do
n <- fromIntegral @Int64 @Int <$> interpret' env a
arrayGenerateLinM (ShNil `ShCons` n)
(\i -> interpret' (Value (fromIntegral @Int @Int64 i) `SCons` env) b)
EBuild _ dim a b -> do
sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a
arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx ixUncons dim idx) `SCons` env) b)
EFold1Inner _ a b -> do
let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a
arr <- interpret' env b
let sh `ShCons` n = arrayShape arr
arrayGenerateM sh $ \idx -> foldl1M f [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
ESum1Inner _ e -> do
arr <- interpret' env e
let STArr _ (STScal t) = typeOf e
sh `ShCons` n = arrayShape arr
numericIsNum t $ arrayGenerateM sh $ \idx -> return $ sum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
EUnit _ e -> arrayGenerateLinM ShNil (\_ -> interpret' env e)
EReplicate1Inner _ a b -> do
n <- fromIntegral @Int64 @Int <$> interpret' env a
arr <- interpret' env b
let sh = arrayShape arr
arrayGenerateM (sh `ShCons` n) (\(idx `IxCons` _) -> return (arrayIndex arr idx))
EConst _ _ v -> return v
EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e
EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b)
EIdx _ n a b -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b)
EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e
EOp _ op e -> interpretOp op <$> interpret' env e
EWith e1 e2 -> do
initval <- interpret' env e1
withAccum (typeOf e1) initval $ \accum ->
interpret' (Value accum `SCons` env) e2
EAccum i e1 e2 e3 -> do
idx <- interpret' env e1
val <- interpret' env e2
accum <- interpret' env e3
accumAdd accum i idx val
EZero t -> do
return $ makeZero t
EPlus t a b -> do
a' <- interpret' env a
b' <- interpret' env b
return $ makePlus t a' b'
EError _ s -> error $ "Interpreter: Program threw error: " ++ s
interpretOp :: SOp a t -> Rep a -> Rep t
interpretOp op arg = case op of
OAdd st -> numericIsNum st $ uncurry (+) arg
OMul st -> numericIsNum st $ uncurry (*) arg
ONeg st -> numericIsNum st $ negate arg
OLt st -> numericIsNum st $ uncurry (<) arg
OLe st -> numericIsNum st $ uncurry (<=) arg
OEq st -> numericIsNum st $ uncurry (==) arg
ONot -> not arg
OIf -> if arg then Left () else Right ()
makeZero :: STy t -> Rep (D2 t)
makeZero typ = case typ of
STNil -> ()
STPair _ _ -> Left ()
STEither _ _ -> Left ()
STMaybe _ -> Nothing
STArr n _ -> emptyArray n
STScal sty -> case sty of
STI32 -> ()
STI64 -> ()
STF32 -> 0.0
STF64 -> 0.0
STBool -> ()
STAccum{} -> error "Zero of Accum"
makePlus :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t)
makePlus typ a b = case typ of
STNil -> ()
STPair t1 t2 -> case (a, b) of
(Left (), _) -> b
(_, Left ()) -> a
(Right (x1, x2), Right (y1, y2)) -> Right (makePlus t1 x1 y1, makePlus t2 x2 y2)
STEither t1 t2 -> case (a, b) of
(Left (), _) -> b
(_, Left ()) -> a
(Right (Left x), Right (Left y)) -> Right (Left (makePlus t1 x y))
(Right (Right x), Right (Right y)) -> Right (Right (makePlus t2 x y))
_ -> error "Plus of inconsistent Eithers"
STArr _ t ->
let sh1 = arrayShape a
sh2 = arrayShape b
in if | shapeSize sh1 == 0 -> b
| shapeSize sh2 == 0 -> a
| sh1 == sh2 -> arrayGenerateLin sh1 (\i -> makePlus t (arrayIndexLinear a i) (arrayIndexLinear b i))
| otherwise -> error "Plus of inconsistently shaped arrays"
STScal sty -> case sty of
STI32 -> ()
STI64 -> ()
STF32 -> a + b
STF64 -> a + b
STBool -> ()
STAccum{} -> error "Plus of Accum"
numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r
numericIsNum STI32 = id
numericIsNum STI64 = id
numericIsNum STF32 = id
numericIsNum STF64 = id
unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m))
-> SNat n -> Rep (Tup (Replicate n TIx)) -> f n
unTupRepIdx nil _ SZ _ = nil
unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i
tupRepIdx :: (forall m. f (S m) -> (f m, Int))
-> SNat n -> f n -> Rep (Tup (Replicate n TIx))
tupRepIdx _ SZ _ = ()
tupRepIdx uncons (SS n) tup =
let (tup', i) = uncons tup
in (tupRepIdx uncons n tup', fromIntegral @Int @Int64 i)
ixUncons :: Index (S n) -> (Index n, Int)
ixUncons (IxCons idx i) = (idx, i)
shUncons :: Shape (S n) -> (Shape n, Int)
shUncons (ShCons idx i) = (idx, i)
foldl1M :: Monad m => (a -> a -> m a) -> [a] -> m a
foldl1M _ [] = error "foldl1M: empty list"
foldl1M f (tophead : toptail) = foldM f tophead toptail
|