summaryrefslogtreecommitdiff
path: root/prot.c
diff options
context:
space:
mode:
Diffstat (limited to 'prot.c')
-rw-r--r--prot.c368
1 files changed, 368 insertions, 0 deletions
diff --git a/prot.c b/prot.c
new file mode 100644
index 0000000..3b580b8
--- /dev/null
+++ b/prot.c
@@ -0,0 +1,368 @@
+#include <stdio.h>
+#include <stdbool.h>
+#include <stdlib.h>
+#include <string.h>
+#include <unistd.h>
+#include <errno.h>
+#include <assert.h>
+#include <sys/select.h>
+#include "prot.h"
+#include "protm.h"
+#include "icmpd.h"
+#include "mt.h"
+#include "util.h"
+
+
+#define SEQ_SIZE 256
+#define WIN_SIZE 16
+
+
+struct prot {
+ // CONSTANTS
+ bool is_accept;
+ uint32_t other_addr;
+ struct icmpd *d;
+ int d_fd;
+ int event_fd_in, event_fd_out;
+
+ // VARIABLES
+ struct mt_mutex mut;
+ int my_win_start, other_win_start;
+ bool recv_success[WIN_SIZE];
+ int my_next_seq;
+ bool peer_closed;
+};
+
+
+// r/w host until spawn, thereafter constant
+static bool thread_spawned = false;
+static struct mt_thread thread;
+static int host_in, host_out, thread_in, thread_out;
+
+// host-only
+static bool have_accept_channel = false;
+
+
+// Arguments to messages are listed in comments.
+enum {
+ MSG_THREAD_UP, // -
+ MSG_NEWCONN, // struct prot*
+ MSG_ENDCONN, // struct prot*; cleanup and free() done by thread
+ MSG_ACCEPT, // -
+ MSG_ACCEPT_ANS, // struct prot*; allocated by thread
+ MSG_RECV, // struct prot*
+ MSG_RECV_ANS, // struct prot_msg
+ MSG_SEND, // struct msg_form_send
+};
+
+struct msg_header {
+ int type;
+ size_t size; // size in bytes of argument
+};
+
+struct msg_in {
+ struct msg_header hdr;
+ void *data; // should be free()'d
+};
+
+struct msg_form_send {
+ struct prot *ch;
+ void *data; // to be free()'d by thread when sent
+ size_t length;
+};
+
+
+static void send_message(int sock, int type, const void *data, size_t size) {
+ struct msg_header head = {type, size};
+
+ assert(writeall(sock, &head, sizeof head) == sizeof head);
+ assert(writeall(sock, data, size) == (ssize_t)size);
+}
+
+// On error, {.type=-1}, errno is set
+static struct msg_in recv_message(int sock) {
+ struct msg_in msg;
+
+ int ret = readall(sock, &msg.hdr, sizeof msg.hdr);
+ if (ret < 0) return (struct msg_in){.hdr.type = -1, .hdr.size = 0, .data = NULL};
+ assert(ret == sizeof msg.hdr);
+
+ msg.data = malloc(msg.hdr.size);
+ assert(readall(sock, msg.data, msg.hdr.size) == (ssize_t)msg.hdr.size);
+
+ return msg;
+}
+
+static void populate_prot(struct prot *ch, bool is_accept, struct icmpd *d, uint32_t other_addr) {
+ ch->is_accept = is_accept;
+ ch->other_addr = other_addr;
+ ch->d = d;
+ ch->d_fd = d != NULL ? icmpd_get_select_fd(d) : -1;
+ int pp[2];
+ assert(pipe(pp) == 0);
+ ch->event_fd_in = pp[1];
+ ch->event_fd_out = pp[0];
+
+ mt_mutex_init(&ch->mut);
+ ch->my_win_start = 0;
+ ch->other_win_start = 0;
+ memset(ch->recv_success, 0, WIN_SIZE * sizeof(bool));
+ ch->my_next_seq = 0;
+ ch->peer_closed = false;
+}
+
+
+static void* thread_entry(void *arg_) {
+ (void)arg_;
+ send_message(thread_out, MSG_THREAD_UP, NULL, 0);
+
+ struct prot *accept_ch = NULL;
+
+ size_t chans_cap = 8, chans_len = 0;
+ struct prot **chans = malloc(chans_cap * sizeof(struct prot*));
+
+ while (true) {
+ fd_set inset;
+ FD_ZERO(&inset);
+ FD_SET(thread_in, &inset);
+ int nfds = thread_in;
+ for (size_t i = 0; i < chans_len; i++) {
+ if (chans[i]->peer_closed) continue;
+ FD_SET(chans[i]->d_fd, &inset);
+ if (chans[i]->d_fd > nfds) nfds = chans[i]->d_fd;
+ }
+
+ int ret = select(nfds + 1, &inset, NULL, NULL, NULL);
+ if (ret < 0) {
+ if (errno == EINTR) continue;
+ perror("select");
+ assert(false);
+ }
+ assert(ret > 0);
+
+ if (FD_ISSET(thread_in, &inset)) {
+ struct msg_in msg = recv_message(thread_in);
+
+ switch (msg.hdr.type) {
+ case MSG_NEWCONN: {
+ struct prot *ch = *(struct prot**)msg.data;
+ free(msg.data);
+
+ if (ch->is_accept) {
+ assert(accept_ch == NULL);
+ accept_ch = ch;
+ break;
+ }
+
+ if (chans_len == chans_cap) {
+ chans_cap *= 2;
+ chans = realloc(chans, chans_cap * sizeof(struct prot*));
+ }
+
+ chans[chans_len++] = ch;
+ break;
+ }
+
+ case MSG_ENDCONN: {
+ struct prot *ch = *(struct prot**)msg.data;
+ free(msg.data);
+
+ if (ch == accept_ch) {
+ accept_ch = NULL;
+ } else {
+ for (size_t i = 0; i < chans_len; i++) {
+ if (chans[i] == ch) {
+ memmove(chans + i, chans + i + 1, (chans_len - i - 1) * sizeof(struct prot*));
+ chans_len--;
+ break;
+ }
+ }
+ }
+
+ icmpd_destroy(ch->d);
+ close(ch->event_fd_in);
+ close(ch->event_fd_out);
+ mt_mutex_destroy(&ch->mut);
+
+ free(ch);
+ break;
+ }
+
+ case MSG_ACCEPT: {
+ free(msg.data);
+ assert(accept_ch != NULL);
+
+ if (!icmpd_peek(accept_ch->d)) {
+ struct prot *ch = NULL;
+ send_message(thread_out, MSG_ACCEPT_ANS, &ch, sizeof ch);
+ break;
+ }
+
+ struct icmpd_received re = icmpd_recv(accept_ch->d);
+ assert(re.length <= PROTM_MAX_SIZE);
+
+ struct protm *m = (struct protm*)re.data;
+ uint8_t type = m->type;
+ uint8_t ack = m->estab.ack;
+
+ free(re.data);
+
+ if (type != PROTM_TYPE_ESTAB || ack != 0) {
+ struct prot *ch = NULL;
+ send_message(thread_out, MSG_ACCEPT_ANS, &ch, sizeof ch);
+ break;
+ }
+
+ ssize_t found = -1;
+ for (size_t i = 0; i < chans_len; i++) {
+ if (chans[i]->other_addr == re.source_addr) {
+ found = i;
+ break;
+ }
+ }
+
+ if (found != -1) {
+ mt_mutex_lock(&chans[found]->mut);
+ chans[found]->peer_closed = true;
+ mt_mutex_unlock(&chans[found]->mut);
+
+ struct protm pm;
+ pm.type = PROTM_TYPE_TERM;
+ pm.term.ack = 0;
+ icmpd_send(chans[found]->d, &pm, protm_size(&pm));
+ break;
+ }
+
+ struct prot *ch = malloc(sizeof(struct prot));
+ populate_prot(ch, false, icmpd_create_server(re.id, re.source_addr), re.source_addr);
+ icmpd_server_set_outstanding(ch->d, re.seqnum);
+
+ send_message(thread_out, MSG_ACCEPT_ANS, &ch, sizeof ch);
+ break;
+ }
+
+ case MSG_RECV: {
+ struct prot ch = *(struct prot*)msg.data;
+ free(msg.data);
+
+ ;
+ break;
+ }
+ }
+ }
+ }
+
+ return NULL;
+}
+
+
+static void spawn_thread(void) {
+ int pp[2];
+
+ assert(pipe(pp) == 0);
+ thread_out = pp[1];
+ host_in = pp[0];
+
+ assert(pipe(pp) == 0);
+ host_out = pp[1];
+ thread_in = pp[0];
+
+ mt_thread_create(&thread, thread_entry, NULL);
+
+ struct msg_in msg = recv_message(host_in);
+ assert(msg.hdr.type == MSG_THREAD_UP);
+ free(msg.data);
+
+ thread_spawned = true;
+}
+
+int prot_get_select_fd(struct prot *ch) {
+ return ch->event_fd_out;
+}
+
+void prot_terminate(struct prot *ch) {
+ send_message(host_out, MSG_ENDCONN, &ch, sizeof ch);
+
+ if (ch->is_accept) {
+ have_accept_channel = false;
+ }
+}
+
+struct prot* prot_create_accept_channel() {
+ if (have_accept_channel) return NULL;
+
+ if (!thread_spawned) spawn_thread();
+
+ have_accept_channel = true;
+
+ struct prot *ach = malloc(sizeof(struct prot));
+ populate_prot(ach, true, icmpd_create_server(-1, 0), 0);
+ return ach;
+}
+
+struct prot* prot_accept(struct prot *ach) {
+ assert(ach->is_accept);
+
+ send_message(host_out, MSG_ACCEPT, NULL, 0);
+
+ struct msg_in msg = recv_message(host_in);
+ assert(msg.hdr.type == MSG_ACCEPT_ANS);
+
+ struct prot *ch = *(struct prot**)msg.data;
+ free(msg.data);
+ return ch;
+}
+
+struct prot* prot_connect(uint32_t server_addr) {
+ if (!thread_spawned) spawn_thread();
+
+ struct prot *ch = malloc(sizeof(struct prot));
+ populate_prot(ch, false, icmpd_create_client(server_addr), server_addr);
+ return ch;
+}
+
+struct prot_msg prot_recv(struct prot *ch) {
+ mt_mutex_lock(&ch->mut);
+ bool closed = ch->peer_closed;
+ mt_mutex_unlock(&ch->mut);
+
+ if (closed) {
+ errno = ECONNRESET;
+ return (struct prot_msg){.data = NULL};
+ }
+
+ send_message(host_out, MSG_RECV, &ch, sizeof ch);
+
+ struct msg_in msg = recv_message(host_in);
+ assert(msg.hdr.type == MSG_RECV_ANS);
+
+ struct prot_msg res;
+ assert(msg.hdr.size == sizeof res);
+ memcpy(&res, msg.data, sizeof res);
+ free(msg.data);
+
+ if (res.data == NULL) errno = EAGAIN;
+
+ return res;
+}
+
+int prot_send(struct prot *ch, const void *data, size_t length) {
+ mt_mutex_lock(&ch->mut);
+ bool closed = ch->peer_closed;
+ mt_mutex_unlock(&ch->mut);
+
+ if (closed) {
+ errno = ECONNRESET;
+ return -1;
+ }
+
+ struct msg_form_send form;
+ form.ch = ch;
+ form.data = malloc(length);
+ memcpy(form.data, data, length);
+ form.length = length;
+
+ send_message(host_out, MSG_SEND, &form, sizeof form);
+
+ return 0;
+}