aboutsummaryrefslogtreecommitdiff
path: root/controller/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'controller/src/lib.rs')
-rw-r--r--controller/src/lib.rs492
1 files changed, 492 insertions, 0 deletions
diff --git a/controller/src/lib.rs b/controller/src/lib.rs
new file mode 100644
index 0000000..ee02b3b
--- /dev/null
+++ b/controller/src/lib.rs
@@ -0,0 +1,492 @@
+use std::collections::{HashMap, VecDeque};
+use std::convert::TryInto;
+use std::io::{self, ErrorKind};
+use std::net::Shutdown;
+use std::sync::Arc;
+use std::thread;
+use tokio::io::{AsyncReadExt, AsyncWriteExt, WriteHalf};
+use tokio::net::{TcpListener, TcpStream};
+use tokio::runtime;
+use tokio::task;
+use tokio::sync::{mpsc, oneshot, Mutex};
+use dd_utils::error::*;
+use dd_utils::idgen::IdGen;
+use dd_utils::protocol::*;
+use dd_utils::read_ext::ReadExt;
+
+async fn send_vectored(
+ writer: &mut (impl AsyncWriteExt + Unpin),
+ typ: u8,
+ id: u64,
+ payload: &[impl AsRef<[u8]>]
+) -> io::Result<()> {
+ let mut header = [0u8; 17];
+ header[0] = typ;
+ header[1..9].copy_from_slice(&id.to_le_bytes());
+ let sumlen: usize = payload.iter().map(|v| v.as_ref().len()).sum();
+ header[9..17].copy_from_slice(&sumlen.to_le_bytes());
+ writer.write_all(&header).await?;
+ for part in payload {
+ writer.write_all(part.as_ref()).await?;
+ }
+ Ok(())
+}
+
+async fn receive_message(reader: &mut (impl AsyncReadExt + Unpin))
+ -> io::Result<Option<RawMessage>> {
+ let mut header = [0u8; 17];
+ if let Err(e) = reader.read(&mut header).await {
+ if e.kind() == ErrorKind::UnexpectedEof { return Ok(None); }
+ else { return Err(e); }
+ }
+
+ let typ = header[0];
+ let id = u64::from_le_bytes(header[1..9].try_into().unwrap());
+ let length = usize::from_le_bytes(header[9..17].try_into().unwrap());
+
+ let mut payload = Vec::new();
+ payload.resize(length, 0u8);
+ if let Err(e) = reader.read(&mut payload).await {
+ if e.kind() == ErrorKind::UnexpectedEof { return Ok(None); }
+ else { return Err(e); }
+ }
+
+ Ok(Some(RawMessage { typ, id, payload }))
+}
+
+fn encode_message(msg: MessageBody) -> Vec<Vec<u8>> {
+ let mut payload = Vec::new();
+
+ match msg {
+ MessageBody::Version(version) => {
+ payload.push(version.to_le_bytes().to_vec());
+ }
+
+ MessageBody::NewCore(name, libfile) => {
+ payload.push(name.len().to_le_bytes().to_vec());
+ payload.push(name.into_bytes());
+ payload.push(libfile.len().to_le_bytes().to_vec());
+ payload.push(libfile);
+ }
+
+ MessageBody::Job(jobid, input) => {
+ payload.push(jobid.to_le_bytes().to_vec());
+ payload.push(input.len().to_le_bytes().to_vec());
+ payload.push(input);
+ }
+ }
+
+ payload
+}
+
+#[derive(Debug)]
+pub struct CompletionEvent {
+ pub jobid: u64,
+ pub result: Result<(i32, Vec<u8>), String>,
+}
+
+#[derive(Debug)]
+enum Inbound {
+ Completion(CompletionEvent),
+}
+
+#[derive(Debug)]
+enum Outbound {
+ NewCore(String, Vec<u8>),
+ NewJob(u64, Vec<u8>),
+ NumWorkers(oneshot::Sender<u64>),
+ Quit,
+}
+
+pub struct ComputePool {
+ runtime: runtime::Runtime,
+
+ iothread: Option<thread::JoinHandle<()>>,
+ inbound: mpsc::UnboundedReceiver<Inbound>,
+ outbound: mpsc::Sender<Outbound>,
+
+ // The number of jobs for which the completion event has not yet been consumed
+ num_running: u64,
+}
+
+#[derive(Debug)]
+enum ThreadCollect {
+ NewWorker(TcpStream, u64), // Worker socket, and next unused message id
+ WorkerReady(u64), // Worker id (indicates that this worker has its core initialised)
+ Query(Outbound),
+ Completion(u64, CompletionEvent), // Worker id, and event
+}
+
+async fn thread_handshake_handler(mut listener: TcpListener, sink: mpsc::Sender<ThreadCollect>) {
+ loop {
+ let (mut sock, _) = listener.accept().await.expect("Accept failed on TCP server socket");
+ let mut sink = sink.clone();
+ task::spawn(async move {
+ let payload = encode_message(MessageBody::Version(1));
+ if send_vectored(&mut sock, 1, 1, &payload).await.is_err() {
+ match sock.shutdown(Shutdown::Both) {
+ Ok(()) => {}
+ Err(_) => {} // explicitly ignore errors here, we're closing anyway
+ }
+ return;
+ }
+
+ match receive_message(&mut sock).await {
+ Ok(Some(rawmsg)) if rawmsg.typ == 1 && rawmsg.id == 1 &&
+ rawmsg.payload.len() == 1 &&
+ rawmsg.payload[0] == 1
+ => {
+ sink.send(ThreadCollect::NewWorker(sock, 2)).await.unwrap();
+ }
+
+ _ => {
+ match sock.shutdown(Shutdown::Both) {
+ Ok(()) => {}
+ Err(_) => {} // explicitly ignore errors here, we're closing anyway
+ }
+ }
+ }
+ });
+ }
+}
+
+async fn thread_query_handler(
+ mut chan: mpsc::Receiver<Outbound>,
+ mut sink: mpsc::Sender<ThreadCollect>
+) {
+ loop {
+ if let Some(msg) = chan.recv().await {
+ sink.send(ThreadCollect::Query(msg)).await.unwrap();
+ } else {
+ return;
+ }
+ }
+}
+
+struct Worker {
+ socket: WriteHalf<TcpStream>,
+ msg_idgen: IdGen,
+ loaded_core: u64,
+ handler_map: Arc<Mutex<HashMap<u64, Box<dyn FnOnce(u64, RawMessage) + Send>>>>,
+}
+
+impl Worker {
+ fn new(worker_id: u64, next_msg_id: u64, socket: TcpStream) -> Self {
+ let (mut read_half, write_half) = tokio::io::split(socket);
+ let worker = Worker {
+ socket: write_half,
+ msg_idgen: IdGen::new(next_msg_id),
+ loaded_core: 0,
+ handler_map: Arc::new(Mutex::new(HashMap::new())),
+ };
+
+ {
+ let handler_map = worker.handler_map.clone();
+ task::spawn(async move {
+ loop {
+ let rawmsg = match receive_message(&mut read_half).await {
+ Ok(Some(rawmsg)) => rawmsg,
+ _ => break,
+ };
+
+ let mut handler_map = handler_map.lock().await;
+ match handler_map.remove(&rawmsg.id) {
+ Some(handler) => handler(worker_id, rawmsg),
+ None => {
+ eprintln!("Warning: no handler found for worker reply id {}", rawmsg.id);
+ }
+ }
+ }
+ });
+ }
+
+ worker
+ }
+
+ /// Returns whether the send succeeded. In case of failure, the handler is not registered.
+ async fn send(
+ &mut self,
+ typ: u8,
+ payload: &[impl AsRef<[u8]>],
+ handler: impl FnOnce(u64, RawMessage) + Send + 'static
+ ) -> bool {
+ let msgid = self.msg_idgen.gen();
+ self.handler_map.lock().await.insert(msgid, Box::new(handler));
+ match send_vectored(&mut self.socket, typ, msgid, &payload).await {
+ Ok(()) => true,
+ Err(_) => {
+ self.handler_map.lock().await.remove(&msgid);
+ false
+ }
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+struct Job {
+ id: u64,
+ input: Vec<u8>,
+}
+
+#[derive(Debug, Clone)]
+struct ComputeCore {
+ name: String,
+ libfile: Vec<u8>,
+}
+
+/// Returns whether job was successfully sent to the worker
+async fn worker_run_job(
+ worker: &mut Worker,
+ job: Job,
+ mut result_chan: mpsc::Sender<ThreadCollect>
+) -> bool {
+ let Job { id: jobid, input } = job;
+ let payload = encode_message(MessageBody::Job(jobid, input));
+ let handler = move |wid: u64, rawmsg: RawMessage| {
+ task::spawn(async move {
+ let result = if rawmsg.typ == 1 {
+ let mut reader: &[u8] = &rawmsg.payload;
+ if let Some(retval) = reader.read_le_i32() {
+ if let Some(output) = reader.read_pascal_blob() {
+ Ok((retval, output.to_vec()))
+ } else {
+ Err("<Invalid reply format!>".to_string())
+ }
+ } else {
+ Err("<Invalid reply format!>".to_string())
+ }
+ } else {
+ Err(String::from_utf8_lossy(&rawmsg.payload).to_string())
+ };
+
+ let event = ThreadCollect::Completion(
+ wid,
+ CompletionEvent { jobid, result }
+ );
+ result_chan.send(event).await.unwrap();
+ });
+ };
+
+ worker.send(3, &payload, handler).await
+}
+
+/// Returns whether the message was successfully sent to the worker.
+async fn worker_set_new_core(
+ worker: &mut Worker,
+ new_core_id: u64,
+ core: ComputeCore,
+ workers_map: Arc<Mutex<HashMap<u64, Worker>>>,
+ result_chan: mpsc::Sender<ThreadCollect>
+) -> bool {
+ worker_set_new_core_payload(
+ worker,
+ new_core_id,
+ &encode_message(MessageBody::NewCore(core.name, core.libfile)),
+ workers_map,
+ result_chan
+ ).await
+}
+
+/// Returns whether the message was successfully sent to the worker.
+async fn worker_set_new_core_payload(
+ worker: &mut Worker,
+ new_core_id: u64,
+ payload: &[impl AsRef<[u8]>],
+ workers_map: Arc<Mutex<HashMap<u64, Worker>>>,
+ mut result_chan: mpsc::Sender<ThreadCollect>
+) -> bool {
+ let handler = move |wid: u64, rawmsg: RawMessage| {
+ task::spawn(async move {
+ if rawmsg.typ == 1 {
+ if let Some(worker) = workers_map.lock().await.get_mut(&wid) {
+ worker.loaded_core = new_core_id;
+ result_chan.send(ThreadCollect::WorkerReady(wid)).await.unwrap();
+ }
+ } else {
+ eprintln!("Worker {} could not load new core:\n{}",
+ wid, String::from_utf8_lossy(&rawmsg.payload));
+ }
+ });
+ };
+ worker.send(2, payload, handler).await
+}
+
+// IO thread:
+// - waits for worker connections
+// - has a channel open to the parent thread on which it receives commands to send to workers
+// - receives responses from workers and puts those on the channel
+fn thread_entry(
+ listen_port: u16,
+ query_chan: mpsc::Receiver<Outbound>,
+ event_chan: mpsc::UnboundedSender<Inbound>
+) {
+ runtime::Runtime::new().unwrap().block_on(async {
+ let (mut collector_sink, mut collector_source) = mpsc::channel(10);
+ let listener = TcpListener::bind(("0.0.0.0", listen_port)).await.unwrap();
+
+ task::spawn(thread_handshake_handler(listener, collector_sink.clone()));
+ task::spawn(thread_query_handler(query_chan, collector_sink.clone()));
+
+ let mut worker_idgen = IdGen::new(1);
+
+ let workers: Arc<Mutex<HashMap<u64, Worker>>> = Arc::new(Mutex::new(HashMap::new()));
+ let mut free_workers: VecDeque<u64> = VecDeque::new();
+ let mut current_core: Option<ComputeCore> = None;
+ let mut current_core_id: u64 = 0;
+ let mut core_idgen = IdGen::new(1);
+ let mut job_queue: VecDeque<Job> = VecDeque::new();
+
+ loop {
+ let message = match collector_source.recv().await {
+ Some(message) => message,
+ None => break,
+ };
+
+ match message {
+ ThreadCollect::Query(Outbound::NewCore(name, libfile)) => {
+ let new_core_id = core_idgen.gen();
+ current_core = Some(ComputeCore { name: name.clone(), libfile: libfile.clone() });
+ current_core_id = new_core_id;
+
+ let payload = encode_message(MessageBody::NewCore(name, libfile));
+
+ let mut workers_locked = workers.lock().await;
+
+ while let Some(worker_id) = free_workers.pop_front() {
+ if let Some(worker) = workers_locked.get_mut(&worker_id) {
+ if !worker_set_new_core_payload(
+ worker, new_core_id, &payload, workers.clone(),
+ collector_sink.clone()).await {
+ workers_locked.remove(&worker_id);
+ }
+ }
+ }
+ }
+
+ ThreadCollect::Query(Outbound::NewJob(jobid, input)) => {
+ job_queue.push_back(Job { id: jobid, input });
+ }
+
+ ThreadCollect::Query(Outbound::NumWorkers(chan)) => {
+ chan.send(workers.lock().await.len() as u64).unwrap();
+ }
+
+ ThreadCollect::Query(Outbound::Quit) => {
+ break;
+ }
+
+ ThreadCollect::Completion(worker_id, event) => {
+ event_chan.send(Inbound::Completion(event)).unwrap();
+ collector_sink.send(ThreadCollect::WorkerReady(worker_id)).await.unwrap();
+ }
+
+ ThreadCollect::NewWorker(socket, next_msg_id) => {
+ let wid = worker_idgen.gen();
+ let worker = Worker::new(wid, next_msg_id, socket);
+
+ workers.lock().await.insert(wid, worker);
+
+ // Delegate the core setup to the WorkerReady event
+ collector_sink.send(ThreadCollect::WorkerReady(wid)).await.unwrap();
+ }
+
+ ThreadCollect::WorkerReady(worker_id) => {
+ let mut workers_locked = workers.lock().await;
+
+ if let Some(worker) = workers_locked.get_mut(&worker_id) {
+ if worker.loaded_core != current_core_id {
+ let current_core = current_core.unwrap().clone();
+ worker_set_new_core(
+ worker, current_core_id, current_core, workers.clone(),
+ collector_sink.clone()
+ ).await;
+ return;
+ }
+
+ free_workers.push_back(worker_id);
+ }
+ }
+ }
+
+ if job_queue.len() > 0 && free_workers.len() > 0 {
+ let job = job_queue.pop_front().unwrap();
+ let mut workers_locked = workers.lock().await;
+
+ // Loop until a worker exists and is functional
+ loop {
+ let worker_id = free_workers.pop_front().unwrap();
+ if let Some(worker) = workers_locked.get_mut(&worker_id) {
+ let result_chan = collector_sink.clone();
+ if !worker_run_job(worker, job.clone(), result_chan).await {
+ workers_locked.remove(&worker_id);
+ } else {
+ break;
+ }
+ }
+ }
+ }
+ }
+ });
+}
+
+/// Will call ComputePool::close(), `unwrap`-ing the result.
+impl Drop for ComputePool {
+ fn drop(&mut self) {
+ self.close().unwrap();
+ }
+}
+
+impl ComputePool {
+ pub fn new(port: u16) -> io::Result<Self> {
+ // Spawn the IO thread
+ let (inbound_sender, inbound_receiver) = mpsc::unbounded_channel();
+ let (outbound_sender, outbound_receiver) = mpsc::channel(1);
+ let jh = thread::spawn(move || { thread_entry(port, outbound_receiver, inbound_sender) });
+ Ok(ComputePool {
+ runtime: runtime::Runtime::new()?,
+ iothread: Some(jh),
+ inbound: inbound_receiver,
+ outbound: outbound_sender,
+ num_running: 0,
+ })
+ }
+
+ pub fn close(&mut self) -> io::Result<()> {
+ self.runtime.block_on(self.outbound.send(Outbound::Quit)).iores()?;
+ self.iothread.take().map(|jh| jh.join().unwrap());
+ Ok(())
+ }
+
+ pub fn set_core(&mut self, name: String, libfile: Vec<u8>) -> io::Result<()> {
+ // Instruct the IO thread to send the core to all workers, and also set it on every new
+ // worker that arrives
+ self.runtime.block_on(self.outbound.send(Outbound::NewCore(name, libfile))).iores()
+ }
+
+ pub fn current_job_parallelism(&mut self) -> io::Result<u64> {
+ // Query the IO thread for the number of workers currently registered
+ let mut outbound = self.outbound.clone();
+ self.runtime.block_on(async {
+ let (sender, receiver) = oneshot::channel();
+ outbound.send(Outbound::NumWorkers(sender)).await.iores()?;
+ receiver.await.iores()
+ })
+ }
+
+ pub fn submit_job(&mut self, jobid: u64, input: Vec<u8>) -> io::Result<()> {
+ // Send the job to the IO thread, which will send it to a round-robin worker
+ self.runtime.block_on(self.outbound.send(Outbound::NewJob(jobid, input))).iores()
+ }
+
+ pub fn next_completion_event(&mut self) -> io::Result<Option<CompletionEvent>> {
+ // If the counter is still positive, wait for events from the IO thread
+ if self.num_running > 0 {
+ match self.runtime.block_on(self.inbound.recv()) {
+ Some(Inbound::Completion(event)) => Ok(Some(event)),
+ None => Err("IO thread unexpectedly quit".ioerr()),
+ }
+ } else {
+ Ok(None)
+ }
+ }
+}