diff options
author | Tom Smeding <tom@tomsmeding.com> | 2021-06-27 18:34:35 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2021-06-27 18:34:35 +0200 |
commit | d4abcc3b2dfefbbcb7cd4a182eec64f1da42d951 (patch) | |
tree | 1ab301617043ac6df228ef617afa22633a01a671 /Eval.hs | |
parent | 0fefe4822c65bde81ec4c0da1b5b32a9b411ca79 (diff) |
Diffstat (limited to 'Eval.hs')
-rw-r--r-- | Eval.hs | 128 |
1 files changed, 128 insertions, 0 deletions
@@ -0,0 +1,128 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeOperators #-} +module Eval ( + eval, +) where + +import Data.List (foldl') +import qualified Data.Vector as V + +import AST + + +data Val env where + Top :: Val '[] + Push :: Val env -> a -> Val (a ': env) + +prj :: Val env -> Idx env a -> a +prj (Push _ x) Zero = x +prj (Push env _) (Succ i) = prj env i + +eval :: Exp '[] a -> a +eval = eval' Top + +eval' :: forall env a. Val env -> Exp env a -> a +eval' env = \case + App f a -> rec f (rec a) + Lam _ e -> \x -> eval' (Push env x) e + Var _ i -> prj env i + Let a e -> eval' (Push env (rec a)) e + Lit l -> evalL l + Cond c a b -> if rec c then rec a else rec b + Const c -> evalC c + Pair a b -> (rec a, rec b) + Fst e -> fst (rec e) + Snd e -> snd (rec e) + Build sht she fe -> + let TFun _ ty = typeof fe + in build ty sht (rec she) (rec fe) + Ifold sht fe e0 she -> ifold sht (rec fe) (rec e0) (rec she) + Index a i -> index (rec a) (rec i) + Shape a -> shape (rec a) + Undef t -> error ("eval: Undef of type " ++ show t) + where rec :: Exp env t -> t + rec = eval' env + +evalL :: Literal a -> a +evalL (LInt n) = n +evalL (LBool b) = b +evalL (LDouble d) = d +evalL (LArray a) = a +evalL (LShape Z) = () +evalL (LShape sh) = unshape sh +evalL LNil = () +evalL (LPair a b) = (evalL a, evalL b) + +evalC :: Constant a -> a +evalC CAddI = uncurry (+) +evalC CSubI = uncurry (-) +evalC CMulI = uncurry (*) +evalC CDivI = uncurry div +evalC CAddF = uncurry (+) +evalC CSubF = uncurry (-) +evalC CMulF = uncurry (*) +evalC CDivF = uncurry (/) +evalC CLog = log +evalC CExp = exp +evalC CtoF = fromIntegral +evalC CRound = round +evalC CLtI = uncurry (<) +evalC CLeI = uncurry (<=) +evalC CLtF = uncurry (<) +evalC (CEq t) | Just Has <- typeHasEq t = uncurry (==) + | otherwise = error ("eval: Cannot Eq compare values of type " ++ show t) +evalC CAnd = uncurry (&&) +evalC COr = uncurry (||) +evalC CNot = not + +build :: Type a -> ShapeType sh -> sh -> (sh -> a) -> Array sh a +build ty sht sh f = + let sh' = toshape sht sh + in Array sh' ty (V.generate (shapesize sh') (\i -> f (fromlinear sh' i))) + +ifold :: ShapeType sh -> ((a, sh) -> a) -> a -> sh -> a +ifold sht f x0 sh = foldl' (curry f) x0 (enumshape (toshape sht sh)) + +index :: Array sh a -> sh -> a +index (Array sh _ v) idx = v V.! tolinear sh idx + +shape :: Array sh a -> sh +shape (Array sh _ _) = unshape sh + +enumshape :: Shape sh -> [sh] +enumshape sh = take (shapesize sh) (iterate (next sh) (zeroshape sh)) + where + next :: Shape sh -> sh -> sh + next Z () = () + next (sh' :. n) (idx, i) + | i < n = (idx, i + 1) + | otherwise = (next sh' idx, 0) + + zeroshape :: Shape sh -> sh + zeroshape Z = () + zeroshape (sh' :. _) = (zeroshape sh', 0) + +unshape :: Shape sh -> sh +unshape Z = () +unshape (sh :. n) = (unshape sh, n) + +toshape :: ShapeType sh -> sh -> Shape sh +toshape STZ () = Z +toshape (STC sht) (sh, n) = toshape sht sh :. n + +tolinear :: Shape sh -> sh -> Int +tolinear Z () = 0 +tolinear (sh :. n) (idx, i) = n * tolinear sh idx + i + +fromlinear :: Shape sh -> Int -> sh +fromlinear Z _ = () +fromlinear (sh :. n) i = + let (q, r) = i `divMod` n + in (fromlinear sh q, r) + +shapesize :: Shape sh -> Int +shapesize Z = 1 +shapesize (sh :. n) = n * shapesize sh |