diff options
author | Tom Smeding <tom.smeding@gmail.com> | 2020-03-27 22:47:57 +0100 |
---|---|---|
committer | Tom Smeding <tom.smeding@gmail.com> | 2020-03-27 22:47:57 +0100 |
commit | fd421e32780cad46782c16cd4e15947f295a08c7 (patch) | |
tree | 04632f49f7c8860dee4237a0afe8292a949bdc9e /controller |
Initial, untested version of controller and worker
Worker has been tested to a marginal extent, but the controller is
litereally untested.
Diffstat (limited to 'controller')
-rw-r--r-- | controller/Cargo.toml | 9 | ||||
-rw-r--r-- | controller/src/lib.rs | 492 |
2 files changed, 501 insertions, 0 deletions
diff --git a/controller/Cargo.toml b/controller/Cargo.toml new file mode 100644 index 0000000..319e1b7 --- /dev/null +++ b/controller/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "dd-controller" +version = "0.1.0" +authors = ["Tom Smeding <tom.smeding@gmail.com>"] +edition = "2018" + +[dependencies] +tokio = { version = "0.2", features = ["full"] } +dd-utils = { path = "../utils" } diff --git a/controller/src/lib.rs b/controller/src/lib.rs new file mode 100644 index 0000000..ee02b3b --- /dev/null +++ b/controller/src/lib.rs @@ -0,0 +1,492 @@ +use std::collections::{HashMap, VecDeque}; +use std::convert::TryInto; +use std::io::{self, ErrorKind}; +use std::net::Shutdown; +use std::sync::Arc; +use std::thread; +use tokio::io::{AsyncReadExt, AsyncWriteExt, WriteHalf}; +use tokio::net::{TcpListener, TcpStream}; +use tokio::runtime; +use tokio::task; +use tokio::sync::{mpsc, oneshot, Mutex}; +use dd_utils::error::*; +use dd_utils::idgen::IdGen; +use dd_utils::protocol::*; +use dd_utils::read_ext::ReadExt; + +async fn send_vectored( + writer: &mut (impl AsyncWriteExt + Unpin), + typ: u8, + id: u64, + payload: &[impl AsRef<[u8]>] +) -> io::Result<()> { + let mut header = [0u8; 17]; + header[0] = typ; + header[1..9].copy_from_slice(&id.to_le_bytes()); + let sumlen: usize = payload.iter().map(|v| v.as_ref().len()).sum(); + header[9..17].copy_from_slice(&sumlen.to_le_bytes()); + writer.write_all(&header).await?; + for part in payload { + writer.write_all(part.as_ref()).await?; + } + Ok(()) +} + +async fn receive_message(reader: &mut (impl AsyncReadExt + Unpin)) + -> io::Result<Option<RawMessage>> { + let mut header = [0u8; 17]; + if let Err(e) = reader.read(&mut header).await { + if e.kind() == ErrorKind::UnexpectedEof { return Ok(None); } + else { return Err(e); } + } + + let typ = header[0]; + let id = u64::from_le_bytes(header[1..9].try_into().unwrap()); + let length = usize::from_le_bytes(header[9..17].try_into().unwrap()); + + let mut payload = Vec::new(); + payload.resize(length, 0u8); + if let Err(e) = reader.read(&mut payload).await { + if e.kind() == ErrorKind::UnexpectedEof { return Ok(None); } + else { return Err(e); } + } + + Ok(Some(RawMessage { typ, id, payload })) +} + +fn encode_message(msg: MessageBody) -> Vec<Vec<u8>> { + let mut payload = Vec::new(); + + match msg { + MessageBody::Version(version) => { + payload.push(version.to_le_bytes().to_vec()); + } + + MessageBody::NewCore(name, libfile) => { + payload.push(name.len().to_le_bytes().to_vec()); + payload.push(name.into_bytes()); + payload.push(libfile.len().to_le_bytes().to_vec()); + payload.push(libfile); + } + + MessageBody::Job(jobid, input) => { + payload.push(jobid.to_le_bytes().to_vec()); + payload.push(input.len().to_le_bytes().to_vec()); + payload.push(input); + } + } + + payload +} + +#[derive(Debug)] +pub struct CompletionEvent { + pub jobid: u64, + pub result: Result<(i32, Vec<u8>), String>, +} + +#[derive(Debug)] +enum Inbound { + Completion(CompletionEvent), +} + +#[derive(Debug)] +enum Outbound { + NewCore(String, Vec<u8>), + NewJob(u64, Vec<u8>), + NumWorkers(oneshot::Sender<u64>), + Quit, +} + +pub struct ComputePool { + runtime: runtime::Runtime, + + iothread: Option<thread::JoinHandle<()>>, + inbound: mpsc::UnboundedReceiver<Inbound>, + outbound: mpsc::Sender<Outbound>, + + // The number of jobs for which the completion event has not yet been consumed + num_running: u64, +} + +#[derive(Debug)] +enum ThreadCollect { + NewWorker(TcpStream, u64), // Worker socket, and next unused message id + WorkerReady(u64), // Worker id (indicates that this worker has its core initialised) + Query(Outbound), + Completion(u64, CompletionEvent), // Worker id, and event +} + +async fn thread_handshake_handler(mut listener: TcpListener, sink: mpsc::Sender<ThreadCollect>) { + loop { + let (mut sock, _) = listener.accept().await.expect("Accept failed on TCP server socket"); + let mut sink = sink.clone(); + task::spawn(async move { + let payload = encode_message(MessageBody::Version(1)); + if send_vectored(&mut sock, 1, 1, &payload).await.is_err() { + match sock.shutdown(Shutdown::Both) { + Ok(()) => {} + Err(_) => {} // explicitly ignore errors here, we're closing anyway + } + return; + } + + match receive_message(&mut sock).await { + Ok(Some(rawmsg)) if rawmsg.typ == 1 && rawmsg.id == 1 && + rawmsg.payload.len() == 1 && + rawmsg.payload[0] == 1 + => { + sink.send(ThreadCollect::NewWorker(sock, 2)).await.unwrap(); + } + + _ => { + match sock.shutdown(Shutdown::Both) { + Ok(()) => {} + Err(_) => {} // explicitly ignore errors here, we're closing anyway + } + } + } + }); + } +} + +async fn thread_query_handler( + mut chan: mpsc::Receiver<Outbound>, + mut sink: mpsc::Sender<ThreadCollect> +) { + loop { + if let Some(msg) = chan.recv().await { + sink.send(ThreadCollect::Query(msg)).await.unwrap(); + } else { + return; + } + } +} + +struct Worker { + socket: WriteHalf<TcpStream>, + msg_idgen: IdGen, + loaded_core: u64, + handler_map: Arc<Mutex<HashMap<u64, Box<dyn FnOnce(u64, RawMessage) + Send>>>>, +} + +impl Worker { + fn new(worker_id: u64, next_msg_id: u64, socket: TcpStream) -> Self { + let (mut read_half, write_half) = tokio::io::split(socket); + let worker = Worker { + socket: write_half, + msg_idgen: IdGen::new(next_msg_id), + loaded_core: 0, + handler_map: Arc::new(Mutex::new(HashMap::new())), + }; + + { + let handler_map = worker.handler_map.clone(); + task::spawn(async move { + loop { + let rawmsg = match receive_message(&mut read_half).await { + Ok(Some(rawmsg)) => rawmsg, + _ => break, + }; + + let mut handler_map = handler_map.lock().await; + match handler_map.remove(&rawmsg.id) { + Some(handler) => handler(worker_id, rawmsg), + None => { + eprintln!("Warning: no handler found for worker reply id {}", rawmsg.id); + } + } + } + }); + } + + worker + } + + /// Returns whether the send succeeded. In case of failure, the handler is not registered. + async fn send( + &mut self, + typ: u8, + payload: &[impl AsRef<[u8]>], + handler: impl FnOnce(u64, RawMessage) + Send + 'static + ) -> bool { + let msgid = self.msg_idgen.gen(); + self.handler_map.lock().await.insert(msgid, Box::new(handler)); + match send_vectored(&mut self.socket, typ, msgid, &payload).await { + Ok(()) => true, + Err(_) => { + self.handler_map.lock().await.remove(&msgid); + false + } + } + } +} + +#[derive(Debug, Clone)] +struct Job { + id: u64, + input: Vec<u8>, +} + +#[derive(Debug, Clone)] +struct ComputeCore { + name: String, + libfile: Vec<u8>, +} + +/// Returns whether job was successfully sent to the worker +async fn worker_run_job( + worker: &mut Worker, + job: Job, + mut result_chan: mpsc::Sender<ThreadCollect> +) -> bool { + let Job { id: jobid, input } = job; + let payload = encode_message(MessageBody::Job(jobid, input)); + let handler = move |wid: u64, rawmsg: RawMessage| { + task::spawn(async move { + let result = if rawmsg.typ == 1 { + let mut reader: &[u8] = &rawmsg.payload; + if let Some(retval) = reader.read_le_i32() { + if let Some(output) = reader.read_pascal_blob() { + Ok((retval, output.to_vec())) + } else { + Err("<Invalid reply format!>".to_string()) + } + } else { + Err("<Invalid reply format!>".to_string()) + } + } else { + Err(String::from_utf8_lossy(&rawmsg.payload).to_string()) + }; + + let event = ThreadCollect::Completion( + wid, + CompletionEvent { jobid, result } + ); + result_chan.send(event).await.unwrap(); + }); + }; + + worker.send(3, &payload, handler).await +} + +/// Returns whether the message was successfully sent to the worker. +async fn worker_set_new_core( + worker: &mut Worker, + new_core_id: u64, + core: ComputeCore, + workers_map: Arc<Mutex<HashMap<u64, Worker>>>, + result_chan: mpsc::Sender<ThreadCollect> +) -> bool { + worker_set_new_core_payload( + worker, + new_core_id, + &encode_message(MessageBody::NewCore(core.name, core.libfile)), + workers_map, + result_chan + ).await +} + +/// Returns whether the message was successfully sent to the worker. +async fn worker_set_new_core_payload( + worker: &mut Worker, + new_core_id: u64, + payload: &[impl AsRef<[u8]>], + workers_map: Arc<Mutex<HashMap<u64, Worker>>>, + mut result_chan: mpsc::Sender<ThreadCollect> +) -> bool { + let handler = move |wid: u64, rawmsg: RawMessage| { + task::spawn(async move { + if rawmsg.typ == 1 { + if let Some(worker) = workers_map.lock().await.get_mut(&wid) { + worker.loaded_core = new_core_id; + result_chan.send(ThreadCollect::WorkerReady(wid)).await.unwrap(); + } + } else { + eprintln!("Worker {} could not load new core:\n{}", + wid, String::from_utf8_lossy(&rawmsg.payload)); + } + }); + }; + worker.send(2, payload, handler).await +} + +// IO thread: +// - waits for worker connections +// - has a channel open to the parent thread on which it receives commands to send to workers +// - receives responses from workers and puts those on the channel +fn thread_entry( + listen_port: u16, + query_chan: mpsc::Receiver<Outbound>, + event_chan: mpsc::UnboundedSender<Inbound> +) { + runtime::Runtime::new().unwrap().block_on(async { + let (mut collector_sink, mut collector_source) = mpsc::channel(10); + let listener = TcpListener::bind(("0.0.0.0", listen_port)).await.unwrap(); + + task::spawn(thread_handshake_handler(listener, collector_sink.clone())); + task::spawn(thread_query_handler(query_chan, collector_sink.clone())); + + let mut worker_idgen = IdGen::new(1); + + let workers: Arc<Mutex<HashMap<u64, Worker>>> = Arc::new(Mutex::new(HashMap::new())); + let mut free_workers: VecDeque<u64> = VecDeque::new(); + let mut current_core: Option<ComputeCore> = None; + let mut current_core_id: u64 = 0; + let mut core_idgen = IdGen::new(1); + let mut job_queue: VecDeque<Job> = VecDeque::new(); + + loop { + let message = match collector_source.recv().await { + Some(message) => message, + None => break, + }; + + match message { + ThreadCollect::Query(Outbound::NewCore(name, libfile)) => { + let new_core_id = core_idgen.gen(); + current_core = Some(ComputeCore { name: name.clone(), libfile: libfile.clone() }); + current_core_id = new_core_id; + + let payload = encode_message(MessageBody::NewCore(name, libfile)); + + let mut workers_locked = workers.lock().await; + + while let Some(worker_id) = free_workers.pop_front() { + if let Some(worker) = workers_locked.get_mut(&worker_id) { + if !worker_set_new_core_payload( + worker, new_core_id, &payload, workers.clone(), + collector_sink.clone()).await { + workers_locked.remove(&worker_id); + } + } + } + } + + ThreadCollect::Query(Outbound::NewJob(jobid, input)) => { + job_queue.push_back(Job { id: jobid, input }); + } + + ThreadCollect::Query(Outbound::NumWorkers(chan)) => { + chan.send(workers.lock().await.len() as u64).unwrap(); + } + + ThreadCollect::Query(Outbound::Quit) => { + break; + } + + ThreadCollect::Completion(worker_id, event) => { + event_chan.send(Inbound::Completion(event)).unwrap(); + collector_sink.send(ThreadCollect::WorkerReady(worker_id)).await.unwrap(); + } + + ThreadCollect::NewWorker(socket, next_msg_id) => { + let wid = worker_idgen.gen(); + let worker = Worker::new(wid, next_msg_id, socket); + + workers.lock().await.insert(wid, worker); + + // Delegate the core setup to the WorkerReady event + collector_sink.send(ThreadCollect::WorkerReady(wid)).await.unwrap(); + } + + ThreadCollect::WorkerReady(worker_id) => { + let mut workers_locked = workers.lock().await; + + if let Some(worker) = workers_locked.get_mut(&worker_id) { + if worker.loaded_core != current_core_id { + let current_core = current_core.unwrap().clone(); + worker_set_new_core( + worker, current_core_id, current_core, workers.clone(), + collector_sink.clone() + ).await; + return; + } + + free_workers.push_back(worker_id); + } + } + } + + if job_queue.len() > 0 && free_workers.len() > 0 { + let job = job_queue.pop_front().unwrap(); + let mut workers_locked = workers.lock().await; + + // Loop until a worker exists and is functional + loop { + let worker_id = free_workers.pop_front().unwrap(); + if let Some(worker) = workers_locked.get_mut(&worker_id) { + let result_chan = collector_sink.clone(); + if !worker_run_job(worker, job.clone(), result_chan).await { + workers_locked.remove(&worker_id); + } else { + break; + } + } + } + } + } + }); +} + +/// Will call ComputePool::close(), `unwrap`-ing the result. +impl Drop for ComputePool { + fn drop(&mut self) { + self.close().unwrap(); + } +} + +impl ComputePool { + pub fn new(port: u16) -> io::Result<Self> { + // Spawn the IO thread + let (inbound_sender, inbound_receiver) = mpsc::unbounded_channel(); + let (outbound_sender, outbound_receiver) = mpsc::channel(1); + let jh = thread::spawn(move || { thread_entry(port, outbound_receiver, inbound_sender) }); + Ok(ComputePool { + runtime: runtime::Runtime::new()?, + iothread: Some(jh), + inbound: inbound_receiver, + outbound: outbound_sender, + num_running: 0, + }) + } + + pub fn close(&mut self) -> io::Result<()> { + self.runtime.block_on(self.outbound.send(Outbound::Quit)).iores()?; + self.iothread.take().map(|jh| jh.join().unwrap()); + Ok(()) + } + + pub fn set_core(&mut self, name: String, libfile: Vec<u8>) -> io::Result<()> { + // Instruct the IO thread to send the core to all workers, and also set it on every new + // worker that arrives + self.runtime.block_on(self.outbound.send(Outbound::NewCore(name, libfile))).iores() + } + + pub fn current_job_parallelism(&mut self) -> io::Result<u64> { + // Query the IO thread for the number of workers currently registered + let mut outbound = self.outbound.clone(); + self.runtime.block_on(async { + let (sender, receiver) = oneshot::channel(); + outbound.send(Outbound::NumWorkers(sender)).await.iores()?; + receiver.await.iores() + }) + } + + pub fn submit_job(&mut self, jobid: u64, input: Vec<u8>) -> io::Result<()> { + // Send the job to the IO thread, which will send it to a round-robin worker + self.runtime.block_on(self.outbound.send(Outbound::NewJob(jobid, input))).iores() + } + + pub fn next_completion_event(&mut self) -> io::Result<Option<CompletionEvent>> { + // If the counter is still positive, wait for events from the IO thread + if self.num_running > 0 { + match self.runtime.block_on(self.inbound.recv()) { + Some(Inbound::Completion(event)) => Ok(Some(event)), + None => Err("IO thread unexpectedly quit".ioerr()), + } + } else { + Ok(None) + } + } +} |