1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
|
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE ViewPatterns #-}
module Ghci (
Ghci,
Result(..),
makeGhci,
runStmt,
runStmtClever,
terminateGhci,
) where
import Control.Exception (catch, SomeException)
import Control.Monad (replicateM, when)
import Data.Bifunctor (first)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Builder as BSB
import qualified Data.ByteString.Lazy as Lazy
import qualified Data.ByteString.UTF8 as UTF8
import qualified Data.ByteString.Lazy.UTF8 as LUTF8
import Data.ByteString.Short (ShortByteString)
import qualified Data.ByteString.Short as BSS
import Data.Char (isSpace, toLower)
import Data.List (nub)
import Foreign (allocaBytes)
import System.IO (hFlush, hIsClosed, hGetBufSome, hPutStrLn, stdout, stderr, Handle)
import System.Process
import System.Random (getStdRandom, uniformR)
import System.Timeout (timeout)
import ExitEarly
debugPrints :: Bool
debugPrints = False
ghciPutStrLn :: Handle -> String -> IO ()
ghciPutStrLn h s = do
when debugPrints $ hPutStrLn stderr ("Writing: <" ++ s ++ ">")
hPutStrLn h s
data Ghci = Ghci
{ ghciProc :: ProcessHandle
, ghciStdin :: Handle
, ghciStdout :: Handle
}
data Result a = Error String | Ignored | Return a
deriving (Show)
makeGhci :: IO Ghci
makeGhci = do
(pipeOut, pipeIn) <- createPipe
(Just stdinH, _, _, proch) <- createProcess
(proc "./start.sh" [])
{ cwd = Just "bwrap-files"
, std_in = CreatePipe
, std_out = UseHandle pipeIn
, std_err = UseHandle pipeIn }
ghciPutStrLn stdinH ":set -interactive-print=Yahb2Defs.limitedPrint"
ghciPutStrLn stdinH ":m"
hFlush stdinH
return Ghci { ghciProc = proch
, ghciStdin = stdinH
, ghciStdout = pipeOut }
runStmtClever :: Ghci -> String -> IO (Ghci, Result String)
runStmtClever ghci line =
case dropWhile isSpace line of
':' : line1 -> case words (map toLower (dropWhile isSpace line1)) of
('!':_) : _ -> return (ghci, Ignored)
cmd : "prompt" : _
| "set" `startsWith` cmd -> return (ghci, Ignored)
cmd : _
| "def" `startsWith` cmd -> return (ghci, Ignored)
| "quit" `startsWith` cmd -> do
terminateGhci ghci
putStrLn "ghci: restarting due to :quit"
hFlush stdout
ghci' <- makeGhci
return (ghci', Return "")
_ -> runStmt ghci line
_ -> runStmt ghci line
where
startsWith :: String -> String -> Bool
long `startsWith` short = take (length short) long == short
runStmt :: Ghci -> String -> IO (Ghci, Result String)
runStmt ghci line = timeouting 2_000_000 (restarting 1 (\g -> runStmt' g line)) ghci
timeouting :: Int -> (Ghci -> IO (Ghci, Result a)) -> Ghci -> IO (Ghci, Result a)
timeouting microseconds f ghci =
-- TODO: The timeout handling code never actually runs, because the timeout
-- exception is already handled by the catch-all exception handler in
-- 'restarting'.
timeout microseconds (f ghci) >>= \case
Nothing -> do putStrLn "ghci: restarting due to timeout"
hFlush stdout
terminateGhci ghci
ghci' <- makeGhci
return (ghci', Error "<timeout>")
Just pair -> return pair
restarting :: Int -> (Ghci -> IO a) -> Ghci -> IO (Ghci, Result a)
restarting numExcRestarts f ghci = do
closed <- hIsClosed (ghciStdin ghci)
ghci' <- if closed
then do
putStrLn "ghci: restarting due to closed stdin"
hFlush stdout
terminateGhci ghci
makeGhci
else return ghci
fmap (\x -> (ghci', Return x)) (f ghci')
`catch` (\e -> do let _ = e :: SomeException
putStrLn $ "ghci: restarting due to exception: " ++ show e
hFlush stdout
terminateGhci ghci'
ghci'' <- makeGhci
if numExcRestarts >= 1
then restarting (numExcRestarts - 1) f ghci''
else return (ghci'', Error "Oops, something went wrong"))
terminateGhci :: Ghci -> IO ()
terminateGhci ghci = terminateProcess (ghciProc ghci)
runStmt' :: Ghci -> String -> IO String
runStmt' ghci stmt = do
tag <- updatePrompt ghci
ghciPutStrLn (ghciStdin ghci) stmt
hFlush (ghciStdin ghci)
(output, reason) <- hGetUntilUTF8 (ghciStdout ghci) (Just 8192) tag
case reason of
ReachedMaxLen -> do
terminateGhci ghci -- because we lost the new prompt
return (formatOutput output) -- don't need to strip tag because 200 << 8192
ReachedTag -> return (formatOutput $ take (length output - length tag) output)
ReachedEOF -> do
terminateGhci ghci
return (formatOutput output)
where
formatOutput (replaceNewlines . dropBothSlow isSpace -> output)
| length output > 200 = take 197 output ++ "..."
| otherwise = output
dropBothSlow f = reverse . dropWhile f . reverse . dropWhile f
replaceNewlines = concatMap (\case '\n' -> " ; " ; c -> [c])
-- | Returns new prompt tag
updatePrompt :: Ghci -> IO String
updatePrompt ghci = do
tag <- genTag
-- putStrLn ("chose prompt: " ++ tag)
ghciPutStrLn (ghciStdin ghci) (":set prompt " ++ tag)
hFlush (ghciStdin ghci)
-- putStrLn ("set prompt " ++ tag)
_ <- hGetUntilUTF8 (ghciStdout ghci) (Just 8192) tag
return tag
genTag :: IO String
genTag = replicateM 20 (getStdRandom (uniformR ('a', 'z')))
data CutoffReason = ReachedMaxLen | ReachedTag | ReachedEOF
deriving (Show)
hGetUntilUTF8 :: Handle -> Maybe Int -> String -> IO (String, CutoffReason)
hGetUntilUTF8 h mmax tag =
first LUTF8.toString <$> hGetUntil h mmax (BSS.toShort (UTF8.fromString tag))
hGetUntil :: Handle -> Maybe Int -> ShortByteString -> IO (Lazy.ByteString, CutoffReason)
hGetUntil h mmax tag = do
let size = 1024
exceedsMax yet = case mmax of Just m -> yet >= m
Nothing -> True
-- putStrLn ("tag = " ++ show tag)
allocaBytes size $ \ptr -> do
let loop yet havePrefixes builder = do
when (exceedsMax yet) $ exitEarly (BSB.toLazyByteString builder, ReachedMaxLen)
nread <- lift $ hGetBufSome h ptr size
when (nread <= 0) $ exitEarly (BSB.toLazyByteString builder, ReachedEOF)
bs <- lift $ BS.packCStringLen (ptr, nread)
-- lift $ putStrLn ("Read: " ++ show bs)
when (or [BSS.toShort (BS.takeEnd suflen bs) == BSS.takeEnd suflen tag
| n <- 0 : havePrefixes
, let suflen = BSS.length tag - n]) $ do
-- lift $ putStrLn "yay determined end"
exitEarly (BSB.toLazyByteString (builder <> BSB.byteString bs)
,ReachedTag)
let nextPrefixes =
nub $
[plen + BS.length bs -- continuations of partial matches
| plen <- havePrefixes
, BS.length bs < BSS.length tag - plen
, BSS.toShort bs == BSS.take (BS.length bs) (BSS.drop plen tag)]
++
[n -- new matches
| n <- [1 .. min (BS.length bs) (BSS.length tag)]
, BSS.toShort (BS.takeEnd n bs) == BSS.take n tag]
-- lift $ putStrLn ("nextPrefixes = " ++ show nextPrefixes)
loop (yet + nread) nextPrefixes (builder <> BSB.byteString bs)
result <- execExitEarlyT (loop 0 [] mempty)
when debugPrints $ hPutStrLn stderr ("Read: <" ++ show result ++ ">")
return $! result
|