aboutsummaryrefslogtreecommitdiff
path: root/worker/src/connection.rs
blob: 1222c928ed517c079cdc8055f0950adc9718447e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
use std::io::{self, BufReader, Write};
use std::net::TcpStream;
use net2::TcpStreamExt;
use dd_utils::read_ext::ReadExt;
use dd_utils::error::*;
use dd_utils::protocol::*;

fn send_vectored<W: Write, V: AsRef<[u8]>>(
    writer: &mut W,
    typ: u8,
    id: u64,
    payload: &[V]
) -> io::Result<()> {
    let mut header = [0u8; 17];
    header[0] = typ;
    header[1..9].copy_from_slice(&id.to_le_bytes());
    let sumlen: usize = payload.iter().map(|v| v.as_ref().len()).sum();
    header[9..17].copy_from_slice(&sumlen.to_le_bytes());
    writer.write_all(&header)?;
    for part in payload {
        writer.write_all(part.as_ref())?;
    }
    Ok(())
}

pub struct Message {
    pub id: u64,
    pub body: MessageBody,
}

fn interpret_message(raw: RawMessage) -> io::Result<Message> {
    let mut reader: &[u8] = &raw.payload;

    macro_rules! too_short { () => { "Message payload too short!".ioerr() } }

    match raw.typ {
        1 => {
            let version = reader.read_le_u32().ok_or(too_short!())?;
            Ok(Message { id: raw.id, body: MessageBody::Version(version) })
        }

        2 => {
            let name = reader.read_pascal_string()
                            .map(|r| r.iores())
                            .ok_or(too_short!())??;
            let libfile = reader.read_pascal_blob()
                            .map(|b| b.to_vec())
                            .ok_or(too_short!())?;
            Ok(Message { id: raw.id, body: MessageBody::NewCore(name, libfile) })
        }

        3 => {
            let jobid = reader.read_le_u64().ok_or(too_short!())?;
            let jobdata = reader.read_pascal_blob().map(|b| b.to_vec()).ok_or(too_short!())?;
            Ok(Message { id: raw.id, body: MessageBody::Job(jobid, jobdata) })
        }

        _ => {
            Err(format!("Unknown message type {}", raw.typ).ioerr())
        }
    }
}

pub struct Connection {
    reader: BufReader<TcpStream>,
}

impl Connection {
    pub fn new(address: &str) -> io::Result<Self> {
        let socket = TcpStream::connect(address)?;

        socket.set_keepalive_ms(Some(60000))?;

        Ok(Self {
            reader: BufReader::new(socket),
        })
    }

    pub fn close(self) {}

    pub fn receive(&mut self) -> io::Result<Option<Message>> {
        RawMessage::receive(&mut self.reader)
            .and_then(|opt|
                opt.map(|rawmsg| interpret_message(rawmsg)).transpose()
            )
    }

    pub fn reply(&mut self, msgid: u64, rmsg: Result<Reply, String>) -> io::Result<()> {
        let socket = self.reader.get_mut();
        let mut payload: Vec<Vec<u8>> = Vec::new();

        let response_type;

        match rmsg {
            Ok(msg) => {
                response_type = 1;
                match msg {
                    Reply::Version(ok) => {
                        payload.push(vec![if ok { 1u8 } else { 0u8 }]);
                    }

                    Reply::NewCore => {}

                    Reply::Job(jobid, output) => {
                        payload.push(jobid.to_le_bytes().to_vec());
                        payload.push(output);
                    }
                }
            }

            Err(err) => {
                response_type = 0xff;
                payload.push(err.into_bytes());
            }
        }

        send_vectored(socket, response_type, msgid, &payload)
    }
}