aboutsummaryrefslogtreecommitdiff
path: root/Eval.hs
blob: 19265bcc7f0a1da673c328f6703b9e36cb5887e7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
module Eval (
    eval,
) where

import Data.List (foldl')
import qualified Data.Vector as V

import AST


data Val env where
    Top :: Val '[]
    Push :: Val env -> a -> Val (a ': env)

prj :: Val env -> Idx env a -> a
prj (Push _ x) Zero = x
prj (Push env _) (Succ i) = prj env i

eval :: Exp '[] a -> a
eval = eval' Top

eval' :: forall env a. Val env -> Exp env a -> a
eval' env = \case
    App f a -> rec f (rec a)
    Lam _ e -> \x -> eval' (Push env x) e
    Var _ i -> prj env i
    Let a e -> eval' (Push env (rec a)) e
    Lit l -> evalL l
    Cond c a b -> if rec c then rec a else rec b
    Const c -> evalC c
    Pair a b -> (rec a, rec b)
    Fst e -> fst (rec e)
    Snd e -> snd (rec e)
    Build sht she fe ->
        let TFun _ ty = typeof fe
        in build ty sht (rec she) (rec fe)
    Ifold sht fe e0 she -> ifold sht (rec fe) (rec e0) (rec she)
    Index a i -> index (rec a) (rec i)
    Shape a -> shape (rec a)
    Undef t -> error ("eval: Undef of type " ++ show t)
  where rec :: Exp env t -> t
        rec = eval' env

evalL :: Literal a -> a
evalL (LInt n) = n
evalL (LBool b) = b
evalL (LDouble d) = d
evalL (LArray a) = a
evalL (LShape Z) = ()
evalL (LShape sh) = unshape sh
evalL LNil = ()
evalL (LPair a b) = (evalL a, evalL b)

evalC :: Constant a -> a
evalC CAddI = uncurry (+)
evalC CSubI = uncurry (-)
evalC CMulI = uncurry (*)
evalC CDivI = uncurry div
evalC CAddF = uncurry (+)
evalC CSubF = uncurry (-)
evalC CMulF = uncurry (*)
evalC CDivF = uncurry (/)
evalC CLog = log
evalC CExp = exp
evalC CtoF = fromIntegral
evalC CRound = round
evalC CLtI = uncurry (<)
evalC CLeI = uncurry (<=)
evalC CLtF = uncurry (<)
evalC (CEq t) | Just Has <- typeHasEq t = uncurry (==)
              | otherwise = error ("eval: Cannot Eq compare values of type " ++ show t)
evalC CAnd = uncurry (&&)
evalC COr = uncurry (||)
evalC CNot = not

build :: Type a -> ShapeType sh -> sh -> (sh -> a) -> Array sh a
build ty sht sh f =
    let sh' = toshape sht sh
    in Array sh' ty (V.generate (shapesize sh') (\i -> f (fromlinear sh' i)))

ifold :: ShapeType sh -> ((a, sh) -> a) -> a -> sh -> a
ifold sht f x0 sh = foldl' (curry f) x0 (enumshape (toshape sht sh))

index :: Array sh a -> sh -> a
index (Array sh _ v) idx = v V.! tolinear sh idx

shape :: Array sh a -> sh
shape (Array sh _ _) = unshape sh

enumshape :: Shape sh -> [sh]
enumshape sh = take (shapesize sh) (iterate (next sh) (zeroshape sh))
  where
    next :: Shape sh -> sh -> sh
    next Z () = ()
    next (sh' :. n) (idx, i)
      | i < n = (idx, i + 1)
      | otherwise = (next sh' idx, 0)

    zeroshape :: Shape sh -> sh
    zeroshape Z = ()
    zeroshape (sh' :. _) = (zeroshape sh', 0)

unshape :: Shape sh -> sh
unshape Z = ()
unshape (sh :. n) = (unshape sh, n)

toshape :: ShapeType sh -> sh -> Shape sh
toshape STZ () = Z
toshape (STC sht) (sh, n) = toshape sht sh :. n

tolinear :: Shape sh -> sh -> Int
tolinear Z () = 0
tolinear (sh :. n) (idx, i) = n * tolinear sh idx + i

fromlinear :: Shape sh -> Int -> sh
fromlinear Z _ = ()
fromlinear (sh :. n) i =
    let (q, r) = i `divMod` n
    in (fromlinear sh q, r)

shapesize :: Shape sh -> Int
shapesize Z = 1
shapesize (sh :. n) = n * shapesize sh