{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE ViewPatterns #-}
module Ghci (
  Ghci,
  Result(..),
  makeGhci,
  runStmt,
  runStmtClever,
  terminateGhci,
  ParseSettings(..),
  parseSettingsReply,
  parseSettingsPaste,
) where

import Control.Exception (catch, SomeException)
import Control.Monad (replicateM, when)
import Data.Bifunctor (first)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Builder as BSB
import qualified Data.ByteString.Lazy as Lazy
import qualified Data.ByteString.UTF8 as UTF8
import qualified Data.ByteString.Lazy.UTF8 as LUTF8
import Data.ByteString.Short (ShortByteString)
import qualified Data.ByteString.Short as BSS
import Data.Char (isSpace, toLower)
import Data.List (nub)
import Foreign (allocaBytes)
import System.IO (hFlush, hIsClosed, hGetBufSome, hGetLine, hPutStrLn, stdout, stderr, Handle)
import System.Process
import System.Random (getStdRandom, uniformR)
import System.Timeout (timeout, Timeout)

import ExitEarly


debugPrints :: Bool
debugPrints = False

ghciPutStrLn :: Handle -> String -> IO ()
ghciPutStrLn h s = do
  when debugPrints $ hPutStrLn stderr ("Writing: <" ++ s ++ ">")
  hPutStrLn h s


data Ghci = Ghci
  { ghciProc :: ProcessHandle
  , ghciStdin :: Handle
  , ghciStdout :: Handle
  , ghciVersion :: String
  }

data ParseSettings = ParseSettings
  { psMaxOutputLen :: Int  -- ^ How much date to return, at most (will add '...')
  , psJoinLines :: Bool    -- ^ Whether to join lines with ';'
  }
  deriving (Show)

parseSettingsReply :: ParseSettings
parseSettingsReply = ParseSettings 200 True

parseSettingsPaste :: ParseSettings
parseSettingsPaste = ParseSettings 50000 False

data Result a = Error String | Ignored | Return a | Exited
  deriving (Show)

makeGhci :: IO Ghci
makeGhci = do
  (pipeOut, pipeIn) <- createPipe
  (Just stdinH, _, _, proch) <- createProcess
    (proc "./start.sh" [])
      { cwd = Just "bwrap-files"
      , std_in = CreatePipe
      , std_out = UseHandle pipeIn
      , std_err = UseHandle pipeIn }

  versionLine <- hGetLine pipeOut
  let version
        | take 19 versionLine == "yahb2-ghci-version=" =
            takeWhile (not . isSpace) (drop 19 versionLine)
        | otherwise =
            "<unknown, sorry>"

  ghciPutStrLn stdinH ":set -interactive-print=Yahb2Defs.limitedPrint"
  ghciPutStrLn stdinH ":m"
  ghciPutStrLn stdinH ":script initdefs.hs"
  hFlush stdinH
  return Ghci { ghciProc = proch
              , ghciStdin = stdinH
              , ghciStdout = pipeOut
              , ghciVersion = version }

runStmtClever :: Ghci -> ParseSettings -> String -> IO (Ghci, Result String)
runStmtClever ghci pset line =
  case dropWhile isSpace line of
    ':' : line1 -> case words (map toLower (dropWhile isSpace line1)) of
      ('!':_) : _ -> return (ghci, Ignored)
      cmd : "prompt" : _
        | "set" `startsWith` cmd -> return (ghci, Ignored)
      cmd : _
        | "def" `startsWith` cmd -> return (ghci, Ignored)
        | "quit" `startsWith` cmd -> do
            terminateGhci ghci
            putStrLn "ghci: restarting due to :quit"
            hFlush stdout
            ghci' <- makeGhci
            return (ghci', Exited)
        | cmd == "version" ->
            -- no startsWith because :version is not a ghci command, and we
            -- don't want to shadow other actual ghci commands
            return (ghci, Return (ghciVersion ghci))
      _ -> runStmt ghci pset line
    _ -> runStmt ghci pset line
  where
    startsWith :: String -> String -> Bool
    long `startsWith` short = take (length short) long == short

runStmt :: Ghci -> ParseSettings -> String -> IO (Ghci, Result String)
runStmt ghci pset line = timeouting 2_000_000 (restarting (\g -> runStmt' g pset line)) ghci

timeouting :: Int -> (Ghci -> IO (Ghci, Result a)) -> Ghci -> IO (Ghci, Result a)
timeouting microseconds f ghci =
  -- TODO: The timeout handling code never actually runs, because the timeout
  -- exception is already handled by the catch-all exception handler in
  -- 'restarting'.
  timeout microseconds (f ghci) >>= \case
    Nothing -> do putStrLn "ghci: restarting due to timeout"
                  hFlush stdout
                  terminateGhci ghci
                  ghci' <- makeGhci
                  return (ghci', Error "<timeout>")
    Just pair -> return pair

restarting :: (Ghci -> IO a) -> Ghci -> IO (Ghci, Result a)
restarting f ghci = do
  closed <- hIsClosed (ghciStdin ghci)
  ghci' <- if closed
             then do
               putStrLn "ghci: restarting due to closed stdin"
               hFlush stdout
               terminateGhci ghci
               makeGhci
             else return ghci

  fmap (\x -> (ghci', Return x)) (f ghci')
    `catch` (\e -> do let _ = e :: Timeout
                      putStrLn "ghci: restarting due to timeout (caught in restarting)"
                      hFlush stdout
                      terminateGhci ghci
                      -- putStrLn $ "ghci: terminated"
                      -- hFlush stdout
                      ghci'' <- makeGhci
                      -- putStrLn $ "ghci: new made"
                      -- hFlush stdout
                      return (ghci'', Error "<timeout>"))
    `catch` (\e -> do let _ = e :: SomeException
                      putStrLn $ "ghci: restarting due to exception: " ++ show e
                      hFlush stdout
                      terminateGhci ghci'
                      -- putStrLn $ "ghci: terminated"
                      -- hFlush stdout
                      ghci'' <- makeGhci
                      -- putStrLn $ "ghci: new made"
                      -- hFlush stdout
                      return (ghci'', Error "Oops, something went wrong"))

terminateGhci :: Ghci -> IO ()
terminateGhci ghci = terminateProcess (ghciProc ghci)

runStmt' :: Ghci -> ParseSettings -> String -> IO String
runStmt' ghci pset stmt = do
  tag <- updatePrompt ghci
  ghciPutStrLn (ghciStdin ghci) stmt
  hFlush (ghciStdin ghci)
  let readmax = psMaxOutputLen pset + 8192
  (output, reason) <- hGetUntilUTF8 (ghciStdout ghci) (Just readmax) tag
  case reason of
    ReachedMaxLen -> do
      terminateGhci ghci  -- because we lost the new prompt
      return (formatOutput output)  -- don't need to strip tag because we read more than the max output len
    ReachedTag -> return (formatOutput $ take (length output - length tag) output)
    ReachedEOF -> do
      terminateGhci ghci
      return (formatOutput output)
  where
    formatOutput output =
      let output' | psJoinLines pset = replaceNewlines (dropBothSlow isSpace output)
                  | otherwise        = output
      in if length output' > psMaxOutputLen pset
           then take (psMaxOutputLen pset - 3) output' ++ "..."
           else output'
    dropBothSlow f = reverse . dropWhile f . reverse . dropWhile f
    replaceNewlines = concatMap (\case '\n' -> " ; " ; c -> [c])

-- | Returns new prompt tag
updatePrompt :: Ghci -> IO String
updatePrompt ghci = do
  tag <- genTag
  -- putStrLn ("chose prompt: " ++ tag)
  ghciPutStrLn (ghciStdin ghci) (":set prompt " ++ tag)
  hFlush (ghciStdin ghci)
  -- putStrLn ("set prompt " ++ tag)
  _ <- hGetUntilUTF8 (ghciStdout ghci) (Just 8192) tag
  return tag

genTag :: IO String
genTag = replicateM 20 (getStdRandom (uniformR ('a', 'z')))

data CutoffReason = ReachedMaxLen | ReachedTag | ReachedEOF
  deriving (Show)

hGetUntilUTF8 :: Handle -> Maybe Int -> String -> IO (String, CutoffReason)
hGetUntilUTF8 h mmax tag =
  first LUTF8.toString <$> hGetUntil h mmax (BSS.toShort (UTF8.fromString tag))

hGetUntil :: Handle -> Maybe Int -> ShortByteString -> IO (Lazy.ByteString, CutoffReason)
hGetUntil h mmax tag = do
  let size = 1024
      exceedsMax yet = case mmax of Just m -> yet >= m
                                    Nothing -> True
  -- putStrLn ("tag = " ++ show tag)
  allocaBytes size $ \ptr -> do
    let loop yet havePrefixes builder = do
          when (exceedsMax yet) $ exitEarly (BSB.toLazyByteString builder, ReachedMaxLen)

          nread <- lift $ hGetBufSome h ptr size
          when (nread <= 0) $ exitEarly (BSB.toLazyByteString builder, ReachedEOF)

          bs <- lift $ BS.packCStringLen (ptr, nread)
          when debugPrints $ lift $ hPutStrLn stderr ("Read: " ++ show bs)
          when (or [BSS.toShort (BS.takeEnd suflen bs) == BSS.takeEnd suflen tag
                   | n <- 0 : havePrefixes
                   , let suflen = BSS.length tag - n]) $ do
            when debugPrints $ lift $ hPutStrLn stderr "yay determined end"
            exitEarly (BSB.toLazyByteString (builder <> BSB.byteString bs)
                      ,ReachedTag)

          let nextPrefixes =
                nub $
                  [plen + BS.length bs  -- continuations of partial matches
                  | plen <- havePrefixes
                  , BS.length bs < BSS.length tag - plen
                  , BSS.toShort bs == BSS.take (BS.length bs) (BSS.drop plen tag)]
                  ++
                  [n  -- new matches
                  | n <- [1 .. min (BS.length bs) (BSS.length tag)]
                  , BSS.toShort (BS.takeEnd n bs) == BSS.take n tag]
          when debugPrints $ lift $ hPutStrLn stderr ("nextPrefixes = " ++ show nextPrefixes)
          loop (yet + nread) nextPrefixes (builder <> BSB.byteString bs)
    result <- execExitEarlyT (loop 0 [] mempty)
    when debugPrints $ hPutStrLn stderr ("Read: <" ++ show result ++ ">")
    return $! result