{-# LANGUAGE TupleSections, DeriveFunctor, GeneralizedNewtypeDeriving #-} module Interpreter(newContext, interpret, interpretProgram, Context) where import Control.Applicative import Control.Monad import Control.Monad.Except import Control.Monad.State import Data.List import Data.Maybe import qualified Data.Map.Strict as Map -- import Debug.Trace import AST import Parser -- TODO: place bottom varmap separately for performance? (global define's in deeply nested contexts) type VarMap = Map.Map Name Value data Context = Context {varMapStk :: [VarMap]} newtype IM a = IM {unIM :: StateT Context (ExceptT String IO) a} deriving (Functor, Applicative, Monad, MonadError String, MonadState Context, MonadIO) type Builtin = [Value] -> IM Value newContext :: Context newContext = Context [Map.fromList [(k, VBuiltin k) | k <- Map.keys builtins]] lookupVar :: Name -> IM (Maybe Value) lookupVar name = liftM (msum . map (Map.lookup name)) (gets varMapStk) withScopeMap :: VarMap -> IM a -> IM a withScopeMap vm act = do modify $ \ctx -> ctx {varMapStk = vm : varMapStk ctx} x <- act modify $ \ctx -> ctx {varMapStk = tail (varMapStk ctx)} return x builtins :: Map.Map String Builtin builtins = Map.fromList [("read", readBuiltin), ("eval", evalBuiltin), ("print", printBuiltin), ("getline", getlineBuiltin), ("loop", loopBuiltin), ("do", doBuiltin), ("if", ifBuiltin), ("define", defineBuiltin), ("lambda", lambdaBuiltin), ("match", matchBuiltin), ("+", plusBuiltin), ("-", arithBuiltin "-" (-) 0), ("*", arithBuiltin "*" (*) 1), ("/", arithBuiltin "/" div 1), ("%", arithBuiltin "%" mod 1), ("<", compareBuiltin "<" (<) (<)), ("<=", compareBuiltin "<=" (<=) (<=)), (">", compareBuiltin ">" (>) (>)), (">=", compareBuiltin ">=" (>=) (>=)), ("=", compareBuiltin "=" (==) (==)), ("!=", neqBuiltin)] nArguments :: String -> Int -> Bool -> Builtin -> Builtin nArguments name n doeval f args | length args /= n = throwError $ "Function '" ++ name ++ "' expects " ++ naStr n ++ " but got " ++ naStr (length args) | doeval = mapM evalValue args >>= f | otherwise = f args where naStr 0 = "no arguments" naStr 1 = "1 argument" naStr num = show num ++ " arguments" readBuiltin :: Builtin readBuiltin = nArguments "read" 1 True go where go :: Builtin go [VString s] = either (throwError . show) return $ parseExpression s go _ = throwError "Can only 'read' a string" evalBuiltin :: Builtin evalBuiltin = nArguments "eval" 1 True (evalValue . head) printBuiltin :: Builtin printBuiltin args = do args' <- mapM evalValue args liftIO (putStrLn $ intercalate " " $ map printShow args') >> return (VList []) where printShow :: Value -> String printShow (VString s) = s printShow v = show v loopBuiltin :: Builtin loopBuiltin = nArguments "loop" 1 False $ forever . evalValue . head getlineBuiltin :: Builtin getlineBuiltin = nArguments "getline" 0 True $ const (liftIO getLine >>= return . VString) doBuiltin :: Builtin doBuiltin [] = return (VList []) doBuiltin args = fmap last $ mapM evalValue args ifBuiltin :: Builtin ifBuiltin [cond, v1] = evalValue cond >>= \c -> if truthy c then evalValue v1 else return (VList []) ifBuiltin [cond, v1, v2] = evalValue cond >>= \c -> if truthy c then evalValue v1 else evalValue v2 ifBuiltin a = throwError $ "Cannot pass " ++ show (length a) ++ " arguments to 'if'" defineBuiltin :: Builtin defineBuiltin [VName name, val] = do val' <- evalValue val stk <- gets varMapStk let go [vm] _ = [Map.insert name val' vm] go (vm : vms) (False : prs) = vm : go vms prs go (vm : vms) (True : _) = Map.insert name val' vm : vms go _ _ = undefined modify $ \ctx -> ctx {varMapStk = go stk (map (isJust . Map.lookup name) stk)} return (VList []) defineBuiltin [name@(VName _), VList args, val] | Just names <- mapM fromVName args = defineBuiltin [name, VLambda names val] | otherwise = throwError "Invalid 'define' syntax: invalid argument list" defineBuiltin _ = throwError "Invalid 'define' syntax" lambdaBuiltin :: Builtin lambdaBuiltin = nArguments "lambda" 2 False go where go :: Builtin go [VList args, body] | Just names <- mapM fromVName args = return (VLambda names body) | otherwise = throwError "Invalid 'lambda' syntax: invalid argument list" go _ = throwError "Invalid 'lambda' syntax" matchBuiltin :: Builtin matchBuiltin [] = throwError "Invalid 'match' syntax: empty match" matchBuiltin [_] = throwError "Invalid 'match' syntax: no arms" matchBuiltin (subject : arms) = do subject' <- evalValue subject go subject' arms where go :: Value -> [Value] -> IM Value go _ [def] = evalValue def go subject' (VList [pat, value] : rest) = case match pat subject' Map.empty of Nothing -> go subject' rest Just mp -> withScopeMap mp (evalValue value) go _ _ = throwError "Invalid 'match' syntax: invalid arm" plusBuiltin :: Builtin plusBuiltin [] = return (VNum 0) plusBuiltin args | Just nums <- mapM fromVNum args = return (VNum (sum nums)) | Just strs <- mapM maybeStrings args = return (VString (concat strs)) | otherwise = throwError "Invalid argument types to operator '+'" arithBuiltin :: String -> (Int -> Int -> Int) -> Int -> Builtin arithBuiltin name oper idelem args = do args' <- mapM evalValue args case mapM fromVNum args' of Just [] -> return (VNum idelem) Just (hd : tl) -> return (VNum (foldl oper hd tl)) _ -> throwError $ "Invalid argument types to operator '" ++ name ++ "'" neqBuiltin :: Builtin neqBuiltin = fmap (\(VNum x) -> VNum (1 - x)) . compareBuiltin "!=" (==) (==) compareBuiltin :: String -> (Int -> Int -> Bool) -> (String -> String -> Bool) -> Builtin compareBuiltin name oper soper args = do args' <- mapM evalValue args res <- case () of _ | Just nums <- mapM fromVNum args' -> return $ all (uncurry oper) (zip nums (tail nums)) | Just strs <- mapM maybeStrings args' -> return $ all (uncurry soper) (zip strs (tail strs)) | otherwise -> throwError $ "Invalid argument types to operator '" ++ name ++ "'" return $ VNum $ fromIntegral $ fromEnum res truthy :: Value -> Bool truthy (VNum n) = n /= 0 truthy _ = True match :: Value -> Value -> VarMap -> Maybe VarMap match (VList []) (VList []) mp = Just mp match (VList [VEllipsis]) (VList _) mp = Just mp match (VList (pat : pats)) (VList (val : vals)) mp = match pat val mp >>= match (VList pats) (VList vals) match (VName name) val mp = case Map.lookup name mp of Nothing -> Just (Map.insert name val mp) Just val' | val == val' -> Just mp | otherwise -> Nothing match (VQuoted a) (VQuoted b) mp = match a b mp match (VLambda _ _) _ _ = Nothing match a b mp | a == b = Just mp | otherwise = Nothing maybeStrings :: Value -> Maybe String maybeStrings = liftM2 (<|>) fromVString (fmap show . fromVNum) evalValue :: Value -> IM Value -- evalValue v | traceShow v False = undefined evalValue (VList exs) = listCall exs evalValue e@(VNum _) = return e evalValue e@(VString _) = return e evalValue (VName name) = lookupVar name >>= \mval -> case mval of Just value -> return value Nothing -> throwError $ "Use of undefined variable '" ++ name ++ "'" evalValue (VQuoted e) = return e evalValue e@(VLambda _ _) = return e evalValue e@(VBuiltin _) = return e evalValue VEllipsis = throwError "Unexpected ellipsis in code" listCall :: [Value] -> IM Value listCall [] = throwError "Cannot call ()" listCall (hd : args) = evalValue hd >>= \hd' -> case hd' of VLambda names body | length names == length args -> do args' <- mapM evalValue args withScopeMap (Map.fromList (zip names args')) (evalValue body) | otherwise -> throwError $ "Invalid number of arguments in call to lambda " ++ "(" ++ show (length args) ++ " found, " ++ show (length names) ++ " needed)" VBuiltin name -> case Map.lookup name builtins of Just f -> f args Nothing -> throwError $ "Unknown builtin '" ++ name ++ "'" v -> throwError $ "Cannot call value: " ++ show v interpret :: Context -> Value -> IO (Either String (Value, Context)) interpret ctx val = runExceptT $ flip runStateT ctx $ unIM $ evalValue val interpretProgram :: Context -> Program -> IO (Either String Context) interpretProgram rootctx (Program l) = go l rootctx where go :: [Value] -> Context -> IO (Either String Context) go [] ctx = return (Right ctx) go (val : vals) ctx = do e <- interpret ctx val case e of Left err -> return (Left err) Right (_, ctx') -> go vals ctx'