aboutsummaryrefslogtreecommitdiff
path: root/worker/src/connection.rs
diff options
context:
space:
mode:
Diffstat (limited to 'worker/src/connection.rs')
-rw-r--r--worker/src/connection.rs119
1 files changed, 119 insertions, 0 deletions
diff --git a/worker/src/connection.rs b/worker/src/connection.rs
new file mode 100644
index 0000000..1222c92
--- /dev/null
+++ b/worker/src/connection.rs
@@ -0,0 +1,119 @@
+use std::io::{self, BufReader, Write};
+use std::net::TcpStream;
+use net2::TcpStreamExt;
+use dd_utils::read_ext::ReadExt;
+use dd_utils::error::*;
+use dd_utils::protocol::*;
+
+fn send_vectored<W: Write, V: AsRef<[u8]>>(
+ writer: &mut W,
+ typ: u8,
+ id: u64,
+ payload: &[V]
+) -> io::Result<()> {
+ let mut header = [0u8; 17];
+ header[0] = typ;
+ header[1..9].copy_from_slice(&id.to_le_bytes());
+ let sumlen: usize = payload.iter().map(|v| v.as_ref().len()).sum();
+ header[9..17].copy_from_slice(&sumlen.to_le_bytes());
+ writer.write_all(&header)?;
+ for part in payload {
+ writer.write_all(part.as_ref())?;
+ }
+ Ok(())
+}
+
+pub struct Message {
+ pub id: u64,
+ pub body: MessageBody,
+}
+
+fn interpret_message(raw: RawMessage) -> io::Result<Message> {
+ let mut reader: &[u8] = &raw.payload;
+
+ macro_rules! too_short { () => { "Message payload too short!".ioerr() } }
+
+ match raw.typ {
+ 1 => {
+ let version = reader.read_le_u32().ok_or(too_short!())?;
+ Ok(Message { id: raw.id, body: MessageBody::Version(version) })
+ }
+
+ 2 => {
+ let name = reader.read_pascal_string()
+ .map(|r| r.iores())
+ .ok_or(too_short!())??;
+ let libfile = reader.read_pascal_blob()
+ .map(|b| b.to_vec())
+ .ok_or(too_short!())?;
+ Ok(Message { id: raw.id, body: MessageBody::NewCore(name, libfile) })
+ }
+
+ 3 => {
+ let jobid = reader.read_le_u64().ok_or(too_short!())?;
+ let jobdata = reader.read_pascal_blob().map(|b| b.to_vec()).ok_or(too_short!())?;
+ Ok(Message { id: raw.id, body: MessageBody::Job(jobid, jobdata) })
+ }
+
+ _ => {
+ Err(format!("Unknown message type {}", raw.typ).ioerr())
+ }
+ }
+}
+
+pub struct Connection {
+ reader: BufReader<TcpStream>,
+}
+
+impl Connection {
+ pub fn new(address: &str) -> io::Result<Self> {
+ let socket = TcpStream::connect(address)?;
+
+ socket.set_keepalive_ms(Some(60000))?;
+
+ Ok(Self {
+ reader: BufReader::new(socket),
+ })
+ }
+
+ pub fn close(self) {}
+
+ pub fn receive(&mut self) -> io::Result<Option<Message>> {
+ RawMessage::receive(&mut self.reader)
+ .and_then(|opt|
+ opt.map(|rawmsg| interpret_message(rawmsg)).transpose()
+ )
+ }
+
+ pub fn reply(&mut self, msgid: u64, rmsg: Result<Reply, String>) -> io::Result<()> {
+ let socket = self.reader.get_mut();
+ let mut payload: Vec<Vec<u8>> = Vec::new();
+
+ let response_type;
+
+ match rmsg {
+ Ok(msg) => {
+ response_type = 1;
+ match msg {
+ Reply::Version(ok) => {
+ payload.push(vec![if ok { 1u8 } else { 0u8 }]);
+ }
+
+ Reply::NewCore => {}
+
+ Reply::Job(jobid, output) => {
+ payload.push(jobid.to_le_bytes().to_vec());
+ payload.push(output);
+ }
+ }
+ }
+
+ Err(err) => {
+ response_type = 0xff;
+ payload.push(err.into_bytes());
+ }
+ }
+
+ send_vectored(socket, response_type, msgid, &payload)
+ }
+}