#include #include #include #include #include #include #include #include #include #include #include #include #include "util.h" #include "sshnc.h" const char *g_hostkey_hash_preload; static int writeall(int sock, const char *buffer, size_t length) { size_t cursor = 0; while (cursor < length) { ssize_t nw = write(sock, buffer + cursor, length - cursor); if (nw < 0) { if (errno == EINTR) continue; return -1; // just dump all error conditions together } if (nw == 0) return -1; // eof is also an error condition cursor += nw; } return 0; } static bool hostkey_checker(const unsigned char *hash, size_t length, void *userdata) { (void)userdata; const char* showed = sshnc_print_hash(hash, length); bool ok = strcmp(showed, g_hostkey_hash_preload) == 0; if (!ok) { fprintf(stderr, "Rejecting host key '%s'!\n", showed); } return ok; } struct thread_data { const char *server_host; int server_port; int client_sock; }; static void* proxy_thread_entry(void *thread_data_) { struct thread_data *thread_data = thread_data_; int client_sock = thread_data->client_sock; struct sshnc_client *client = NULL; enum sshnc_retval ret = sshnc_connect( thread_data->server_host, thread_data->server_port, "tomsg", "tomsg", hostkey_checker, NULL, &client); if (ret != SSHNC_OK) { fprintf(stderr, "Could not connect over SSH: %s\n", sshnc_strerror(ret)); goto cleanup; } struct pollfd polls[2]; polls[0] = (struct pollfd){ .fd = sshnc_poll_fd(client), .events = POLLIN, }; polls[1] = (struct pollfd){ .fd = client_sock, .events = POLLIN, }; while (true) { int pollret = poll(polls, sizeof polls / sizeof polls[0], -1); if (pollret < 0) { perror("poll"); goto cleanup; } if (polls[0].revents & (POLLERR | POLLNVAL)) { fprintf(stderr, "Error reading from SSH socket\n"); goto cleanup; } if (polls[1].revents & (POLLERR | POLLNVAL)) { // Assume downstream has been closed break; } if (polls[0].revents & (POLLIN | POLLHUP)) { // Get all data currently available bool should_exit = false; while (true) { char buffer[4096]; size_t length = 0; ret = sshnc_maybe_recv(client, sizeof buffer, buffer, &length); if (ret == SSHNC_AGAIN) break; if (ret == SSHNC_OK) { if (writeall(client_sock, buffer, length) < 0) { // Error writing back to downstream, let's just close should_exit = true; break; } } else if (ret == SSHNC_EOF) { should_exit = true; break; } else { fprintf(stderr, "Error on SSH recv: %s\n", sshnc_strerror(ret)); goto cleanup; } } if (should_exit) break; } if (polls[1].revents & (POLLIN | POLLHUP)) { char buffer[4096]; ssize_t nr = read(client_sock, buffer, sizeof buffer); if (nr < 0) { if (errno == ECONNRESET) break; perror("Error reading from downstream"); goto cleanup; } if (nr == 0) { break; } ret = sshnc_send(client, buffer, nr); if (ret == SSHNC_EOF) { break; } else if (ret != SSHNC_OK) { fprintf(stderr, "Error on SSH send: %s\n", sshnc_strerror(ret)); goto cleanup; } } } cleanup: close(client_sock); if (client) sshnc_close(client); return NULL; } int main(int argc, char **argv) { const char *server_host = NULL; int server_port = 2222; if (argc != 4) { fprintf(stderr, "Usage: %s \n", argv[0]); fprintf(stderr, "If :port is not specified for the backend server, %d is assumed.\n", server_port); fprintf(stderr, "Will listen for connections to proxy on .\n"); fprintf(stderr, "The hostkey hash should be in the form 'SHA256:base64'.\n"); return 1; } if (!parse_host_port(argv[1], &server_host, &server_port)) { fprintf(stderr, "Cannot parse host:port from argument '%s'\n", argv[1]); return 1; } int bindport = strtol(argv[2], NULL, 10); if (bindport == 0) { fprintf(stderr, "Invalid listen port given\n"); return 1; } if (memcmp(argv[3], "SHA256:", 7) != 0) { fprintf(stderr, "Preloaded host key in invalid format\n"); return 1; } g_hostkey_hash_preload = argv[3]; int bindsock = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, IPPROTO_TCP); if (bindsock < 0) { perror("socket(bind)"); return 1; } const int yes = 1; if (setsockopt(bindsock, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof yes) < 0) { fprintf(stderr, "Warning: could not set SO_REUSEADDR\n"); } struct sockaddr_in bindaddr; memset(&bindaddr, 0, sizeof bindaddr); bindaddr.sin_family = AF_INET; bindaddr.sin_port = htons(bindport); bindaddr.sin_addr.s_addr = htonl(INADDR_ANY); if (bind(bindsock, (const struct sockaddr*)&bindaddr, sizeof bindaddr) < 0) { perror("bind"); return 1; } if (listen(bindsock, 5) < 0) { perror("listen"); return 1; } pthread_attr_t thread_attr; assert(pthread_attr_init(&thread_attr) == 0); assert(pthread_attr_setdetachstate(&thread_attr, PTHREAD_CREATE_DETACHED) == 0); while (true) { int sock = accept(bindsock, NULL, NULL); if (sock < 0) { perror("accept"); return 1; } struct thread_data *data = malloc(sizeof(struct thread_data)); if (!data) { // OOM, reject this connection and try to go on close(sock); usleep(500000); continue; } data->server_host = server_host; data->server_port = server_port; data->client_sock = sock; pthread_t thread; int ret = pthread_create(&thread, &thread_attr, proxy_thread_entry, data); if (ret < 0) { // Something happened; reject this connection and try to go on free(data); close(sock); perror("pthread_create"); usleep(500000); continue; } } }