aboutsummaryrefslogtreecommitdiff
path: root/src/Ghci.hs
blob: f931a7b9922302f34fe284aee7fbf763c1c05b1f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE ViewPatterns #-}
module Ghci (
  Ghci,
  Result(..),
  makeGhci,
  runStmt,
  runStmtClever,
  terminateGhci,
) where

import Control.Exception (catch, SomeException)
import Control.Monad (replicateM, when)
import Data.Bifunctor (first)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Builder as BSB
import qualified Data.ByteString.Lazy as Lazy
import qualified Data.ByteString.UTF8 as UTF8
import qualified Data.ByteString.Lazy.UTF8 as LUTF8
import Data.ByteString.Short (ShortByteString)
import qualified Data.ByteString.Short as BSS
import Data.Char (isSpace, toLower)
import Data.List (nub)
import Foreign (allocaBytes)
import System.IO (hFlush, hIsClosed, hGetBufSome, hPutStrLn, stderr, Handle)
import System.Process
import System.Random (getStdRandom, uniformR)
import System.Timeout (timeout)

import ExitEarly


debugPrints :: Bool
debugPrints = False

ghciPutStrLn :: Handle -> String -> IO ()
ghciPutStrLn h s = do
  when debugPrints $ hPutStrLn stderr ("Writing: <" ++ s ++ ">")
  hPutStrLn h s


data Ghci = Ghci
  { ghciProc :: ProcessHandle
  , ghciStdin :: Handle
  , ghciStdout :: Handle
  }

data Result a = Error String | Ignored | Return a
  deriving (Show)

makeGhci :: IO Ghci
makeGhci = do
  (pipeOut, pipeIn) <- createPipe
  (Just stdinH, _, _, proch) <- createProcess
    (proc "./start.sh" [])
      { cwd = Just "bwrap-files"
      , std_in = CreatePipe
      , std_out = UseHandle pipeIn
      , std_err = UseHandle pipeIn }
  ghciPutStrLn stdinH ":set -interactive-print=Yahb2Defs.limitedPrint"
  ghciPutStrLn stdinH ":m"
  hFlush stdinH
  return Ghci { ghciProc = proch
              , ghciStdin = stdinH
              , ghciStdout = pipeOut }

runStmtClever :: Ghci -> String -> IO (Ghci, Result String)
runStmtClever ghci line =
  case dropWhile isSpace line of
    ':' : line1 -> case words (map toLower (dropWhile isSpace line1)) of
      ('!':_) : _ -> return (ghci, Ignored)
      cmd : "prompt" : _
        | "set" `startsWith` cmd -> return (ghci, Ignored)
      cmd : _
        | "def" `startsWith` cmd -> return (ghci, Ignored)
        | "quit" `startsWith` cmd -> do
            terminateGhci ghci
            putStrLn "ghci: restarting due to :quit"
            ghci' <- makeGhci
            return (ghci', Return "")
      _ -> runStmt ghci line
    _ -> runStmt ghci line
  where
    startsWith :: String -> String -> Bool
    long `startsWith` short = take (length short) long == short

runStmt :: Ghci -> String -> IO (Ghci, Result String)
runStmt ghci line = timeouting 2_000_000 (restarting 1 (\g -> runStmt' g line)) ghci

timeouting :: Int -> (Ghci -> IO (Ghci, Result a)) -> Ghci -> IO (Ghci, Result a)
timeouting microseconds f ghci =
  -- TODO: The timeout handling code never actually runs, because the timeout
  -- exception is already handled by the catch-all exception handler in
  -- 'restarting'.
  timeout microseconds (f ghci) >>= \case
    Nothing -> do putStrLn "ghci: restarting due to timeout"
                  terminateGhci ghci
                  ghci' <- makeGhci
                  return (ghci', Error "<timeout>")
    Just pair -> return pair

restarting :: Int -> (Ghci -> IO a) -> Ghci -> IO (Ghci, Result a)
restarting numExcRestarts f ghci = do
  closed <- hIsClosed (ghciStdin ghci)
  ghci' <- if closed
             then do
               putStrLn "ghci: restarting due to closed stdin"
               terminateGhci ghci
               makeGhci
             else return ghci
  (f ghci' >>= \x -> return (ghci', Return x))
    `catch` (\e -> do let _ = e :: SomeException
                      putStrLn $ "ghci: restarting due to exception: " ++ show e
                      terminateGhci ghci'
                      ghci'' <- makeGhci
                      if numExcRestarts >= 1
                        then restarting (numExcRestarts - 1) f ghci''
                        else return (ghci'', Error "Oops, something went wrong"))

terminateGhci :: Ghci -> IO ()
terminateGhci ghci = terminateProcess (ghciProc ghci)

runStmt' :: Ghci -> String -> IO String
runStmt' ghci stmt = do
  tag <- updatePrompt ghci
  ghciPutStrLn (ghciStdin ghci) stmt
  hFlush (ghciStdin ghci)
  (output, reason) <- hGetUntilUTF8 (ghciStdout ghci) (Just 8192) tag
  case reason of
    ReachedMaxLen -> do
      terminateGhci ghci  -- because we lost the new prompt
      return (formatOutput output)  -- don't need to strip tag because 200 << 8192
    ReachedTag -> return (formatOutput $ take (length output - length tag) output)
    ReachedEOF -> do
      terminateGhci ghci
      return (formatOutput output)
  where
    formatOutput (replaceNewlines . dropBothSlow isSpace -> output)
      | length output > 200 = take 197 output ++ "..."
      | otherwise = output
    dropBothSlow f = reverse . dropWhile f . reverse . dropWhile f
    replaceNewlines = concatMap (\case '\n' -> " ; " ; c -> [c])

-- | Returns new prompt tag
updatePrompt :: Ghci -> IO String
updatePrompt ghci = do
  tag <- genTag
  -- putStrLn ("chose prompt: " ++ tag)
  ghciPutStrLn (ghciStdin ghci) (":set prompt " ++ tag)
  hFlush (ghciStdin ghci)
  -- putStrLn ("set prompt " ++ tag)
  _ <- hGetUntilUTF8 (ghciStdout ghci) (Just 8192) tag
  return tag

genTag :: IO String
genTag = replicateM 20 (getStdRandom (uniformR ('a', 'z')))

data CutoffReason = ReachedMaxLen | ReachedTag | ReachedEOF
  deriving (Show)

hGetUntilUTF8 :: Handle -> Maybe Int -> String -> IO (String, CutoffReason)
hGetUntilUTF8 h mmax tag =
  first LUTF8.toString <$> hGetUntil h mmax (BSS.toShort (UTF8.fromString tag))

hGetUntil :: Handle -> Maybe Int -> ShortByteString -> IO (Lazy.ByteString, CutoffReason)
hGetUntil h mmax tag = do
  let size = 1024
      exceedsMax yet = case mmax of Just m -> yet >= m
                                    Nothing -> True
  -- putStrLn ("tag = " ++ show tag)
  allocaBytes size $ \ptr -> do
    let loop yet havePrefixes builder = do
          when (exceedsMax yet) $ exitEarly (BSB.toLazyByteString builder, ReachedMaxLen)

          nread <- lift $ hGetBufSome h ptr size
          when (nread <= 0) $ exitEarly (BSB.toLazyByteString builder, ReachedEOF)

          bs <- lift $ BS.packCStringLen (ptr, nread)
          -- lift $ putStrLn ("Read: " ++ show bs)
          when (or [BSS.toShort (BS.takeEnd suflen bs) == BSS.takeEnd suflen tag
                   | n <- 0 : havePrefixes
                   , let suflen = BSS.length tag - n]) $ do
            -- lift $ putStrLn "yay determined end"
            exitEarly (BSB.toLazyByteString (builder <> BSB.byteString bs)
                      ,ReachedTag)

          let nextPrefixes =
                nub $
                  [plen + BS.length bs  -- continuations of partial matches
                  | plen <- havePrefixes
                  , BS.length bs < BSS.length tag - plen
                  , BSS.toShort bs == BSS.take (BS.length bs) (BSS.drop plen tag)]
                  ++
                  [n  -- new matches
                  | n <- [1 .. min (BS.length bs) (BSS.length tag)]
                  , BSS.toShort (BS.takeEnd n bs) == BSS.take n tag]
          -- lift $ putStrLn ("nextPrefixes = " ++ show nextPrefixes)
          loop (yet + nread) nextPrefixes (builder <> BSB.byteString bs)
    result <- execExitEarlyT (loop 0 [] mempty)
    when debugPrints $ hPutStrLn stderr ("Read: <" ++ show result ++ ">")
    return $! result