aboutsummaryrefslogtreecommitdiff
path: root/src/Ghci.hs
blob: d16d2a011ee505216d0033f3063fe4d921dfc463 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE ViewPatterns #-}
module Ghci (
  Ghci,
  Result(..),
  makeGhci,
  runStmt,
  runStmtClever,
  terminateGhci,
  ParseSettings(..),
  parseSettingsReply,
  parseSettingsPaste,
) where

import Control.Exception (catch, SomeException)
import Control.Monad (replicateM, when)
import Data.Bifunctor (first)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Builder as BSB
import qualified Data.ByteString.Lazy as Lazy
import qualified Data.ByteString.UTF8 as UTF8
import qualified Data.ByteString.Lazy.UTF8 as LUTF8
import Data.ByteString.Short (ShortByteString)
import qualified Data.ByteString.Short as BSS
import Data.Char (isSpace, toLower)
import Data.List (nub)
import Foreign (allocaBytes)
import System.IO (hFlush, hIsClosed, hGetBufSome, hPutStrLn, stdout, stderr, Handle)
import System.Process
import System.Random (getStdRandom, uniformR)
import System.Timeout (timeout, Timeout)

import ExitEarly


debugPrints :: Bool
debugPrints = False

ghciPutStrLn :: Handle -> String -> IO ()
ghciPutStrLn h s = do
  when debugPrints $ hPutStrLn stderr ("Writing: <" ++ s ++ ">")
  hPutStrLn h s


data Ghci = Ghci
  { ghciProc :: ProcessHandle
  , ghciStdin :: Handle
  , ghciStdout :: Handle
  }

data ParseSettings = ParseSettings
  { psMaxOutputLen :: Int  -- ^ How much date to return, at most (will add '...')
  , psJoinLines :: Bool    -- ^ Whether to join lines with ';'
  }
  deriving (Show)

parseSettingsReply :: ParseSettings
parseSettingsReply = ParseSettings 200 True

parseSettingsPaste :: ParseSettings
parseSettingsPaste = ParseSettings 50000 False

data Result a = Error String | Ignored | Return a | Exited
  deriving (Show)

makeGhci :: IO Ghci
makeGhci = do
  (pipeOut, pipeIn) <- createPipe
  (Just stdinH, _, _, proch) <- createProcess
    (proc "./start.sh" [])
      { cwd = Just "bwrap-files"
      , std_in = CreatePipe
      , std_out = UseHandle pipeIn
      , std_err = UseHandle pipeIn }
  ghciPutStrLn stdinH ":set -interactive-print=Yahb2Defs.limitedPrint"
  ghciPutStrLn stdinH ":m"
  ghciPutStrLn stdinH ":script initdefs.hs"
  hFlush stdinH
  return Ghci { ghciProc = proch
              , ghciStdin = stdinH
              , ghciStdout = pipeOut }

runStmtClever :: Ghci -> ParseSettings -> String -> IO (Ghci, Result String)
runStmtClever ghci pset line =
  case dropWhile isSpace line of
    ':' : line1 -> case words (map toLower (dropWhile isSpace line1)) of
      ('!':_) : _ -> return (ghci, Ignored)
      cmd : "prompt" : _
        | "set" `startsWith` cmd -> return (ghci, Ignored)
      cmd : _
        | "def" `startsWith` cmd -> return (ghci, Ignored)
        | "quit" `startsWith` cmd -> do
            terminateGhci ghci
            putStrLn "ghci: restarting due to :quit"
            hFlush stdout
            ghci' <- makeGhci
            return (ghci', Exited)
      _ -> runStmt ghci pset line
    _ -> runStmt ghci pset line
  where
    startsWith :: String -> String -> Bool
    long `startsWith` short = take (length short) long == short

runStmt :: Ghci -> ParseSettings -> String -> IO (Ghci, Result String)
runStmt ghci pset line = timeouting 2_000_000 (restarting (\g -> runStmt' g pset line)) ghci

timeouting :: Int -> (Ghci -> IO (Ghci, Result a)) -> Ghci -> IO (Ghci, Result a)
timeouting microseconds f ghci =
  -- TODO: The timeout handling code never actually runs, because the timeout
  -- exception is already handled by the catch-all exception handler in
  -- 'restarting'.
  timeout microseconds (f ghci) >>= \case
    Nothing -> do putStrLn "ghci: restarting due to timeout"
                  hFlush stdout
                  terminateGhci ghci
                  ghci' <- makeGhci
                  return (ghci', Error "<timeout>")
    Just pair -> return pair

restarting :: (Ghci -> IO a) -> Ghci -> IO (Ghci, Result a)
restarting f ghci = do
  closed <- hIsClosed (ghciStdin ghci)
  ghci' <- if closed
             then do
               putStrLn "ghci: restarting due to closed stdin"
               hFlush stdout
               terminateGhci ghci
               makeGhci
             else return ghci

  fmap (\x -> (ghci', Return x)) (f ghci')
    `catch` (\e -> do let _ = e :: Timeout
                      putStrLn "ghci: restarting due to timeout (caught in restarting)"
                      hFlush stdout
                      terminateGhci ghci
                      -- putStrLn $ "ghci: terminated"
                      -- hFlush stdout
                      ghci'' <- makeGhci
                      -- putStrLn $ "ghci: new made"
                      -- hFlush stdout
                      return (ghci'', Error "<timeout>"))
    `catch` (\e -> do let _ = e :: SomeException
                      putStrLn $ "ghci: restarting due to exception: " ++ show e
                      hFlush stdout
                      terminateGhci ghci'
                      -- putStrLn $ "ghci: terminated"
                      -- hFlush stdout
                      ghci'' <- makeGhci
                      -- putStrLn $ "ghci: new made"
                      -- hFlush stdout
                      return (ghci'', Error "Oops, something went wrong"))

terminateGhci :: Ghci -> IO ()
terminateGhci ghci = terminateProcess (ghciProc ghci)

runStmt' :: Ghci -> ParseSettings -> String -> IO String
runStmt' ghci pset stmt = do
  tag <- updatePrompt ghci
  ghciPutStrLn (ghciStdin ghci) stmt
  hFlush (ghciStdin ghci)
  let readmax = psMaxOutputLen pset + 8192
  (output, reason) <- hGetUntilUTF8 (ghciStdout ghci) (Just readmax) tag
  case reason of
    ReachedMaxLen -> do
      terminateGhci ghci  -- because we lost the new prompt
      return (formatOutput output)  -- don't need to strip tag because we read more than the max output len
    ReachedTag -> return (formatOutput $ take (length output - length tag) output)
    ReachedEOF -> do
      terminateGhci ghci
      return (formatOutput output)
  where
    formatOutput output =
      let output' | psJoinLines pset = replaceNewlines (dropBothSlow isSpace output)
                  | otherwise        = output
      in if length output' > psMaxOutputLen pset
           then take (psMaxOutputLen pset - 3) output' ++ "..."
           else output'
    dropBothSlow f = reverse . dropWhile f . reverse . dropWhile f
    replaceNewlines = concatMap (\case '\n' -> " ; " ; c -> [c])

-- | Returns new prompt tag
updatePrompt :: Ghci -> IO String
updatePrompt ghci = do
  tag <- genTag
  -- putStrLn ("chose prompt: " ++ tag)
  ghciPutStrLn (ghciStdin ghci) (":set prompt " ++ tag)
  hFlush (ghciStdin ghci)
  -- putStrLn ("set prompt " ++ tag)
  _ <- hGetUntilUTF8 (ghciStdout ghci) (Just 8192) tag
  return tag

genTag :: IO String
genTag = replicateM 20 (getStdRandom (uniformR ('a', 'z')))

data CutoffReason = ReachedMaxLen | ReachedTag | ReachedEOF
  deriving (Show)

hGetUntilUTF8 :: Handle -> Maybe Int -> String -> IO (String, CutoffReason)
hGetUntilUTF8 h mmax tag =
  first LUTF8.toString <$> hGetUntil h mmax (BSS.toShort (UTF8.fromString tag))

hGetUntil :: Handle -> Maybe Int -> ShortByteString -> IO (Lazy.ByteString, CutoffReason)
hGetUntil h mmax tag = do
  let size = 1024
      exceedsMax yet = case mmax of Just m -> yet >= m
                                    Nothing -> True
  -- putStrLn ("tag = " ++ show tag)
  allocaBytes size $ \ptr -> do
    let loop yet havePrefixes builder = do
          when (exceedsMax yet) $ exitEarly (BSB.toLazyByteString builder, ReachedMaxLen)

          nread <- lift $ hGetBufSome h ptr size
          when (nread <= 0) $ exitEarly (BSB.toLazyByteString builder, ReachedEOF)

          bs <- lift $ BS.packCStringLen (ptr, nread)
          when debugPrints $ lift $ hPutStrLn stderr ("Read: " ++ show bs)
          when (or [BSS.toShort (BS.takeEnd suflen bs) == BSS.takeEnd suflen tag
                   | n <- 0 : havePrefixes
                   , let suflen = BSS.length tag - n]) $ do
            when debugPrints $ lift $ hPutStrLn stderr "yay determined end"
            exitEarly (BSB.toLazyByteString (builder <> BSB.byteString bs)
                      ,ReachedTag)

          let nextPrefixes =
                nub $
                  [plen + BS.length bs  -- continuations of partial matches
                  | plen <- havePrefixes
                  , BS.length bs < BSS.length tag - plen
                  , BSS.toShort bs == BSS.take (BS.length bs) (BSS.drop plen tag)]
                  ++
                  [n  -- new matches
                  | n <- [1 .. min (BS.length bs) (BSS.length tag)]
                  , BSS.toShort (BS.takeEnd n bs) == BSS.take n tag]
          when debugPrints $ lift $ hPutStrLn stderr ("nextPrefixes = " ++ show nextPrefixes)
          loop (yet + nread) nextPrefixes (builder <> BSB.byteString bs)
    result <- execExitEarlyT (loop 0 [] mempty)
    when debugPrints $ hPutStrLn stderr ("Read: <" ++ show result ++ ">")
    return $! result