1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
module Eval (
eval,
) where
import Data.List (foldl')
import qualified Data.Vector as V
import AST
data Val env where
Top :: Val '[]
Push :: Val env -> a -> Val (a ': env)
prj :: Val env -> Idx env a -> a
prj (Push _ x) Zero = x
prj (Push env _) (Succ i) = prj env i
eval :: Exp '[] a -> a
eval = eval' Top
eval' :: forall env a. Val env -> Exp env a -> a
eval' env = \case
App f a -> rec f (rec a)
Lam _ e -> \x -> eval' (Push env x) e
Var _ i -> prj env i
Let a e -> eval' (Push env (rec a)) e
Lit l -> evalL l
Cond c a b -> if rec c then rec a else rec b
Const c -> evalC c
Pair a b -> (rec a, rec b)
Fst e -> fst (rec e)
Snd e -> snd (rec e)
Build sht she fe ->
let TFun _ ty = typeof fe
in build ty sht (rec she) (rec fe)
Ifold sht fe e0 she -> ifold sht (rec fe) (rec e0) (rec she)
Index a i -> index (rec a) (rec i)
Shape a -> shape (rec a)
Undef t -> error ("eval: Undef of type " ++ show t)
where rec :: Exp env t -> t
rec = eval' env
evalL :: Literal a -> a
evalL (LInt n) = n
evalL (LBool b) = b
evalL (LDouble d) = d
evalL (LArray a) = a
evalL (LShape Z) = ()
evalL (LShape sh) = unshape sh
evalL LNil = ()
evalL (LPair a b) = (evalL a, evalL b)
evalC :: Constant a -> a
evalC CAddI = uncurry (+)
evalC CSubI = uncurry (-)
evalC CMulI = uncurry (*)
evalC CDivI = uncurry div
evalC CAddF = uncurry (+)
evalC CSubF = uncurry (-)
evalC CMulF = uncurry (*)
evalC CDivF = uncurry (/)
evalC CLog = log
evalC CExp = exp
evalC CtoF = fromIntegral
evalC CRound = round
evalC CLtI = uncurry (<)
evalC CLeI = uncurry (<=)
evalC CLtF = uncurry (<)
evalC (CEq t) | Just Has <- typeHasEq t = uncurry (==)
| otherwise = error ("eval: Cannot Eq compare values of type " ++ show t)
evalC CAnd = uncurry (&&)
evalC COr = uncurry (||)
evalC CNot = not
build :: Type a -> ShapeType sh -> sh -> (sh -> a) -> Array sh a
build ty sht sh f =
let sh' = toshape sht sh
in Array sh' ty (V.generate (shapesize sh') (\i -> f (fromlinear sh' i)))
ifold :: ShapeType sh -> ((a, sh) -> a) -> a -> sh -> a
ifold sht f x0 sh = foldl' (curry f) x0 (enumshape (toshape sht sh))
index :: Array sh a -> sh -> a
index (Array sh _ v) idx = v V.! tolinear sh idx
shape :: Array sh a -> sh
shape (Array sh _ _) = unshape sh
enumshape :: Shape sh -> [sh]
enumshape sh = take (shapesize sh) (iterate (next sh) (zeroshape sh))
where
next :: Shape sh -> sh -> sh
next Z () = ()
next (sh' :. n) (idx, i)
| i < n = (idx, i + 1)
| otherwise = (next sh' idx, 0)
zeroshape :: Shape sh -> sh
zeroshape Z = ()
zeroshape (sh' :. _) = (zeroshape sh', 0)
unshape :: Shape sh -> sh
unshape Z = ()
unshape (sh :. n) = (unshape sh, n)
toshape :: ShapeType sh -> sh -> Shape sh
toshape STZ () = Z
toshape (STC sht) (sh, n) = toshape sht sh :. n
tolinear :: Shape sh -> sh -> Int
tolinear Z () = 0
tolinear (sh :. n) (idx, i) = n * tolinear sh idx + i
fromlinear :: Shape sh -> Int -> sh
fromlinear Z _ = ()
fromlinear (sh :. n) i =
let (q, r) = i `divMod` n
in (fromlinear sh q, r)
shapesize :: Shape sh -> Int
shapesize Z = 1
shapesize (sh :. n) = n * shapesize sh
|