summaryrefslogtreecommitdiff
path: root/icmpd.c
diff options
context:
space:
mode:
Diffstat (limited to 'icmpd.c')
-rw-r--r--icmpd.c502
1 files changed, 502 insertions, 0 deletions
diff --git a/icmpd.c b/icmpd.c
new file mode 100644
index 0000000..e3127c7
--- /dev/null
+++ b/icmpd.c
@@ -0,0 +1,502 @@
+#include <stdio.h>
+#include <stdlib.h>
+#include <stdbool.h>
+#include <string.h>
+#include <fcntl.h>
+#include <unistd.h>
+#include <errno.h>
+#include <assert.h>
+#include <sys/select.h>
+#include "icmpd.h"
+#include "util.h"
+#include "icmp_server.h"
+#include "icmp_client.h"
+
+
+struct icmpd {
+ struct mt_mutex mut; // for this struct
+
+ bool isserver; // constant
+
+ // IPv4 address of other party
+ uint32_t other_addr; // constant
+
+ bool outstanding; // server: whether an echo request hasn't been matched yet
+
+ // id for messages, -1 if not set
+ int id; // constant
+
+ // Client: seqnum can change before every send
+ // Server: seqnum is the sequence number of the last-received ECHO
+ // value is -1 if not set
+ int seqnum;
+};
+
+// Global state
+bool thread_running = false;
+struct mt_mutex thread_mutex; // protecting these global values
+struct mt_thread thread;
+int thread_in, thread_out;
+int host_in, host_out;
+
+
+// Arguments to messages are listed in comments.
+enum {
+ MSG_THREAD_UP, // -
+ MSG_NEWCONN, // struct icmpd*
+ MSG_ENDCONN, // struct icmpd*; free() done by thread
+ MSG_PEEK, // struct icmpd*
+ MSG_PEEK_ANS, // uint8_t (bool)
+ MSG_RECV, // struct icmpd*
+ MSG_RECV_ANS, // struct icmpd_received
+ MSG_SEND, // struct msg_form_send
+};
+
+struct msg_header {
+ int type;
+ size_t size; // size in bytes of argument
+};
+
+struct msg_in {
+ struct msg_header hdr;
+ void *data; // should be free()'d
+};
+
+struct msg_form_send {
+ struct icmpd *d;
+ void *data; // to be free()'d by thread when sent
+ size_t length;
+};
+
+
+static struct icmpd* find_conn(int id, uint32_t addr, struct icmpd **conns, size_t conns_len) {
+ struct icmpd *ret = NULL;
+ for (size_t i = 0; i < conns_len; i++) {
+ struct icmpd *d = conns[i];
+ if ((d->id == -1 || d->id == id) &&
+ (d->other_addr == 0 || d->other_addr == addr)) {
+ if (ret == NULL || (ret->other_addr == 0 && d->other_addr != 0)) {
+ ret = d;
+ } else {
+ fprintf(stderr, "icmpd thread: warning: multiple connections match id=%d addr=%x\n",
+ id, addr);
+ break;
+ }
+ }
+ }
+ return ret;
+}
+
+static void send_message(int sock, int type, const void *data, size_t size) {
+ struct msg_header head = {type, size};
+
+ assert(writeall(sock, &head, sizeof head) == sizeof head);
+ assert(writeall(sock, data, size) == (ssize_t)size);
+}
+
+// On error, {.type=-1}, errno is set
+static struct msg_in recv_message(int sock) {
+ struct msg_in msg;
+
+ int ret = readall(sock, &msg.hdr, sizeof msg.hdr);
+ if (ret < 0) return (struct msg_in){.hdr.type = -1, .hdr.size = 0, .data = NULL};
+ assert(ret == sizeof msg.hdr);
+
+ msg.data = malloc(msg.hdr.size);
+ assert(readall(sock, msg.data, msg.hdr.size) == (ssize_t)msg.hdr.size);
+
+ return msg;
+}
+
+static void client_increment_seqnum(struct icmpd *d) {
+ d->seqnum = (d->seqnum + 1) & 0xffff;
+}
+
+static void* thread_entry(void *arg) {
+ (void)arg;
+
+ send_message(thread_out, MSG_THREAD_UP, NULL, 0);
+
+ struct recvqu_item {
+ struct icmpd *d;
+ struct icmpd_received msg;
+ };
+
+ struct sendqu_item {
+ struct icmpd *d;
+ void *data;
+ size_t length;
+ };
+
+ size_t conns_cap = 8, conns_len = 0;
+ struct icmpd **conns = malloc(conns_cap * sizeof(struct icmpd*));
+
+ int sock_server = -1, sock_client = -1;
+
+ size_t recvqu_cap = 8, recvqu_len = 0;
+ struct recvqu_item *recvqu = malloc(recvqu_cap * sizeof(struct recvqu_item));
+
+ // Server messages submitted for transmission while no non-ponged client
+ // ping is available.
+ size_t sendqu_cap = 8, sendqu_len = 0;
+ struct sendqu_item *sendqu = malloc(sendqu_cap * sizeof(struct sendqu_item));
+
+ while (true) {
+ fd_set inset;
+ FD_ZERO(&inset);
+ FD_SET(thread_in, &inset);
+ if (sock_server >= 0) FD_SET(sock_server, &inset);
+ if (sock_client >= 0) FD_SET(sock_client, &inset);
+
+ int nfds = thread_in;
+ if (sock_server > nfds) nfds = sock_server;
+ if (sock_client > nfds) nfds = sock_client;
+
+ int ret = select(nfds + 1, &inset, NULL, NULL, NULL);
+ if (ret < 0) {
+ if (errno == EINTR) continue;
+ perror("select");
+ assert(false);
+ }
+ assert(ret > 0);
+
+ if (FD_ISSET(thread_in, &inset)) {
+ struct msg_in msg = recv_message(thread_in);
+ if (msg.hdr.type == -1) {
+ if (errno == EAGAIN || errno == EWOULDBLOCK) continue;
+ perror("recv_message");
+ assert(false);
+ }
+
+ switch (msg.hdr.type) {
+ case MSG_NEWCONN: {
+ struct icmpd *d = *(struct icmpd**)msg.data;
+ free(msg.data);
+ if (conns_len == conns_cap) {
+ conns_cap *= 2;
+ conns = realloc(conns, conns_cap * sizeof(struct icmpd*));
+ }
+ conns[conns_len++] = d;
+
+ if (d->isserver) {
+ if (sock_server < 0) {
+ sock_server = icmp_server_open_socket();
+ if (sock_server < 0) {
+ perror("icmp_server_open_socket");
+ assert(false);
+ }
+ }
+ } else {
+ if (sock_client < 0) {
+ sock_client = icmp_client_open_socket();
+ if (sock_client < 0) {
+ perror("icmp_client_open_socket");
+ assert(false);
+ }
+ }
+ }
+ break;
+ }
+
+ case MSG_ENDCONN: {
+ struct icmpd *d = *(struct icmpd**)msg.data;
+ free(msg.data);
+
+ bool found = false;
+ for (size_t i = 0; i < conns_len; i++) {
+ if (conns[i] == d) {
+ memmove(conns + i, conns + i + 1, conns_len - i - 1);
+ conns_len--;
+ found = true;
+ break;
+ }
+ }
+ assert(found);
+
+ mt_mutex_destroy(&d->mut);
+ free(d);
+ break;
+ }
+
+ case MSG_PEEK: {
+ struct icmpd *d = *(struct icmpd**)msg.data;
+ free(msg.data);
+
+ uint8_t ret = 0;
+ for (size_t i = 0; i < recvqu_len; i++) {
+ if (recvqu[i].d == d) {
+ ret = 1;
+ break;
+ }
+ }
+
+ send_message(thread_out, MSG_PEEK_ANS, &ret, 1);
+ break;
+ }
+
+ case MSG_RECV: {
+ struct icmpd *d = *(struct icmpd**)msg.data;
+ free(msg.data);
+
+ ssize_t index = -1;
+ for (size_t i = 0; i < recvqu_len; i++) {
+ if (recvqu[i].d == d) {
+ index = i;
+ break;
+ }
+ }
+
+ assert(index != -1); // TODO: make host wait for something
+
+ send_message(thread_out, MSG_RECV_ANS,
+ &recvqu[index].msg, sizeof recvqu[index].msg);
+
+ memcpy(recvqu + index, recvqu + index + 1,
+ (recvqu_len - index - 1) * sizeof recvqu[0]);
+
+ recvqu_len--;
+ break;
+ }
+
+ case MSG_SEND: {
+ struct msg_form_send form = *(struct msg_form_send*)msg.data;
+ free(msg.data);
+
+ struct icmpd *d = form.d;
+
+ if (d->isserver) {
+ mt_mutex_lock(&d->mut);
+ bool outstanding = d->outstanding;
+ int id = d->id, seqnum = d->seqnum;
+ mt_mutex_unlock(&d->mut);
+
+ if (outstanding) {
+ int ret = icmp_server_send_reply(
+ sock_server, d->other_addr, id, seqnum,
+ form.data, form.length);
+ if (ret < 0) {
+ perror("icmp_server_send_reply");
+ }
+ } else {
+ if (sendqu_len == sendqu_cap) {
+ sendqu_cap *= 2;
+ sendqu = realloc(sendqu, sendqu_cap * sizeof sendqu[0]);
+ }
+
+ sendqu[sendqu_len].d = d;
+ sendqu[sendqu_len].data = form.data;
+ sendqu[sendqu_len].length = form.length;
+ sendqu_len++;
+ }
+ } else {
+ mt_mutex_lock(&d->mut);
+ client_increment_seqnum(d);
+ int seqnum = d->seqnum;
+ mt_mutex_unlock(&d->mut);
+
+ int ret = icmp_client_send(
+ sock_client, d->other_addr, seqnum,
+ form.data, form.length);
+ if (ret < 0) {
+ perror("icmp_client_send");
+ }
+ }
+
+ break;
+ }
+ }
+ }
+
+ if (FD_ISSET(sock_server, &inset)) {
+ struct icmp_incoming msg = icmp_server_receive(sock_server);
+ struct icmpd *d = find_conn(msg.id, msg.source_addr, conns, conns_len);
+
+ if (d == NULL) {
+ fprintf(stderr, "icmpd thread: ping received with unknown id=%d addr=%x\n",
+ msg.id, msg.source_addr);
+ int ret = icmp_server_send_reply(
+ sock_server, msg.source_addr, msg.id, msg.seqnum,
+ msg.data, msg.length);
+ if (ret < 0) {
+ perror("icmpd thread: unknown ping reply: icmp_server_send_reply");
+ }
+ } else {
+ if (msg.length == 0) {
+ fprintf(stderr, "server recv: empty\n");
+ } else {
+ if (recvqu_len == recvqu_cap) {
+ recvqu_cap *= 2;
+ recvqu = realloc(recvqu, recvqu_cap * sizeof(struct recvqu_item));
+ }
+
+ recvqu[recvqu_len].d = d;
+ recvqu[recvqu_len].msg.data = malloc(msg.length);
+ memcpy(recvqu[recvqu_len].msg.data, msg.data, msg.length);
+ recvqu[recvqu_len].msg.length = msg.length;
+ recvqu[recvqu_len].msg.source_addr = msg.source_addr;
+ recvqu[recvqu_len].msg.id = msg.id;
+ recvqu[recvqu_len].msg.seqnum = msg.seqnum;
+ recvqu_len++;
+
+ fprintf(stderr, "server recv: recvqu_len = %zu\n", recvqu_len);
+ }
+
+ ssize_t index = -1;
+ for (size_t i = 0; i < sendqu_len; i++) {
+ if (sendqu[i].d == d) {
+ index = i;
+ break;
+ }
+ }
+
+ if (index != -1) {
+ int ret = icmp_server_send_reply(
+ sock_server, d->other_addr, d->id, d->seqnum,
+ sendqu[index].data, sendqu[index].length);
+ if (ret < 0) {
+ perror("icmp_server_send_reply");
+ }
+
+ memmove(sendqu + index, sendqu + index + 1,
+ (sendqu_len - index - 1) * sizeof sendqu[0]);
+ sendqu_len--;
+ } else {
+ mt_mutex_lock(&d->mut);
+ d->outstanding = true;
+ d->seqnum = msg.seqnum;
+ mt_mutex_unlock(&d->mut);
+ }
+ }
+ }
+
+ if (FD_ISSET(sock_client, &inset)) {
+ struct icmp_incoming msg = icmp_client_receive(sock_client);
+ struct icmpd *d = find_conn(msg.id, msg.source_addr, conns, conns_len);
+
+ if (d == NULL) {
+ fprintf(stderr, "icmpd thread: pong received with unknown id=%d addr=%x\n",
+ msg.id, msg.source_addr);
+ } else {
+ if (recvqu_len == recvqu_cap) {
+ recvqu_cap *= 2;
+ recvqu = realloc(recvqu, recvqu_cap * sizeof(struct recvqu_item));
+ }
+
+ recvqu[recvqu_len].d = d;
+ recvqu[recvqu_len].msg.data = malloc(msg.length);
+ memcpy(recvqu[recvqu_len].msg.data, msg.data, msg.length);
+ recvqu[recvqu_len].msg.length = msg.length;
+ recvqu[recvqu_len].msg.source_addr = msg.source_addr;
+ recvqu[recvqu_len].msg.id = msg.id;
+ recvqu[recvqu_len].msg.seqnum = msg.seqnum;
+ recvqu_len++;
+
+ fprintf(stderr, "client recv: recvqu_len = %zu\n", recvqu_len);
+
+ mt_mutex_lock(&d->mut);
+ client_increment_seqnum(d);
+ int seqnum = d->seqnum;
+ mt_mutex_unlock(&d->mut);
+
+ int ret = icmp_client_send(sock_client, d->other_addr, seqnum, NULL, 0);
+ if (ret < 0) {
+ perror("icmp_client_send");
+ }
+ }
+ }
+ }
+
+ return NULL;
+}
+
+static void spawn_icmpd_thread(void) {
+ int pp[2];
+
+ assert(pipe(pp) == 0);
+ thread_out = pp[1];
+ host_in = pp[0];
+
+ assert(pipe(pp) == 0);
+ host_out = pp[1];
+ thread_in = pp[0];
+
+ mt_mutex_init(&thread_mutex);
+
+ mt_thread_create(&thread, thread_entry, NULL);
+
+ struct msg_in msg = recv_message(host_in);
+ assert(msg.hdr.type == MSG_THREAD_UP);
+ free(msg.data);
+}
+
+static struct icmpd* icmpd_create_base(int id, bool isserver, uint32_t other_addr) {
+ if (!thread_running) {
+ spawn_icmpd_thread();
+ thread_running = true;
+ }
+
+ struct icmpd *d = malloc(sizeof(struct icmpd));
+ mt_mutex_init(&d->mut);
+ d->id = id;
+ d->seqnum = -1;
+ d->isserver = isserver;
+ d->other_addr = other_addr;
+ d->outstanding = false;
+
+ send_message(host_out, MSG_NEWCONN, &d, sizeof d);
+
+ return d;
+}
+
+struct icmpd* icmpd_create_server(int id, uint32_t client_addr) {
+ return icmpd_create_base(id, true, client_addr);
+}
+
+struct icmpd* icmpd_create_client(uint32_t server_addr) {
+ assert(server_addr != 0);
+ return icmpd_create_base(-1, false, server_addr);
+}
+
+void icmpd_destroy(struct icmpd *d) {
+ send_message(host_out, MSG_ENDCONN, &d, sizeof d);
+ // free() is done by thread
+}
+
+void icmpd_server_set_outstanding(struct icmpd *d, int seqnum) {
+ assert(d->isserver);
+ mt_mutex_lock(&d->mut);
+ d->outstanding = true;
+ d->seqnum = seqnum;
+ mt_mutex_unlock(&d->mut);
+}
+
+bool icmpd_peek(struct icmpd *d) {
+ send_message(host_out, MSG_PEEK, &d, sizeof d);
+ struct msg_in msg = recv_message(host_in);
+ assert(msg.hdr.type == MSG_PEEK_ANS);
+
+ bool ret = ((uint8_t*)msg.data)[0];
+ free(msg.data);
+ return ret;
+}
+
+struct icmpd_received icmpd_recv(struct icmpd *d) {
+ send_message(host_out, MSG_RECV, &d, sizeof d);
+ struct msg_in msg = recv_message(host_in);
+ assert(msg.hdr.type == MSG_RECV_ANS);
+
+ struct icmpd_received r;
+ assert(msg.hdr.size == sizeof r);
+ memcpy(&r, msg.data, sizeof r);
+ return r;
+}
+
+void icmpd_send(struct icmpd *d, const void *data, size_t length) {
+ struct msg_form_send form;
+ form.d = d;
+ form.data = malloc(length);
+ memcpy(form.data, data, length);
+ form.length = length;
+
+ send_message(host_out, MSG_SEND, &form, sizeof form);
+}