blob: 4f71ed88be9ee1c6c5ccd164e05c620f4f7352ab [file] [log] [blame] [raw]
/*
* Copyright 2015-2017 Rivoreo
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation, either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*/
#include "common.h"
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <sys/wait.h>
#include <unistd.h>
#include "syncrw.h"
#include <signal.h>
#include <string.h>
#include <stdlib.h>
#include <stdio.h>
#include <errno.h>
#include <time.h>
extern int forward(int, int);
static void signal_handler(int sig) {
if(sig != SIGCHLD) return;
//while(wait(NULL) < 0) {
while(waitpid(-1, NULL, WNOHANG) < 0) {
if(errno == EINTR) continue;
perror("waitpid");
return;
}
}
int main(int argc, char **argv) {
if(argc != 3) {
fprintf(stderr, "Usage: %s <forward-host>:<forward-port> <rpf-server>:<server-port>\n", argv[0]);
return -1;
}
struct sigaction act = { .sa_handler = SIG_IGN };
if(sigaction(SIGPIPE, &act, NULL) < 0) {
perror("sigaction");
return 1;
}
act.sa_handler = signal_handler;
sigaction(SIGCHLD, &act, NULL);
struct in_addr forward_host_address, rpf_server_address;
char *col = strchr(argv[1], ':');
if(!col) {
fprintf(stderr, "%s: Missing ':' in forward-host\n", argv[0]);
return -1;
}
int forward_host_port = atoi(col + 1);
if(forward_host_port < 1) {
fprintf(stderr, "%s: Wrong port of forward host; port number must greater that 0\n", argv[0]);
return -1;
}
*col = 0;
if(inet_pton(AF_INET, argv[1], &forward_host_address) < 1) {
fprintf(stderr, "%s: Invalid forward host address '%s'\n", argv[0], argv[1]);
return -1;
}
col = strchr(argv[2], ':');
if(!col) {
fprintf(stderr, "%s: Missing ':' in rpf-server\n", argv[0]);
return -1;
}
int rpf_server_port = atoi(col + 1);
if(rpf_server_port < 1) {
fprintf(stderr, "%s: Wrong port of rpf server; port number must greater that 0\n", argv[0]);
return -1;
}
*col = 0;
if(inet_pton(AF_INET, argv[2], &rpf_server_address) < 1) {
fprintf(stderr, "%s: Invalid rpf server address '%s'\n", argv[0], argv[1]);
return -1;
}
int fd = -1;
struct sockaddr_in server_addr = {
.sin_family = AF_INET,
.sin_addr = rpf_server_address,
.sin_port = htons(rpf_server_port)
};
while(1) {
if(fd != -1) close(fd);
fd = socket(AF_INET, SOCK_STREAM, 0);
if(fd == -1) {
perror("socket");
return 1;
}
fprintf(stderr, "%s: Connecting to rpf server %s\n", argv[0], argv[2]);
while(connect(fd, (struct sockaddr *)&server_addr, sizeof server_addr) < 0) {
if(errno == EINTR) continue;
perror("connect");
//return 1;
sleep(10);
}
while(1) {
char magic[sizeof MAGIC - 1];
int s = sync_read(fd, magic, sizeof magic);
fprintf(stderr, "%s debug: " __FILE__ ":%d: s = %d\n", argv[0], __LINE__, s);
if(s < 0) {
perror("read");
//return 1;
sleep(1);
break;
}
if(s < sizeof magic) {
fprintf(stderr, "%s: Server closed connection\n", argv[0]);
//return 0;
break;
}
if(memcmp(magic, MAGIC, sizeof magic)) {
fprintf(stderr, "%s: Protocol mismatch\n", argv[0]);
//return 1;
sleep(1);
break;
}
uint16_t packet_type;
s = sync_read(fd, &packet_type, sizeof packet_type);
fprintf(stderr, "%s debug: " __FILE__ ":%d: s = %d\n", argv[0], __LINE__, s);
if(s < 0) {
perror("read");
//return 1;
sleep(1);
break;
}
if(s < sizeof packet_type) {
fprintf(stderr, "%s: Server closed connection\n", argv[0]);
//return 0;
break;
}
fprintf(stderr, "%s debug: " __FILE__ ":%d: packet_type = %hu\n", argv[0], __LINE__, packet_type);
/*
switch(ntohs(packet_type)) {
case KEEP_ALIVE:
fprintf(stderr, "%s: Keep alive from server\n", argv[0]);
break;
case NEW_CONNECTION:
fork_and_start_forwarding(fd, &forward_host_address, &forward_host_port);
}*/
packet_type = ntohs(packet_type);
fprintf(stderr, "%s debug: " __FILE__ ":%d: packet_type = %hu\n", argv[0], __LINE__, packet_type);
if(packet_type == KEEP_ALIVE) {
fprintf(stderr, "%s: Keep alive from server\n", argv[0]);
continue;
}
if(packet_type == NEW_CONNECTION) {
fprintf(stderr, "%s debug: " __FILE__ ":%d: server has got a new connection\n", argv[0], __LINE__);
struct new_connection_packet packet;
s = sync_read(fd, &packet.len, sizeof packet.len);
fprintf(stderr, "%s debug: " __FILE__ ":%d: s = %d\n", argv[0], __LINE__, s);
if(s < 0) {
perror("read");
sleep(1);
break;
}
if(s < sizeof packet.len) break;
size_t remain_len = ntohl(packet.len) - sizeof packet.len;
fprintf(stderr, "%s debug: " __FILE__ ":%d: remain_len = %zu\n", argv[0], __LINE__, remain_len);
if(remain_len > sizeof packet - sizeof packet.len) {
fprintf(stderr, "%s: Recevied packet too long, skipping\n", argv[0]);
char buffer[remain_len];
s = sync_read(fd, buffer, remain_len);
fprintf(stderr, "%s debug: " __FILE__ ":%d: s = %d\n", argv[0], __LINE__, s);
if(s < 0) {
perror("read");
sleep(1);
break;
}
if(s < remain_len) break;
} else {
s = sync_read(fd, (char *)&packet + sizeof packet.len, remain_len);
fprintf(stderr, "%s debug: " __FILE__ ":%d: s = %d\n", argv[0], __LINE__, s);
if(s < remain_len) {
if(s < 0) perror("read");
sleep(1);
break;
}
struct sockaddr_in *addr = (struct sockaddr_in *)&packet.addr;
fprintf(stderr, "%s: New connection recevied on server time %s from %s port %d\n",
argv[0], ctime(&packet.tv.tv_sec), inet_ntoa(addr->sin_addr), ntohs(addr->sin_port));
}
pid_t pid = fork();
if(pid) {
if(pid < 0) perror("fork");
break;
} else {
int forward_host_socket = socket(AF_INET, SOCK_STREAM, 0);
if(forward_host_socket == -1) {
perror("socket");
exit(1);
}
struct sockaddr_in addr = {
.sin_family = AF_INET,
.sin_addr = forward_host_address,
.sin_port = htons(forward_host_port)
};
fprintf(stderr, "%s: Connecting to %s:%d\n", argv[0], argv[1], forward_host_port);
while(connect(forward_host_socket, (struct sockaddr *)&addr, sizeof addr) < 0) {
if(errno == EINTR) continue;
perror("connect");
exit(1);
}
exit(forward(fd, forward_host_socket) < 0 ? 1 : 0);
}
} else {
fprintf(stderr, "%s: Unknown packet type %hu\n", argv[0], packet_type);
break;
}
}
}
}