use std::collections::{HashMap, VecDeque}; use std::convert::TryInto; use std::io::{self, ErrorKind}; use std::net::Shutdown; use std::sync::Arc; use std::thread; use tokio::io::{AsyncReadExt, AsyncWriteExt, WriteHalf}; use tokio::net::{TcpListener, TcpStream}; use tokio::runtime; use tokio::task; use tokio::sync::{mpsc, oneshot, Mutex}; use dd_utils::error::*; use dd_utils::idgen::IdGen; use dd_utils::protocol::*; use dd_utils::read_ext::ReadExt; async fn send_vectored( writer: &mut (impl AsyncWriteExt + Unpin), typ: u8, id: u64, payload: &[impl AsRef<[u8]>] ) -> io::Result<()> { let mut header = [0u8; 17]; header[0] = typ; header[1..9].copy_from_slice(&id.to_le_bytes()); let sumlen: usize = payload.iter().map(|v| v.as_ref().len()).sum(); header[9..17].copy_from_slice(&sumlen.to_le_bytes()); writer.write_all(&header).await?; for part in payload { writer.write_all(part.as_ref()).await?; } Ok(()) } async fn receive_message(reader: &mut (impl AsyncReadExt + Unpin)) -> io::Result> { let mut header = [0u8; 17]; if let Err(e) = reader.read(&mut header).await { if e.kind() == ErrorKind::UnexpectedEof { return Ok(None); } else { return Err(e); } } let typ = header[0]; let id = u64::from_le_bytes(header[1..9].try_into().unwrap()); let length = usize::from_le_bytes(header[9..17].try_into().unwrap()); let mut payload = Vec::new(); payload.resize(length, 0u8); if let Err(e) = reader.read(&mut payload).await { if e.kind() == ErrorKind::UnexpectedEof { return Ok(None); } else { return Err(e); } } Ok(Some(RawMessage { typ, id, payload })) } fn encode_message(msg: MessageBody) -> Vec> { let mut payload = Vec::new(); match msg { MessageBody::Version(version) => { payload.push(version.to_le_bytes().to_vec()); } MessageBody::NewCore(name, libfile) => { payload.push(name.len().to_le_bytes().to_vec()); payload.push(name.into_bytes()); payload.push(libfile.len().to_le_bytes().to_vec()); payload.push(libfile); } MessageBody::Job(jobid, input) => { payload.push(jobid.to_le_bytes().to_vec()); payload.push(input.len().to_le_bytes().to_vec()); payload.push(input); } } payload } #[derive(Debug)] pub struct CompletionEvent { pub jobid: u64, pub result: Result<(i32, Vec), String>, } #[derive(Debug)] enum Inbound { Completion(CompletionEvent), } #[derive(Debug)] enum Outbound { NewCore(String, Vec), NewJob(u64, Vec), NumWorkers(oneshot::Sender), Quit, } pub struct ComputePool { runtime: runtime::Runtime, iothread: Option>, inbound: mpsc::UnboundedReceiver, outbound: mpsc::Sender, // The number of jobs for which the completion event has not yet been consumed num_running: u64, } #[derive(Debug)] enum ThreadCollect { NewWorker(TcpStream, u64), // Worker socket, and next unused message id WorkerReady(u64), // Worker id (indicates that this worker has its core initialised) Query(Outbound), Completion(u64, CompletionEvent), // Worker id, and event } async fn thread_handshake_handler(mut listener: TcpListener, sink: mpsc::Sender) { loop { let (mut sock, _) = listener.accept().await.expect("Accept failed on TCP server socket"); let mut sink = sink.clone(); task::spawn(async move { let payload = encode_message(MessageBody::Version(1)); if send_vectored(&mut sock, 1, 1, &payload).await.is_err() { match sock.shutdown(Shutdown::Both) { Ok(()) => {} Err(_) => {} // explicitly ignore errors here, we're closing anyway } return; } match receive_message(&mut sock).await { Ok(Some(rawmsg)) if rawmsg.typ == 1 && rawmsg.id == 1 && rawmsg.payload.len() == 1 && rawmsg.payload[0] == 1 => { sink.send(ThreadCollect::NewWorker(sock, 2)).await.unwrap(); } _ => { match sock.shutdown(Shutdown::Both) { Ok(()) => {} Err(_) => {} // explicitly ignore errors here, we're closing anyway } } } }); } } async fn thread_query_handler( mut chan: mpsc::Receiver, mut sink: mpsc::Sender ) { loop { if let Some(msg) = chan.recv().await { sink.send(ThreadCollect::Query(msg)).await.unwrap(); } else { return; } } } struct Worker { socket: WriteHalf, msg_idgen: IdGen, loaded_core: u64, handler_map: Arc>>>, } impl Worker { fn new(worker_id: u64, next_msg_id: u64, socket: TcpStream) -> Self { let (mut read_half, write_half) = tokio::io::split(socket); let worker = Worker { socket: write_half, msg_idgen: IdGen::new(next_msg_id), loaded_core: 0, handler_map: Arc::new(Mutex::new(HashMap::new())), }; { let handler_map = worker.handler_map.clone(); task::spawn(async move { loop { let rawmsg = match receive_message(&mut read_half).await { Ok(Some(rawmsg)) => rawmsg, _ => break, }; let mut handler_map = handler_map.lock().await; match handler_map.remove(&rawmsg.id) { Some(handler) => handler(worker_id, rawmsg), None => { eprintln!("Warning: no handler found for worker reply id {}", rawmsg.id); } } } }); } worker } /// Returns whether the send succeeded. In case of failure, the handler is not registered. async fn send( &mut self, typ: u8, payload: &[impl AsRef<[u8]>], handler: impl FnOnce(u64, RawMessage) + Send + 'static ) -> bool { let msgid = self.msg_idgen.gen(); self.handler_map.lock().await.insert(msgid, Box::new(handler)); match send_vectored(&mut self.socket, typ, msgid, &payload).await { Ok(()) => true, Err(_) => { self.handler_map.lock().await.remove(&msgid); false } } } } #[derive(Debug, Clone)] struct Job { id: u64, input: Vec, } #[derive(Debug, Clone)] struct ComputeCore { name: String, libfile: Vec, } /// Returns whether job was successfully sent to the worker async fn worker_run_job( worker: &mut Worker, job: Job, mut result_chan: mpsc::Sender ) -> bool { let Job { id: jobid, input } = job; let payload = encode_message(MessageBody::Job(jobid, input)); let handler = move |wid: u64, rawmsg: RawMessage| { task::spawn(async move { let result = if rawmsg.typ == 1 { let mut reader: &[u8] = &rawmsg.payload; if let Some(retval) = reader.read_le_i32() { if let Some(output) = reader.read_pascal_blob() { Ok((retval, output.to_vec())) } else { Err("".to_string()) } } else { Err("".to_string()) } } else { Err(String::from_utf8_lossy(&rawmsg.payload).to_string()) }; let event = ThreadCollect::Completion( wid, CompletionEvent { jobid, result } ); result_chan.send(event).await.unwrap(); }); }; worker.send(3, &payload, handler).await } /// Returns whether the message was successfully sent to the worker. async fn worker_set_new_core( worker: &mut Worker, new_core_id: u64, core: ComputeCore, workers_map: Arc>>, result_chan: mpsc::Sender ) -> bool { worker_set_new_core_payload( worker, new_core_id, &encode_message(MessageBody::NewCore(core.name, core.libfile)), workers_map, result_chan ).await } /// Returns whether the message was successfully sent to the worker. async fn worker_set_new_core_payload( worker: &mut Worker, new_core_id: u64, payload: &[impl AsRef<[u8]>], workers_map: Arc>>, mut result_chan: mpsc::Sender ) -> bool { let handler = move |wid: u64, rawmsg: RawMessage| { task::spawn(async move { if rawmsg.typ == 1 { if let Some(worker) = workers_map.lock().await.get_mut(&wid) { worker.loaded_core = new_core_id; result_chan.send(ThreadCollect::WorkerReady(wid)).await.unwrap(); } } else { eprintln!("Worker {} could not load new core:\n{}", wid, String::from_utf8_lossy(&rawmsg.payload)); } }); }; worker.send(2, payload, handler).await } // IO thread: // - waits for worker connections // - has a channel open to the parent thread on which it receives commands to send to workers // - receives responses from workers and puts those on the channel fn thread_entry( listen_port: u16, query_chan: mpsc::Receiver, event_chan: mpsc::UnboundedSender ) { runtime::Runtime::new().unwrap().block_on(async { let (mut collector_sink, mut collector_source) = mpsc::channel(10); let listener = TcpListener::bind(("0.0.0.0", listen_port)).await.unwrap(); task::spawn(thread_handshake_handler(listener, collector_sink.clone())); task::spawn(thread_query_handler(query_chan, collector_sink.clone())); let mut worker_idgen = IdGen::new(1); let workers: Arc>> = Arc::new(Mutex::new(HashMap::new())); let mut free_workers: VecDeque = VecDeque::new(); let mut current_core: Option = None; let mut current_core_id: u64 = 0; let mut core_idgen = IdGen::new(1); let mut job_queue: VecDeque = VecDeque::new(); loop { let message = match collector_source.recv().await { Some(message) => message, None => break, }; match message { ThreadCollect::Query(Outbound::NewCore(name, libfile)) => { let new_core_id = core_idgen.gen(); current_core = Some(ComputeCore { name: name.clone(), libfile: libfile.clone() }); current_core_id = new_core_id; let payload = encode_message(MessageBody::NewCore(name, libfile)); let mut workers_locked = workers.lock().await; while let Some(worker_id) = free_workers.pop_front() { if let Some(worker) = workers_locked.get_mut(&worker_id) { if !worker_set_new_core_payload( worker, new_core_id, &payload, workers.clone(), collector_sink.clone()).await { workers_locked.remove(&worker_id); } } } } ThreadCollect::Query(Outbound::NewJob(jobid, input)) => { job_queue.push_back(Job { id: jobid, input }); } ThreadCollect::Query(Outbound::NumWorkers(chan)) => { chan.send(workers.lock().await.len() as u64).unwrap(); } ThreadCollect::Query(Outbound::Quit) => { break; } ThreadCollect::Completion(worker_id, event) => { event_chan.send(Inbound::Completion(event)).unwrap(); collector_sink.send(ThreadCollect::WorkerReady(worker_id)).await.unwrap(); } ThreadCollect::NewWorker(socket, next_msg_id) => { let wid = worker_idgen.gen(); let worker = Worker::new(wid, next_msg_id, socket); workers.lock().await.insert(wid, worker); // Delegate the core setup to the WorkerReady event collector_sink.send(ThreadCollect::WorkerReady(wid)).await.unwrap(); } ThreadCollect::WorkerReady(worker_id) => { let mut workers_locked = workers.lock().await; if let Some(worker) = workers_locked.get_mut(&worker_id) { if worker.loaded_core != current_core_id { let current_core = current_core.unwrap().clone(); worker_set_new_core( worker, current_core_id, current_core, workers.clone(), collector_sink.clone() ).await; return; } free_workers.push_back(worker_id); } } } if job_queue.len() > 0 && free_workers.len() > 0 { let job = job_queue.pop_front().unwrap(); let mut workers_locked = workers.lock().await; // Loop until a worker exists and is functional loop { let worker_id = free_workers.pop_front().unwrap(); if let Some(worker) = workers_locked.get_mut(&worker_id) { let result_chan = collector_sink.clone(); if !worker_run_job(worker, job.clone(), result_chan).await { workers_locked.remove(&worker_id); } else { break; } } } } } }); } /// Will call ComputePool::close(), `unwrap`-ing the result. impl Drop for ComputePool { fn drop(&mut self) { self.close().unwrap(); } } impl ComputePool { pub fn new(port: u16) -> io::Result { // Spawn the IO thread let (inbound_sender, inbound_receiver) = mpsc::unbounded_channel(); let (outbound_sender, outbound_receiver) = mpsc::channel(1); let jh = thread::spawn(move || { thread_entry(port, outbound_receiver, inbound_sender) }); Ok(ComputePool { runtime: runtime::Runtime::new()?, iothread: Some(jh), inbound: inbound_receiver, outbound: outbound_sender, num_running: 0, }) } pub fn close(&mut self) -> io::Result<()> { self.runtime.block_on(self.outbound.send(Outbound::Quit)).iores()?; self.iothread.take().map(|jh| jh.join().unwrap()); Ok(()) } pub fn set_core(&mut self, name: String, libfile: Vec) -> io::Result<()> { // Instruct the IO thread to send the core to all workers, and also set it on every new // worker that arrives self.runtime.block_on(self.outbound.send(Outbound::NewCore(name, libfile))).iores() } pub fn current_job_parallelism(&mut self) -> io::Result { // Query the IO thread for the number of workers currently registered let mut outbound = self.outbound.clone(); self.runtime.block_on(async { let (sender, receiver) = oneshot::channel(); outbound.send(Outbound::NumWorkers(sender)).await.iores()?; receiver.await.iores() }) } pub fn submit_job(&mut self, jobid: u64, input: Vec) -> io::Result<()> { // Send the job to the IO thread, which will send it to a round-robin worker self.runtime.block_on(self.outbound.send(Outbound::NewJob(jobid, input))).iores() } pub fn next_completion_event(&mut self) -> io::Result> { // If the counter is still positive, wait for events from the IO thread if self.num_running > 0 { match self.runtime.block_on(self.inbound.recv()) { Some(Inbound::Completion(event)) => Ok(Some(event)), None => Err("IO thread unexpectedly quit".ioerr()), } } else { Ok(None) } } }