blob: 1e8c03e52eedacd71a6592f34d7dd9a331764ac8 [file] [log] [blame] [raw]
/*
* Copyright 2015-2016 Rivoreo
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation, either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*/
#define DEFAULT_PORT 3446
#define CHUNKED 1
#define CHUNK_SIZE (8*1024)
#include "ecc.h"
#include "ecies-chacha20.h"
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <unistd.h>
#include "syncrw.h"
#include <stdio.h>
#include <errno.h>
#include <string.h>
#include <stdint.h>
#include <stdlib.h>
#include <fcntl.h>
// Return number of bytes decrypted
int ecies_read(const ECIES_privkey_t *privkey, int fd, void *buffer, size_t count) {
#ifdef CHUNKED
int r = 0;
ECIES_stream_t stm;
{
ECIES_byte_t encrypted_data[ECIES_START_OVERHEAD];
int s = sync_read(fd, encrypted_data, sizeof encrypted_data);
if(s < 0) return -1;
if(s < sizeof encrypted_data) {
errno = ENODATA;
return -1;
}
s = ECIES_decrypt_start(&stm, encrypted_data, privkey);
if(s < 0) {
errno = EBADMSG;
return -1;
}
//memcpy(buffer, encrypted_data, s);
r += s;
}
{
ECIES_byte_t encrypted_data[CHUNK_SIZE + ECIES_CHUNK_OVERHEAD];
do {
int s = sync_read(fd, encrypted_data, sizeof encrypted_data);
if(s < 0) return -1;
if(s < sizeof encrypted_data) {
return r;
}
s = ECIES_decrypt_chunk(&stm, encrypted_data, s - ECIES_CHUNK_OVERHEAD);
if(s < 0) {
errno = EBADMSG;
return -1;
}
if(s > count) s = count;
memcpy((char *)buffer + r, encrypted_data, s);
r += s;
count -= s;
} while(count > 0);
}
return r;
#else
char encrypted_data[count + ECIES_OVERHEAD];
int read_count = 0;
do {
int s = read(fd, encrypted_data + read_count, sizeof encrypted_data - read_count);
if(s < 0) {
if(s == EINTR) continue;
return -1;
}
if(!s) break;
read_count += s;
} while(read_count < sizeof encrypted_data);
count = read_count - ECIES_OVERHEAD;
ECIES_decrypt(buffer, count, encrypted_data, privkey);
return count;
#endif
}
// Return number of bytes written to fd
int ecies_write(const ECIES_pubkey_t *pubkey, int fd, const void *buffer, size_t count) {
#ifdef CHUNKED
int r = 0;
ECIES_stream_t stm;
if(count < ECIES_START_OVERHEAD) {
errno = ENOBUFS;
return -1;
}
{
ECIES_byte_t encrypted_data[ECIES_START_OVERHEAD];
ECIES_encrypt_start(&stm, encrypted_data, pubkey);
if(sync_write(fd, encrypted_data, sizeof encrypted_data) < 0) return -1;
r += sizeof encrypted_data;
}
{
ECIES_byte_t encrypted_data[CHUNK_SIZE + ECIES_CHUNK_OVERHEAD];
do {
size_t chunk_size = MIN(CHUNK_SIZE, count);
memcpy(encrypted_data, buffer, chunk_size);
ECIES_encrypt_chunk(&stm, encrypted_data, chunk_size);
int s = sync_write(fd, encrypted_data, chunk_size + ECIES_CHUNK_OVERHEAD);
if(s > 0) return -1;
r += s;
count -= chunk_size;
} while(count > 0);
}
return r;
#else
#endif
}
static void print_usage(const char *name) {
fprintf(stderr, "Usage: %s -k <server-key> [-b <bind-address>] [-p <port>]\n", name);
}
static int read_server_key(const char *path, ECIES_privkey_t *privkey, ECIES_pubkey_t *pubkey) {
int fd = open(path, O_RDONLY);
if(fd == -1) return -1;
int s = sync_read(fd, &privkey->k, ECIES_KEY_SIZE);
if(s < ECIES_KEY_SIZE) {
if(s >= 0) errno = EBADMSG;
return -1;
}
s = sync_read(fd, &pubkey->x, ECIES_KEY_SIZE);
if(s < ECIES_KEY_SIZE) {
if(s >= 0) errno = EBADMSG;
return -1;
}
s = sync_read(fd, &pubkey->y, ECIES_KEY_SIZE);
if(s < ECIES_KEY_SIZE) {
if(s >= 0) errno = EBADMSG;
return -1;
}
close(fd);
return 0;
}
int main(int argc, char **argv) {
const char *server_key_path = NULL;
ECIES_privkey_t privkey;
ECIES_pubkey_t pubkey;
struct in_addr address = { .s_addr = htonl(INADDR_ANY) };
int port = DEFAULT_PORT;
while(1) {
int c = getopt(argc, argv, "b:k:p:h");
if(c == -1) break;
switch(c) {
case 'b':
if(inet_pton(AF_INET, optarg, &address) < 1) {
fprintf(stderr, "%s: Invalid address '%s'\n", argv[0], optarg);
return 1;
}
break;
case 'k':
server_key_path = optarg;
break;
case 'p':
port = atoi(optarg);
if(!port) {
fprintf(stderr, "%s: Invalid port '%s'\n", argv[0], optarg);
return 1;
}
break;
case 'h':
print_usage(argv[0]);
break;
default:
return -1;
}
}
if(!server_key_path) {
fprintf(stderr, "%s: need server key path\n", argv[0]);
print_usage(argv[0]);
return -1;
}
if(read_server_key(server_key_path, &privkey, &pubkey) < 0) {
perror(server_key_path);
return 1;
}
ECIES_set_symmetric_crypt_functions(ChaCha20_key_bytes, ChaCha20_ctr_crypt, ChaCha20_cbc_mac, ChaCha20_davies_meyer);
ECIES_set_symmetric_crypt_nonce_location(ChaCha20_nonce, &ChaCha20_nonce_size);
int fd = socket(AF_INET, SOCK_STREAM, 0);
if(fd == -1) {
perror("socket");
return 1;
}
int reuseaddr = 1;
if(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &reuseaddr, sizeof reuseaddr) < 0) perror("setsockopt");
struct sockaddr_in listen_addr = { .sin_family = AF_INET };
//memset(&listen_addr, 0, sizeof listen_addr);
//listen_addr.sin_addr.s_addr = htonl(INADDR_ANY);
listen_addr.sin_addr = address;
listen_addr.sin_port = htons(port);
while(bind(fd, (struct sockaddr *)&listen_addr, sizeof listen_addr) < 0) {
if(errno == EAGAIN || errno == EINTR) continue;
perror("bind");
return 1;
}
if(listen(fd, 64) < 0) {
perror("listen");
return 1;
}
fd_set fdset;
FD_ZERO(&fdset);
FD_SET(fd, &fdset);
int maxfd = fd;
int client_fds[FD_SETSIZE];
int i;
for(i=0; i<FD_SETSIZE; i++) client_fds[i] = -1;
while(1) {
fd_set rfdset = fdset;
int n = select(maxfd + 1, &rfdset, NULL, NULL, NULL);
if(n < 0) {
if(errno == EINTR) continue;
perror("select");
return 1;
}
if(FD_ISSET(fd, &rfdset)) {
struct sockaddr_in client_addr;
socklen_t addr_len = sizeof client_addr;
int cfd;
do {
cfd = accept(fd, (struct sockaddr *)&client_addr, &addr_len);
} while(cfd == -1 && errno == EINTR);
if(cfd == -1) {
perror("accept");
//if(errno == EMFILE) continue;
//return 1;
//continue;
if(errno == EMFILE && n < 2) sleep(1);
} else {
fprintf(stderr, "connection from %s port %d fd %d\n", inet_ntoa(client_addr.sin_addr), ntohs(client_addr.sin_port), cfd);
/*
for(i = 0; client_fds[i] != -1; i++) {
}
*/
i = 0;
while(1) {
if(i >= FD_SETSIZE) {
fprintf(stderr, "warning: cannot add fd %d to set, too many clients\n", cfd);
close(cfd);
break;
}
if(client_fds[i] == -1) {
client_fds[i] = cfd;
FD_SET(cfd, &fdset);
if(cfd > maxfd) maxfd = cfd;
fprintf(stderr, "client %d fd %d\n", i, cfd);
break;
}
i++;
}
}
n--;
}
for(i=0; n && i<FD_SETSIZE; i++) {
int cfd = client_fds[i];
if(cfd == -1) continue;
if(FD_ISSET(cfd, &rfdset)) {
n--;
uint32_t msg_len;
int s = sync_read(cfd, &msg_len, sizeof msg_len);
if(s < 0) {
perror("read");
continue;
}
if(s) {
//fprintf(stderr, "got %d bytes from client %d fd %d\nstring: \"%s\"\n",
// s, i, cfd, buffer);
if(msg_len > 8192) {
fprintf(stderr, "client %d fd %d want send %u bytes, that is too large."
" disconnecting.\n", i, cfd, (unsigned int)msg_len);
close(cfd);
FD_CLR(cfd, &fdset);
client_fds[i] = -1;
continue;
}
msg_len = ntohl(msg_len);
fprintf(stderr, "need to read %u bytes from client %d fd %d\n",
(unsigned int)msg_len, i, cfd);
char buffer[msg_len];
s = ecies_read(&privkey, cfd, buffer, msg_len);
if(s < 0) {
perror("ecies_read");
continue;
}
buffer[s] = 0;
fprintf(stderr, "got %d bytes from client %d fd %d\nstring: \"%s\"\n",
s, i, cfd, buffer);
} else {
fprintf(stderr, "client %d fd %d closed\n", i, cfd);
close(cfd);
FD_CLR(cfd, &fdset);
client_fds[i] = -1;
}
}
}
}
}