{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} module Eval ( eval, ) where import Data.List (foldl') import qualified Data.Vector as V import AST data Val env where Top :: Val '[] Push :: Val env -> a -> Val (a ': env) prj :: Val env -> Idx env a -> a prj (Push _ x) Zero = x prj (Push env _) (Succ i) = prj env i eval :: Exp '[] a -> a eval = eval' Top eval' :: forall env a. Val env -> Exp env a -> a eval' env = \case App f a -> rec f (rec a) Lam _ e -> \x -> eval' (Push env x) e Var _ i -> prj env i Let a e -> eval' (Push env (rec a)) e Lit l -> evalL l Cond c a b -> if rec c then rec a else rec b Const c -> evalC c Pair a b -> (rec a, rec b) Fst e -> fst (rec e) Snd e -> snd (rec e) Build sht she fe -> let TFun _ ty = typeof fe in build ty sht (rec she) (rec fe) Ifold sht fe e0 she -> ifold sht (rec fe) (rec e0) (rec she) Index a i -> index (rec a) (rec i) Shape a -> shape (rec a) Undef t -> error ("eval: Undef of type " ++ show t) where rec :: Exp env t -> t rec = eval' env evalL :: Literal a -> a evalL (LInt n) = n evalL (LBool b) = b evalL (LDouble d) = d evalL (LArray a) = a evalL (LShape Z) = () evalL (LShape sh) = unshape sh evalL LNil = () evalL (LPair a b) = (evalL a, evalL b) evalC :: Constant a -> a evalC CAddI = uncurry (+) evalC CSubI = uncurry (-) evalC CMulI = uncurry (*) evalC CDivI = uncurry div evalC CAddF = uncurry (+) evalC CSubF = uncurry (-) evalC CMulF = uncurry (*) evalC CDivF = uncurry (/) evalC CLog = log evalC CExp = exp evalC CtoF = fromIntegral evalC CRound = round evalC CLtI = uncurry (<) evalC CLeI = uncurry (<=) evalC CLtF = uncurry (<) evalC (CEq t) | Just Has <- typeHasEq t = uncurry (==) | otherwise = error ("eval: Cannot Eq compare values of type " ++ show t) evalC CAnd = uncurry (&&) evalC COr = uncurry (||) evalC CNot = not build :: Type a -> ShapeType sh -> sh -> (sh -> a) -> Array sh a build ty sht sh f = let sh' = toshape sht sh in Array sh' ty (V.generate (shapesize sh') (\i -> f (fromlinear sh' i))) ifold :: ShapeType sh -> ((a, sh) -> a) -> a -> sh -> a ifold sht f x0 sh = foldl' (curry f) x0 (enumshape (toshape sht sh)) index :: Array sh a -> sh -> a index (Array sh _ v) idx = v V.! tolinear sh idx shape :: Array sh a -> sh shape (Array sh _ _) = unshape sh enumshape :: Shape sh -> [sh] enumshape sh = take (shapesize sh) (iterate (next sh) (zeroshape sh)) where next :: Shape sh -> sh -> sh next Z () = () next (sh' :. n) (idx, i) | i < n = (idx, i + 1) | otherwise = (next sh' idx, 0) zeroshape :: Shape sh -> sh zeroshape Z = () zeroshape (sh' :. _) = (zeroshape sh', 0) unshape :: Shape sh -> sh unshape Z = () unshape (sh :. n) = (unshape sh, n) toshape :: ShapeType sh -> sh -> Shape sh toshape STZ () = Z toshape (STC sht) (sh, n) = toshape sht sh :. n tolinear :: Shape sh -> sh -> Int tolinear Z () = 0 tolinear (sh :. n) (idx, i) = n * tolinear sh idx + i fromlinear :: Shape sh -> Int -> sh fromlinear Z _ = () fromlinear (sh :. n) i = let (q, r) = i `divMod` n in (fromlinear sh q, r) shapesize :: Shape sh -> Int shapesize Z = 1 shapesize (sh :. n) = n * shapesize sh