diff options
Diffstat (limited to 'worker/src/connection.rs')
-rw-r--r-- | worker/src/connection.rs | 119 |
1 files changed, 119 insertions, 0 deletions
diff --git a/worker/src/connection.rs b/worker/src/connection.rs new file mode 100644 index 0000000..1222c92 --- /dev/null +++ b/worker/src/connection.rs @@ -0,0 +1,119 @@ +use std::io::{self, BufReader, Write}; +use std::net::TcpStream; +use net2::TcpStreamExt; +use dd_utils::read_ext::ReadExt; +use dd_utils::error::*; +use dd_utils::protocol::*; + +fn send_vectored<W: Write, V: AsRef<[u8]>>( + writer: &mut W, + typ: u8, + id: u64, + payload: &[V] +) -> io::Result<()> { + let mut header = [0u8; 17]; + header[0] = typ; + header[1..9].copy_from_slice(&id.to_le_bytes()); + let sumlen: usize = payload.iter().map(|v| v.as_ref().len()).sum(); + header[9..17].copy_from_slice(&sumlen.to_le_bytes()); + writer.write_all(&header)?; + for part in payload { + writer.write_all(part.as_ref())?; + } + Ok(()) +} + +pub struct Message { + pub id: u64, + pub body: MessageBody, +} + +fn interpret_message(raw: RawMessage) -> io::Result<Message> { + let mut reader: &[u8] = &raw.payload; + + macro_rules! too_short { () => { "Message payload too short!".ioerr() } } + + match raw.typ { + 1 => { + let version = reader.read_le_u32().ok_or(too_short!())?; + Ok(Message { id: raw.id, body: MessageBody::Version(version) }) + } + + 2 => { + let name = reader.read_pascal_string() + .map(|r| r.iores()) + .ok_or(too_short!())??; + let libfile = reader.read_pascal_blob() + .map(|b| b.to_vec()) + .ok_or(too_short!())?; + Ok(Message { id: raw.id, body: MessageBody::NewCore(name, libfile) }) + } + + 3 => { + let jobid = reader.read_le_u64().ok_or(too_short!())?; + let jobdata = reader.read_pascal_blob().map(|b| b.to_vec()).ok_or(too_short!())?; + Ok(Message { id: raw.id, body: MessageBody::Job(jobid, jobdata) }) + } + + _ => { + Err(format!("Unknown message type {}", raw.typ).ioerr()) + } + } +} + +pub struct Connection { + reader: BufReader<TcpStream>, +} + +impl Connection { + pub fn new(address: &str) -> io::Result<Self> { + let socket = TcpStream::connect(address)?; + + socket.set_keepalive_ms(Some(60000))?; + + Ok(Self { + reader: BufReader::new(socket), + }) + } + + pub fn close(self) {} + + pub fn receive(&mut self) -> io::Result<Option<Message>> { + RawMessage::receive(&mut self.reader) + .and_then(|opt| + opt.map(|rawmsg| interpret_message(rawmsg)).transpose() + ) + } + + pub fn reply(&mut self, msgid: u64, rmsg: Result<Reply, String>) -> io::Result<()> { + let socket = self.reader.get_mut(); + let mut payload: Vec<Vec<u8>> = Vec::new(); + + let response_type; + + match rmsg { + Ok(msg) => { + response_type = 1; + match msg { + Reply::Version(ok) => { + payload.push(vec![if ok { 1u8 } else { 0u8 }]); + } + + Reply::NewCore => {} + + Reply::Job(jobid, output) => { + payload.push(jobid.to_le_bytes().to_vec()); + payload.push(output); + } + } + } + + Err(err) => { + response_type = 0xff; + payload.push(err.into_bytes()); + } + } + + send_vectored(socket, response_type, msgid, &payload) + } +} |