blob: 28299665e62b2b68ea41997bc8231cb4cb6afff6 [file] [log] [blame] [raw]
/*
* Copyright 2015-2017 Rivoreo
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation, either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*/
#include "common.h"
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <sys/wait.h>
#include <netdb.h>
#include <unistd.h>
#include "syncrw.h"
#include <signal.h>
#include <string.h>
#include <stdlib.h>
#include <stdio.h>
#include <errno.h>
#include <time.h>
extern int forward(int, int);
static int alrm = 0;
static void signal_handler(int sig) {
switch(sig) {
case SIGCHLD:
//while(wait(NULL) < 0) {
while(waitpid(-1, NULL, WNOHANG) < 0) {
if(errno == EINTR) continue;
perror("waitpid");
return;
}
return;
case SIGALRM:
alrm = 1;
return;
}
}
int main(int argc, char **argv) {
const char *name = strchr(argv[0], '/');
if(name) name++; else name = argv[0];
if(argc != 3) {
fprintf(stderr, "Usage: %s <forward-host>:<forward-port> <rpf-server>:<server-port>\n", name);
return -1;
}
struct sigaction orig_act;
struct sigaction act = { .sa_handler = SIG_IGN };
if(sigaction(SIGPIPE, &act, NULL) < 0) {
perror("sigaction");
return 1;
}
act.sa_handler = signal_handler;
sigaction(SIGCHLD, &act, NULL);
int server_address_is_name = 0;
struct addrinfo *server_name_info = NULL, *orig_addr_info_p = NULL;
struct in_addr forward_host_address, rpf_server_address;
char *col = strchr(argv[1], ':');
if(!col) {
fprintf(stderr, "%s: Missing ':' in forward-host\n", argv[0]);
return -1;
}
int forward_host_port = atoi(col + 1);
if(forward_host_port < 1) {
fprintf(stderr, "%s: Wrong port of forward host; port number must greater that 0\n", argv[0]);
return -1;
}
*col = 0;
if(inet_pton(AF_INET, argv[1], &forward_host_address) < 1) {
fprintf(stderr, "%s: Invalid forward host address '%s'\n", argv[0], argv[1]);
return -1;
}
col = strchr(argv[2], ':');
if(!col) {
fprintf(stderr, "%s: Missing ':' in rpf-server\n", argv[0]);
return -1;
}
int rpf_server_port = atoi(col + 1);
if(rpf_server_port < 1) {
fprintf(stderr, "%s: Wrong port of rpf server; port number must greater that 0\n", argv[0]);
return -1;
}
*col = 0;
if(inet_pton(AF_INET, argv[2], &rpf_server_address) < 1) {
fprintf(stderr, "%s: Invalid rpf server address '%s', trying getaddrinfo\n", argv[0], argv[2]);
} else {
server_address_is_name = 1;
}
fprintf(stderr, "\nTCP Reverse Port Forwarding Client - %s\n"
RIVOREO_COPYRIGHT_NOTICE "\n"
LICENSE_INFORMATION "\n\n", name);
int fd = -1;
fd_set fdset;
struct timeval read_timeout;
struct sockaddr_in server_addr = {
.sin_family = AF_INET,
.sin_port = htons(rpf_server_port)
};
if(!server_address_is_name) server_addr.sin_addr = rpf_server_address;
while(1) {
if(server_address_is_name) {
if(!server_name_info) {
freeaddrinfo(orig_addr_info_p);
struct addrinfo hints = {
.ai_family = AF_INET,
.ai_socktype = 0,
.ai_protocol = 0
};
int e = getaddrinfo(argv[2], NULL, &hints, &server_name_info);
if(e) {
fprintf(stderr, "%s: Cannot resolv server name '%s': %s\n", argv[0], argv[2], gai_strerror(e));
sleep(4);
continue;
}
orig_addr_info_p = server_name_info;
}
server_addr.sin_addr = ((struct sockaddr_in *)server_name_info->ai_addr)->sin_addr;
server_name_info = server_name_info->ai_next;
}
//if(fd != -1) close(fd);
fd = socket(AF_INET, SOCK_STREAM, 0);
if(fd == -1) {
perror("socket");
return 1;
}
static const int keepalive = 1;
if(setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &keepalive, sizeof keepalive) < 0) perror("setsockopt");
static const struct timeval sendtimeout = { .tv_sec = 10 };
if(setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &sendtimeout, sizeof sendtimeout) < 0) perror("setsockopt");
fprintf(stderr, "%s: Connecting to rpf server %s", argv[0], argv[2]);
if(server_address_is_name) fprintf(stderr, " (%s)\n", inet_ntoa(server_addr.sin_addr));
else fputc('\n', stderr);
sigaction(SIGALRM, &act, &orig_act);
alarm(20);
alrm = 0;
while(connect(fd, (struct sockaddr *)&server_addr, sizeof server_addr) < 0) {
if(errno == EINTR || errno == EINPROGRESS || errno == EALREADY) {
if(alrm) {
alarm(0);
sigaction(SIGALRM, &orig_act, NULL);
alrm = 0;
fprintf(stderr, "%s: Connection to rpf server %s timed out\n", argv[0], argv[2]);
goto close_and_recreate_socket;
}
continue;
}
perror("connect: rpf server");
alarm(0);
sigaction(SIGALRM, &orig_act, NULL);
//alrm = 0;
sleep(10);
goto close_and_recreate_socket;
}
alarm(0);
sigaction(SIGALRM, &orig_act, NULL);
int keep_alive_sent = 0;
while(1) {
char magic[sizeof MAGIC - 1];
FD_ZERO(&fdset);
FD_SET(fd, &fdset);
read_timeout.tv_sec = 120;
read_timeout.tv_usec = 0;
int n = select(fd + 1, &fdset, NULL, NULL, &read_timeout);
if(n < 0) {
if(errno == EINTR) continue;
perror("select");
sleep(1);
break;
} else if(!n) {
#if 1
if(keep_alive_sent) {
fprintf(stderr, "%s: Waiting for data from server timed out after keep alive is sent\n", argv[0]);
break;
}
#endif
uint16_t packet_type = htons(KEEP_ALIVE);
if(sync_write(fd, MAGIC, sizeof MAGIC - 1) < 0 || sync_write(fd, &packet_type, sizeof packet_type) < 0) {
perror("write");
sleep(1);
break;
}
keep_alive_sent++;
continue;
}
int s = sync_read(fd, magic, sizeof magic);
//fprintf(stderr, "%s debug: " __FILE__ ":%d: s = %d\n", argv[0], __LINE__, s);
if(s < 0) {
perror("read");
sleep(1);
break;
}
if(s < (int)sizeof magic) {
fprintf(stderr, "%s: Server closed connection\n", argv[0]);
break;
}
if(memcmp(magic, MAGIC, sizeof magic)) {
fprintf(stderr, "%s: Protocol mismatch\n", argv[0]);
sleep(1);
break;
}
uint16_t packet_type;
s = sync_read(fd, &packet_type, sizeof packet_type);
//fprintf(stderr, "%s debug: " __FILE__ ":%d: s = %d\n", argv[0], __LINE__, s);
if(s < 0) {
perror("read");
sleep(1);
break;
}
if(s < (int)sizeof packet_type) {
fprintf(stderr, "%s: Server closed connection\n", argv[0]);
break;
}
//fprintf(stderr, "%s debug: " __FILE__ ":%d: packet_type = %hu\n", argv[0], __LINE__, packet_type);
packet_type = ntohs(packet_type);
//fprintf(stderr, "%s debug: " __FILE__ ":%d: packet_type = %hu\n", argv[0], __LINE__, packet_type);
if(packet_type == KEEP_ALIVE) {
fprintf(stderr, "%s: Keep alive from server\n", argv[0]);
packet_type = htons(KEEP_ALIVE_REPLY);
if(sync_write(fd, MAGIC, sizeof MAGIC - 1) < 0 || sync_write(fd, &packet_type, sizeof packet_type) < 0) {
perror("write");
sleep(1);
break;
}
continue;
}
if(packet_type == KEEP_ALIVE_REPLY) {
if(!keep_alive_sent) {
fprintf(stderr, "%s: Unexpected keep alive reply from server, disconnecting\n", argv[0]);
sleep(2);
break;
}
keep_alive_sent--;
continue;
}
if(packet_type == NEW_CONNECTION) {
//fprintf(stderr, "%s debug: " __FILE__ ":%d: server has got a new connection\n", argv[0], __LINE__);
struct new_connection_packet packet;
s = sync_read(fd, &packet.len, sizeof packet.len);
//fprintf(stderr, "%s debug: " __FILE__ ":%d: s = %d\n", argv[0], __LINE__, s);
if(s < 0) {
perror("read");
sleep(1);
break;
}
if(s < (int)sizeof packet.len) break;
size_t remain_len = ntohl(packet.len) - sizeof packet.len;
//fprintf(stderr, "%s debug: " __FILE__ ":%d: remain_len = %zu\n", argv[0], __LINE__, remain_len);
if(remain_len < INITIAL_NEW_CONN_PACKET_LENGTH - sizeof packet.len) {
fprintf(stderr, "%s: Recevied packet too short, disconnecting\n", argv[0]);
break;
}
size_t extra_len = 0;
if(remain_len > sizeof packet - sizeof packet.len) {
fprintf(stderr, "%s: Recevied packet too long, skipping extra parts\n", argv[0]);
extra_len = remain_len - (sizeof packet - sizeof packet.len);
remain_len = sizeof packet - sizeof packet.len;
}
s = sync_read(fd, (char *)&packet + sizeof packet.len, remain_len);
//fprintf(stderr, "%s debug: " __FILE__ ":%d: s = %d\n", argv[0], __LINE__, s);
if(s < (int)remain_len) {
if(s < 0) perror("read");
sleep(1);
break;
}
//struct sockaddr_in *addr = (struct sockaddr_in *)&packet.addr;
//fprintf(stderr, "%s: New connection recevied on server time %s from %s port %d\n",
// argv[0], ctime(&packet.tv.tv_sec), inet_ntoa(addr->sin_addr), ntohs(addr->sin_port));
fprintf(stderr, "%s: New connection recevied on server %s from %s port %d\n",
argv[0], argv[2], inet_ntoa(packet.address), ntohs(packet.port));
if(extra_len) {
//size_t chunk_len = MIN(extra_len, 16384);
size_t chunk_len = MIN(extra_len, 4096);
char buffer[chunk_len];
do {
int len = MIN(chunk_len, extra_len);
s = sync_read(fd, buffer, len);
//fprintf(stderr, "%s debug: " __FILE__ ":%d: s = %d\n", argv[0], __LINE__, s);
if(s < 0) {
perror("read");
sleep(1);
break;
}
if(s < len) {
if(server_address_is_name) server_name_info = NULL;
goto close_and_recreate_socket;
}
extra_len -= len;
} while(extra_len > 0);
}
// For future expansion of struct new_connection_packet, compare
// ntohl(packet.len) and INITIAL_NEW_CONN_PACKET_LENGTH here
pid_t pid = fork();
if(pid) {
if(pid < 0) perror("fork");
break;
} else {
int forward_host_socket = socket(AF_INET, SOCK_STREAM, 0);
if(forward_host_socket == -1) {
perror("socket");
exit(1);
}
struct sockaddr_in addr = {
.sin_family = AF_INET,
.sin_addr = forward_host_address,
.sin_port = htons(forward_host_port)
};
fprintf(stderr, "%s: Connecting to %s:%d\n", argv[0], argv[1], forward_host_port);
while(connect(forward_host_socket, (struct sockaddr *)&addr, sizeof addr) < 0) {
if(errno == EINTR) continue;
perror("connect: forwarding");
exit(1);
}
exit(forward(fd, forward_host_socket) < 0 ? 1 : 0);
}
} else {
fprintf(stderr, "%s: Unknown packet type %hu\n", argv[0], packet_type);
break;
}
}
if(server_address_is_name) {
// Skip remaining addresses if any
//freeaddrinfo(orig_addr_info_p);
//orig_addr_info_p = NULL;
server_name_info = NULL;
}
close_and_recreate_socket:
close(fd);
}
}