#include #include #include #include #include #include #include #include #include #include #include "icmpd.h" #include "util.h" #include "icmp_server.h" #include "icmp_client.h" #include "mt.h" #define KEEPALIVE_DELAY 5000 // maximum time to send nothing (milliseconds) struct icmpd { // CONSTANTS bool isserver; uint32_t other_addr; // IPv4 address of other party int id; // id for messages, -1 if not set int signal_out, signal_in; // 1 byte can be read from signal_out for every message received // VARIABLES struct mt_mutex mut; // for the variables in this struct bool outstanding; // server: whether an echo request hasn't been matched yet // client: seqnum can change before every send // server: seqnum is the sequence number of the last-received ECHO // value is -1 if not set int seqnum; // client: timestamp of when last message was sent // value is -1 if no message sent yet int64_t last_send_stamp; }; // Global state static bool thread_running = false; static struct mt_thread thread; static int thread_in, thread_out; static int host_in, host_out; // Arguments to messages are listed in comments. enum { MSG_THREAD_UP, // - MSG_NEWCONN, // struct icmpd* MSG_ENDCONN, // struct icmpd*; free() done by thread MSG_PEEK, // struct icmpd* MSG_PEEK_ANS, // uint8_t (bool) MSG_RECV, // struct icmpd* MSG_RECV_ANS, // struct icmpd_received MSG_SEND, // struct msg_form_send }; struct msg_header { int type; size_t size; // size in bytes of argument }; struct msg_in { struct msg_header hdr; void *data; // should be free()'d }; struct msg_form_send { struct icmpd *d; void *data; // to be free()'d by thread when sent size_t length; }; static struct icmpd* find_conn(int id, uint32_t addr, struct icmpd **conns, size_t conns_len) { struct icmpd *ret = NULL; for (size_t i = 0; i < conns_len; i++) { struct icmpd *d = conns[i]; if ((d->id == -1 || d->id == id) && (d->other_addr == 0 || d->other_addr == addr)) { if (ret == NULL || (ret->other_addr == 0 && d->other_addr != 0)) { ret = d; } else { fprintf(stderr, "icmpd thread: warning: multiple connections match id=%d addr=%x\n", id, addr); break; } } } return ret; } static void send_message(int sock, int type, const void *data, size_t size) { struct msg_header head = {type, size}; assert(writeall(sock, &head, sizeof head) == sizeof head); assert(writeall(sock, data, size) == (ssize_t)size); } // On error, {.type=-1}, errno is set static struct msg_in recv_message(int sock) { struct msg_in msg; int ret = readall(sock, &msg.hdr, sizeof msg.hdr); if (ret < 0) return (struct msg_in){.hdr.type = -1, .hdr.size = 0, .data = NULL}; assert(ret == sizeof msg.hdr); msg.data = malloc(msg.hdr.size); assert(readall(sock, msg.data, msg.hdr.size) == (ssize_t)msg.hdr.size); return msg; } static void client_increment_seqnum(struct icmpd *d) { d->seqnum = (d->seqnum + 1) & 0xffff; } static void* thread_entry(void *arg) { (void)arg; send_message(thread_out, MSG_THREAD_UP, NULL, 0); struct recvqu_item { struct icmpd *d; struct icmpd_received msg; }; struct sendqu_item { struct icmpd *d; void *data; size_t length; }; size_t conns_cap = 8, conns_len = 0; struct icmpd **conns = malloc(conns_cap * sizeof(struct icmpd*)); int sock_server = -1, sock_client = -1; size_t recvqu_cap = 8, recvqu_len = 0; struct recvqu_item *recvqu = malloc(recvqu_cap * sizeof(struct recvqu_item)); // Server messages submitted for transmission while no outstanding client // ping is available. size_t sendqu_cap = 8, sendqu_len = 0; struct sendqu_item *sendqu = malloc(sendqu_cap * sizeof(struct sendqu_item)); while (true) { fd_set inset; FD_ZERO(&inset); FD_SET(thread_in, &inset); if (sock_server >= 0) FD_SET(sock_server, &inset); if (sock_client >= 0) FD_SET(sock_client, &inset); int nfds = maxi(maxi(thread_in, sock_server), sock_client) + 1; int64_t now = gettimestamp(); int64_t wait_interval = INT64_MAX; // microseconds for (size_t i = 0; i < conns_len; i++) { if (conns[i]->isserver) continue; mt_mutex_lock(&conns[i]->mut); int64_t stamp = conns[i]->last_send_stamp; mt_mutex_unlock(&conns[i]->mut); if (stamp + KEEPALIVE_DELAY * 1000 - now < wait_interval) { wait_interval = stamp + KEEPALIVE_DELAY * 1000 - now; } } struct timeval timeout_tv; timeout_tv.tv_sec = wait_interval / 1000000; timeout_tv.tv_usec = wait_interval % 1000000; int ret = select(nfds, &inset, NULL, NULL, wait_interval == INT64_MAX ? NULL : &timeout_tv); if (ret < 0) { if (errno == EINTR) continue; perror("select"); assert(false); } if (FD_ISSET(thread_in, &inset)) { struct msg_in msg = recv_message(thread_in); if (msg.hdr.type == -1) { if (errno == EAGAIN || errno == EWOULDBLOCK) continue; perror("recv_message"); assert(false); } switch (msg.hdr.type) { case MSG_NEWCONN: { struct icmpd *d = *(struct icmpd**)msg.data; free(msg.data); if (conns_len == conns_cap) { conns_cap *= 2; conns = realloc(conns, conns_cap * sizeof(struct icmpd*)); } conns[conns_len++] = d; if (d->isserver) { if (sock_server < 0) { sock_server = icmp_server_open_socket(); if (sock_server < 0) { perror("icmp_server_open_socket"); assert(false); } } } else { if (sock_client < 0) { sock_client = icmp_client_open_socket(); if (sock_client < 0) { perror("icmp_client_open_socket"); assert(false); } } } break; } case MSG_ENDCONN: { struct icmpd *d = *(struct icmpd**)msg.data; free(msg.data); bool found = false; for (size_t i = 0; i < conns_len; i++) { if (conns[i] == d) { memmove(conns + i, conns + i + 1, conns_len - i - 1); conns_len--; found = true; break; } } assert(found); mt_mutex_destroy(&d->mut); close(d->signal_out); close(d->signal_in); free(d); break; } case MSG_PEEK: { struct icmpd *d = *(struct icmpd**)msg.data; free(msg.data); uint8_t ret = 0; for (size_t i = 0; i < recvqu_len; i++) { if (recvqu[i].d == d) { ret = 1; break; } } send_message(thread_out, MSG_PEEK_ANS, &ret, 1); break; } case MSG_RECV: { struct icmpd *d = *(struct icmpd**)msg.data; free(msg.data); ssize_t index = -1; for (size_t i = 0; i < recvqu_len; i++) { if (recvqu[i].d == d) { index = i; break; } } assert(index != -1); // TODO: make host wait for something send_message(thread_out, MSG_RECV_ANS, &recvqu[index].msg, sizeof recvqu[index].msg); memcpy(recvqu + index, recvqu + index + 1, (recvqu_len - index - 1) * sizeof recvqu[0]); recvqu_len--; break; } case MSG_SEND: { struct msg_form_send form = *(struct msg_form_send*)msg.data; free(msg.data); struct icmpd *d = form.d; if (d->isserver) { mt_mutex_lock(&d->mut); bool outstanding = d->outstanding; int id = d->id, seqnum = d->seqnum; mt_mutex_unlock(&d->mut); if (outstanding) { int ret = icmp_server_send_reply( sock_server, d->other_addr, id, seqnum, form.data, form.length); if (ret < 0) { perror("icmp_server_send_reply"); } } else { if (sendqu_len == sendqu_cap) { sendqu_cap *= 2; sendqu = realloc(sendqu, sendqu_cap * sizeof sendqu[0]); } sendqu[sendqu_len].d = d; sendqu[sendqu_len].data = form.data; sendqu[sendqu_len].length = form.length; sendqu_len++; } } else { mt_mutex_lock(&d->mut); client_increment_seqnum(d); int seqnum = d->seqnum; d->last_send_stamp = gettimestamp(); mt_mutex_unlock(&d->mut); int ret = icmp_client_send( sock_client, d->other_addr, seqnum, form.data, form.length); if (ret < 0) { perror("icmp_client_send"); } } break; } } } if (FD_ISSET(sock_server, &inset)) { struct icmp_incoming msg = icmp_server_receive(sock_server); struct icmpd *d = find_conn(msg.id, msg.source_addr, conns, conns_len); if (d == NULL) { fprintf(stderr, "icmpd thread: ping received with unknown id=%d addr=%x\n", msg.id, msg.source_addr); int ret = icmp_server_send_reply( sock_server, msg.source_addr, msg.id, msg.seqnum, msg.data, msg.length); if (ret < 0) { perror("icmpd thread: unknown ping reply: icmp_server_send_reply"); } } else { if (msg.length == 0) { fprintf(stderr, "server recv: empty\n"); } else { if (recvqu_len == recvqu_cap) { recvqu_cap *= 2; recvqu = realloc(recvqu, recvqu_cap * sizeof(struct recvqu_item)); } recvqu[recvqu_len].d = d; recvqu[recvqu_len].msg.data = malloc(msg.length); memcpy(recvqu[recvqu_len].msg.data, msg.data, msg.length); recvqu[recvqu_len].msg.length = msg.length; recvqu[recvqu_len].msg.source_addr = msg.source_addr; recvqu[recvqu_len].msg.id = msg.id; recvqu[recvqu_len].msg.seqnum = msg.seqnum; recvqu_len++; char c = 42; assert(writeall(d->signal_in, &c, 1) == 1); fprintf(stderr, "server recv: recvqu_len = %zu\n", recvqu_len); } ssize_t index = -1; for (size_t i = 0; i < sendqu_len; i++) { if (sendqu[i].d == d) { index = i; break; } } if (index != -1) { int ret = icmp_server_send_reply( sock_server, d->other_addr, d->id, d->seqnum, sendqu[index].data, sendqu[index].length); if (ret < 0) { perror("icmp_server_send_reply"); } memmove(sendqu + index, sendqu + index + 1, (sendqu_len - index - 1) * sizeof sendqu[0]); sendqu_len--; } else { mt_mutex_lock(&d->mut); d->outstanding = true; d->seqnum = msg.seqnum; mt_mutex_unlock(&d->mut); } } } if (FD_ISSET(sock_client, &inset)) { struct icmp_incoming msg = icmp_client_receive(sock_client); struct icmpd *d = find_conn(msg.id, msg.source_addr, conns, conns_len); if (d == NULL) { fprintf(stderr, "icmpd thread: pong received with unknown id=%d addr=%x\n", msg.id, msg.source_addr); } else { if (recvqu_len == recvqu_cap) { recvqu_cap *= 2; recvqu = realloc(recvqu, recvqu_cap * sizeof(struct recvqu_item)); } recvqu[recvqu_len].d = d; recvqu[recvqu_len].msg.data = malloc(msg.length); memcpy(recvqu[recvqu_len].msg.data, msg.data, msg.length); recvqu[recvqu_len].msg.length = msg.length; recvqu[recvqu_len].msg.source_addr = msg.source_addr; recvqu[recvqu_len].msg.id = msg.id; recvqu[recvqu_len].msg.seqnum = msg.seqnum; recvqu_len++; char c = 42; assert(writeall(d->signal_in, &c, 1) == 1); fprintf(stderr, "client recv: recvqu_len = %zu\n", recvqu_len); mt_mutex_lock(&d->mut); client_increment_seqnum(d); int seqnum = d->seqnum; d->last_send_stamp = gettimestamp(); mt_mutex_unlock(&d->mut); int ret = icmp_client_send(sock_client, d->other_addr, seqnum, NULL, 0); if (ret < 0) { perror("icmp_client_send"); } } } now = gettimestamp(); for (size_t i = 0; i < conns_len; i++) { if (conns[i]->isserver) continue; mt_mutex_lock(&conns[i]->mut); int64_t stamp = conns[i]->last_send_stamp; mt_mutex_unlock(&conns[i]->mut); if (now - stamp >= KEEPALIVE_DELAY) { mt_mutex_lock(&conns[i]->mut); client_increment_seqnum(conns[i]); int seqnum = conns[i]->seqnum; conns[i]->last_send_stamp = gettimestamp(); mt_mutex_unlock(&conns[i]->mut); int ret = icmp_client_send(sock_client, conns[i]->other_addr, seqnum, NULL, 0); if (ret < 0) { perror("icmp_client_send"); } } } } return NULL; } static void spawn_icmpd_thread(void) { int pp[2]; assert(pipe(pp) == 0); thread_out = pp[1]; host_in = pp[0]; assert(pipe(pp) == 0); host_out = pp[1]; thread_in = pp[0]; mt_thread_create(&thread, thread_entry, NULL); struct msg_in msg = recv_message(host_in); assert(msg.hdr.type == MSG_THREAD_UP); free(msg.data); } static struct icmpd* icmpd_create_base(int id, bool isserver, uint32_t other_addr) { if (!thread_running) { spawn_icmpd_thread(); thread_running = true; } struct icmpd *d = malloc(sizeof(struct icmpd)); mt_mutex_init(&d->mut); d->isserver = isserver; d->other_addr = other_addr; d->id = id; d->seqnum = -1; d->outstanding = false; d->last_send_stamp = gettimestamp(); int pp[2]; assert(pipe(pp) == 0); d->signal_in = pp[1]; d->signal_out = pp[0]; send_message(host_out, MSG_NEWCONN, &d, sizeof d); return d; } struct icmpd* icmpd_create_server(int id, uint32_t client_addr) { return icmpd_create_base(id, true, client_addr); } struct icmpd* icmpd_create_client(uint32_t server_addr) { assert(server_addr != 0); return icmpd_create_base(-1, false, server_addr); } void icmpd_destroy(struct icmpd *d) { send_message(host_out, MSG_ENDCONN, &d, sizeof d); // free() is done by thread } void icmpd_server_set_outstanding(struct icmpd *d, int seqnum) { assert(d->isserver); mt_mutex_lock(&d->mut); d->outstanding = true; d->seqnum = seqnum; mt_mutex_unlock(&d->mut); } bool icmpd_peek(struct icmpd *d) { #if 1 send_message(host_out, MSG_PEEK, &d, sizeof d); struct msg_in msg = recv_message(host_in); assert(msg.hdr.type == MSG_PEEK_ANS); bool ret = ((uint8_t*)msg.data)[0]; free(msg.data); return ret; #else // This version has the same fallacy as just selecting on icmpd_get_select_fd() directly, // i.e. that select(2) can return when there's nothing to read fd_set inset; FD_ZERO(&inset); FD_SET(d->signal_out, &inset); while (true) { struct timeval tv; tv.tv_sec = tv.tv_usec = 0; int ret = select(d->signal_out + 1, &inset, NULL, NULL, &tv); if (ret < 0) { if (errno == EINTR) continue; perror("select"); assert(false); } return ret > 0; } #endif } struct icmpd_received icmpd_recv(struct icmpd *d) { send_message(host_out, MSG_RECV, &d, sizeof d); char c; assert(readall(d->signal_out, &c, 1) == 1); struct msg_in msg = recv_message(host_in); assert(msg.hdr.type == MSG_RECV_ANS); struct icmpd_received r; assert(msg.hdr.size == sizeof r); memcpy(&r, msg.data, sizeof r); return r; } void icmpd_send(struct icmpd *d, const void *data, size_t length) { struct msg_form_send form; form.d = d; form.data = malloc(length); memcpy(form.data, data, length); form.length = length; send_message(host_out, MSG_SEND, &form, sizeof form); } int icmpd_get_select_fd(struct icmpd *d) { return d->signal_out; }