diff options
Diffstat (limited to 'icmpd.c')
-rw-r--r-- | icmpd.c | 502 |
1 files changed, 502 insertions, 0 deletions
@@ -0,0 +1,502 @@ +#include <stdio.h> +#include <stdlib.h> +#include <stdbool.h> +#include <string.h> +#include <fcntl.h> +#include <unistd.h> +#include <errno.h> +#include <assert.h> +#include <sys/select.h> +#include "icmpd.h" +#include "util.h" +#include "icmp_server.h" +#include "icmp_client.h" + + +struct icmpd { + struct mt_mutex mut; // for this struct + + bool isserver; // constant + + // IPv4 address of other party + uint32_t other_addr; // constant + + bool outstanding; // server: whether an echo request hasn't been matched yet + + // id for messages, -1 if not set + int id; // constant + + // Client: seqnum can change before every send + // Server: seqnum is the sequence number of the last-received ECHO + // value is -1 if not set + int seqnum; +}; + +// Global state +bool thread_running = false; +struct mt_mutex thread_mutex; // protecting these global values +struct mt_thread thread; +int thread_in, thread_out; +int host_in, host_out; + + +// Arguments to messages are listed in comments. +enum { + MSG_THREAD_UP, // - + MSG_NEWCONN, // struct icmpd* + MSG_ENDCONN, // struct icmpd*; free() done by thread + MSG_PEEK, // struct icmpd* + MSG_PEEK_ANS, // uint8_t (bool) + MSG_RECV, // struct icmpd* + MSG_RECV_ANS, // struct icmpd_received + MSG_SEND, // struct msg_form_send +}; + +struct msg_header { + int type; + size_t size; // size in bytes of argument +}; + +struct msg_in { + struct msg_header hdr; + void *data; // should be free()'d +}; + +struct msg_form_send { + struct icmpd *d; + void *data; // to be free()'d by thread when sent + size_t length; +}; + + +static struct icmpd* find_conn(int id, uint32_t addr, struct icmpd **conns, size_t conns_len) { + struct icmpd *ret = NULL; + for (size_t i = 0; i < conns_len; i++) { + struct icmpd *d = conns[i]; + if ((d->id == -1 || d->id == id) && + (d->other_addr == 0 || d->other_addr == addr)) { + if (ret == NULL || (ret->other_addr == 0 && d->other_addr != 0)) { + ret = d; + } else { + fprintf(stderr, "icmpd thread: warning: multiple connections match id=%d addr=%x\n", + id, addr); + break; + } + } + } + return ret; +} + +static void send_message(int sock, int type, const void *data, size_t size) { + struct msg_header head = {type, size}; + + assert(writeall(sock, &head, sizeof head) == sizeof head); + assert(writeall(sock, data, size) == (ssize_t)size); +} + +// On error, {.type=-1}, errno is set +static struct msg_in recv_message(int sock) { + struct msg_in msg; + + int ret = readall(sock, &msg.hdr, sizeof msg.hdr); + if (ret < 0) return (struct msg_in){.hdr.type = -1, .hdr.size = 0, .data = NULL}; + assert(ret == sizeof msg.hdr); + + msg.data = malloc(msg.hdr.size); + assert(readall(sock, msg.data, msg.hdr.size) == (ssize_t)msg.hdr.size); + + return msg; +} + +static void client_increment_seqnum(struct icmpd *d) { + d->seqnum = (d->seqnum + 1) & 0xffff; +} + +static void* thread_entry(void *arg) { + (void)arg; + + send_message(thread_out, MSG_THREAD_UP, NULL, 0); + + struct recvqu_item { + struct icmpd *d; + struct icmpd_received msg; + }; + + struct sendqu_item { + struct icmpd *d; + void *data; + size_t length; + }; + + size_t conns_cap = 8, conns_len = 0; + struct icmpd **conns = malloc(conns_cap * sizeof(struct icmpd*)); + + int sock_server = -1, sock_client = -1; + + size_t recvqu_cap = 8, recvqu_len = 0; + struct recvqu_item *recvqu = malloc(recvqu_cap * sizeof(struct recvqu_item)); + + // Server messages submitted for transmission while no non-ponged client + // ping is available. + size_t sendqu_cap = 8, sendqu_len = 0; + struct sendqu_item *sendqu = malloc(sendqu_cap * sizeof(struct sendqu_item)); + + while (true) { + fd_set inset; + FD_ZERO(&inset); + FD_SET(thread_in, &inset); + if (sock_server >= 0) FD_SET(sock_server, &inset); + if (sock_client >= 0) FD_SET(sock_client, &inset); + + int nfds = thread_in; + if (sock_server > nfds) nfds = sock_server; + if (sock_client > nfds) nfds = sock_client; + + int ret = select(nfds + 1, &inset, NULL, NULL, NULL); + if (ret < 0) { + if (errno == EINTR) continue; + perror("select"); + assert(false); + } + assert(ret > 0); + + if (FD_ISSET(thread_in, &inset)) { + struct msg_in msg = recv_message(thread_in); + if (msg.hdr.type == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) continue; + perror("recv_message"); + assert(false); + } + + switch (msg.hdr.type) { + case MSG_NEWCONN: { + struct icmpd *d = *(struct icmpd**)msg.data; + free(msg.data); + if (conns_len == conns_cap) { + conns_cap *= 2; + conns = realloc(conns, conns_cap * sizeof(struct icmpd*)); + } + conns[conns_len++] = d; + + if (d->isserver) { + if (sock_server < 0) { + sock_server = icmp_server_open_socket(); + if (sock_server < 0) { + perror("icmp_server_open_socket"); + assert(false); + } + } + } else { + if (sock_client < 0) { + sock_client = icmp_client_open_socket(); + if (sock_client < 0) { + perror("icmp_client_open_socket"); + assert(false); + } + } + } + break; + } + + case MSG_ENDCONN: { + struct icmpd *d = *(struct icmpd**)msg.data; + free(msg.data); + + bool found = false; + for (size_t i = 0; i < conns_len; i++) { + if (conns[i] == d) { + memmove(conns + i, conns + i + 1, conns_len - i - 1); + conns_len--; + found = true; + break; + } + } + assert(found); + + mt_mutex_destroy(&d->mut); + free(d); + break; + } + + case MSG_PEEK: { + struct icmpd *d = *(struct icmpd**)msg.data; + free(msg.data); + + uint8_t ret = 0; + for (size_t i = 0; i < recvqu_len; i++) { + if (recvqu[i].d == d) { + ret = 1; + break; + } + } + + send_message(thread_out, MSG_PEEK_ANS, &ret, 1); + break; + } + + case MSG_RECV: { + struct icmpd *d = *(struct icmpd**)msg.data; + free(msg.data); + + ssize_t index = -1; + for (size_t i = 0; i < recvqu_len; i++) { + if (recvqu[i].d == d) { + index = i; + break; + } + } + + assert(index != -1); // TODO: make host wait for something + + send_message(thread_out, MSG_RECV_ANS, + &recvqu[index].msg, sizeof recvqu[index].msg); + + memcpy(recvqu + index, recvqu + index + 1, + (recvqu_len - index - 1) * sizeof recvqu[0]); + + recvqu_len--; + break; + } + + case MSG_SEND: { + struct msg_form_send form = *(struct msg_form_send*)msg.data; + free(msg.data); + + struct icmpd *d = form.d; + + if (d->isserver) { + mt_mutex_lock(&d->mut); + bool outstanding = d->outstanding; + int id = d->id, seqnum = d->seqnum; + mt_mutex_unlock(&d->mut); + + if (outstanding) { + int ret = icmp_server_send_reply( + sock_server, d->other_addr, id, seqnum, + form.data, form.length); + if (ret < 0) { + perror("icmp_server_send_reply"); + } + } else { + if (sendqu_len == sendqu_cap) { + sendqu_cap *= 2; + sendqu = realloc(sendqu, sendqu_cap * sizeof sendqu[0]); + } + + sendqu[sendqu_len].d = d; + sendqu[sendqu_len].data = form.data; + sendqu[sendqu_len].length = form.length; + sendqu_len++; + } + } else { + mt_mutex_lock(&d->mut); + client_increment_seqnum(d); + int seqnum = d->seqnum; + mt_mutex_unlock(&d->mut); + + int ret = icmp_client_send( + sock_client, d->other_addr, seqnum, + form.data, form.length); + if (ret < 0) { + perror("icmp_client_send"); + } + } + + break; + } + } + } + + if (FD_ISSET(sock_server, &inset)) { + struct icmp_incoming msg = icmp_server_receive(sock_server); + struct icmpd *d = find_conn(msg.id, msg.source_addr, conns, conns_len); + + if (d == NULL) { + fprintf(stderr, "icmpd thread: ping received with unknown id=%d addr=%x\n", + msg.id, msg.source_addr); + int ret = icmp_server_send_reply( + sock_server, msg.source_addr, msg.id, msg.seqnum, + msg.data, msg.length); + if (ret < 0) { + perror("icmpd thread: unknown ping reply: icmp_server_send_reply"); + } + } else { + if (msg.length == 0) { + fprintf(stderr, "server recv: empty\n"); + } else { + if (recvqu_len == recvqu_cap) { + recvqu_cap *= 2; + recvqu = realloc(recvqu, recvqu_cap * sizeof(struct recvqu_item)); + } + + recvqu[recvqu_len].d = d; + recvqu[recvqu_len].msg.data = malloc(msg.length); + memcpy(recvqu[recvqu_len].msg.data, msg.data, msg.length); + recvqu[recvqu_len].msg.length = msg.length; + recvqu[recvqu_len].msg.source_addr = msg.source_addr; + recvqu[recvqu_len].msg.id = msg.id; + recvqu[recvqu_len].msg.seqnum = msg.seqnum; + recvqu_len++; + + fprintf(stderr, "server recv: recvqu_len = %zu\n", recvqu_len); + } + + ssize_t index = -1; + for (size_t i = 0; i < sendqu_len; i++) { + if (sendqu[i].d == d) { + index = i; + break; + } + } + + if (index != -1) { + int ret = icmp_server_send_reply( + sock_server, d->other_addr, d->id, d->seqnum, + sendqu[index].data, sendqu[index].length); + if (ret < 0) { + perror("icmp_server_send_reply"); + } + + memmove(sendqu + index, sendqu + index + 1, + (sendqu_len - index - 1) * sizeof sendqu[0]); + sendqu_len--; + } else { + mt_mutex_lock(&d->mut); + d->outstanding = true; + d->seqnum = msg.seqnum; + mt_mutex_unlock(&d->mut); + } + } + } + + if (FD_ISSET(sock_client, &inset)) { + struct icmp_incoming msg = icmp_client_receive(sock_client); + struct icmpd *d = find_conn(msg.id, msg.source_addr, conns, conns_len); + + if (d == NULL) { + fprintf(stderr, "icmpd thread: pong received with unknown id=%d addr=%x\n", + msg.id, msg.source_addr); + } else { + if (recvqu_len == recvqu_cap) { + recvqu_cap *= 2; + recvqu = realloc(recvqu, recvqu_cap * sizeof(struct recvqu_item)); + } + + recvqu[recvqu_len].d = d; + recvqu[recvqu_len].msg.data = malloc(msg.length); + memcpy(recvqu[recvqu_len].msg.data, msg.data, msg.length); + recvqu[recvqu_len].msg.length = msg.length; + recvqu[recvqu_len].msg.source_addr = msg.source_addr; + recvqu[recvqu_len].msg.id = msg.id; + recvqu[recvqu_len].msg.seqnum = msg.seqnum; + recvqu_len++; + + fprintf(stderr, "client recv: recvqu_len = %zu\n", recvqu_len); + + mt_mutex_lock(&d->mut); + client_increment_seqnum(d); + int seqnum = d->seqnum; + mt_mutex_unlock(&d->mut); + + int ret = icmp_client_send(sock_client, d->other_addr, seqnum, NULL, 0); + if (ret < 0) { + perror("icmp_client_send"); + } + } + } + } + + return NULL; +} + +static void spawn_icmpd_thread(void) { + int pp[2]; + + assert(pipe(pp) == 0); + thread_out = pp[1]; + host_in = pp[0]; + + assert(pipe(pp) == 0); + host_out = pp[1]; + thread_in = pp[0]; + + mt_mutex_init(&thread_mutex); + + mt_thread_create(&thread, thread_entry, NULL); + + struct msg_in msg = recv_message(host_in); + assert(msg.hdr.type == MSG_THREAD_UP); + free(msg.data); +} + +static struct icmpd* icmpd_create_base(int id, bool isserver, uint32_t other_addr) { + if (!thread_running) { + spawn_icmpd_thread(); + thread_running = true; + } + + struct icmpd *d = malloc(sizeof(struct icmpd)); + mt_mutex_init(&d->mut); + d->id = id; + d->seqnum = -1; + d->isserver = isserver; + d->other_addr = other_addr; + d->outstanding = false; + + send_message(host_out, MSG_NEWCONN, &d, sizeof d); + + return d; +} + +struct icmpd* icmpd_create_server(int id, uint32_t client_addr) { + return icmpd_create_base(id, true, client_addr); +} + +struct icmpd* icmpd_create_client(uint32_t server_addr) { + assert(server_addr != 0); + return icmpd_create_base(-1, false, server_addr); +} + +void icmpd_destroy(struct icmpd *d) { + send_message(host_out, MSG_ENDCONN, &d, sizeof d); + // free() is done by thread +} + +void icmpd_server_set_outstanding(struct icmpd *d, int seqnum) { + assert(d->isserver); + mt_mutex_lock(&d->mut); + d->outstanding = true; + d->seqnum = seqnum; + mt_mutex_unlock(&d->mut); +} + +bool icmpd_peek(struct icmpd *d) { + send_message(host_out, MSG_PEEK, &d, sizeof d); + struct msg_in msg = recv_message(host_in); + assert(msg.hdr.type == MSG_PEEK_ANS); + + bool ret = ((uint8_t*)msg.data)[0]; + free(msg.data); + return ret; +} + +struct icmpd_received icmpd_recv(struct icmpd *d) { + send_message(host_out, MSG_RECV, &d, sizeof d); + struct msg_in msg = recv_message(host_in); + assert(msg.hdr.type == MSG_RECV_ANS); + + struct icmpd_received r; + assert(msg.hdr.size == sizeof r); + memcpy(&r, msg.data, sizeof r); + return r; +} + +void icmpd_send(struct icmpd *d, const void *data, size_t length) { + struct msg_form_send form; + form.d = d; + form.data = malloc(length); + memcpy(form.data, data, length); + form.length = length; + + send_message(host_out, MSG_SEND, &form, sizeof form); +} |