diff options
Diffstat (limited to 'ssh/tomsg_clientlib.c')
-rw-r--r-- | ssh/tomsg_clientlib.c | 337 |
1 files changed, 337 insertions, 0 deletions
diff --git a/ssh/tomsg_clientlib.c b/ssh/tomsg_clientlib.c index 80a9bf3..9c2d80a 100644 --- a/ssh/tomsg_clientlib.c +++ b/ssh/tomsg_clientlib.c @@ -4,7 +4,10 @@ #include <string.h> #include <ctype.h> #include <inttypes.h> +#include <unistd.h> #include <assert.h> +#include <errno.h> +#include <pthread.h> #include <poll.h> #include "tomsg_clientlib.h" #include "sshnc.h" @@ -32,6 +35,25 @@ struct tomsg_client { struct inflight *inflight; }; +enum tomsg_async_connect_state { + STATE_CONNECTING, + STATE_KEY_RECEIVED, + STATE_ACCEPTED, +}; + +// Only field written to by the thread is 'client'. +struct tomsg_async_connect { + enum tomsg_async_connect_state state; + + char *hostname; + int port; + + // Two pipes for communicating with the thread + int host_w, host_r, thread_w, thread_r; + + struct tomsg_client *client; // filled by thread in case all is successful +}; + static size_t min_size_t(size_t a, size_t b) { return a < b ? a : b; } static bool hasspacelf(const char *string) { @@ -140,6 +162,7 @@ const char* tomsg_strerror(enum tomsg_retval code) { switch (code) { case TOMSG_OK: return "Success"; case TOMSG_ERR_CONNECT: return "Server refused connection"; + case TOMSG_ERR_UNTRUSTED: return "Hostkey was rejected"; case TOMSG_ERR_CLOSED: return "Server connection unexpectedly closed"; case TOMSG_ERR_VERSION: return "Server protocol version incompatible"; case TOMSG_ERR_TRANSPORT: return "Error in the underlying SSH transport"; @@ -198,6 +221,7 @@ enum tomsg_retval tomsg_connect( hostname, port, "tomsg", "tomsg", checker, userdata, &conn); if (ret == SSHNC_ERR_CONNECT) return TOMSG_ERR_CONNECT; + if (ret == SSHNC_ERR_UNTRUSTED) return TOMSG_ERR_UNTRUSTED; if (ret != SSHNC_OK) return TOMSG_ERR_TRANSPORT; struct tomsg_client *client = malloc(sizeof(struct tomsg_client)); @@ -219,6 +243,319 @@ enum tomsg_retval tomsg_connect( return version_negotiation(client); } +// Returns whether successful +static bool writeall(int fd, const unsigned char *buffer, size_t length) { + size_t cursor = 0; + while (cursor < length) { + ssize_t nw = write(fd, buffer + cursor, length - cursor); + if (nw < 0) { + if (errno == EINTR) continue; + return false; + } + if (nw == 0) return false; + cursor += nw; + } + return true; +} + +// Returns whether successful +static bool readall(int fd, unsigned char *buffer, size_t length) { + size_t cursor = 0; + while (cursor < length) { + ssize_t nr = read(fd, buffer + cursor, length - cursor); + if (nr < 0) { + if (errno == EINTR) continue; + return false; + } + if (nr == 0) return false; + cursor += nr; + } + return true; +} + +// Async socket protocol: +// Sizes are always indicated using a host-order 8-byte signed integer. +// - Thread sends an error byte; if that's TOMSG_OK, it is followed by size and +// hashbytes of the hostkey. Otherwise, the thread exits. +// - Host sends one byte: 0 for reject, 1 for accept. +// - For reject, the thread closes the connection and exits. For accept, the +// thread negotiates the protocol version and if successful, initialises and +// populates a client structure in the state, and sends a TOMSG_OK byte. If +// unsuccessful, sends an error byte and exits. +// After the thread sends either an error byte or the final OK byte, or after +// it has received a reject message, it will not access 'state' anymore; it can +// thus be freed by the host. + +static bool async_hostkey_checker(const unsigned char *hash, size_t length, void *state_) { + struct tomsg_async_connect *state = state_; + + unsigned char buffer[256]; + if (9 + length > sizeof buffer) { assert(false); return false; } + buffer[0] = TOMSG_OK; + *(int64_t*)&buffer[1] = length; + memcpy(buffer + 9, hash, length); + + if (!writeall(state->thread_w, buffer, 9 + length)) return false; + + char response; + ssize_t nr = -1; + while (nr < 0) { + nr = read(state->thread_r, &response, 1); + if (nr == 0 || (nr < 0 && errno != EINTR)) break; + } + if (nr < 0) return false; + + return response == 1; +} + +static void* async_connect_thread_entry(void *state_) { + struct tomsg_async_connect *state = state_; + + const int thread_r = state->thread_r; + const int thread_w = state->thread_w; + + struct sshnc_client *conn; + enum sshnc_retval ret = sshnc_connect( + state->hostname, state->port, "tomsg", "tomsg", + async_hostkey_checker, state, &conn); + + enum tomsg_retval sendret; + if (ret == SSHNC_ERR_CONNECT) sendret = TOMSG_ERR_CONNECT; + else if (ret == SSHNC_ERR_UNTRUSTED) sendret = TOMSG_ERR_UNTRUSTED; + else if (ret != SSHNC_OK) sendret = TOMSG_ERR_TRANSPORT; + else sendret = TOMSG_OK; + + // If sendret is TOMSG_ERR_UNTRUSTED, the host may free 'state' at this point. Thus, don't access it, please. + + struct tomsg_client *client = NULL; + unsigned char byte; + + if (sendret == TOMSG_ERR_UNTRUSTED) goto pipe_return; + if (sendret != TOMSG_OK) goto send_error_return; + + client = calloc(1, sizeof(struct tomsg_client)); + if (!client) { sendret = TOMSG_ERR_MEMORY; goto send_error_return; } + client->conn = conn; + client->buffer_len = 0; + client->buffer_cap = 1024; + client->buffer = malloc(client->buffer_cap); + if (!client->buffer) { sendret = TOMSG_ERR_MEMORY; goto send_error_return; } + client->buffer_newline_cursor = 0; + client->next_tag = 1; + client->inflight_num = 0; + client->inflight_cap = 2; + client->inflight = malloc(client->inflight_cap * sizeof(struct inflight)); + if (!client->inflight) { sendret = TOMSG_ERR_MEMORY; goto send_error_return; } + + enum tomsg_retval versionret = version_negotiation(client); + if (versionret != TOMSG_OK) { sendret = versionret; goto send_error_return; } + + state->client = client; + + byte = TOMSG_OK; + // After the writeall, it is forbidden to access anything in 'state', + // because it may have been freed by the host at this point. + if (!writeall(thread_w, &byte, 1)) goto free_client; + goto pipe_return; + +send_error_return: + // After this, it is forbidden to access anything in 'state', because it + // may have been freed by the host at this point. + byte = sendret; + writeall(thread_w, &byte, 1); +free_client: + sshnc_close(conn); + if (client) { + if (client->buffer) free(client->buffer); + if (client->inflight) free(client->inflight); + free(client); + } +pipe_return: + close(thread_r); + close(thread_w); + return NULL; +} + +enum tomsg_retval tomsg_async_connect( + const char *hostname, int port, + struct tomsg_async_connect **clientp) { + // In case we throw an error along the way + *clientp = NULL; + + struct tomsg_async_connect *client = malloc(sizeof(struct tomsg_async_connect)); + if (!client) return TOMSG_ERR_MEMORY; + + client->state = STATE_CONNECTING; + + client->hostname = strdup(hostname); + client->port = port; + + pthread_attr_t attr; + if (pthread_attr_init(&attr) < 0) { + free(client->hostname); + free(client); + if (errno == ENOMEM) return TOMSG_ERR_MEMORY; + return TOMSG_ERR_CONNECT; + } + + pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_DETACHED); + + int pipeHT[2] = {-1, -1}, pipeTH[2]; + if (pipe(pipeHT) < 0 || pipe(pipeTH) < 0) { + if (pipeHT[0] != -1) { close(pipeHT[0]); close(pipeHT[1]); } + free(client->hostname); + free(client); + return TOMSG_ERR_CONNECT; + } + + client->host_w = pipeHT[1]; + client->thread_r = pipeHT[0]; + client->thread_w = pipeTH[1]; + client->host_r = pipeTH[0]; + + pthread_t thread; + if (pthread_create(&thread, &attr, async_connect_thread_entry, client) < 0) { + pthread_attr_destroy(&attr); + close(client->host_w); close(client->host_r); + close(client->thread_w); close(client->thread_r); + free(client->hostname); + free(client); + return TOMSG_ERR_CONNECT; + } + + *clientp = client; + return TOMSG_OK; +} + +static bool check_readable(int fd) { + struct pollfd pfd; + pfd.fd = fd; + pfd.events = POLLIN; + pfd.revents = 0; + poll(&pfd, 1, 0); + return pfd.revents & (POLLIN | POLLHUP | POLLERR); +} + +void tomsg_async_connect_event_nullify(struct tomsg_async_connect_event *event) { + switch (event->type) { + case TOMSG_AC_HOSTKEY: + free(event->key.hostkey); + break; + + case TOMSG_AC_SUCCESS: + tomsg_close(event->client); + break; + } +} + +enum tomsg_retval tomsg_async_connect_next_event( + struct tomsg_async_connect *client, + struct tomsg_async_connect_event *event // output +) { + enum tomsg_retval final_retval; + + switch (client->state) { + case STATE_CONNECTING: { + if (!check_readable(client->host_r)) return TOMSG_ERR_AGAIN; + + unsigned char byte; + if (!readall(client->host_r, &byte, 1)) goto non_recoverable_error; + if (byte != TOMSG_OK) { + final_retval = byte; + goto free_return; + } + + // Now read the hostkey from the pipe + int64_t length; + if (!readall(client->host_r, (unsigned char*)&length, 8)) goto non_recoverable_error; + unsigned char *hash = malloc(length); + if (!hash) goto non_recoverable_error; + if (!readall(client->host_r, hash, length)) goto non_recoverable_error; + + client->state = STATE_KEY_RECEIVED; + event->type = TOMSG_AC_HOSTKEY; + event->key.hostkey = hash; + event->key.length = length; + return TOMSG_OK; + } + + case STATE_KEY_RECEIVED: + return TOMSG_ERR_AGAIN; // you need to accept or reject! + + case STATE_ACCEPTED: { + if (!check_readable(client->host_r)) return TOMSG_ERR_AGAIN; + + unsigned char byte; + if (!readall(client->host_r, &byte, 1)) goto non_recoverable_error; + if (byte == TOMSG_OK) { + event->type = TOMSG_AC_SUCCESS; + event->client = client->client; + client->client = NULL; + final_retval = TOMSG_OK; + goto free_return; + } else { + final_retval = byte; + goto free_return; + } + break; + } + } + +non_recoverable_error: + // Don't even know how to handle this correctly; let's just hope the thread kills itself somehow + close(client->host_w); + close(client->host_r); + return TOMSG_ERR_TRANSPORT; + +free_return: + free(client->hostname); + close(client->host_w); + close(client->host_r); + if (client->client) tomsg_close(client->client); + return final_retval; +} + +enum tomsg_retval tomsg_async_connect_accept(struct tomsg_async_connect *client, bool accept) { + enum tomsg_retval final_retval; + if (client->state != STATE_KEY_RECEIVED) { + fprintf(stderr, "connect_accept: client->state = %d != STATE_KEY_RECEIVED\n", client->state); + final_retval = TOMSG_ERR_TRANSPORT; // shrug + goto free_return; + } + + unsigned char byte = accept ? 1 : 0; + if (!writeall(client->host_w, &byte, 1)) { + fprintf(stderr, "writeall failed: %s\n", strerror(errno)); + goto non_recoverable_error; + } + + if (!accept) { + final_retval = TOMSG_ERR_UNTRUSTED; + goto free_return; + } + + client->state = STATE_ACCEPTED; + + return TOMSG_OK; + +non_recoverable_error: + // Don't even know how to handle this correctly; let's just hope the thread kills itself somehow + close(client->host_w); + close(client->host_r); + return TOMSG_ERR_TRANSPORT; + +free_return: + free(client->hostname); + close(client->host_w); + close(client->host_r); + if (client->client) tomsg_close(client->client); + return final_retval; +} + +int tomsg_async_connect_poll_fd(const struct tomsg_async_connect *client) { + return client->host_r; +} + void tomsg_close(struct tomsg_client *client) { if (client->conn) sshnc_close(client->conn); free(client->buffer); |