aboutsummaryrefslogtreecommitdiff
path: root/ssh/tomsg_clientlib.c
diff options
context:
space:
mode:
Diffstat (limited to 'ssh/tomsg_clientlib.c')
-rw-r--r--ssh/tomsg_clientlib.c337
1 files changed, 337 insertions, 0 deletions
diff --git a/ssh/tomsg_clientlib.c b/ssh/tomsg_clientlib.c
index 80a9bf3..9c2d80a 100644
--- a/ssh/tomsg_clientlib.c
+++ b/ssh/tomsg_clientlib.c
@@ -4,7 +4,10 @@
#include <string.h>
#include <ctype.h>
#include <inttypes.h>
+#include <unistd.h>
#include <assert.h>
+#include <errno.h>
+#include <pthread.h>
#include <poll.h>
#include "tomsg_clientlib.h"
#include "sshnc.h"
@@ -32,6 +35,25 @@ struct tomsg_client {
struct inflight *inflight;
};
+enum tomsg_async_connect_state {
+ STATE_CONNECTING,
+ STATE_KEY_RECEIVED,
+ STATE_ACCEPTED,
+};
+
+// Only field written to by the thread is 'client'.
+struct tomsg_async_connect {
+ enum tomsg_async_connect_state state;
+
+ char *hostname;
+ int port;
+
+ // Two pipes for communicating with the thread
+ int host_w, host_r, thread_w, thread_r;
+
+ struct tomsg_client *client; // filled by thread in case all is successful
+};
+
static size_t min_size_t(size_t a, size_t b) { return a < b ? a : b; }
static bool hasspacelf(const char *string) {
@@ -140,6 +162,7 @@ const char* tomsg_strerror(enum tomsg_retval code) {
switch (code) {
case TOMSG_OK: return "Success";
case TOMSG_ERR_CONNECT: return "Server refused connection";
+ case TOMSG_ERR_UNTRUSTED: return "Hostkey was rejected";
case TOMSG_ERR_CLOSED: return "Server connection unexpectedly closed";
case TOMSG_ERR_VERSION: return "Server protocol version incompatible";
case TOMSG_ERR_TRANSPORT: return "Error in the underlying SSH transport";
@@ -198,6 +221,7 @@ enum tomsg_retval tomsg_connect(
hostname, port, "tomsg", "tomsg", checker, userdata, &conn);
if (ret == SSHNC_ERR_CONNECT) return TOMSG_ERR_CONNECT;
+ if (ret == SSHNC_ERR_UNTRUSTED) return TOMSG_ERR_UNTRUSTED;
if (ret != SSHNC_OK) return TOMSG_ERR_TRANSPORT;
struct tomsg_client *client = malloc(sizeof(struct tomsg_client));
@@ -219,6 +243,319 @@ enum tomsg_retval tomsg_connect(
return version_negotiation(client);
}
+// Returns whether successful
+static bool writeall(int fd, const unsigned char *buffer, size_t length) {
+ size_t cursor = 0;
+ while (cursor < length) {
+ ssize_t nw = write(fd, buffer + cursor, length - cursor);
+ if (nw < 0) {
+ if (errno == EINTR) continue;
+ return false;
+ }
+ if (nw == 0) return false;
+ cursor += nw;
+ }
+ return true;
+}
+
+// Returns whether successful
+static bool readall(int fd, unsigned char *buffer, size_t length) {
+ size_t cursor = 0;
+ while (cursor < length) {
+ ssize_t nr = read(fd, buffer + cursor, length - cursor);
+ if (nr < 0) {
+ if (errno == EINTR) continue;
+ return false;
+ }
+ if (nr == 0) return false;
+ cursor += nr;
+ }
+ return true;
+}
+
+// Async socket protocol:
+// Sizes are always indicated using a host-order 8-byte signed integer.
+// - Thread sends an error byte; if that's TOMSG_OK, it is followed by size and
+// hashbytes of the hostkey. Otherwise, the thread exits.
+// - Host sends one byte: 0 for reject, 1 for accept.
+// - For reject, the thread closes the connection and exits. For accept, the
+// thread negotiates the protocol version and if successful, initialises and
+// populates a client structure in the state, and sends a TOMSG_OK byte. If
+// unsuccessful, sends an error byte and exits.
+// After the thread sends either an error byte or the final OK byte, or after
+// it has received a reject message, it will not access 'state' anymore; it can
+// thus be freed by the host.
+
+static bool async_hostkey_checker(const unsigned char *hash, size_t length, void *state_) {
+ struct tomsg_async_connect *state = state_;
+
+ unsigned char buffer[256];
+ if (9 + length > sizeof buffer) { assert(false); return false; }
+ buffer[0] = TOMSG_OK;
+ *(int64_t*)&buffer[1] = length;
+ memcpy(buffer + 9, hash, length);
+
+ if (!writeall(state->thread_w, buffer, 9 + length)) return false;
+
+ char response;
+ ssize_t nr = -1;
+ while (nr < 0) {
+ nr = read(state->thread_r, &response, 1);
+ if (nr == 0 || (nr < 0 && errno != EINTR)) break;
+ }
+ if (nr < 0) return false;
+
+ return response == 1;
+}
+
+static void* async_connect_thread_entry(void *state_) {
+ struct tomsg_async_connect *state = state_;
+
+ const int thread_r = state->thread_r;
+ const int thread_w = state->thread_w;
+
+ struct sshnc_client *conn;
+ enum sshnc_retval ret = sshnc_connect(
+ state->hostname, state->port, "tomsg", "tomsg",
+ async_hostkey_checker, state, &conn);
+
+ enum tomsg_retval sendret;
+ if (ret == SSHNC_ERR_CONNECT) sendret = TOMSG_ERR_CONNECT;
+ else if (ret == SSHNC_ERR_UNTRUSTED) sendret = TOMSG_ERR_UNTRUSTED;
+ else if (ret != SSHNC_OK) sendret = TOMSG_ERR_TRANSPORT;
+ else sendret = TOMSG_OK;
+
+ // If sendret is TOMSG_ERR_UNTRUSTED, the host may free 'state' at this point. Thus, don't access it, please.
+
+ struct tomsg_client *client = NULL;
+ unsigned char byte;
+
+ if (sendret == TOMSG_ERR_UNTRUSTED) goto pipe_return;
+ if (sendret != TOMSG_OK) goto send_error_return;
+
+ client = calloc(1, sizeof(struct tomsg_client));
+ if (!client) { sendret = TOMSG_ERR_MEMORY; goto send_error_return; }
+ client->conn = conn;
+ client->buffer_len = 0;
+ client->buffer_cap = 1024;
+ client->buffer = malloc(client->buffer_cap);
+ if (!client->buffer) { sendret = TOMSG_ERR_MEMORY; goto send_error_return; }
+ client->buffer_newline_cursor = 0;
+ client->next_tag = 1;
+ client->inflight_num = 0;
+ client->inflight_cap = 2;
+ client->inflight = malloc(client->inflight_cap * sizeof(struct inflight));
+ if (!client->inflight) { sendret = TOMSG_ERR_MEMORY; goto send_error_return; }
+
+ enum tomsg_retval versionret = version_negotiation(client);
+ if (versionret != TOMSG_OK) { sendret = versionret; goto send_error_return; }
+
+ state->client = client;
+
+ byte = TOMSG_OK;
+ // After the writeall, it is forbidden to access anything in 'state',
+ // because it may have been freed by the host at this point.
+ if (!writeall(thread_w, &byte, 1)) goto free_client;
+ goto pipe_return;
+
+send_error_return:
+ // After this, it is forbidden to access anything in 'state', because it
+ // may have been freed by the host at this point.
+ byte = sendret;
+ writeall(thread_w, &byte, 1);
+free_client:
+ sshnc_close(conn);
+ if (client) {
+ if (client->buffer) free(client->buffer);
+ if (client->inflight) free(client->inflight);
+ free(client);
+ }
+pipe_return:
+ close(thread_r);
+ close(thread_w);
+ return NULL;
+}
+
+enum tomsg_retval tomsg_async_connect(
+ const char *hostname, int port,
+ struct tomsg_async_connect **clientp) {
+ // In case we throw an error along the way
+ *clientp = NULL;
+
+ struct tomsg_async_connect *client = malloc(sizeof(struct tomsg_async_connect));
+ if (!client) return TOMSG_ERR_MEMORY;
+
+ client->state = STATE_CONNECTING;
+
+ client->hostname = strdup(hostname);
+ client->port = port;
+
+ pthread_attr_t attr;
+ if (pthread_attr_init(&attr) < 0) {
+ free(client->hostname);
+ free(client);
+ if (errno == ENOMEM) return TOMSG_ERR_MEMORY;
+ return TOMSG_ERR_CONNECT;
+ }
+
+ pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_DETACHED);
+
+ int pipeHT[2] = {-1, -1}, pipeTH[2];
+ if (pipe(pipeHT) < 0 || pipe(pipeTH) < 0) {
+ if (pipeHT[0] != -1) { close(pipeHT[0]); close(pipeHT[1]); }
+ free(client->hostname);
+ free(client);
+ return TOMSG_ERR_CONNECT;
+ }
+
+ client->host_w = pipeHT[1];
+ client->thread_r = pipeHT[0];
+ client->thread_w = pipeTH[1];
+ client->host_r = pipeTH[0];
+
+ pthread_t thread;
+ if (pthread_create(&thread, &attr, async_connect_thread_entry, client) < 0) {
+ pthread_attr_destroy(&attr);
+ close(client->host_w); close(client->host_r);
+ close(client->thread_w); close(client->thread_r);
+ free(client->hostname);
+ free(client);
+ return TOMSG_ERR_CONNECT;
+ }
+
+ *clientp = client;
+ return TOMSG_OK;
+}
+
+static bool check_readable(int fd) {
+ struct pollfd pfd;
+ pfd.fd = fd;
+ pfd.events = POLLIN;
+ pfd.revents = 0;
+ poll(&pfd, 1, 0);
+ return pfd.revents & (POLLIN | POLLHUP | POLLERR);
+}
+
+void tomsg_async_connect_event_nullify(struct tomsg_async_connect_event *event) {
+ switch (event->type) {
+ case TOMSG_AC_HOSTKEY:
+ free(event->key.hostkey);
+ break;
+
+ case TOMSG_AC_SUCCESS:
+ tomsg_close(event->client);
+ break;
+ }
+}
+
+enum tomsg_retval tomsg_async_connect_next_event(
+ struct tomsg_async_connect *client,
+ struct tomsg_async_connect_event *event // output
+) {
+ enum tomsg_retval final_retval;
+
+ switch (client->state) {
+ case STATE_CONNECTING: {
+ if (!check_readable(client->host_r)) return TOMSG_ERR_AGAIN;
+
+ unsigned char byte;
+ if (!readall(client->host_r, &byte, 1)) goto non_recoverable_error;
+ if (byte != TOMSG_OK) {
+ final_retval = byte;
+ goto free_return;
+ }
+
+ // Now read the hostkey from the pipe
+ int64_t length;
+ if (!readall(client->host_r, (unsigned char*)&length, 8)) goto non_recoverable_error;
+ unsigned char *hash = malloc(length);
+ if (!hash) goto non_recoverable_error;
+ if (!readall(client->host_r, hash, length)) goto non_recoverable_error;
+
+ client->state = STATE_KEY_RECEIVED;
+ event->type = TOMSG_AC_HOSTKEY;
+ event->key.hostkey = hash;
+ event->key.length = length;
+ return TOMSG_OK;
+ }
+
+ case STATE_KEY_RECEIVED:
+ return TOMSG_ERR_AGAIN; // you need to accept or reject!
+
+ case STATE_ACCEPTED: {
+ if (!check_readable(client->host_r)) return TOMSG_ERR_AGAIN;
+
+ unsigned char byte;
+ if (!readall(client->host_r, &byte, 1)) goto non_recoverable_error;
+ if (byte == TOMSG_OK) {
+ event->type = TOMSG_AC_SUCCESS;
+ event->client = client->client;
+ client->client = NULL;
+ final_retval = TOMSG_OK;
+ goto free_return;
+ } else {
+ final_retval = byte;
+ goto free_return;
+ }
+ break;
+ }
+ }
+
+non_recoverable_error:
+ // Don't even know how to handle this correctly; let's just hope the thread kills itself somehow
+ close(client->host_w);
+ close(client->host_r);
+ return TOMSG_ERR_TRANSPORT;
+
+free_return:
+ free(client->hostname);
+ close(client->host_w);
+ close(client->host_r);
+ if (client->client) tomsg_close(client->client);
+ return final_retval;
+}
+
+enum tomsg_retval tomsg_async_connect_accept(struct tomsg_async_connect *client, bool accept) {
+ enum tomsg_retval final_retval;
+ if (client->state != STATE_KEY_RECEIVED) {
+ fprintf(stderr, "connect_accept: client->state = %d != STATE_KEY_RECEIVED\n", client->state);
+ final_retval = TOMSG_ERR_TRANSPORT; // shrug
+ goto free_return;
+ }
+
+ unsigned char byte = accept ? 1 : 0;
+ if (!writeall(client->host_w, &byte, 1)) {
+ fprintf(stderr, "writeall failed: %s\n", strerror(errno));
+ goto non_recoverable_error;
+ }
+
+ if (!accept) {
+ final_retval = TOMSG_ERR_UNTRUSTED;
+ goto free_return;
+ }
+
+ client->state = STATE_ACCEPTED;
+
+ return TOMSG_OK;
+
+non_recoverable_error:
+ // Don't even know how to handle this correctly; let's just hope the thread kills itself somehow
+ close(client->host_w);
+ close(client->host_r);
+ return TOMSG_ERR_TRANSPORT;
+
+free_return:
+ free(client->hostname);
+ close(client->host_w);
+ close(client->host_r);
+ if (client->client) tomsg_close(client->client);
+ return final_retval;
+}
+
+int tomsg_async_connect_poll_fd(const struct tomsg_async_connect *client) {
+ return client->host_r;
+}
+
void tomsg_close(struct tomsg_client *client) {
if (client->conn) sshnc_close(client->conn);
free(client->buffer);