aboutsummaryrefslogtreecommitdiff
path: root/ssh/client_proxy.c
diff options
context:
space:
mode:
Diffstat (limited to 'ssh/client_proxy.c')
-rw-r--r--ssh/client_proxy.c225
1 files changed, 225 insertions, 0 deletions
diff --git a/ssh/client_proxy.c b/ssh/client_proxy.c
new file mode 100644
index 0000000..58f40cb
--- /dev/null
+++ b/ssh/client_proxy.c
@@ -0,0 +1,225 @@
+#include <stdio.h>
+#include <stdbool.h>
+#include <stdlib.h>
+#include <string.h>
+#include <ctype.h>
+#include <poll.h>
+#include <errno.h>
+#include <unistd.h>
+#include <pthread.h>
+#include <sys/socket.h>
+#include <arpa/inet.h>
+#include <assert.h>
+#include "util.h"
+#include "sshnc.h"
+
+
+const char *g_hostkey_hash_preload;
+
+static int writeall(int sock, const char *buffer, size_t length) {
+ size_t cursor = 0;
+ while (cursor < length) {
+ ssize_t nw = write(sock, buffer + cursor, length - cursor);
+ if (nw < 0) {
+ if (errno == EINTR) continue;
+ return -1; // just dump all error conditions together
+ }
+ if (nw == 0) return -1; // eof is also an error condition
+ cursor += nw;
+ }
+ return 0;
+}
+
+static bool hostkey_checker(const unsigned char *hash, size_t length, void *userdata) {
+ (void)userdata;
+ const char* showed = sshnc_print_hash(hash, length);
+
+ bool ok = strcmp(showed, g_hostkey_hash_preload) == 0;
+ if (!ok) {
+ fprintf(stderr, "Rejecting host key '%s'!\n", showed);
+ }
+ return ok;
+}
+
+struct thread_data {
+ const char *server_host;
+ int server_port;
+ int client_sock;
+};
+
+static void* proxy_thread_entry(void *thread_data_) {
+ struct thread_data *thread_data = thread_data_;
+ int client_sock = thread_data->client_sock;
+
+ struct sshnc_client *client;
+ enum sshnc_retval ret = sshnc_connect(
+ thread_data->server_host, thread_data->server_port, "tomsg", "tomsg",
+ hostkey_checker, NULL, &client);
+
+ if (ret != SSHNC_OK) {
+ fprintf(stderr, "Could not connect over SSH: %s\n", sshnc_strerror(ret));
+ return NULL;
+ }
+
+ struct pollfd polls[2];
+ polls[0] = (struct pollfd){
+ .fd = sshnc_poll_fd(client),
+ .events = POLLIN,
+ };
+ polls[1] = (struct pollfd){
+ .fd = client_sock,
+ .events = POLLIN,
+ };
+
+ while (true) {
+ int pollret = poll(polls, sizeof polls / sizeof polls[0], -1);
+ if (pollret < 0) {
+ perror("poll");
+ goto cleanup;
+ }
+
+ if (polls[0].revents & (POLLERR | POLLNVAL)) {
+ fprintf(stderr, "Error reading from SSH socket\n");
+ goto cleanup;
+ }
+ if (polls[1].revents & (POLLERR | POLLNVAL)) {
+ // Assume downstream has been closed
+ break;
+ }
+
+ if (polls[0].revents & (POLLIN | POLLHUP)) {
+ char buffer[4096];
+ size_t length = 0;
+ ret = sshnc_maybe_recv(client, sizeof buffer, buffer, &length);
+ if (ret == SSHNC_OK) {
+ if (writeall(client_sock, buffer, length) < 0) {
+ // Error writing back to downstream, let's just close
+ break;
+ }
+ } else if (ret == SSHNC_EOF) {
+ break;
+ } else if (ret != SSHNC_AGAIN) {
+ fprintf(stderr, "Error on SSH recv: %s\n", sshnc_strerror(ret));
+ goto cleanup;
+ }
+ }
+
+ if (polls[1].revents & (POLLIN | POLLHUP)) {
+ char buffer[4096];
+ ssize_t nr = read(client_sock, buffer, sizeof buffer);
+ if (nr < 0) {
+ if (errno == ECONNRESET) break;
+ perror("Error reading from downstream");
+ goto cleanup;
+ }
+ if (nr == 0) {
+ break;
+ }
+
+ ret = sshnc_send(client, buffer, nr);
+ if (ret == SSHNC_EOF) {
+ break;
+ } else if (ret != SSHNC_OK) {
+ fprintf(stderr, "Error on SSH send: %s\n", sshnc_strerror(ret));
+ goto cleanup;
+ }
+ }
+ }
+
+cleanup:
+ close(client_sock);
+ sshnc_close(client);
+ return NULL;
+}
+
+int main(int argc, char **argv) {
+ const char *server_host = NULL;
+ int server_port = 2222;
+
+ if (argc != 4) {
+ fprintf(stderr, "Usage: %s <server[:port]> <listen_port> <hostkey_hash_preload>\n", argv[0]);
+ fprintf(stderr, "If :port is not specified for the backend server, %d is assumed.\n", server_port);
+ fprintf(stderr, "Will listen for connections to proxy on <listen_port>.\n");
+ fprintf(stderr, "The hostkey hash should be in the form 'SHA256:base64'.\n");
+ return 1;
+ }
+
+ if (!parse_host_port(argv[1], &server_host, &server_port)) {
+ fprintf(stderr, "Cannot parse host:port from argument '%s'\n", argv[1]);
+ return 1;
+ }
+
+ int bindport = strtol(argv[2], NULL, 10);
+ if (bindport == 0) {
+ fprintf(stderr, "Invalid listen port given\n");
+ return 1;
+ }
+
+ if (memcmp(argv[3], "SHA256:", 7) != 0) {
+ fprintf(stderr, "Preloaded host key in invalid format\n");
+ return 1;
+ }
+
+ g_hostkey_hash_preload = argv[3];
+
+ int bindsock = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, IPPROTO_TCP);
+ if (bindsock < 0) {
+ perror("socket(bind)");
+ return 1;
+ }
+
+ const int yes = 1;
+ if (setsockopt(bindsock, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof yes) < 0) {
+ fprintf(stderr, "Warning: could not set SO_REUSEADDR\n");
+ }
+
+ struct sockaddr_in bindaddr;
+ memset(&bindaddr, 0, sizeof bindaddr);
+ bindaddr.sin_family = AF_INET;
+ bindaddr.sin_port = htons(bindport);
+ bindaddr.sin_addr.s_addr = htonl(INADDR_ANY);
+
+ if (bind(bindsock, (const struct sockaddr*)&bindaddr, sizeof bindaddr) < 0) {
+ perror("bind");
+ return 1;
+ }
+
+ if (listen(bindsock, 5) < 0) {
+ perror("listen");
+ return 1;
+ }
+
+ pthread_attr_t thread_attr;
+ assert(pthread_attr_init(&thread_attr) == 0);
+ assert(pthread_attr_setdetachstate(&thread_attr, PTHREAD_CREATE_DETACHED) == 0);
+
+ while (true) {
+ int sock = accept(bindsock, NULL, NULL);
+ if (sock < 0) {
+ perror("accept");
+ return 1;
+ }
+
+ struct thread_data *data = malloc(sizeof(struct thread_data));
+ if (!data) { // OOM, reject this connection and try to go on
+ close(sock);
+ usleep(500000);
+ continue;
+ }
+
+ data->server_host = server_host;
+ data->server_port = server_port;
+ data->client_sock = sock;
+
+ pthread_t thread;
+ int ret = pthread_create(&thread, &thread_attr, proxy_thread_entry, data);
+ if (ret < 0) {
+ // Something happened; reject this connection and try to go on
+ free(data);
+ close(sock);
+ perror("pthread_create");
+ usleep(500000);
+ continue;
+ }
+ }
+}