1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
module Interpreter (
interpret,
interpret',
Value,
NoAccum(..),
unAccum,
) where
import Data.Int (Int64)
import Data.Proxy
import AST
import Data
import Array
import Interpreter.Accum
import Interpreter.Rep
import Control.Monad (foldM)
interpret :: NoAccum t => Ex '[] t -> Rep t
interpret e = runAcM (go e)
where
go :: forall s t. NoAccum t => Ex '[] t -> AcM s (Rep t)
go e' | Refl <- noAccum (Proxy @s) (Proxy @t) = interpret' SNil e'
newtype Value s t = Value (Rep' s t)
interpret' :: forall env t s. SList (Value s) env -> Ex env t -> AcM s (Rep' s t)
interpret' env = \case
EVar _ _ i -> case slistIdx env i of Value x -> return x
ELet _ a b -> do
x <- interpret' env a
interpret' (Value x `SCons` env) b
EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b
EFst _ e -> fst <$> interpret' env e
ESnd _ e -> snd <$> interpret' env e
ENil _ -> return ()
EInl _ _ e -> Left <$> interpret' env e
EInr _ _ e -> Right <$> interpret' env e
ECase _ e a b -> interpret' env e >>= \case
Left x -> interpret' (Value x `SCons` env) a
Right y -> interpret' (Value y `SCons` env) b
EConstArr _ _ _ v -> return v
EBuild1 _ a b -> do
n <- fromIntegral @Int64 @Int <$> interpret' env a
arrayGenerateLinM (ShNil `ShCons` n)
(\i -> interpret' (Value (fromIntegral @Int @Int64 i) `SCons` env) b)
EBuild _ dim a b -> do
sh <- unTupRepIdx (Proxy @s) ShNil ShCons dim <$> interpret' env a
arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx (Proxy @s) ixUncons dim idx) `SCons` env) b)
EFold1Inner _ a b -> do
let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a
arr <- interpret' env b
let sh `ShCons` n = arrayShape arr
arrayGenerateM sh $ \idx -> foldl1M f [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
ESum1Inner _ e -> do
arr <- interpret' env e
let STArr _ (STScal t) = typeOf e
sh `ShCons` n = arrayShape arr
numericIsNum t $ arrayGenerateM sh $ \idx -> return $ sum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
EUnit _ e -> arrayGenerateLinM ShNil (\_ -> interpret' env e)
EReplicate1Inner _ a b -> do
n <- fromIntegral @Int64 @Int <$> interpret' env a
arr <- interpret' env b
let sh = arrayShape arr
arrayGenerateM (sh `ShCons` n) (\(idx `IxCons` _) -> return (arrayIndex arr idx))
EConst _ _ v -> return v
EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e
EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b)
EIdx _ n a b -> arrayIndex <$> interpret' env a <*> (unTupRepIdx (Proxy @s) IxNil IxCons n <$> interpret' env b)
EShape _ e | STArr n _ <- typeOf e -> tupRepIdx (Proxy @s) shUncons n . arrayShape <$> interpret' env e
EOp _ op e -> interpretOp (Proxy @s) op <$> interpret' env e
EWith e1 e2 -> do
initval <- interpret' env e1
withAccum (typeOf e1) initval $ \accum ->
interpret' (Value accum `SCons` env) e2
EAccum i e1 e2 e3 -> do
idx <- interpret' env e1
val <- interpret' env e2
accum <- interpret' env e3
accumAdd accum i idx val
EError _ s -> error $ "Interpreter: Program threw error: " ++ s
interpretOp :: Proxy s -> SOp a t -> Rep' s a -> Rep' s t
interpretOp _ op arg = case op of
OAdd st -> numericIsNum st $ uncurry (+) arg
OMul st -> numericIsNum st $ uncurry (*) arg
ONeg st -> numericIsNum st $ negate arg
OLt st -> numericIsNum st $ uncurry (<) arg
OLe st -> numericIsNum st $ uncurry (<=) arg
OEq st -> numericIsNum st $ uncurry (==) arg
ONot -> not arg
OIf -> if arg then Left () else Right ()
numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r
numericIsNum STI32 = id
numericIsNum STI64 = id
numericIsNum STF32 = id
numericIsNum STF64 = id
unTupRepIdx :: Proxy s -> f Z -> (forall m. f m -> Int -> f (S m))
-> SNat n -> Rep' s (Tup (Replicate n TIx)) -> f n
unTupRepIdx _ nil _ SZ _ = nil
unTupRepIdx p nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i
tupRepIdx :: Proxy s -> (forall m. f (S m) -> (f m, Int))
-> SNat n -> f n -> Rep' s (Tup (Replicate n TIx))
tupRepIdx _ _ SZ _ = ()
tupRepIdx p uncons (SS n) tup =
let (tup', i) = uncons tup
in (tupRepIdx p uncons n tup', fromIntegral @Int @Int64 i)
ixUncons :: Index (S n) -> (Index n, Int)
ixUncons (IxCons idx i) = (idx, i)
shUncons :: Shape (S n) -> (Shape n, Int)
shUncons (ShCons idx i) = (idx, i)
class NoAccum t where
noAccum :: Proxy s -> Proxy t -> Rep' s t :~: Rep t
instance NoAccum TNil where
noAccum _ _ = Refl
instance (NoAccum a, NoAccum b) => NoAccum (TPair a b) where
noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl
instance (NoAccum a, NoAccum b) => NoAccum (TEither a b) where
noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl
instance NoAccum t => NoAccum (TArr n t) where
noAccum p _ | Refl <- noAccum p (Proxy @t) = Refl
instance NoAccum (TScal t) where
noAccum _ _ = Refl
unAccum :: Proxy s -> STy t -> Maybe (Dict (NoAccum t))
unAccum _ STNil = Just Dict
unAccum p (STPair t1 t2)
| Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict
| otherwise = Nothing
unAccum p (STEither t1 t2)
| Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict
| otherwise = Nothing
unAccum p (STArr _ t)
| Just Dict <- unAccum p t = Just Dict
| otherwise = Nothing
unAccum _ STScal{} = Just Dict
unAccum _ STAccum{} = Nothing
foldl1M :: Monad m => (a -> a -> m a) -> [a] -> m a
foldl1M _ [] = error "foldl1M: empty list"
foldl1M f (tophead : toptail) = foldM f tophead toptail
|