diff options
Diffstat (limited to 'ssh/client_proxy.c')
-rw-r--r-- | ssh/client_proxy.c | 225 |
1 files changed, 225 insertions, 0 deletions
diff --git a/ssh/client_proxy.c b/ssh/client_proxy.c new file mode 100644 index 0000000..58f40cb --- /dev/null +++ b/ssh/client_proxy.c @@ -0,0 +1,225 @@ +#include <stdio.h> +#include <stdbool.h> +#include <stdlib.h> +#include <string.h> +#include <ctype.h> +#include <poll.h> +#include <errno.h> +#include <unistd.h> +#include <pthread.h> +#include <sys/socket.h> +#include <arpa/inet.h> +#include <assert.h> +#include "util.h" +#include "sshnc.h" + + +const char *g_hostkey_hash_preload; + +static int writeall(int sock, const char *buffer, size_t length) { + size_t cursor = 0; + while (cursor < length) { + ssize_t nw = write(sock, buffer + cursor, length - cursor); + if (nw < 0) { + if (errno == EINTR) continue; + return -1; // just dump all error conditions together + } + if (nw == 0) return -1; // eof is also an error condition + cursor += nw; + } + return 0; +} + +static bool hostkey_checker(const unsigned char *hash, size_t length, void *userdata) { + (void)userdata; + const char* showed = sshnc_print_hash(hash, length); + + bool ok = strcmp(showed, g_hostkey_hash_preload) == 0; + if (!ok) { + fprintf(stderr, "Rejecting host key '%s'!\n", showed); + } + return ok; +} + +struct thread_data { + const char *server_host; + int server_port; + int client_sock; +}; + +static void* proxy_thread_entry(void *thread_data_) { + struct thread_data *thread_data = thread_data_; + int client_sock = thread_data->client_sock; + + struct sshnc_client *client; + enum sshnc_retval ret = sshnc_connect( + thread_data->server_host, thread_data->server_port, "tomsg", "tomsg", + hostkey_checker, NULL, &client); + + if (ret != SSHNC_OK) { + fprintf(stderr, "Could not connect over SSH: %s\n", sshnc_strerror(ret)); + return NULL; + } + + struct pollfd polls[2]; + polls[0] = (struct pollfd){ + .fd = sshnc_poll_fd(client), + .events = POLLIN, + }; + polls[1] = (struct pollfd){ + .fd = client_sock, + .events = POLLIN, + }; + + while (true) { + int pollret = poll(polls, sizeof polls / sizeof polls[0], -1); + if (pollret < 0) { + perror("poll"); + goto cleanup; + } + + if (polls[0].revents & (POLLERR | POLLNVAL)) { + fprintf(stderr, "Error reading from SSH socket\n"); + goto cleanup; + } + if (polls[1].revents & (POLLERR | POLLNVAL)) { + // Assume downstream has been closed + break; + } + + if (polls[0].revents & (POLLIN | POLLHUP)) { + char buffer[4096]; + size_t length = 0; + ret = sshnc_maybe_recv(client, sizeof buffer, buffer, &length); + if (ret == SSHNC_OK) { + if (writeall(client_sock, buffer, length) < 0) { + // Error writing back to downstream, let's just close + break; + } + } else if (ret == SSHNC_EOF) { + break; + } else if (ret != SSHNC_AGAIN) { + fprintf(stderr, "Error on SSH recv: %s\n", sshnc_strerror(ret)); + goto cleanup; + } + } + + if (polls[1].revents & (POLLIN | POLLHUP)) { + char buffer[4096]; + ssize_t nr = read(client_sock, buffer, sizeof buffer); + if (nr < 0) { + if (errno == ECONNRESET) break; + perror("Error reading from downstream"); + goto cleanup; + } + if (nr == 0) { + break; + } + + ret = sshnc_send(client, buffer, nr); + if (ret == SSHNC_EOF) { + break; + } else if (ret != SSHNC_OK) { + fprintf(stderr, "Error on SSH send: %s\n", sshnc_strerror(ret)); + goto cleanup; + } + } + } + +cleanup: + close(client_sock); + sshnc_close(client); + return NULL; +} + +int main(int argc, char **argv) { + const char *server_host = NULL; + int server_port = 2222; + + if (argc != 4) { + fprintf(stderr, "Usage: %s <server[:port]> <listen_port> <hostkey_hash_preload>\n", argv[0]); + fprintf(stderr, "If :port is not specified for the backend server, %d is assumed.\n", server_port); + fprintf(stderr, "Will listen for connections to proxy on <listen_port>.\n"); + fprintf(stderr, "The hostkey hash should be in the form 'SHA256:base64'.\n"); + return 1; + } + + if (!parse_host_port(argv[1], &server_host, &server_port)) { + fprintf(stderr, "Cannot parse host:port from argument '%s'\n", argv[1]); + return 1; + } + + int bindport = strtol(argv[2], NULL, 10); + if (bindport == 0) { + fprintf(stderr, "Invalid listen port given\n"); + return 1; + } + + if (memcmp(argv[3], "SHA256:", 7) != 0) { + fprintf(stderr, "Preloaded host key in invalid format\n"); + return 1; + } + + g_hostkey_hash_preload = argv[3]; + + int bindsock = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, IPPROTO_TCP); + if (bindsock < 0) { + perror("socket(bind)"); + return 1; + } + + const int yes = 1; + if (setsockopt(bindsock, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof yes) < 0) { + fprintf(stderr, "Warning: could not set SO_REUSEADDR\n"); + } + + struct sockaddr_in bindaddr; + memset(&bindaddr, 0, sizeof bindaddr); + bindaddr.sin_family = AF_INET; + bindaddr.sin_port = htons(bindport); + bindaddr.sin_addr.s_addr = htonl(INADDR_ANY); + + if (bind(bindsock, (const struct sockaddr*)&bindaddr, sizeof bindaddr) < 0) { + perror("bind"); + return 1; + } + + if (listen(bindsock, 5) < 0) { + perror("listen"); + return 1; + } + + pthread_attr_t thread_attr; + assert(pthread_attr_init(&thread_attr) == 0); + assert(pthread_attr_setdetachstate(&thread_attr, PTHREAD_CREATE_DETACHED) == 0); + + while (true) { + int sock = accept(bindsock, NULL, NULL); + if (sock < 0) { + perror("accept"); + return 1; + } + + struct thread_data *data = malloc(sizeof(struct thread_data)); + if (!data) { // OOM, reject this connection and try to go on + close(sock); + usleep(500000); + continue; + } + + data->server_host = server_host; + data->server_port = server_port; + data->client_sock = sock; + + pthread_t thread; + int ret = pthread_create(&thread, &thread_attr, proxy_thread_entry, data); + if (ret < 0) { + // Something happened; reject this connection and try to go on + free(data); + close(sock); + perror("pthread_create"); + usleep(500000); + continue; + } + } +} |