| /* |
| * Copyright 2015-2017 Rivoreo |
| * |
| * This program is free software; you can redistribute it and/or modify it |
| * under the terms of the GNU General Public License as published by the |
| * Free Software Foundation, either version 2 of the License, or (at your |
| * option) any later version. |
| * |
| * This program is distributed in the hope that it will be useful, but WITHOUT |
| * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or |
| * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for |
| * more details. |
| */ |
| |
| #include "common.h" |
| #include <sys/types.h> |
| #include <sys/socket.h> |
| #include <netinet/in.h> |
| #include <arpa/inet.h> |
| #include <sys/wait.h> |
| #include <unistd.h> |
| #include "syncrw.h" |
| #include <signal.h> |
| #include <string.h> |
| #include <stdlib.h> |
| #include <stdio.h> |
| #include <errno.h> |
| #include <time.h> |
| |
| extern int forward(int, int); |
| |
| static int alrm = 0; |
| |
| static void signal_handler(int sig) { |
| switch(sig) { |
| case SIGCHLD: |
| //while(wait(NULL) < 0) { |
| while(waitpid(-1, NULL, WNOHANG) < 0) { |
| if(errno == EINTR) continue; |
| perror("waitpid"); |
| return; |
| } |
| return; |
| case SIGALRM: |
| alrm = 1; |
| return; |
| } |
| } |
| |
| int main(int argc, char **argv) { |
| const char *name = strchr(argv[0], '/'); |
| if(name) name++; else name = argv[0]; |
| |
| if(argc != 3) { |
| fprintf(stderr, "Usage: %s <forward-host>:<forward-port> <rpf-server>:<server-port>\n", name); |
| return -1; |
| } |
| |
| struct sigaction orig_act; |
| struct sigaction act = { .sa_handler = SIG_IGN }; |
| if(sigaction(SIGPIPE, &act, NULL) < 0) { |
| perror("sigaction"); |
| return 1; |
| } |
| act.sa_handler = signal_handler; |
| sigaction(SIGCHLD, &act, NULL); |
| |
| struct in_addr forward_host_address, rpf_server_address; |
| |
| char *col = strchr(argv[1], ':'); |
| if(!col) { |
| fprintf(stderr, "%s: Missing ':' in forward-host\n", argv[0]); |
| return -1; |
| } |
| int forward_host_port = atoi(col + 1); |
| if(forward_host_port < 1) { |
| fprintf(stderr, "%s: Wrong port of forward host; port number must greater that 0\n", argv[0]); |
| return -1; |
| } |
| *col = 0; |
| if(inet_pton(AF_INET, argv[1], &forward_host_address) < 1) { |
| fprintf(stderr, "%s: Invalid forward host address '%s'\n", argv[0], argv[1]); |
| return -1; |
| } |
| |
| col = strchr(argv[2], ':'); |
| if(!col) { |
| fprintf(stderr, "%s: Missing ':' in rpf-server\n", argv[0]); |
| return -1; |
| } |
| int rpf_server_port = atoi(col + 1); |
| if(rpf_server_port < 1) { |
| fprintf(stderr, "%s: Wrong port of rpf server; port number must greater that 0\n", argv[0]); |
| return -1; |
| } |
| *col = 0; |
| if(inet_pton(AF_INET, argv[2], &rpf_server_address) < 1) { |
| fprintf(stderr, "%s: Invalid rpf server address '%s'\n", argv[0], argv[1]); |
| return -1; |
| } |
| |
| fprintf(stderr, "\nTCP Reverse Port Forwarding Client - %s\n" |
| RIVOREO_COPYRIGHT_NOTICE "\n" |
| LICENSE_INFORMATION "\n\n", name); |
| |
| int fd = -1; |
| fd_set fdset; |
| struct timeval read_timeout; |
| struct sockaddr_in server_addr = { |
| .sin_family = AF_INET, |
| .sin_addr = rpf_server_address, |
| .sin_port = htons(rpf_server_port) |
| }; |
| while(1) { |
| //if(fd != -1) close(fd); |
| fd = socket(AF_INET, SOCK_STREAM, 0); |
| if(fd == -1) { |
| perror("socket"); |
| return 1; |
| } |
| static const int keepalive = 1; |
| if(setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &keepalive, sizeof keepalive) < 0) perror("setsockopt"); |
| static const struct timeval sendtimeout = { .tv_sec = 10 }; |
| if(setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &sendtimeout, sizeof sendtimeout) < 0) perror("setsockopt"); |
| fprintf(stderr, "%s: Connecting to rpf server %s\n", argv[0], argv[2]); |
| sigaction(SIGALRM, &act, &orig_act); |
| alarm(20); |
| alrm = 0; |
| while(connect(fd, (struct sockaddr *)&server_addr, sizeof server_addr) < 0) { |
| if(errno == EINTR || errno == EINPROGRESS || errno == EALREADY) { |
| if(alrm) { |
| alarm(0); |
| sigaction(SIGALRM, &orig_act, NULL); |
| alrm = 0; |
| fprintf(stderr, "%s: Connection to rpf server %s timed out\n", argv[0], argv[2]); |
| goto close_and_recreate_socket; |
| } |
| continue; |
| } |
| perror("connect: rpf server"); |
| alarm(0); |
| sigaction(SIGALRM, &orig_act, NULL); |
| //alrm = 0; |
| sleep(10); |
| goto close_and_recreate_socket; |
| } |
| alarm(0); |
| sigaction(SIGALRM, &orig_act, NULL); |
| |
| int keep_alive_sent = 0; |
| |
| while(1) { |
| char magic[sizeof MAGIC - 1]; |
| FD_ZERO(&fdset); |
| FD_SET(fd, &fdset); |
| read_timeout.tv_sec = 120; |
| read_timeout.tv_usec = 0; |
| int n = select(fd + 1, &fdset, NULL, NULL, &read_timeout); |
| if(n < 0) { |
| if(errno == EINTR) continue; |
| perror("select"); |
| sleep(1); |
| break; |
| } else if(!n) { |
| #if 0 |
| if(keep_alive_sent) { |
| fprintf(stderr, "%s: Waiting for data from server timed out after keep alive is sent\n", argv[0]); |
| break; |
| } |
| #endif |
| uint16_t packet_type = htons(KEEP_ALIVE); |
| if(sync_write(fd, MAGIC, sizeof MAGIC - 1) < 0 || sync_write(fd, &packet_type, sizeof packet_type) < 0) { |
| perror("write"); |
| sleep(1); |
| break; |
| } |
| keep_alive_sent++; |
| continue; |
| } |
| int s = sync_read(fd, magic, sizeof magic); |
| //fprintf(stderr, "%s debug: " __FILE__ ":%d: s = %d\n", argv[0], __LINE__, s); |
| if(s < 0) { |
| perror("read"); |
| sleep(1); |
| break; |
| } |
| if(s < sizeof magic) { |
| fprintf(stderr, "%s: Server closed connection\n", argv[0]); |
| break; |
| } |
| if(memcmp(magic, MAGIC, sizeof magic)) { |
| fprintf(stderr, "%s: Protocol mismatch\n", argv[0]); |
| sleep(1); |
| break; |
| } |
| uint16_t packet_type; |
| s = sync_read(fd, &packet_type, sizeof packet_type); |
| //fprintf(stderr, "%s debug: " __FILE__ ":%d: s = %d\n", argv[0], __LINE__, s); |
| if(s < 0) { |
| perror("read"); |
| sleep(1); |
| break; |
| } |
| if(s < sizeof packet_type) { |
| fprintf(stderr, "%s: Server closed connection\n", argv[0]); |
| break; |
| } |
| //fprintf(stderr, "%s debug: " __FILE__ ":%d: packet_type = %hu\n", argv[0], __LINE__, packet_type); |
| packet_type = ntohs(packet_type); |
| //fprintf(stderr, "%s debug: " __FILE__ ":%d: packet_type = %hu\n", argv[0], __LINE__, packet_type); |
| if(packet_type == KEEP_ALIVE) { |
| fprintf(stderr, "%s: Keep alive from server\n", argv[0]); |
| packet_type = htons(KEEP_ALIVE_REPLY); |
| if(sync_write(fd, MAGIC, sizeof MAGIC - 1) < 0 || sync_write(fd, &packet_type, sizeof packet_type) < 0) { |
| perror("write"); |
| sleep(1); |
| break; |
| } |
| continue; |
| } |
| if(packet_type == KEEP_ALIVE_REPLY) { |
| if(!keep_alive_sent) { |
| fprintf(stderr, "%s: Unexpected keep alive reply from server, disconnecting\n", argv[0]); |
| sleep(2); |
| break; |
| } |
| keep_alive_sent--; |
| continue; |
| } |
| if(packet_type == NEW_CONNECTION) { |
| //fprintf(stderr, "%s debug: " __FILE__ ":%d: server has got a new connection\n", argv[0], __LINE__); |
| struct new_connection_packet packet; |
| s = sync_read(fd, &packet.len, sizeof packet.len); |
| //fprintf(stderr, "%s debug: " __FILE__ ":%d: s = %d\n", argv[0], __LINE__, s); |
| if(s < 0) { |
| perror("read"); |
| sleep(1); |
| break; |
| } |
| if(s < sizeof packet.len) break; |
| size_t remain_len = ntohl(packet.len) - sizeof packet.len; |
| //fprintf(stderr, "%s debug: " __FILE__ ":%d: remain_len = %zu\n", argv[0], __LINE__, remain_len); |
| if(remain_len < INITIAL_NEW_CONN_PACKET_LENGTH - sizeof packet.len) { |
| fprintf(stderr, "%s: Recevied packet too short, disconnecting\n", argv[0]); |
| break; |
| } |
| size_t extra_len = 0; |
| if(remain_len > sizeof packet - sizeof packet.len) { |
| fprintf(stderr, "%s: Recevied packet too long, skipping extra parts\n", argv[0]); |
| extra_len = remain_len - (sizeof packet - sizeof packet.len); |
| remain_len = sizeof packet - sizeof packet.len; |
| } |
| |
| s = sync_read(fd, (char *)&packet + sizeof packet.len, remain_len); |
| //fprintf(stderr, "%s debug: " __FILE__ ":%d: s = %d\n", argv[0], __LINE__, s); |
| if(s < remain_len) { |
| if(s < 0) perror("read"); |
| sleep(1); |
| break; |
| } |
| //struct sockaddr_in *addr = (struct sockaddr_in *)&packet.addr; |
| //fprintf(stderr, "%s: New connection recevied on server time %s from %s port %d\n", |
| // argv[0], ctime(&packet.tv.tv_sec), inet_ntoa(addr->sin_addr), ntohs(addr->sin_port)); |
| fprintf(stderr, "%s: New connection recevied on server %s from %s port %d\n", |
| argv[0], argv[2], inet_ntoa(packet.address), ntohs(packet.port)); |
| |
| if(extra_len) { |
| //size_t chunk_len = MIN(extra_len, 16384); |
| size_t chunk_len = MIN(extra_len, 4096); |
| char buffer[chunk_len]; |
| do { |
| int len = MIN(chunk_len, extra_len); |
| s = sync_read(fd, buffer, len); |
| //fprintf(stderr, "%s debug: " __FILE__ ":%d: s = %d\n", argv[0], __LINE__, s); |
| if(s < 0) { |
| perror("read"); |
| sleep(1); |
| break; |
| } |
| if(s < len) goto close_and_recreate_socket; |
| extra_len -= len; |
| } while(extra_len > 0); |
| } |
| // For future expansion of struct new_connection_packet, compare |
| // ntohl(packet.len) and INITIAL_NEW_CONN_PACKET_LENGTH here |
| |
| pid_t pid = fork(); |
| if(pid) { |
| if(pid < 0) perror("fork"); |
| break; |
| } else { |
| int forward_host_socket = socket(AF_INET, SOCK_STREAM, 0); |
| if(forward_host_socket == -1) { |
| perror("socket"); |
| exit(1); |
| } |
| struct sockaddr_in addr = { |
| .sin_family = AF_INET, |
| .sin_addr = forward_host_address, |
| .sin_port = htons(forward_host_port) |
| }; |
| fprintf(stderr, "%s: Connecting to %s:%d\n", argv[0], argv[1], forward_host_port); |
| while(connect(forward_host_socket, (struct sockaddr *)&addr, sizeof addr) < 0) { |
| if(errno == EINTR) continue; |
| perror("connect: forwarding"); |
| exit(1); |
| } |
| exit(forward(fd, forward_host_socket) < 0 ? 1 : 0); |
| } |
| } else { |
| fprintf(stderr, "%s: Unknown packet type %hu\n", argv[0], packet_type); |
| break; |
| } |
| } |
| close_and_recreate_socket: |
| close(fd); |
| } |
| } |