From 4d4cbdaf49f616fea47c543fe2cb74d1d8a1e7ff Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom.smeding@gmail.com>
Date: Thu, 25 Jun 2020 22:47:10 +0200
Subject: ssh: WIP ssh proxy server

---
 ssh/.gitignore |   3 +-
 ssh/Makefile   |   7 +-
 ssh/client.c   |  39 ++---
 ssh/server.c   | 509 +++++++++++++++++++++++++++++++--------------------------
 ssh/util.c     |  29 ++++
 ssh/util.h     |   7 +
 6 files changed, 328 insertions(+), 266 deletions(-)
 create mode 100644 ssh/util.c
 create mode 100644 ssh/util.h

diff --git a/ssh/.gitignore b/ssh/.gitignore
index 36f3e53..17baeb4 100644
--- a/ssh/.gitignore
+++ b/ssh/.gitignore
@@ -1,5 +1,4 @@
-host_key
-host_key.pub
+ssh_host_key
 server
 client
 *.o
diff --git a/ssh/Makefile b/ssh/Makefile
index b7c1d1b..68113cd 100644
--- a/ssh/Makefile
+++ b/ssh/Makefile
@@ -4,18 +4,19 @@ LDFLAGS = -pthread
 CFLAGS += $(shell pkg-config --cflags libssh)
 LDFLAGS += $(shell pkg-config --libs libssh)
 
+
 .PHONY: all clean
 
 all: server client
 
 clean:
-	rm -f server *.o
+	rm -f server client *.o
 
 
-server: server.o ../global.o ../memory.o
+server: server.o util.o
 	$(CC) -o $@ $^ $(LDFLAGS)
 
-client: client.o ../global.o ../memory.o
+client: client.o util.o
 	$(CC) -o $@ $^ $(LDFLAGS)
 
 %.o: %.c $(wildcard *.h)
diff --git a/ssh/client.c b/ssh/client.c
index 9e60033..0adfffb 100644
--- a/ssh/client.c
+++ b/ssh/client.c
@@ -8,28 +8,8 @@
 #include <libssh/callbacks.h>
 #include <sys/select.h>
 #include <poll.h>
-#include "../global.h"
-
-
-static void parse_args(const char *arg, const char **server_host, int *port) {
-	const char *ptr = strchr(arg, ':');
-	if (ptr == NULL) {
-		*server_host = arg;
-	} else {
-		size_t length = ptr - arg;
-		char *host = malloc(length + 1, char);
-		memcpy(host, arg, length);
-		host[length] = '\0';
-		*server_host = host;
-
-		char *endp;
-		*port = strtol(ptr + 1, &endp, 10);
-		if (endp == ptr || *endp != '\0') {
-			fprintf(stderr, "Cannot parse server:port from argument '%s'\n", arg);
-			exit(1);
-		}
-	}
-}
+#include "util.h"
+
 
 static bool prompt_yn(const char *text) {
 	printf("%s", text);
@@ -159,7 +139,10 @@ int main(int argc, char **argv) {
 		return 1;
 	}
 
-	parse_args(argv[1], &server_host, &port);
+	if (!parse_host_port(argv[1], &server_host, &port)) {
+		fprintf(stderr, "Cannot parse host:port from argument '%s'\n", argv[1]);
+		return 1;
+	}
 
 	ssh_session session = ssh_new();
 	if (!session) {
@@ -169,7 +152,8 @@ int main(int argc, char **argv) {
 
 	const char *ciphers_str = "aes256-gcm@openssh.com,aes256-ctr,aes256-cbc";
 	bool procconfig = false;
-	bool ok = ssh_options_set(session, SSH_OPTIONS_PROCESS_CONFIG, &procconfig) == SSH_OK;
+	bool ok = true;
+	ok &= ssh_options_set(session, SSH_OPTIONS_PROCESS_CONFIG, &procconfig) == SSH_OK;
 	ok &= ssh_options_set(session, SSH_OPTIONS_USER, "tomsg") == SSH_OK;
 	ok &= ssh_options_set(session, SSH_OPTIONS_HOST, server_host) == SSH_OK;
 	ok &= ssh_options_set(session, SSH_OPTIONS_PORT, &port) == SSH_OK;
@@ -257,7 +241,12 @@ int main(int argc, char **argv) {
 
 	printf("Obtained tomsg subsystem on channel\n");
 
-	struct session_data *sesdata = malloc(1, struct session_data);
+	struct session_data *sesdata = malloc(sizeof(struct session_data));
+	if (!sesdata) {
+		fprintf(stderr, "Out of memory (allocating session_data)!\n");
+		return 1;
+	}
+
 	sesdata->session = session;
 	sesdata->channel = channel;
 	sesdata->should_close = false;
diff --git a/ssh/server.c b/ssh/server.c
index 5a2b162..ed578c4 100644
--- a/ssh/server.c
+++ b/ssh/server.c
@@ -7,21 +7,20 @@
 #include <errno.h>
 #include <assert.h>
 #include <pthread.h>
-#include <sys/select.h>
+#include <netdb.h>
+#include <poll.h>
+#include <signal.h>
+#include <sys/socket.h>
+#include <sys/stat.h>
 #include <libssh/server.h>
 #include <libssh/callbacks.h>
-#include "../global.h"
+#include "util.h"
 
 
-#define CHECK(obj_, expr_) do { \
-		if (!(expr_)) { \
-			fprintf(stderr, "libssh error! expression: " #expr_ "\nError description: %s\n", \
-					ssh_get_error((obj_))); \
-			exit(1); \
-		} \
-	} while (0)
+#define RESOURCE_ERROR_SLEEP_MS 10000
 
-void xxd(FILE *stream, const void *buf_, size_t length) {
+
+static void xxd(FILE *stream, const void *buf_, size_t length) {
 	unsigned char *buf = (unsigned char*)buf_;
 
 	for (size_t cursor = 0; cursor < length;) {
@@ -47,43 +46,13 @@ void xxd(FILE *stream, const void *buf_, size_t length) {
 	}
 }
 
-// struct sessions {
-//     // Always NULL-terminated
-//     ssh_session *list;
-//     ssh_session *outlist;  // same length as 'list', contents not managed
-//     size_t cap, len;
-// };
-
-// static struct sessions sessions_make() {
-//     size_t cap = 2;
-//     return (struct sessions) {
-//         .list = malloc(cap, ssh_session),
-//         .outlist = malloc(cap, ssh_session),
-//         .cap = cap,
-//         .len = 0,
-//     };
-// }
-
-// static void sessions_add(struct sessions *ss, ssh_session ses) {
-//     if (ss->len + 1 >= ss->cap) {
-//         ss->cap *= 2;
-//         ss->list = realloc(ss->list, ss->cap, ssh_session);
-//         ss->outlist = realloc(ss->outlist, ss->cap, ssh_session);
-//     }
-//     ss->list[ss->len++] = ses;
-//     ss->list[ss->len] = NULL;
-// }
-
-// static void sessions_remove(struct sessions *ss, size_t index) {
-//     assert(0 <= index && index < ss->len);
-//     if (index < ss->len - 1) ss->list[index] = ss->list[ss->len - 1];
-//     ss->len--;
-//     ss->list[ss->len] = NULL;
-// }
-
 static atomic_int g_thread_count;
 
 struct thread_data {
+	struct addrinfo backend_addr;
+
+	int backend_fd;
+
 	int thread_id;
 	ssh_session session;
 	ssh_channel channel;  // NULL before channel has been opened
@@ -95,7 +64,7 @@ struct thread_data {
 
 ///////// CHANNEL CALLBACKS //////////
 
-int channel_subsystem_request_cb(ssh_session session, ssh_channel channel, const char *subsystem, void *tdata_) {
+static int channel_subsystem_request_cb(ssh_session session, ssh_channel channel, const char *subsystem, void *tdata_) {
 	(void)session;
 	(void)channel;
 	struct thread_data *tdata = (struct thread_data*)tdata_;
@@ -108,89 +77,51 @@ int channel_subsystem_request_cb(ssh_session session, ssh_channel channel, const
 	}
 }
 
-void channel_close_cb(ssh_session session, ssh_channel channel, void *tdata_) {
+static void channel_close_cb(ssh_session session, ssh_channel channel, void *tdata_) {
 	(void)session; (void)channel;
 	struct thread_data *tdata = (struct thread_data*)tdata_;
 	printf("[%d] channel close!\n", tdata->thread_id);
 }
 
-int channel_shell_request_cb(ssh_session session, ssh_channel channel, void *tdata_) {
-	(void)session; (void)channel;
-	struct thread_data *tdata = (struct thread_data*)tdata_;
-	printf("[%d] shell request, denying\n", tdata->thread_id);
-	return 1;
-}
-
-void channel_eof_cb(ssh_session session, ssh_channel channel, void *tdata_) {
+static void channel_eof_cb(ssh_session session, ssh_channel channel, void *tdata_) {
 	(void)session; (void)channel;
 	struct thread_data *tdata = (struct thread_data*)tdata_;
 	printf("[%d] eof on channel, setting close flag\n", tdata->thread_id);
 	tdata->should_close = true;
 }
 
-int channel_data_cb(ssh_session session, ssh_channel channel, void *data, uint32_t len, int is_stderr, void *tdata_) {
+static int channel_data_cb(ssh_session session, ssh_channel channel, void *data, uint32_t len, int is_stderr, void *tdata_) {
 	(void)is_stderr; (void)data; (void)channel; (void)session;
 	struct thread_data *tdata = (struct thread_data*)tdata_;
 	printf("[%d] data on channel (length %u):\n", tdata->thread_id, len);
 	xxd(stdout, data, len);
-	printf("[%d] echoing back!\n", tdata->thread_id);
-	if (ssh_channel_write(channel, data, len) == SSH_ERROR) {
-		printf("[%d] write to channel failed! Setting close flag\n", tdata->thread_id);
-		tdata->should_close = true;
+	// printf("[%d] echoing back!\n", tdata->thread_id);
+	// if (ssh_channel_write(channel, data, len) == SSH_ERROR) {
+	//     printf("[%d] write to channel failed! Setting close flag\n", tdata->thread_id);
+	//     tdata->should_close = true;
+	// }
+	const char *start = (const char*)data;
+	const char *cursor = start;
+	const char *end = start + len;
+	while (cursor < end) {
+		ssize_t nw = write(tdata->backend_fd, cursor, end - cursor);
+		if (nw < 0) {
+			if (errno == EINTR) continue;
+			printf("[%d] error writing to backend socket: %s\n", tdata->thread_id, strerror(errno));
+			tdata->should_close = true;
+			return cursor - start;
+		}
+		if (nw == 0) {  // should not happen?
+			printf("[%d] write(2) returned 0?\n", tdata->thread_id);
+			tdata->should_close = true;
+			return cursor - start;
+		}
+		cursor += nw;
 	}
 	return len;
 }
 
-void channel_signal_cb(ssh_session session, ssh_channel channel, const char *signal, void *tdata_) {
-	(void)channel; (void)session;
-	printf("[%d] signal SIG%s\n", ((struct thread_data*)tdata_)->thread_id, signal);
-}
-
-void channel_exit_status_cb(ssh_session session, ssh_channel channel, int exit_status, void *tdata_) {
-	(void)channel; (void)session;
-	printf("[%d] exit status %d\n", ((struct thread_data*)tdata_)->thread_id, exit_status);
-}
-
-void channel_exit_signal_cb(ssh_session session, ssh_channel channel, const char *signal, int core, const char *errmsg, const char *lang, void *tdata_) {
-	(void)lang; (void)errmsg; (void)core; (void)channel; (void)session;
-	printf("[%d] exit signal %s\n", ((struct thread_data*)tdata_)->thread_id, signal);
-}
-
-int channel_pty_request_cb(ssh_session session, ssh_channel channel, const char *term, int width, int height, int pxwidth, int pwheight, void *tdata_) {
-	(void)pwheight; (void)pxwidth; (void)channel; (void)session;
-	printf("[%d] pty request (term %s, %dx%d), denying\n", ((struct thread_data*)tdata_)->thread_id, term, width, height);
-	return -1;
-}
-
-void channel_auth_agent_req_cb(ssh_session session, ssh_channel channel, void *tdata_) {
-	(void)channel; (void)session;
-	printf("[%d] auth agent request\n", ((struct thread_data*)tdata_)->thread_id);
-}
-
-void channel_x11_req_cb(ssh_session session, ssh_channel channel, int single_connection, const char *auth_protocol, const char *auth_cookie, uint32_t screen_number, void *tdata_) {
-	(void)screen_number; (void)auth_cookie; (void)auth_protocol; (void)single_connection; (void)channel; (void)session;
-	printf("[%d] X11 REQUEST WTF\n", ((struct thread_data*)tdata_)->thread_id);
-}
-
-int channel_pty_window_change_cb(ssh_session session, ssh_channel channel, int width, int height, int pxwidth, int pwheight, void *tdata_) {
-	(void)pwheight; (void)pxwidth; (void)channel; (void)session;
-	printf("[%d] pty window change (%dx%d), denying\n", ((struct thread_data*)tdata_)->thread_id, width, height);
-	return -1;
-}
-
-int channel_exec_request_cb(ssh_session session, ssh_channel channel, const char *command, void *tdata_) {
-	(void)channel; (void)session;
-	printf("[%d] exec request (<%s>), denying\n", ((struct thread_data*)tdata_)->thread_id, command);
-	return 1;
-}
-
-int channel_env_request_cb(ssh_session session, ssh_channel channel, const char *env_name, const char *env_value, void *tdata_) {
-	(void)channel; (void)session;
-	printf("[%d] environment request (<%s> = <%s>), denying\n", ((struct thread_data*)tdata_)->thread_id, env_name, env_value);
-	return 1;
-}
-
-int channel_write_wontblock_cb(ssh_session session, ssh_channel channel, size_t bytes, void *tdata_) {
+static int channel_write_wontblock_cb(ssh_session session, ssh_channel channel, size_t bytes, void *tdata_) {
 	(void)channel; (void)session;
 	printf("[%d] write won't block for %zu bytes notification\n", ((struct thread_data*)tdata_)->thread_id, bytes);
 	return 0;
@@ -198,35 +129,14 @@ int channel_write_wontblock_cb(ssh_session session, ssh_channel channel, size_t
 
 ////////// SERVER CALLBACKS //////////
 
-int auth_none_cb(ssh_session session, const char *user, void *tdata_) {
+static int auth_none_cb(ssh_session session, const char *user, void *tdata_) {
 	(void)session;
 	struct thread_data *tdata = (struct thread_data*)tdata_;
 	printf("[%d] auth none (user <%s>), accepting\n", tdata->thread_id, user);
 	return SSH_AUTH_SUCCESS;
 }
 
-int auth_password_cb(ssh_session session, const char *user, const char *password, void *tdata_) {
-	(void)session;
-	struct thread_data *tdata = (struct thread_data*)tdata_;
-	printf("[%d] auth password (user <%s> password <%s>), denying\n", tdata->thread_id, user, password);
-	return SSH_AUTH_DENIED;
-}
-
-int auth_gssapi_mic_cb(ssh_session session, const char *user, const char *principal, void *tdata_) {
-	(void)session;
-	struct thread_data *tdata = (struct thread_data*)tdata_;
-	printf("[%d] auth gssapi (user <%s> principal <%s>), denying\n", tdata->thread_id, user, principal);
-	return SSH_AUTH_DENIED;
-}
-
-int auth_pubkey_cb(ssh_session session, const char *user, struct ssh_key_struct *pubkey, char signature_state, void *tdata_) {
-	(void)session; (void)pubkey; (void)signature_state;
-	struct thread_data *tdata = (struct thread_data*)tdata_;
-	printf("[%d] auth pubkey (user <%s>), denying\n", tdata->thread_id, user);
-	return SSH_AUTH_DENIED;
-}
-
-int service_request_cb(ssh_session session, const char *service, void *tdata_) {
+static int service_request_cb(ssh_session session, const char *service, void *tdata_) {
 	(void)session;
 	struct thread_data *tdata = (struct thread_data*)tdata_;
 	if (strcmp(service, "ssh-userauth") == 0) {
@@ -238,7 +148,7 @@ int service_request_cb(ssh_session session, const char *service, void *tdata_) {
 	}
 }
 
-ssh_channel chan_open_request_cb(ssh_session session, void *tdata_) {
+static ssh_channel chan_open_request_cb(ssh_session session, void *tdata_) {
 	struct thread_data *tdata = (struct thread_data*)tdata_;
 	if (tdata->channel == NULL) {
 		ssh_channel chan = ssh_channel_new(session);
@@ -256,61 +166,100 @@ ssh_channel chan_open_request_cb(ssh_session session, void *tdata_) {
 	return NULL;
 }
 
-// int msg_callback(ssh_session session, ssh_message msg, void *tdata_) {
-//     (void)session;
-//     struct thread_data *tdata = (struct thread_data*)tdata_;
-//     const int tid = tdata->thread_id;
-
-//     const int subtype = ssh_message_subtype(msg);
-
-//     switch (ssh_message_type(msg)) {
-//         case SSH_REQUEST_AUTH:
-//             printf("[%d] message callback: type auth (subtype %d)\n", tid, subtype);
-//             if (subtype == SSH_AUTH_METHOD_NONE) {
-//                 if (ssh_message_auth_reply_success(msg, false) == SSH_OK) {
-//                     return 0;  // handled
-//                 } else {
-//                     printf("[%d]   failed to reply success for auth method none\n", tid);
-//                 }
-//             }
-//             break;
-
-//         case SSH_REQUEST_CHANNEL_OPEN:
-//             printf("[%d] message callback: type channel open (subtype %d)\n", tid, subtype);
-//             if (tdata->channel == NULL && subtype == SSH_CHANNEL_SESSION) {
-//                 ssh_channel chan = ssh_message_channel_request_open_reply_accept(msg);
-//                 chan_cb.userdata = tdata;
-
-//                 if (chan && ssh_set_channel_callbacks(chan, &chan_cb) == SSH_OK) {
-//                     tdata->channel = chan;
-//                     return 0;  // handled
-//                 } else {
-//                     if (chan) {
-//                         ssh_channel_close(chan);
-//                         ssh_channel_free(chan);
-//                     }
-//                     printf("[%d]   failed to accept channel open request for session\n", tid);
-//                 }
-//             }
-//             break;
-
-//         case SSH_REQUEST_CHANNEL:
-//             printf("[%d] message callback: type channel (subtype %d)\n", tid, subtype);
-//             break;
-
-//         case SSH_REQUEST_SERVICE:
-//             printf("[%d] message callback: type service\n", tid);
-//             break;
-
-//         case SSH_REQUEST_GLOBAL:
-//             printf("[%d] message callback: type global (subtype %d)\n", tid, subtype);
-//             break;
-//     }
-
-//     return 1;  // not handled
-// }
-
-void* thread_entry(void *tdata_) {
+static int backend_data_cb(int fd, int revents, void *tdata_) {
+	struct thread_data *tdata = (struct thread_data*)tdata_;
+
+	if (revents & POLLIN) {
+		char buffer[1024];
+		ssize_t nr = read(fd, buffer, sizeof buffer);
+		if (nr < 0) {
+			if (errno == EINTR) return 0;
+			printf("[%d] Error reading from backend socket: %s\n", tdata->thread_id, strerror(errno));
+			tdata->should_close = true;
+			return 0;
+		}
+
+		if (nr == 0) {  // eof
+			tdata->should_close = true;
+			return 0;
+		}
+
+		if (ssh_channel_write(tdata->channel, buffer, nr) != SSH_OK) {
+			printf("[%d] Error writing to ssh channel: %s\n", tdata->thread_id, ssh_get_error(tdata->channel));
+			tdata->should_close = true;
+			return 0;
+		}
+	}
+
+	return 0;
+}
+
+// Returns whether successful.
+static bool lookup_backend(const char *host, int port, struct addrinfo *dst) {
+	char port_string[16];
+	sprintf(port_string, "%d", port);
+
+	struct addrinfo hints;
+	memset(&hints, 0, sizeof hints);
+	hints.ai_family = AF_UNSPEC;
+	hints.ai_socktype = SOCK_STREAM;
+	hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG;
+
+	struct addrinfo *result;
+	int ret = getaddrinfo(host, port_string, &hints, &result);
+
+	if (ret < 0) {
+		fprintf(stderr, "Could not resolve backend: %s\n", gai_strerror(ret));
+		return false;
+	}
+
+	int last_failure = 0;
+	bool success = false;
+	for (struct addrinfo *item = result; item; item = item->ai_next) {
+		int sock = socket(item->ai_family, item->ai_socktype, item->ai_protocol);
+		if (sock == -1) {
+			last_failure = errno;
+			continue;
+		}
+
+		int ret = connect(sock, item->ai_addr, item->ai_addrlen);
+		close(sock);
+
+		if (ret == 0) {
+			success = true;
+			*dst = *item;
+			dst->ai_next = NULL;
+			break;
+		} else {
+			last_failure = errno;
+		}
+	}
+
+	freeaddrinfo(result);
+
+	if (success) {
+		return true;
+	} else {
+		fprintf(stderr, "Could not connect to backend: %s\n", strerror(last_failure));
+		return false;
+	}
+}
+
+static int connect_backend(const struct thread_data *tdata) {
+	const struct addrinfo *item = &tdata->backend_addr;
+	int sock = socket(item->ai_family, item->ai_socktype, item->ai_protocol);
+	if (sock == -1) return -1;
+
+	if (connect(sock, item->ai_addr, item->ai_addrlen) == 0) {
+		printf("connect_backend: sock=%d\n", sock);
+		return sock;
+	}
+
+	close(sock);
+	return -1;
+}
+
+static void* thread_entry(void *tdata_) {
 	struct thread_data *tdata = (struct thread_data*)tdata_;
 	const int tid = tdata->thread_id;
 	const ssh_session session = tdata->session;
@@ -318,19 +267,11 @@ void* thread_entry(void *tdata_) {
 
 	printf("[%d] Thread started\n", tid);
 
-	// ssh_set_message_callback(session, msg_callback, tdata);
-
 	memset(&tdata->server_cb, 0, sizeof tdata->server_cb);
 	ssh_callbacks_init(&tdata->server_cb);
 	tdata->server_cb.userdata = tdata;
 	tdata->server_cb.auth_none_function = auth_none_cb;
 	tdata->server_cb.channel_open_request_session_function = chan_open_request_cb;
-	tdata->server_cb.auth_password_function = auth_password_cb;
-	tdata->server_cb.auth_gssapi_mic_function = auth_gssapi_mic_cb;
-	tdata->server_cb.auth_pubkey_function = auth_pubkey_cb;
-	tdata->server_cb.gssapi_select_oid_function = (void*)0x424242;  // just crash if it attemps to invoke these
-	tdata->server_cb.gssapi_accept_sec_ctx_function = (void*)0x424242;
-	tdata->server_cb.gssapi_verify_mic_function = (void*)0x424242;
 	tdata->server_cb.service_request_function = service_request_cb;
 
 	memset(&tdata->chan_cb, 0, sizeof tdata->chan_cb);
@@ -338,18 +279,8 @@ void* thread_entry(void *tdata_) {
 	tdata->chan_cb.userdata = tdata;
 	tdata->chan_cb.channel_subsystem_request_function = channel_subsystem_request_cb;
 	tdata->chan_cb.channel_close_function = channel_close_cb;
-	tdata->chan_cb.channel_shell_request_function = channel_shell_request_cb;
 	tdata->chan_cb.channel_eof_function = channel_eof_cb;
 	tdata->chan_cb.channel_data_function = channel_data_cb;
-	tdata->chan_cb.channel_signal_function = channel_signal_cb;
-	tdata->chan_cb.channel_exit_status_function = channel_exit_status_cb;
-	tdata->chan_cb.channel_exit_signal_function = channel_exit_signal_cb;
-	tdata->chan_cb.channel_pty_request_function = channel_pty_request_cb;
-	tdata->chan_cb.channel_auth_agent_req_function = channel_auth_agent_req_cb;
-	tdata->chan_cb.channel_x11_req_function = channel_x11_req_cb;
-	tdata->chan_cb.channel_pty_window_change_function = channel_pty_window_change_cb;
-	tdata->chan_cb.channel_exec_request_function = channel_exec_request_cb;
-	tdata->chan_cb.channel_env_request_function = channel_env_request_cb;
 	tdata->chan_cb.channel_write_wontblock_function = channel_write_wontblock_cb;
 
 	if (ssh_set_server_callbacks(session, &tdata->server_cb) != SSH_OK) {
@@ -365,8 +296,18 @@ void* thread_entry(void *tdata_) {
 	}
 	printf("[%d] Handled key exchange\n", tid);
 
+	tdata->backend_fd = connect_backend(tdata);
+	if (tdata->backend_fd == -1) {
+		printf("[%d] Failed to connect to backend: %s\n", tid, strerror(errno));
+		goto cleanup;
+	}
+
+	printf("[%d] Connected to backend (fd=%d)\n", tid, tdata->backend_fd);
+
 	event = ssh_event_new();
-	if (!event || ssh_event_add_session(event, session) != SSH_OK) {
+	if (!event
+			|| ssh_event_add_session(event, session) != SSH_OK
+			|| ssh_event_add_fd(event, tdata->backend_fd, POLLIN, backend_data_cb, tdata) != SSH_OK) {
 		printf("[%d] Failed to create ssh event context\n", tid);
 		goto cleanup;
 	}
@@ -398,15 +339,80 @@ cleanup:
 	pthread_exit(NULL);
 }
 
-int main(void) {
+static void generate_key(const char *outfile) {
+	ssh_key host_key;
+	int ret = ssh_pki_generate(SSH_KEYTYPE_ED25519, 0, &host_key);
+	if (ret != SSH_OK) {
+		fprintf(stderr, "Key generation failed (ed25519)!\n");
+		exit(1);
+	}
+
+	ret = ssh_pki_export_privkey_file(host_key, NULL, NULL, NULL, outfile);
+	if (ret != SSH_OK) {
+		fprintf(stderr, "Failed to export generated host key to file '%s'; is that location accessible?\n", outfile);
+		exit(1);
+	}
+
+	if (chmod(outfile, S_IRUSR | S_IWUSR) != 0) {
+		fprintf(stderr, "Failed to set mode 600 on generated host key file '%s'; this is insecure!\n", outfile);
+		exit(1);
+	}
+
+	printf("ed25519 host key generated and written to '%s'.\n", outfile);
+}
+
+static void usage(const char *argv0) {
+	fprintf(stderr,
+			"Usage: %s <ssh host key file> <ssh port> [backendhost:port]\n"
+			"       %s --generate <host key output file>\n"
+			"SSH-TCP bridge for tomsg. Accepts SSH connections with a channel for subsystem\n"
+			"'tomsg', and matches each SSH connection with a plain TCP connection to the\n"
+			"backend server (which defaults to localhost:29536). All data is forwarded\n"
+			"transparently.\n"
+			"Use the '--generate' form to generate a host key for use in the main invocation\n"
+			"form.\n",
+			argv0, argv0);
+}
+
+int main(int argc, char **argv) {
+	const char *host_key_fname;
+	int ssh_port = 2222;
+	const char *backend_host = "localhost";
+	int backend_port = 29536;
+
+	if (argc == 3 && strcmp(argv[1], "--generate") == 0) {
+		generate_key(argv[2]);
+		return 0;
+	} else if (3 <= argc && argc <= 4) {
+		host_key_fname = argv[1];
+		char *endp;
+		ssh_port = strtol(argv[2], &endp, 10);
+		if (argv[2][0] == '\0' || *endp != '\0' || ssh_port < 0 || ssh_port > 65535) {
+			fprintf(stderr, "Cannot parse port number from argument '%s'\n", argv[2]);
+			return 1;
+		}
+		if (argc == 4) {
+			if (!parse_host_port(argv[3], &backend_host, &backend_port)) {
+				fprintf(stderr, "Cannot parse host:port from argument '%s'\n", argv[3]);
+				return 1;
+			}
+		}
+	} else {
+		usage(argv[0]);
+		return 1;
+	}
+
+	// We prefer to detect socket closure through return codes, not signals.
+	signal(SIGPIPE, SIG_IGN);
+
 	if (ssh_init() != SSH_OK) {
 		fprintf(stderr, "Could not initialise libssh\n");
 		return 1;
 	}
 
 	ssh_key host_key;
-	if (ssh_pki_import_privkey_file("host_key", NULL, NULL, NULL, &host_key) != SSH_OK) {
-		fprintf(stderr, "Failed to read host private key file 'host_key'\n");
+	if (ssh_pki_import_privkey_file(host_key_fname, NULL, NULL, NULL, &host_key) != SSH_OK) {
+		fprintf(stderr, "Failed to read host private key file '%s'\n", host_key_fname);
 		return 1;
 	}
 
@@ -422,23 +428,38 @@ int main(void) {
 	ssh_print_hash(SSH_PUBLICKEY_HASH_SHA256, host_key_hash, host_key_hash_length);
 
 	ssh_bind srvbind = ssh_bind_new();
-	CHECK(srvbind, srvbind);
+	if (!srvbind) {
+		fprintf(stderr, "Failed to create new bind socket\n");
+		return 1;
+	}
 
 	bool procconfig = false;
-	CHECK(srvbind, ssh_bind_options_set(srvbind, SSH_BIND_OPTIONS_PROCESS_CONFIG, &procconfig) == SSH_OK);
-	int port = 2222;
-	CHECK(srvbind, ssh_bind_options_set(srvbind, SSH_BIND_OPTIONS_BINDPORT, &port) == SSH_OK);
-	CHECK(srvbind, ssh_bind_options_set(srvbind, SSH_BIND_OPTIONS_IMPORT_KEY, host_key) == SSH_OK);
 	const char *ciphers_str = "aes256-gcm@openssh.com,aes256-ctr,aes256-cbc";
-	CHECK(srvbind, ssh_bind_options_set(srvbind, SSH_BIND_OPTIONS_CIPHERS_C_S, ciphers_str) == SSH_OK);
-	CHECK(srvbind, ssh_bind_options_set(srvbind, SSH_BIND_OPTIONS_CIPHERS_S_C, ciphers_str) == SSH_OK);
+	bool ok = true;
+	ok &= ssh_bind_options_set(srvbind, SSH_BIND_OPTIONS_PROCESS_CONFIG, &procconfig) == SSH_OK;
+	ok &= ssh_bind_options_set(srvbind, SSH_BIND_OPTIONS_BINDPORT, &ssh_port) == SSH_OK;
+	ok &= ssh_bind_options_set(srvbind, SSH_BIND_OPTIONS_IMPORT_KEY, host_key) == SSH_OK;
+	ok &= ssh_bind_options_set(srvbind, SSH_BIND_OPTIONS_CIPHERS_C_S, ciphers_str) == SSH_OK;
+	ok &= ssh_bind_options_set(srvbind, SSH_BIND_OPTIONS_CIPHERS_S_C, ciphers_str) == SSH_OK;
+
+	if (!ok) {
+		fprintf(stderr, "Could not set options on SSH bind socket: %s\n", ssh_get_error(srvbind));
+		return 1;
+	}
 
-	CHECK(srvbind, ssh_bind_listen(srvbind) == SSH_OK);
-	printf("Listening for SSH connections on port %d\n", port);
+	if (ssh_bind_listen(srvbind) != SSH_OK) {
+		fprintf(stderr, "Could not listen on SSH bind socket: %s\n", ssh_get_error(srvbind));
+		return 1;
+	}
 
-	// int srvbind_fd = ssh_bind_get_fd(srvbind);
+	struct addrinfo backend_addr;
+	if (!lookup_backend(backend_host, backend_port, &backend_addr)) {
+		// Error already printed in lookup_backend
+		return 1;
+	}
 
-	// struct sessions sessions = sessions_make();
+	printf("Listening for SSH connections on port %d\n", ssh_port);
+	printf("Forwarding to backend at %s:%d\n", backend_host, backend_port);
 
 	pthread_attr_t thread_attrs;
 	assert(pthread_attr_init(&thread_attrs) == 0);
@@ -448,32 +469,48 @@ int main(void) {
 	atomic_store(&g_thread_count, 0);
 
 	while (true) {
-		// fd_set inset;
-		// FD_ZERO(&inset);
-		// FD_SET(srvbind_fd, &inset);
-
-		// int ret = ssh_select(sessions.list, sessions.outlist, srvbind_fd + 1, &inset, NULL);
-		// if (ret == SSH_EINTR) continue;
-		// if (ret == SSH_ERROR) {
-		//     fprintf(stderr, "ssh_select reported error!\n");
-		//     return 1;
-		// }
-
 		ssh_session session = ssh_new();
-		CHECK(session, session);
+		if (!session) {
+			fprintf(stderr, "ERROR: Cannot create new SSH session object!\n");
+			usleep(1000 * RESOURCE_ERROR_SLEEP_MS);
+			continue;
+		}
+
+		if (ssh_bind_accept(srvbind, session) != SSH_OK) {
+			fprintf(stderr, "ERROR: Cannot accept on bind socket: %s", ssh_get_error(srvbind));
+			ssh_free(session);
+			usleep(1000 * RESOURCE_ERROR_SLEEP_MS);
+			continue;
+		}
 
-		CHECK(srvbind, ssh_bind_accept(srvbind, session) == SSH_OK);
 		int num_existing_threads = atomic_fetch_add(&g_thread_count, 1);
 		printf("Accepted connection, spinning up thread (currently %d threads)\n",
 				num_existing_threads + 1);
 
-		struct thread_data *tdata = calloc(1, struct thread_data);
+		struct thread_data *tdata = calloc(1, sizeof(struct thread_data));
+		if (!tdata) {
+			fprintf(stderr, "ERROR: Out of memory, cannot allocate thread_data!\n");
+			ssh_disconnect(session);
+			ssh_free(session);
+			usleep(1000 * RESOURCE_ERROR_SLEEP_MS);
+			continue;
+		}
+
+		tdata->backend_addr = backend_addr;
+		tdata->backend_fd = -1;
 		tdata->thread_id = next_thread_id++;
 		tdata->session = session;
 		tdata->channel = NULL;
 		tdata->should_close = false;
 
 		pthread_t thread;
-		assert(pthread_create(&thread, &thread_attrs, thread_entry, tdata) == 0);
+		if (pthread_create(&thread, &thread_attrs, thread_entry, tdata) != 0) {
+			fprintf(stderr, "ERROR: Could not spawn thread: %s!\n", strerror(errno));
+			free(tdata);
+			ssh_disconnect(session);
+			ssh_free(session);
+			usleep(1000 * RESOURCE_ERROR_SLEEP_MS);
+			continue;
+		}
 	}
 }
diff --git a/ssh/util.c b/ssh/util.c
new file mode 100644
index 0000000..a8f3c41
--- /dev/null
+++ b/ssh/util.c
@@ -0,0 +1,29 @@
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include "util.h"
+
+
+bool parse_host_port(const char *arg, const char **server_host, int *port) {
+	const char *ptr = strchr(arg, ':');
+	if (ptr == NULL) {
+		*server_host = arg;
+	} else {
+		size_t length = ptr - arg;
+		char *host = malloc(length + 1);
+		if (!host) {
+			fprintf(stderr, "Cannot allocate memory!\n");
+			exit(1);
+		}
+		memcpy(host, arg, length);
+		host[length] = '\0';
+		*server_host = host;
+
+		char *endp;
+		*port = strtol(ptr + 1, &endp, 10);
+		if (endp == ptr || *endp != '\0') {
+			return false;
+		}
+	}
+	return true;
+}
diff --git a/ssh/util.h b/ssh/util.h
new file mode 100644
index 0000000..0ca7d45
--- /dev/null
+++ b/ssh/util.h
@@ -0,0 +1,7 @@
+#pragma once
+
+#include <stdbool.h>
+
+
+// Returns whether successful.
+bool parse_host_port(const char *arg, const char **server_host, int *port);
-- 
cgit v1.2.3-70-g09d2