diff options
| -rw-r--r-- | ssh/.gitignore | 3 | ||||
| -rw-r--r-- | ssh/Makefile | 7 | ||||
| -rw-r--r-- | ssh/client.c | 37 | ||||
| -rw-r--r-- | ssh/server.c | 491 | ||||
| -rw-r--r-- | ssh/util.c | 29 | ||||
| -rw-r--r-- | ssh/util.h | 7 | 
6 files changed, 318 insertions, 256 deletions
| diff --git a/ssh/.gitignore b/ssh/.gitignore index 36f3e53..17baeb4 100644 --- a/ssh/.gitignore +++ b/ssh/.gitignore @@ -1,5 +1,4 @@ -host_key -host_key.pub +ssh_host_key  server  client  *.o diff --git a/ssh/Makefile b/ssh/Makefile index b7c1d1b..68113cd 100644 --- a/ssh/Makefile +++ b/ssh/Makefile @@ -4,18 +4,19 @@ LDFLAGS = -pthread  CFLAGS += $(shell pkg-config --cflags libssh)  LDFLAGS += $(shell pkg-config --libs libssh) +  .PHONY: all clean  all: server client  clean: -	rm -f server *.o +	rm -f server client *.o -server: server.o ../global.o ../memory.o +server: server.o util.o  	$(CC) -o $@ $^ $(LDFLAGS) -client: client.o ../global.o ../memory.o +client: client.o util.o  	$(CC) -o $@ $^ $(LDFLAGS)  %.o: %.c $(wildcard *.h) diff --git a/ssh/client.c b/ssh/client.c index 9e60033..0adfffb 100644 --- a/ssh/client.c +++ b/ssh/client.c @@ -8,29 +8,9 @@  #include <libssh/callbacks.h>  #include <sys/select.h>  #include <poll.h> -#include "../global.h" +#include "util.h" -static void parse_args(const char *arg, const char **server_host, int *port) { -	const char *ptr = strchr(arg, ':'); -	if (ptr == NULL) { -		*server_host = arg; -	} else { -		size_t length = ptr - arg; -		char *host = malloc(length + 1, char); -		memcpy(host, arg, length); -		host[length] = '\0'; -		*server_host = host; - -		char *endp; -		*port = strtol(ptr + 1, &endp, 10); -		if (endp == ptr || *endp != '\0') { -			fprintf(stderr, "Cannot parse server:port from argument '%s'\n", arg); -			exit(1); -		} -	} -} -  static bool prompt_yn(const char *text) {  	printf("%s", text);  	fflush(stdout); @@ -159,7 +139,10 @@ int main(int argc, char **argv) {  		return 1;  	} -	parse_args(argv[1], &server_host, &port); +	if (!parse_host_port(argv[1], &server_host, &port)) { +		fprintf(stderr, "Cannot parse host:port from argument '%s'\n", argv[1]); +		return 1; +	}  	ssh_session session = ssh_new();  	if (!session) { @@ -169,7 +152,8 @@ int main(int argc, char **argv) {  	const char *ciphers_str = "aes256-gcm@openssh.com,aes256-ctr,aes256-cbc";  	bool procconfig = false; -	bool ok = ssh_options_set(session, SSH_OPTIONS_PROCESS_CONFIG, &procconfig) == SSH_OK; +	bool ok = true; +	ok &= ssh_options_set(session, SSH_OPTIONS_PROCESS_CONFIG, &procconfig) == SSH_OK;  	ok &= ssh_options_set(session, SSH_OPTIONS_USER, "tomsg") == SSH_OK;  	ok &= ssh_options_set(session, SSH_OPTIONS_HOST, server_host) == SSH_OK;  	ok &= ssh_options_set(session, SSH_OPTIONS_PORT, &port) == SSH_OK; @@ -257,7 +241,12 @@ int main(int argc, char **argv) {  	printf("Obtained tomsg subsystem on channel\n"); -	struct session_data *sesdata = malloc(1, struct session_data); +	struct session_data *sesdata = malloc(sizeof(struct session_data)); +	if (!sesdata) { +		fprintf(stderr, "Out of memory (allocating session_data)!\n"); +		return 1; +	} +  	sesdata->session = session;  	sesdata->channel = channel;  	sesdata->should_close = false; diff --git a/ssh/server.c b/ssh/server.c index 5a2b162..ed578c4 100644 --- a/ssh/server.c +++ b/ssh/server.c @@ -7,21 +7,20 @@  #include <errno.h>  #include <assert.h>  #include <pthread.h> -#include <sys/select.h> +#include <netdb.h> +#include <poll.h> +#include <signal.h> +#include <sys/socket.h> +#include <sys/stat.h>  #include <libssh/server.h>  #include <libssh/callbacks.h> -#include "../global.h" +#include "util.h" -#define CHECK(obj_, expr_) do { \ -		if (!(expr_)) { \ -			fprintf(stderr, "libssh error! expression: " #expr_ "\nError description: %s\n", \ -					ssh_get_error((obj_))); \ -			exit(1); \ -		} \ -	} while (0) +#define RESOURCE_ERROR_SLEEP_MS 10000 -void xxd(FILE *stream, const void *buf_, size_t length) { + +static void xxd(FILE *stream, const void *buf_, size_t length) {  	unsigned char *buf = (unsigned char*)buf_;  	for (size_t cursor = 0; cursor < length;) { @@ -47,43 +46,13 @@ void xxd(FILE *stream, const void *buf_, size_t length) {  	}  } -// struct sessions { -//     // Always NULL-terminated -//     ssh_session *list; -//     ssh_session *outlist;  // same length as 'list', contents not managed -//     size_t cap, len; -// }; - -// static struct sessions sessions_make() { -//     size_t cap = 2; -//     return (struct sessions) { -//         .list = malloc(cap, ssh_session), -//         .outlist = malloc(cap, ssh_session), -//         .cap = cap, -//         .len = 0, -//     }; -// } - -// static void sessions_add(struct sessions *ss, ssh_session ses) { -//     if (ss->len + 1 >= ss->cap) { -//         ss->cap *= 2; -//         ss->list = realloc(ss->list, ss->cap, ssh_session); -//         ss->outlist = realloc(ss->outlist, ss->cap, ssh_session); -//     } -//     ss->list[ss->len++] = ses; -//     ss->list[ss->len] = NULL; -// } - -// static void sessions_remove(struct sessions *ss, size_t index) { -//     assert(0 <= index && index < ss->len); -//     if (index < ss->len - 1) ss->list[index] = ss->list[ss->len - 1]; -//     ss->len--; -//     ss->list[ss->len] = NULL; -// } -  static atomic_int g_thread_count;  struct thread_data { +	struct addrinfo backend_addr; + +	int backend_fd; +  	int thread_id;  	ssh_session session;  	ssh_channel channel;  // NULL before channel has been opened @@ -95,7 +64,7 @@ struct thread_data {  ///////// CHANNEL CALLBACKS ////////// -int channel_subsystem_request_cb(ssh_session session, ssh_channel channel, const char *subsystem, void *tdata_) { +static int channel_subsystem_request_cb(ssh_session session, ssh_channel channel, const char *subsystem, void *tdata_) {  	(void)session;  	(void)channel;  	struct thread_data *tdata = (struct thread_data*)tdata_; @@ -108,89 +77,51 @@ int channel_subsystem_request_cb(ssh_session session, ssh_channel channel, const  	}  } -void channel_close_cb(ssh_session session, ssh_channel channel, void *tdata_) { +static void channel_close_cb(ssh_session session, ssh_channel channel, void *tdata_) {  	(void)session; (void)channel;  	struct thread_data *tdata = (struct thread_data*)tdata_;  	printf("[%d] channel close!\n", tdata->thread_id);  } -int channel_shell_request_cb(ssh_session session, ssh_channel channel, void *tdata_) { -	(void)session; (void)channel; -	struct thread_data *tdata = (struct thread_data*)tdata_; -	printf("[%d] shell request, denying\n", tdata->thread_id); -	return 1; -} - -void channel_eof_cb(ssh_session session, ssh_channel channel, void *tdata_) { +static void channel_eof_cb(ssh_session session, ssh_channel channel, void *tdata_) {  	(void)session; (void)channel;  	struct thread_data *tdata = (struct thread_data*)tdata_;  	printf("[%d] eof on channel, setting close flag\n", tdata->thread_id);  	tdata->should_close = true;  } -int channel_data_cb(ssh_session session, ssh_channel channel, void *data, uint32_t len, int is_stderr, void *tdata_) { +static int channel_data_cb(ssh_session session, ssh_channel channel, void *data, uint32_t len, int is_stderr, void *tdata_) {  	(void)is_stderr; (void)data; (void)channel; (void)session;  	struct thread_data *tdata = (struct thread_data*)tdata_;  	printf("[%d] data on channel (length %u):\n", tdata->thread_id, len);  	xxd(stdout, data, len); -	printf("[%d] echoing back!\n", tdata->thread_id); -	if (ssh_channel_write(channel, data, len) == SSH_ERROR) { -		printf("[%d] write to channel failed! Setting close flag\n", tdata->thread_id); -		tdata->should_close = true; +	// printf("[%d] echoing back!\n", tdata->thread_id); +	// if (ssh_channel_write(channel, data, len) == SSH_ERROR) { +	//     printf("[%d] write to channel failed! Setting close flag\n", tdata->thread_id); +	//     tdata->should_close = true; +	// } +	const char *start = (const char*)data; +	const char *cursor = start; +	const char *end = start + len; +	while (cursor < end) { +		ssize_t nw = write(tdata->backend_fd, cursor, end - cursor); +		if (nw < 0) { +			if (errno == EINTR) continue; +			printf("[%d] error writing to backend socket: %s\n", tdata->thread_id, strerror(errno)); +			tdata->should_close = true; +			return cursor - start; +		} +		if (nw == 0) {  // should not happen? +			printf("[%d] write(2) returned 0?\n", tdata->thread_id); +			tdata->should_close = true; +			return cursor - start; +		} +		cursor += nw;  	}  	return len;  } -void channel_signal_cb(ssh_session session, ssh_channel channel, const char *signal, void *tdata_) { -	(void)channel; (void)session; -	printf("[%d] signal SIG%s\n", ((struct thread_data*)tdata_)->thread_id, signal); -} - -void channel_exit_status_cb(ssh_session session, ssh_channel channel, int exit_status, void *tdata_) { -	(void)channel; (void)session; -	printf("[%d] exit status %d\n", ((struct thread_data*)tdata_)->thread_id, exit_status); -} - -void channel_exit_signal_cb(ssh_session session, ssh_channel channel, const char *signal, int core, const char *errmsg, const char *lang, void *tdata_) { -	(void)lang; (void)errmsg; (void)core; (void)channel; (void)session; -	printf("[%d] exit signal %s\n", ((struct thread_data*)tdata_)->thread_id, signal); -} - -int channel_pty_request_cb(ssh_session session, ssh_channel channel, const char *term, int width, int height, int pxwidth, int pwheight, void *tdata_) { -	(void)pwheight; (void)pxwidth; (void)channel; (void)session; -	printf("[%d] pty request (term %s, %dx%d), denying\n", ((struct thread_data*)tdata_)->thread_id, term, width, height); -	return -1; -} - -void channel_auth_agent_req_cb(ssh_session session, ssh_channel channel, void *tdata_) { -	(void)channel; (void)session; -	printf("[%d] auth agent request\n", ((struct thread_data*)tdata_)->thread_id); -} - -void channel_x11_req_cb(ssh_session session, ssh_channel channel, int single_connection, const char *auth_protocol, const char *auth_cookie, uint32_t screen_number, void *tdata_) { -	(void)screen_number; (void)auth_cookie; (void)auth_protocol; (void)single_connection; (void)channel; (void)session; -	printf("[%d] X11 REQUEST WTF\n", ((struct thread_data*)tdata_)->thread_id); -} - -int channel_pty_window_change_cb(ssh_session session, ssh_channel channel, int width, int height, int pxwidth, int pwheight, void *tdata_) { -	(void)pwheight; (void)pxwidth; (void)channel; (void)session; -	printf("[%d] pty window change (%dx%d), denying\n", ((struct thread_data*)tdata_)->thread_id, width, height); -	return -1; -} - -int channel_exec_request_cb(ssh_session session, ssh_channel channel, const char *command, void *tdata_) { -	(void)channel; (void)session; -	printf("[%d] exec request (<%s>), denying\n", ((struct thread_data*)tdata_)->thread_id, command); -	return 1; -} - -int channel_env_request_cb(ssh_session session, ssh_channel channel, const char *env_name, const char *env_value, void *tdata_) { -	(void)channel; (void)session; -	printf("[%d] environment request (<%s> = <%s>), denying\n", ((struct thread_data*)tdata_)->thread_id, env_name, env_value); -	return 1; -} - -int channel_write_wontblock_cb(ssh_session session, ssh_channel channel, size_t bytes, void *tdata_) { +static int channel_write_wontblock_cb(ssh_session session, ssh_channel channel, size_t bytes, void *tdata_) {  	(void)channel; (void)session;  	printf("[%d] write won't block for %zu bytes notification\n", ((struct thread_data*)tdata_)->thread_id, bytes);  	return 0; @@ -198,35 +129,14 @@ int channel_write_wontblock_cb(ssh_session session, ssh_channel channel, size_t  ////////// SERVER CALLBACKS ////////// -int auth_none_cb(ssh_session session, const char *user, void *tdata_) { +static int auth_none_cb(ssh_session session, const char *user, void *tdata_) {  	(void)session;  	struct thread_data *tdata = (struct thread_data*)tdata_;  	printf("[%d] auth none (user <%s>), accepting\n", tdata->thread_id, user);  	return SSH_AUTH_SUCCESS;  } -int auth_password_cb(ssh_session session, const char *user, const char *password, void *tdata_) { -	(void)session; -	struct thread_data *tdata = (struct thread_data*)tdata_; -	printf("[%d] auth password (user <%s> password <%s>), denying\n", tdata->thread_id, user, password); -	return SSH_AUTH_DENIED; -} - -int auth_gssapi_mic_cb(ssh_session session, const char *user, const char *principal, void *tdata_) { -	(void)session; -	struct thread_data *tdata = (struct thread_data*)tdata_; -	printf("[%d] auth gssapi (user <%s> principal <%s>), denying\n", tdata->thread_id, user, principal); -	return SSH_AUTH_DENIED; -} - -int auth_pubkey_cb(ssh_session session, const char *user, struct ssh_key_struct *pubkey, char signature_state, void *tdata_) { -	(void)session; (void)pubkey; (void)signature_state; -	struct thread_data *tdata = (struct thread_data*)tdata_; -	printf("[%d] auth pubkey (user <%s>), denying\n", tdata->thread_id, user); -	return SSH_AUTH_DENIED; -} - -int service_request_cb(ssh_session session, const char *service, void *tdata_) { +static int service_request_cb(ssh_session session, const char *service, void *tdata_) {  	(void)session;  	struct thread_data *tdata = (struct thread_data*)tdata_;  	if (strcmp(service, "ssh-userauth") == 0) { @@ -238,7 +148,7 @@ int service_request_cb(ssh_session session, const char *service, void *tdata_) {  	}  } -ssh_channel chan_open_request_cb(ssh_session session, void *tdata_) { +static ssh_channel chan_open_request_cb(ssh_session session, void *tdata_) {  	struct thread_data *tdata = (struct thread_data*)tdata_;  	if (tdata->channel == NULL) {  		ssh_channel chan = ssh_channel_new(session); @@ -256,61 +166,100 @@ ssh_channel chan_open_request_cb(ssh_session session, void *tdata_) {  	return NULL;  } -// int msg_callback(ssh_session session, ssh_message msg, void *tdata_) { -//     (void)session; -//     struct thread_data *tdata = (struct thread_data*)tdata_; -//     const int tid = tdata->thread_id; +static int backend_data_cb(int fd, int revents, void *tdata_) { +	struct thread_data *tdata = (struct thread_data*)tdata_; + +	if (revents & POLLIN) { +		char buffer[1024]; +		ssize_t nr = read(fd, buffer, sizeof buffer); +		if (nr < 0) { +			if (errno == EINTR) return 0; +			printf("[%d] Error reading from backend socket: %s\n", tdata->thread_id, strerror(errno)); +			tdata->should_close = true; +			return 0; +		} -//     const int subtype = ssh_message_subtype(msg); +		if (nr == 0) {  // eof +			tdata->should_close = true; +			return 0; +		} -//     switch (ssh_message_type(msg)) { -//         case SSH_REQUEST_AUTH: -//             printf("[%d] message callback: type auth (subtype %d)\n", tid, subtype); -//             if (subtype == SSH_AUTH_METHOD_NONE) { -//                 if (ssh_message_auth_reply_success(msg, false) == SSH_OK) { -//                     return 0;  // handled -//                 } else { -//                     printf("[%d]   failed to reply success for auth method none\n", tid); -//                 } -//             } -//             break; +		if (ssh_channel_write(tdata->channel, buffer, nr) != SSH_OK) { +			printf("[%d] Error writing to ssh channel: %s\n", tdata->thread_id, ssh_get_error(tdata->channel)); +			tdata->should_close = true; +			return 0; +		} +	} -//         case SSH_REQUEST_CHANNEL_OPEN: -//             printf("[%d] message callback: type channel open (subtype %d)\n", tid, subtype); -//             if (tdata->channel == NULL && subtype == SSH_CHANNEL_SESSION) { -//                 ssh_channel chan = ssh_message_channel_request_open_reply_accept(msg); -//                 chan_cb.userdata = tdata; +	return 0; +} -//                 if (chan && ssh_set_channel_callbacks(chan, &chan_cb) == SSH_OK) { -//                     tdata->channel = chan; -//                     return 0;  // handled -//                 } else { -//                     if (chan) { -//                         ssh_channel_close(chan); -//                         ssh_channel_free(chan); -//                     } -//                     printf("[%d]   failed to accept channel open request for session\n", tid); -//                 } -//             } -//             break; +// Returns whether successful. +static bool lookup_backend(const char *host, int port, struct addrinfo *dst) { +	char port_string[16]; +	sprintf(port_string, "%d", port); -//         case SSH_REQUEST_CHANNEL: -//             printf("[%d] message callback: type channel (subtype %d)\n", tid, subtype); -//             break; +	struct addrinfo hints; +	memset(&hints, 0, sizeof hints); +	hints.ai_family = AF_UNSPEC; +	hints.ai_socktype = SOCK_STREAM; +	hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG; -//         case SSH_REQUEST_SERVICE: -//             printf("[%d] message callback: type service\n", tid); -//             break; +	struct addrinfo *result; +	int ret = getaddrinfo(host, port_string, &hints, &result); -//         case SSH_REQUEST_GLOBAL: -//             printf("[%d] message callback: type global (subtype %d)\n", tid, subtype); -//             break; -//     } +	if (ret < 0) { +		fprintf(stderr, "Could not resolve backend: %s\n", gai_strerror(ret)); +		return false; +	} -//     return 1;  // not handled -// } +	int last_failure = 0; +	bool success = false; +	for (struct addrinfo *item = result; item; item = item->ai_next) { +		int sock = socket(item->ai_family, item->ai_socktype, item->ai_protocol); +		if (sock == -1) { +			last_failure = errno; +			continue; +		} -void* thread_entry(void *tdata_) { +		int ret = connect(sock, item->ai_addr, item->ai_addrlen); +		close(sock); + +		if (ret == 0) { +			success = true; +			*dst = *item; +			dst->ai_next = NULL; +			break; +		} else { +			last_failure = errno; +		} +	} + +	freeaddrinfo(result); + +	if (success) { +		return true; +	} else { +		fprintf(stderr, "Could not connect to backend: %s\n", strerror(last_failure)); +		return false; +	} +} + +static int connect_backend(const struct thread_data *tdata) { +	const struct addrinfo *item = &tdata->backend_addr; +	int sock = socket(item->ai_family, item->ai_socktype, item->ai_protocol); +	if (sock == -1) return -1; + +	if (connect(sock, item->ai_addr, item->ai_addrlen) == 0) { +		printf("connect_backend: sock=%d\n", sock); +		return sock; +	} + +	close(sock); +	return -1; +} + +static void* thread_entry(void *tdata_) {  	struct thread_data *tdata = (struct thread_data*)tdata_;  	const int tid = tdata->thread_id;  	const ssh_session session = tdata->session; @@ -318,19 +267,11 @@ void* thread_entry(void *tdata_) {  	printf("[%d] Thread started\n", tid); -	// ssh_set_message_callback(session, msg_callback, tdata); -  	memset(&tdata->server_cb, 0, sizeof tdata->server_cb);  	ssh_callbacks_init(&tdata->server_cb);  	tdata->server_cb.userdata = tdata;  	tdata->server_cb.auth_none_function = auth_none_cb;  	tdata->server_cb.channel_open_request_session_function = chan_open_request_cb; -	tdata->server_cb.auth_password_function = auth_password_cb; -	tdata->server_cb.auth_gssapi_mic_function = auth_gssapi_mic_cb; -	tdata->server_cb.auth_pubkey_function = auth_pubkey_cb; -	tdata->server_cb.gssapi_select_oid_function = (void*)0x424242;  // just crash if it attemps to invoke these -	tdata->server_cb.gssapi_accept_sec_ctx_function = (void*)0x424242; -	tdata->server_cb.gssapi_verify_mic_function = (void*)0x424242;  	tdata->server_cb.service_request_function = service_request_cb;  	memset(&tdata->chan_cb, 0, sizeof tdata->chan_cb); @@ -338,18 +279,8 @@ void* thread_entry(void *tdata_) {  	tdata->chan_cb.userdata = tdata;  	tdata->chan_cb.channel_subsystem_request_function = channel_subsystem_request_cb;  	tdata->chan_cb.channel_close_function = channel_close_cb; -	tdata->chan_cb.channel_shell_request_function = channel_shell_request_cb;  	tdata->chan_cb.channel_eof_function = channel_eof_cb;  	tdata->chan_cb.channel_data_function = channel_data_cb; -	tdata->chan_cb.channel_signal_function = channel_signal_cb; -	tdata->chan_cb.channel_exit_status_function = channel_exit_status_cb; -	tdata->chan_cb.channel_exit_signal_function = channel_exit_signal_cb; -	tdata->chan_cb.channel_pty_request_function = channel_pty_request_cb; -	tdata->chan_cb.channel_auth_agent_req_function = channel_auth_agent_req_cb; -	tdata->chan_cb.channel_x11_req_function = channel_x11_req_cb; -	tdata->chan_cb.channel_pty_window_change_function = channel_pty_window_change_cb; -	tdata->chan_cb.channel_exec_request_function = channel_exec_request_cb; -	tdata->chan_cb.channel_env_request_function = channel_env_request_cb;  	tdata->chan_cb.channel_write_wontblock_function = channel_write_wontblock_cb;  	if (ssh_set_server_callbacks(session, &tdata->server_cb) != SSH_OK) { @@ -365,8 +296,18 @@ void* thread_entry(void *tdata_) {  	}  	printf("[%d] Handled key exchange\n", tid); +	tdata->backend_fd = connect_backend(tdata); +	if (tdata->backend_fd == -1) { +		printf("[%d] Failed to connect to backend: %s\n", tid, strerror(errno)); +		goto cleanup; +	} + +	printf("[%d] Connected to backend (fd=%d)\n", tid, tdata->backend_fd); +  	event = ssh_event_new(); -	if (!event || ssh_event_add_session(event, session) != SSH_OK) { +	if (!event +			|| ssh_event_add_session(event, session) != SSH_OK +			|| ssh_event_add_fd(event, tdata->backend_fd, POLLIN, backend_data_cb, tdata) != SSH_OK) {  		printf("[%d] Failed to create ssh event context\n", tid);  		goto cleanup;  	} @@ -398,15 +339,80 @@ cleanup:  	pthread_exit(NULL);  } -int main(void) { +static void generate_key(const char *outfile) { +	ssh_key host_key; +	int ret = ssh_pki_generate(SSH_KEYTYPE_ED25519, 0, &host_key); +	if (ret != SSH_OK) { +		fprintf(stderr, "Key generation failed (ed25519)!\n"); +		exit(1); +	} + +	ret = ssh_pki_export_privkey_file(host_key, NULL, NULL, NULL, outfile); +	if (ret != SSH_OK) { +		fprintf(stderr, "Failed to export generated host key to file '%s'; is that location accessible?\n", outfile); +		exit(1); +	} + +	if (chmod(outfile, S_IRUSR | S_IWUSR) != 0) { +		fprintf(stderr, "Failed to set mode 600 on generated host key file '%s'; this is insecure!\n", outfile); +		exit(1); +	} + +	printf("ed25519 host key generated and written to '%s'.\n", outfile); +} + +static void usage(const char *argv0) { +	fprintf(stderr, +			"Usage: %s <ssh host key file> <ssh port> [backendhost:port]\n" +			"       %s --generate <host key output file>\n" +			"SSH-TCP bridge for tomsg. Accepts SSH connections with a channel for subsystem\n" +			"'tomsg', and matches each SSH connection with a plain TCP connection to the\n" +			"backend server (which defaults to localhost:29536). All data is forwarded\n" +			"transparently.\n" +			"Use the '--generate' form to generate a host key for use in the main invocation\n" +			"form.\n", +			argv0, argv0); +} + +int main(int argc, char **argv) { +	const char *host_key_fname; +	int ssh_port = 2222; +	const char *backend_host = "localhost"; +	int backend_port = 29536; + +	if (argc == 3 && strcmp(argv[1], "--generate") == 0) { +		generate_key(argv[2]); +		return 0; +	} else if (3 <= argc && argc <= 4) { +		host_key_fname = argv[1]; +		char *endp; +		ssh_port = strtol(argv[2], &endp, 10); +		if (argv[2][0] == '\0' || *endp != '\0' || ssh_port < 0 || ssh_port > 65535) { +			fprintf(stderr, "Cannot parse port number from argument '%s'\n", argv[2]); +			return 1; +		} +		if (argc == 4) { +			if (!parse_host_port(argv[3], &backend_host, &backend_port)) { +				fprintf(stderr, "Cannot parse host:port from argument '%s'\n", argv[3]); +				return 1; +			} +		} +	} else { +		usage(argv[0]); +		return 1; +	} + +	// We prefer to detect socket closure through return codes, not signals. +	signal(SIGPIPE, SIG_IGN); +  	if (ssh_init() != SSH_OK) {  		fprintf(stderr, "Could not initialise libssh\n");  		return 1;  	}  	ssh_key host_key; -	if (ssh_pki_import_privkey_file("host_key", NULL, NULL, NULL, &host_key) != SSH_OK) { -		fprintf(stderr, "Failed to read host private key file 'host_key'\n"); +	if (ssh_pki_import_privkey_file(host_key_fname, NULL, NULL, NULL, &host_key) != SSH_OK) { +		fprintf(stderr, "Failed to read host private key file '%s'\n", host_key_fname);  		return 1;  	} @@ -422,23 +428,38 @@ int main(void) {  	ssh_print_hash(SSH_PUBLICKEY_HASH_SHA256, host_key_hash, host_key_hash_length);  	ssh_bind srvbind = ssh_bind_new(); -	CHECK(srvbind, srvbind); +	if (!srvbind) { +		fprintf(stderr, "Failed to create new bind socket\n"); +		return 1; +	}  	bool procconfig = false; -	CHECK(srvbind, ssh_bind_options_set(srvbind, SSH_BIND_OPTIONS_PROCESS_CONFIG, &procconfig) == SSH_OK); -	int port = 2222; -	CHECK(srvbind, ssh_bind_options_set(srvbind, SSH_BIND_OPTIONS_BINDPORT, &port) == SSH_OK); -	CHECK(srvbind, ssh_bind_options_set(srvbind, SSH_BIND_OPTIONS_IMPORT_KEY, host_key) == SSH_OK);  	const char *ciphers_str = "aes256-gcm@openssh.com,aes256-ctr,aes256-cbc"; -	CHECK(srvbind, ssh_bind_options_set(srvbind, SSH_BIND_OPTIONS_CIPHERS_C_S, ciphers_str) == SSH_OK); -	CHECK(srvbind, ssh_bind_options_set(srvbind, SSH_BIND_OPTIONS_CIPHERS_S_C, ciphers_str) == SSH_OK); +	bool ok = true; +	ok &= ssh_bind_options_set(srvbind, SSH_BIND_OPTIONS_PROCESS_CONFIG, &procconfig) == SSH_OK; +	ok &= ssh_bind_options_set(srvbind, SSH_BIND_OPTIONS_BINDPORT, &ssh_port) == SSH_OK; +	ok &= ssh_bind_options_set(srvbind, SSH_BIND_OPTIONS_IMPORT_KEY, host_key) == SSH_OK; +	ok &= ssh_bind_options_set(srvbind, SSH_BIND_OPTIONS_CIPHERS_C_S, ciphers_str) == SSH_OK; +	ok &= ssh_bind_options_set(srvbind, SSH_BIND_OPTIONS_CIPHERS_S_C, ciphers_str) == SSH_OK; -	CHECK(srvbind, ssh_bind_listen(srvbind) == SSH_OK); -	printf("Listening for SSH connections on port %d\n", port); +	if (!ok) { +		fprintf(stderr, "Could not set options on SSH bind socket: %s\n", ssh_get_error(srvbind)); +		return 1; +	} -	// int srvbind_fd = ssh_bind_get_fd(srvbind); +	if (ssh_bind_listen(srvbind) != SSH_OK) { +		fprintf(stderr, "Could not listen on SSH bind socket: %s\n", ssh_get_error(srvbind)); +		return 1; +	} -	// struct sessions sessions = sessions_make(); +	struct addrinfo backend_addr; +	if (!lookup_backend(backend_host, backend_port, &backend_addr)) { +		// Error already printed in lookup_backend +		return 1; +	} + +	printf("Listening for SSH connections on port %d\n", ssh_port); +	printf("Forwarding to backend at %s:%d\n", backend_host, backend_port);  	pthread_attr_t thread_attrs;  	assert(pthread_attr_init(&thread_attrs) == 0); @@ -448,32 +469,48 @@ int main(void) {  	atomic_store(&g_thread_count, 0);  	while (true) { -		// fd_set inset; -		// FD_ZERO(&inset); -		// FD_SET(srvbind_fd, &inset); - -		// int ret = ssh_select(sessions.list, sessions.outlist, srvbind_fd + 1, &inset, NULL); -		// if (ret == SSH_EINTR) continue; -		// if (ret == SSH_ERROR) { -		//     fprintf(stderr, "ssh_select reported error!\n"); -		//     return 1; -		// } -  		ssh_session session = ssh_new(); -		CHECK(session, session); +		if (!session) { +			fprintf(stderr, "ERROR: Cannot create new SSH session object!\n"); +			usleep(1000 * RESOURCE_ERROR_SLEEP_MS); +			continue; +		} + +		if (ssh_bind_accept(srvbind, session) != SSH_OK) { +			fprintf(stderr, "ERROR: Cannot accept on bind socket: %s", ssh_get_error(srvbind)); +			ssh_free(session); +			usleep(1000 * RESOURCE_ERROR_SLEEP_MS); +			continue; +		} -		CHECK(srvbind, ssh_bind_accept(srvbind, session) == SSH_OK);  		int num_existing_threads = atomic_fetch_add(&g_thread_count, 1);  		printf("Accepted connection, spinning up thread (currently %d threads)\n",  				num_existing_threads + 1); -		struct thread_data *tdata = calloc(1, struct thread_data); +		struct thread_data *tdata = calloc(1, sizeof(struct thread_data)); +		if (!tdata) { +			fprintf(stderr, "ERROR: Out of memory, cannot allocate thread_data!\n"); +			ssh_disconnect(session); +			ssh_free(session); +			usleep(1000 * RESOURCE_ERROR_SLEEP_MS); +			continue; +		} + +		tdata->backend_addr = backend_addr; +		tdata->backend_fd = -1;  		tdata->thread_id = next_thread_id++;  		tdata->session = session;  		tdata->channel = NULL;  		tdata->should_close = false;  		pthread_t thread; -		assert(pthread_create(&thread, &thread_attrs, thread_entry, tdata) == 0); +		if (pthread_create(&thread, &thread_attrs, thread_entry, tdata) != 0) { +			fprintf(stderr, "ERROR: Could not spawn thread: %s!\n", strerror(errno)); +			free(tdata); +			ssh_disconnect(session); +			ssh_free(session); +			usleep(1000 * RESOURCE_ERROR_SLEEP_MS); +			continue; +		}  	}  } diff --git a/ssh/util.c b/ssh/util.c new file mode 100644 index 0000000..a8f3c41 --- /dev/null +++ b/ssh/util.c @@ -0,0 +1,29 @@ +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include "util.h" + + +bool parse_host_port(const char *arg, const char **server_host, int *port) { +	const char *ptr = strchr(arg, ':'); +	if (ptr == NULL) { +		*server_host = arg; +	} else { +		size_t length = ptr - arg; +		char *host = malloc(length + 1); +		if (!host) { +			fprintf(stderr, "Cannot allocate memory!\n"); +			exit(1); +		} +		memcpy(host, arg, length); +		host[length] = '\0'; +		*server_host = host; + +		char *endp; +		*port = strtol(ptr + 1, &endp, 10); +		if (endp == ptr || *endp != '\0') { +			return false; +		} +	} +	return true; +} diff --git a/ssh/util.h b/ssh/util.h new file mode 100644 index 0000000..0ca7d45 --- /dev/null +++ b/ssh/util.h @@ -0,0 +1,7 @@ +#pragma once + +#include <stdbool.h> + + +// Returns whether successful. +bool parse_host_port(const char *arg, const char **server_host, int *port); | 
