{-# LANGUAGE LambdaCase #-} {-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE ViewPatterns #-} module Ghci ( Ghci, Result(..), makeGhci, runStmt, runStmtClever, terminateGhci, ParseSettings(..), parseSettingsReply, parseSettingsPaste, ) where import Control.Exception (catch, SomeException) import Control.Monad (replicateM, when) import Data.Bifunctor (first) import qualified Data.ByteString as BS import qualified Data.ByteString.Builder as BSB import qualified Data.ByteString.Lazy as Lazy import qualified Data.ByteString.UTF8 as UTF8 import qualified Data.ByteString.Lazy.UTF8 as LUTF8 import Data.ByteString.Short (ShortByteString) import qualified Data.ByteString.Short as BSS import Data.Char (isSpace, toLower) import Data.List (nub) import Foreign (allocaBytes) import System.IO (hFlush, hIsClosed, hGetBufSome, hPutStrLn, stdout, stderr, Handle) import System.Process import System.Random (getStdRandom, uniformR) import System.Timeout (timeout, Timeout) import ExitEarly debugPrints :: Bool debugPrints = False ghciPutStrLn :: Handle -> String -> IO () ghciPutStrLn h s = do when debugPrints $ hPutStrLn stderr ("Writing: <" ++ s ++ ">") hPutStrLn h s data Ghci = Ghci { ghciProc :: ProcessHandle , ghciStdin :: Handle , ghciStdout :: Handle } data ParseSettings = ParseSettings { psMaxOutputLen :: Int -- ^ How much date to return, at most (will add '...') , psJoinLines :: Bool -- ^ Whether to join lines with ';' } deriving (Show) parseSettingsReply :: ParseSettings parseSettingsReply = ParseSettings 200 True parseSettingsPaste :: ParseSettings parseSettingsPaste = ParseSettings 50000 False data Result a = Error String | Ignored | Return a deriving (Show) makeGhci :: IO Ghci makeGhci = do (pipeOut, pipeIn) <- createPipe (Just stdinH, _, _, proch) <- createProcess (proc "./start.sh" []) { cwd = Just "bwrap-files" , std_in = CreatePipe , std_out = UseHandle pipeIn , std_err = UseHandle pipeIn } ghciPutStrLn stdinH ":set -interactive-print=Yahb2Defs.limitedPrint" ghciPutStrLn stdinH ":m" hFlush stdinH return Ghci { ghciProc = proch , ghciStdin = stdinH , ghciStdout = pipeOut } runStmtClever :: Ghci -> ParseSettings -> String -> IO (Ghci, Result String) runStmtClever ghci pset line = case dropWhile isSpace line of ':' : line1 -> case words (map toLower (dropWhile isSpace line1)) of ('!':_) : _ -> return (ghci, Ignored) cmd : "prompt" : _ | "set" `startsWith` cmd -> return (ghci, Ignored) cmd : _ | "def" `startsWith` cmd -> return (ghci, Ignored) | "quit" `startsWith` cmd -> do terminateGhci ghci putStrLn "ghci: restarting due to :quit" hFlush stdout ghci' <- makeGhci return (ghci', Return "") _ -> runStmt ghci pset line _ -> runStmt ghci pset line where startsWith :: String -> String -> Bool long `startsWith` short = take (length short) long == short runStmt :: Ghci -> ParseSettings -> String -> IO (Ghci, Result String) runStmt ghci pset line = timeouting 2_000_000 (restarting (\g -> runStmt' g pset line)) ghci timeouting :: Int -> (Ghci -> IO (Ghci, Result a)) -> Ghci -> IO (Ghci, Result a) timeouting microseconds f ghci = -- TODO: The timeout handling code never actually runs, because the timeout -- exception is already handled by the catch-all exception handler in -- 'restarting'. timeout microseconds (f ghci) >>= \case Nothing -> do putStrLn "ghci: restarting due to timeout" hFlush stdout terminateGhci ghci ghci' <- makeGhci return (ghci', Error "") Just pair -> return pair restarting :: (Ghci -> IO a) -> Ghci -> IO (Ghci, Result a) restarting f ghci = do closed <- hIsClosed (ghciStdin ghci) ghci' <- if closed then do putStrLn "ghci: restarting due to closed stdin" hFlush stdout terminateGhci ghci makeGhci else return ghci fmap (\x -> (ghci', Return x)) (f ghci') `catch` (\e -> do let _ = e :: Timeout putStrLn "ghci: restarting due to timeout (caught in restarting)" hFlush stdout terminateGhci ghci -- putStrLn $ "ghci: terminated" -- hFlush stdout ghci'' <- makeGhci -- putStrLn $ "ghci: new made" -- hFlush stdout return (ghci'', Error "")) `catch` (\e -> do let _ = e :: SomeException putStrLn $ "ghci: restarting due to exception: " ++ show e hFlush stdout terminateGhci ghci' -- putStrLn $ "ghci: terminated" -- hFlush stdout ghci'' <- makeGhci -- putStrLn $ "ghci: new made" -- hFlush stdout return (ghci'', Error "Oops, something went wrong")) terminateGhci :: Ghci -> IO () terminateGhci ghci = terminateProcess (ghciProc ghci) runStmt' :: Ghci -> ParseSettings -> String -> IO String runStmt' ghci pset stmt = do tag <- updatePrompt ghci ghciPutStrLn (ghciStdin ghci) stmt hFlush (ghciStdin ghci) let readmax = psMaxOutputLen pset + 200 (output, reason) <- hGetUntilUTF8 (ghciStdout ghci) (Just readmax) tag case reason of ReachedMaxLen -> do terminateGhci ghci -- because we lost the new prompt return (formatOutput output) -- don't need to strip tag because we read more than the max output len ReachedTag -> return (formatOutput $ take (length output - length tag) output) ReachedEOF -> do terminateGhci ghci return (formatOutput output) where formatOutput output = let output' | psJoinLines pset = replaceNewlines (dropBothSlow isSpace output) | otherwise = output in if length output' > psMaxOutputLen pset then take (psMaxOutputLen pset - 3) output' ++ "..." else output' dropBothSlow f = reverse . dropWhile f . reverse . dropWhile f replaceNewlines = concatMap (\case '\n' -> " ; " ; c -> [c]) -- | Returns new prompt tag updatePrompt :: Ghci -> IO String updatePrompt ghci = do tag <- genTag -- putStrLn ("chose prompt: " ++ tag) ghciPutStrLn (ghciStdin ghci) (":set prompt " ++ tag) hFlush (ghciStdin ghci) -- putStrLn ("set prompt " ++ tag) _ <- hGetUntilUTF8 (ghciStdout ghci) (Just 8192) tag return tag genTag :: IO String genTag = replicateM 20 (getStdRandom (uniformR ('a', 'z'))) data CutoffReason = ReachedMaxLen | ReachedTag | ReachedEOF deriving (Show) hGetUntilUTF8 :: Handle -> Maybe Int -> String -> IO (String, CutoffReason) hGetUntilUTF8 h mmax tag = first LUTF8.toString <$> hGetUntil h mmax (BSS.toShort (UTF8.fromString tag)) hGetUntil :: Handle -> Maybe Int -> ShortByteString -> IO (Lazy.ByteString, CutoffReason) hGetUntil h mmax tag = do let size = 1024 exceedsMax yet = case mmax of Just m -> yet >= m Nothing -> True -- putStrLn ("tag = " ++ show tag) allocaBytes size $ \ptr -> do let loop yet havePrefixes builder = do when (exceedsMax yet) $ exitEarly (BSB.toLazyByteString builder, ReachedMaxLen) nread <- lift $ hGetBufSome h ptr size when (nread <= 0) $ exitEarly (BSB.toLazyByteString builder, ReachedEOF) bs <- lift $ BS.packCStringLen (ptr, nread) -- lift $ putStrLn ("Read: " ++ show bs) when (or [BSS.toShort (BS.takeEnd suflen bs) == BSS.takeEnd suflen tag | n <- 0 : havePrefixes , let suflen = BSS.length tag - n]) $ do -- lift $ putStrLn "yay determined end" exitEarly (BSB.toLazyByteString (builder <> BSB.byteString bs) ,ReachedTag) let nextPrefixes = nub $ [plen + BS.length bs -- continuations of partial matches | plen <- havePrefixes , BS.length bs < BSS.length tag - plen , BSS.toShort bs == BSS.take (BS.length bs) (BSS.drop plen tag)] ++ [n -- new matches | n <- [1 .. min (BS.length bs) (BSS.length tag)] , BSS.toShort (BS.takeEnd n bs) == BSS.take n tag] -- lift $ putStrLn ("nextPrefixes = " ++ show nextPrefixes) loop (yet + nread) nextPrefixes (builder <> BSB.byteString bs) result <- execExitEarlyT (loop 0 [] mempty) when debugPrints $ hPutStrLn stderr ("Read: <" ++ show result ++ ">") return $! result