#include #include #include #include #include #include #include #include #include "prot.h" #include "protm.h" #include "icmpd.h" #include "mt.h" #include "util.h" #define SEQ_SIZE 256 #define WIN_SIZE 16 struct prot { // CONSTANTS bool is_accept; uint32_t other_addr; struct icmpd *d; int d_fd; int event_fd_in, event_fd_out; // VARIABLES struct mt_mutex mut; int my_win_start, other_win_start; bool recv_success[WIN_SIZE]; int my_next_seq; bool peer_closed; }; // r/w host until spawn, thereafter constant static bool thread_spawned = false; static struct mt_thread thread; static int host_in, host_out, thread_in, thread_out; // host-only static bool have_accept_channel = false; // Arguments to messages are listed in comments. enum { MSG_THREAD_UP, // - MSG_NEWCONN, // struct prot* MSG_ENDCONN, // struct prot*; cleanup and free() done by thread MSG_ACCEPT, // - MSG_ACCEPT_ANS, // struct prot*; allocated by thread MSG_RECV, // struct prot* MSG_RECV_ANS, // struct prot_msg MSG_SEND, // struct msg_form_send }; struct msg_header { int type; size_t size; // size in bytes of argument }; struct msg_in { struct msg_header hdr; void *data; // should be free()'d }; struct msg_form_send { struct prot *ch; void *data; // to be free()'d by thread when sent size_t length; }; static void send_message(int sock, int type, const void *data, size_t size) { struct msg_header head = {type, size}; assert(writeall(sock, &head, sizeof head) == sizeof head); assert(writeall(sock, data, size) == (ssize_t)size); } // On error, {.type=-1}, errno is set static struct msg_in recv_message(int sock) { struct msg_in msg; int ret = readall(sock, &msg.hdr, sizeof msg.hdr); if (ret < 0) return (struct msg_in){.hdr.type = -1, .hdr.size = 0, .data = NULL}; assert(ret == sizeof msg.hdr); msg.data = malloc(msg.hdr.size); assert(readall(sock, msg.data, msg.hdr.size) == (ssize_t)msg.hdr.size); return msg; } static void populate_prot(struct prot *ch, bool is_accept, struct icmpd *d, uint32_t other_addr) { ch->is_accept = is_accept; ch->other_addr = other_addr; ch->d = d; ch->d_fd = d != NULL ? icmpd_get_select_fd(d) : -1; int pp[2]; assert(pipe(pp) == 0); ch->event_fd_in = pp[1]; ch->event_fd_out = pp[0]; mt_mutex_init(&ch->mut); ch->my_win_start = 0; ch->other_win_start = 0; memset(ch->recv_success, 0, WIN_SIZE * sizeof(bool)); ch->my_next_seq = 0; ch->peer_closed = false; } static void* thread_entry(void *arg_) { (void)arg_; send_message(thread_out, MSG_THREAD_UP, NULL, 0); struct prot *accept_ch = NULL; size_t chans_cap = 8, chans_len = 0; struct prot **chans = malloc(chans_cap * sizeof(struct prot*)); while (true) { fd_set inset; FD_ZERO(&inset); FD_SET(thread_in, &inset); int nfds = thread_in; for (size_t i = 0; i < chans_len; i++) { if (chans[i]->peer_closed) continue; FD_SET(chans[i]->d_fd, &inset); if (chans[i]->d_fd > nfds) nfds = chans[i]->d_fd; } int ret = select(nfds + 1, &inset, NULL, NULL, NULL); if (ret < 0) { if (errno == EINTR) continue; perror("select"); assert(false); } assert(ret > 0); if (FD_ISSET(thread_in, &inset)) { struct msg_in msg = recv_message(thread_in); switch (msg.hdr.type) { case MSG_NEWCONN: { struct prot *ch = *(struct prot**)msg.data; free(msg.data); if (ch->is_accept) { assert(accept_ch == NULL); accept_ch = ch; break; } if (chans_len == chans_cap) { chans_cap *= 2; chans = realloc(chans, chans_cap * sizeof(struct prot*)); } chans[chans_len++] = ch; break; } case MSG_ENDCONN: { struct prot *ch = *(struct prot**)msg.data; free(msg.data); if (ch == accept_ch) { accept_ch = NULL; } else { for (size_t i = 0; i < chans_len; i++) { if (chans[i] == ch) { memmove(chans + i, chans + i + 1, (chans_len - i - 1) * sizeof(struct prot*)); chans_len--; break; } } } icmpd_destroy(ch->d); close(ch->event_fd_in); close(ch->event_fd_out); mt_mutex_destroy(&ch->mut); free(ch); break; } case MSG_ACCEPT: { free(msg.data); assert(accept_ch != NULL); if (!icmpd_peek(accept_ch->d)) { struct prot *ch = NULL; send_message(thread_out, MSG_ACCEPT_ANS, &ch, sizeof ch); break; } struct icmpd_received re = icmpd_recv(accept_ch->d); assert(re.length <= PROTM_MAX_SIZE); struct protm *m = (struct protm*)re.data; uint8_t type = m->type; uint8_t ack = m->estab.ack; free(re.data); if (type != PROTM_TYPE_ESTAB || ack != 0) { struct prot *ch = NULL; send_message(thread_out, MSG_ACCEPT_ANS, &ch, sizeof ch); break; } ssize_t found = -1; for (size_t i = 0; i < chans_len; i++) { if (chans[i]->other_addr == re.source_addr) { found = i; break; } } if (found != -1) { mt_mutex_lock(&chans[found]->mut); chans[found]->peer_closed = true; mt_mutex_unlock(&chans[found]->mut); struct protm pm; pm.type = PROTM_TYPE_TERM; pm.term.ack = 0; icmpd_send(chans[found]->d, &pm, protm_size(&pm)); break; } struct prot *ch = malloc(sizeof(struct prot)); populate_prot(ch, false, icmpd_create_server(re.id, re.source_addr), re.source_addr); icmpd_server_set_outstanding(ch->d, re.seqnum); send_message(thread_out, MSG_ACCEPT_ANS, &ch, sizeof ch); break; } case MSG_RECV: { struct prot ch = *(struct prot*)msg.data; free(msg.data); ; break; } } } } return NULL; } static void spawn_thread(void) { int pp[2]; assert(pipe(pp) == 0); thread_out = pp[1]; host_in = pp[0]; assert(pipe(pp) == 0); host_out = pp[1]; thread_in = pp[0]; mt_thread_create(&thread, thread_entry, NULL); struct msg_in msg = recv_message(host_in); assert(msg.hdr.type == MSG_THREAD_UP); free(msg.data); thread_spawned = true; } int prot_get_select_fd(struct prot *ch) { return ch->event_fd_out; } void prot_terminate(struct prot *ch) { send_message(host_out, MSG_ENDCONN, &ch, sizeof ch); if (ch->is_accept) { have_accept_channel = false; } } struct prot* prot_create_accept_channel() { if (have_accept_channel) return NULL; if (!thread_spawned) spawn_thread(); have_accept_channel = true; struct prot *ach = malloc(sizeof(struct prot)); populate_prot(ach, true, icmpd_create_server(-1, 0), 0); return ach; } struct prot* prot_accept(struct prot *ach) { assert(ach->is_accept); send_message(host_out, MSG_ACCEPT, NULL, 0); struct msg_in msg = recv_message(host_in); assert(msg.hdr.type == MSG_ACCEPT_ANS); struct prot *ch = *(struct prot**)msg.data; free(msg.data); return ch; } struct prot* prot_connect(uint32_t server_addr) { if (!thread_spawned) spawn_thread(); struct prot *ch = malloc(sizeof(struct prot)); populate_prot(ch, false, icmpd_create_client(server_addr), server_addr); return ch; } struct prot_msg prot_recv(struct prot *ch) { mt_mutex_lock(&ch->mut); bool closed = ch->peer_closed; mt_mutex_unlock(&ch->mut); if (closed) { errno = ECONNRESET; return (struct prot_msg){.data = NULL}; } send_message(host_out, MSG_RECV, &ch, sizeof ch); struct msg_in msg = recv_message(host_in); assert(msg.hdr.type == MSG_RECV_ANS); struct prot_msg res; assert(msg.hdr.size == sizeof res); memcpy(&res, msg.data, sizeof res); free(msg.data); if (res.data == NULL) errno = EAGAIN; return res; } int prot_send(struct prot *ch, const void *data, size_t length) { mt_mutex_lock(&ch->mut); bool closed = ch->peer_closed; mt_mutex_unlock(&ch->mut); if (closed) { errno = ECONNRESET; return -1; } struct msg_form_send form; form.ch = ch; form.data = malloc(length); memcpy(form.data, data, length); form.length = length; send_message(host_out, MSG_SEND, &form, sizeof form); return 0; }