blob: 55b742cbbb21b93548eb2ce429a28441dfb18947 [file] [log] [blame] [raw]
/*
copyright: Boaz segev, 2016-2017
license: MIT
Feel free to copy, use and enjoy according to the license provided.
*/
#include "spnlock.inc"
#include "fio_list.h"
#include "bscrypt.h"
#include "pubsub.h"
#include "websockets.h"
#include <arpa/inet.h>
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <strings.h>
#if !defined(__BIG_ENDIAN__) && !defined(__LITTLE_ENDIAN__)
#include <endian.h>
#if !defined(__BIG_ENDIAN__) && !defined(__LITTLE_ENDIAN__) && \
__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
#define __BIG_ENDIAN__
#endif
#endif
/*******************************************************************************
Buffer management - update to change the way the buffer is handled.
*/
struct buffer_s {
void *data;
size_t size;
};
#pragma weak create_ws_buffer
/** returns a buffer_s struct, with a buffer (at least) `size` long. */
struct buffer_s create_ws_buffer(ws_s *owner);
#pragma weak resize_ws_buffer
/** returns a buffer_s struct, with a buffer (at least) `size` long. */
struct buffer_s resize_ws_buffer(ws_s *owner, struct buffer_s);
#pragma weak free_ws_buffer
/** releases an existing buffer. */
void free_ws_buffer(ws_s *owner, struct buffer_s);
/** Sets the initial buffer size. (4Kb)*/
#define WS_INITIAL_BUFFER_SIZE 4096UL
/*******************************************************************************
Buffer management - simple implementation...
Since Websocket connections have a long life expectancy, optimizing this part of
the code probably wouldn't offer a high performance boost.
*/
// buffer increments by 4,096 Bytes (4Kb)
#define round_up_buffer_size(size) (((size) >> 12) + 1) << 12
struct buffer_s create_ws_buffer(ws_s *owner) {
(void)(owner);
struct buffer_s buff;
buff.size = round_up_buffer_size(WS_INITIAL_BUFFER_SIZE);
buff.data = malloc(buff.size);
return buff;
}
struct buffer_s resize_ws_buffer(ws_s *owner, struct buffer_s buff) {
buff.size = round_up_buffer_size(buff.size);
void *tmp = realloc(buff.data, buff.size);
if (!tmp) {
free_ws_buffer(owner, buff);
buff.data = NULL;
buff.size = 0;
}
buff.data = tmp;
return buff;
}
void free_ws_buffer(ws_s *owner, struct buffer_s buff) {
(void)(owner);
if (buff.data)
free(buff.data);
}
#undef round_up_buffer_size
/*******************************************************************************
Create/Destroy the websocket object (prototypes)
*/
static ws_s *new_websocket();
static void destroy_ws(ws_s *ws);
/*******************************************************************************
The Websocket object (protocol + parser)
*/
struct Websocket {
/** The Websocket protocol */
protocol_s protocol;
/** connection data */
intptr_t fd;
/** callbacks */
void (*on_message)(ws_s *ws, char *data, size_t size, uint8_t is_text);
void (*on_shutdown)(ws_s *ws);
void (*on_ready)(ws_s *ws);
void (*on_close)(ws_s *ws);
/** Opaque user data. */
void *udata;
/** The maximum websocket message size */
size_t max_msg_size;
/** active pub/sub subscriptions */
fio_list_s subscriptions;
/** message buffer. */
struct buffer_s buffer;
/** message length (how much of the buffer actually used). */
size_t length;
/** for fragmenting the `on_data` parsing and message handling. */
size_t resume_from;
/** parser. */
struct {
union {
unsigned len1 : 16;
unsigned long long len2 : 64;
char bytes[8];
} psize;
size_t length;
size_t received;
char mask[4];
struct {
unsigned op_code : 4;
unsigned rsv3 : 1;
unsigned rsv2 : 1;
unsigned rsv1 : 1;
unsigned fin : 1;
} head, head2;
struct {
unsigned size : 7;
unsigned masked : 1;
} sdata;
struct {
unsigned has_mask : 1;
unsigned at_mask : 2;
unsigned has_len : 1;
unsigned at_len : 3;
unsigned has_head : 1;
} state;
unsigned client : 1;
} parser;
};
/**
The Websocket Protocol Identifying String. Used for the `each` function.
*/
char *WEBSOCKET_ID_STR = "websockets";
/**
A thread localized buffer used for reading and parsing data from the socket.
*/
#define WEBSOCKET_READ_MAX 4096
static __thread struct {
int pos;
char buffer[WEBSOCKET_READ_MAX];
} read_buffer;
/* *****************************************************************************
Create/Destroy the websocket subscription objects
***************************************************************************** */
typedef struct {
fio_list_s node;
pubsub_sub_pt sub;
} subscription_s;
static inline subscription_s *create_subscription(ws_s *ws, pubsub_sub_pt sub) {
subscription_s *s = malloc(sizeof(*s));
s->sub = sub;
fio_list_add(&ws->subscriptions, &s->node);
return s;
}
static inline void free_subscription(subscription_s *s) {
fio_list_remove(&s->node);
free(s);
}
static inline void clear_subscriptions(ws_s *ws) {
subscription_s *s;
fio_list_for_each(subscription_s, node, s, ws->subscriptions) {
pubsub_unsubscribe(s->sub);
free_subscription(s);
}
}
/*******************************************************************************
The Websocket Protocol implementation
*/
#define ws_protocol(fd) ((ws_s *)(server_get_protocol(fd)))
static void ws_ping(intptr_t fd, protocol_s *ws) {
(void)(ws);
sock_write2(.uuid = fd, .buffer = "\x89\x00", .length = 2, .move = 1,
.dealloc = SOCK_DEALLOC_NOOP);
}
static void on_close(intptr_t uuid, protocol_s *_ws) {
destroy_ws((ws_s *)_ws);
(void)uuid;
}
static void on_ready(intptr_t fduuid, protocol_s *ws) {
(void)(fduuid);
if (ws && ws->service == WEBSOCKET_ID_STR && ((ws_s *)ws)->on_ready)
((ws_s *)ws)->on_ready((ws_s *)ws);
}
// static void on_open(intptr_t fd, protocol_s *ws, void *callback) {
// (void)(fd);
// if (callback && ws && ws->service == WEBSOCKET_ID_STR)
// ((void (*)(void *))callback)(ws);
// }
static void on_shutdown(intptr_t fd, protocol_s *ws) {
(void)(fd);
if (ws && ((ws_s *)ws)->on_shutdown)
((ws_s *)ws)->on_shutdown((ws_s *)ws);
}
/* later */
static void websocket_write_impl(intptr_t fd, void *data, size_t len, char text,
char first, char last, char client);
static size_t websocket_encode(void *buff, void *data, size_t len, char text,
char first, char last, char client);
static void on_data(intptr_t sockfd, protocol_s *_ws);
static void on_data_def(intptr_t sockfd, protocol_s *_ws, void *arg) {
on_data(sockfd, _ws);
(void)arg;
}
/* read data from the socket, parse it and invoke the websocket events. */
static void on_data(intptr_t sockfd, protocol_s *_ws) {
#define ws ((ws_s *)_ws)
if (ws == NULL || ws->protocol.service != WEBSOCKET_ID_STR)
return;
ssize_t len = 0;
ssize_t data_len = 0;
if (ws->resume_from) {
sock_touch(sockfd);
len = ws->resume_from;
ws->resume_from = 0;
goto resume_parsing;
}
while ((len = sock_read(sockfd, read_buffer.buffer, WEBSOCKET_READ_MAX)) >
0) {
resume_parsing:
data_len = 0;
read_buffer.pos = 0;
while (read_buffer.pos < len) {
// collect the frame's head
if (!ws->parser.state.has_head) {
ws->parser.state.has_head = 1;
*((char *)(&(ws->parser.head))) = read_buffer.buffer[read_buffer.pos];
// save a copy if it's the first head in a fragmented message
if (!(*(char *)(&ws->parser.head2))) {
ws->parser.head2 = ws->parser.head;
}
// advance
read_buffer.pos++;
// go back to the `while` head, to review if there's more data
continue;
}
// save the mask and size information
if (!ws->parser.state.at_len && !ws->parser.state.has_len) {
// uint8_t tmp = ws->parser.sdata.masked;
*((char *)(&(ws->parser.sdata))) = read_buffer.buffer[read_buffer.pos];
// ws->parser.sdata.masked |= tmp;
// set length
ws->parser.state.at_len = (ws->parser.sdata.size == 127
? 7
: ws->parser.sdata.size == 126 ? 1 : 0);
if (!ws->parser.state.at_len) {
ws->parser.length = ws->parser.sdata.size;
ws->parser.state.has_len = 1;
}
read_buffer.pos++;
continue;
}
// check that if we need to collect the length data
if (!ws->parser.state.has_len) {
// avoiding a loop so we don't mixup the meaning of "continue" and
// "break"
collect_len:
////////// NOTICE: Network Byte Order requires us to translate the data
#ifdef __BIG_ENDIAN__
if ((ws->parser.state.at_len == 1 && ws->parser.sdata.size == 126) ||
(ws->parser.state.at_len == 7 && ws->parser.sdata.size == 127)) {
ws->parser.psize.bytes[ws->parser.state.at_len] =
read_buffer.buffer[read_buffer.pos++];
ws->parser.state.has_len = 1;
ws->parser.length = (ws->parser.sdata.size == 126)
? ws->parser.psize.len1
: ws->parser.psize.len2;
} else {
ws->parser.psize.bytes[ws->parser.state.at_len++] =
read_buffer.buffer[read_buffer.pos++];
if (read_buffer.pos < len)
goto collect_len;
}
#else
if (ws->parser.state.at_len == 0) {
ws->parser.psize.bytes[ws->parser.state.at_len] =
read_buffer.buffer[read_buffer.pos++];
ws->parser.state.has_len = 1;
ws->parser.length = (ws->parser.sdata.size == 126)
? ws->parser.psize.len1
: ws->parser.psize.len2;
} else {
ws->parser.psize.bytes[ws->parser.state.at_len--] =
read_buffer.buffer[read_buffer.pos++];
if (read_buffer.pos < len)
goto collect_len;
}
#endif
// check message size limit
if (ws->max_msg_size <
ws->length + (ws->parser.length - ws->parser.received)) {
// close connection!
fprintf(stderr, "ERROR Websocket: Payload too big, review limits.\n");
sock_close(sockfd);
return;
}
continue;
}
// check that the data is masked and that we didn't colleced the mask yet
if (ws->parser.sdata.masked && !(ws->parser.state.has_mask)) {
// avoiding a loop so we don't mixup the meaning of "continue" and "break"
collect_mask:
if (ws->parser.state.at_mask == 3) {
ws->parser.mask[ws->parser.state.at_mask] =
read_buffer.buffer[read_buffer.pos++];
ws->parser.state.has_mask = 1;
ws->parser.state.at_mask = 0;
} else {
ws->parser.mask[ws->parser.state.at_mask++] =
read_buffer.buffer[read_buffer.pos++];
if (read_buffer.pos < len)
goto collect_mask;
else
continue;
}
// since it's possible that there's no more data (0 length frame),
// we don't use `continue` (check while loop) and we process what we
// have.
}
// Now that we know everything about the frame, let's collect the data
// How much data in the buffer is part of the frame?
data_len = len - read_buffer.pos;
if (data_len + ws->parser.received > ws->parser.length)
data_len = ws->parser.length - ws->parser.received;
// a note about unmasking: since ws->parser.state.at_mask is only 2 bits,
// it will wrap around (i.e. 3++ == 0), so no modulus is required :-)
// unmask:
if (ws->parser.sdata.masked) {
for (int i = 0; i < data_len; i++) {
read_buffer.buffer[i + read_buffer.pos] ^=
ws->parser.mask[ws->parser.state.at_mask++];
}
} else if (ws->parser.client == 0) {
// enforce masking unless acting as client, also for security reasons...
fprintf(stderr, "ERROR Websockets: unmasked frame, disconnecting.\n");
sock_close(sockfd);
return;
}
// Copy the data to the Websocket buffer - only if it's a user message
if (data_len &&
(ws->parser.head.op_code == 1 || ws->parser.head.op_code == 2 ||
(!ws->parser.head.op_code && (ws->parser.head2.op_code == 1 ||
ws->parser.head2.op_code == 2)))) {
// review and resize the buffer's capacity - it can only grow.
if (ws->length + ws->parser.length - ws->parser.received >
ws->buffer.size) {
ws->buffer = resize_ws_buffer(ws, ws->buffer);
if (!ws->buffer.data) {
// no memory.
websocket_close(ws);
return;
}
}
// copy here
memcpy((uint8_t *)ws->buffer.data + ws->length,
read_buffer.buffer + read_buffer.pos, data_len);
ws->length += data_len;
}
// set the frame's data received so far (copied or not)
ws->parser.received += data_len;
// check that we have collected the whole of the frame.
if (ws->parser.length > ws->parser.received) {
read_buffer.pos += data_len;
// fprintf(stderr, "%p websocket has %lu out of %lu\n", (void *)ws,
// ws->parser.received, ws->parser.length);
continue;
}
// we have the whole frame, time to process the data.
// pings, pongs and other non-user messages are handled independently.
if (ws->parser.head.op_code == 0 || ws->parser.head.op_code == 1 ||
ws->parser.head.op_code == 2) {
/* a user data frame - make sure we got the `fin` flag, or an error
* occured */
if (!ws->parser.head.fin) {
/* This frame was a partial message. */
goto reset_state;
}
/* This was the last frame */
if (ws->on_message) /* call the on_message callback */
ws->on_message(ws, ws->buffer.data, ws->length,
ws->parser.head2.op_code == 1);
goto reset_parser;
} else if (ws->parser.head.op_code == 8) {
/* op-code == close */
websocket_close(ws);
if (ws->parser.head2.op_code == ws->parser.head.op_code)
goto reset_parser;
} else if (ws->parser.head.op_code == 9) {
/* ping */
// write Pong - including ping data...
websocket_write_impl(sockfd, read_buffer.buffer + read_buffer.pos,
data_len, 10, 1, 1, ws->parser.client);
if (ws->parser.head2.op_code == ws->parser.head.op_code)
goto reset_parser;
} else if (ws->parser.head.op_code == 10) {
/* pong */
// do nothing... almost
if (ws->parser.head2.op_code == ws->parser.head.op_code)
goto reset_parser;
} else if (ws->parser.head.op_code > 2 && ws->parser.head.op_code < 8) {
/* future control frames. ignore. */
if (ws->parser.head2.op_code == ws->parser.head.op_code)
goto reset_parser;
} else {
/* WTF? */
// fprintf(stderr, "%p websocket reached a WTF?! state..."
// "op1: %i , op2: %i\n",
// (void *)ws, ws->parser.head.op_code,
// ws->parser.head2.op_code);
fprintf(stderr, "ERROR Websockets: protocol error, disconnecting.\n");
sock_close(sockfd);
return;
}
reset_parser:
ws->length = 0;
// clear the parser's multi-frame state
*((char *)(&(ws->parser.head2))) = 0;
ws->parser.sdata.masked = 0;
reset_state:
// move the pos marker along - in case we have more then one frame in the
// buffer
read_buffer.pos += data_len;
// reset parser state
ws->parser.state.has_len = 0;
ws->parser.state.at_len = 0;
ws->parser.state.has_mask = 0;
ws->parser.state.at_mask = 0;
ws->parser.state.has_head = 0;
ws->parser.sdata.size = 0;
*((char *)(&(ws->parser.head))) = 0;
ws->parser.received = ws->parser.length = ws->parser.psize.len2 =
data_len = 0;
if (read_buffer.pos < len) {
/* we just finished an on_message callback, let's fragment the event. */
ws->resume_from = len;
facil_defer(.task = on_data_def, .task_type = FIO_PR_LOCK_TASK,
.uuid = sockfd);
return;
}
}
}
#undef ws
}
/*******************************************************************************
Create/Destroy the websocket object
*/
static ws_s *new_websocket() {
// allocate the protocol object
ws_s *ws = malloc(sizeof(*ws));
*ws = (ws_s){
.protocol.service = WEBSOCKET_ID_STR,
.protocol.ping = ws_ping,
.protocol.on_data = on_data,
.protocol.on_close = on_close,
.protocol.on_ready = on_ready,
.protocol.on_shutdown = on_shutdown,
.subscriptions = FIO_LIST_INIT_STATIC(ws->subscriptions),
};
return ws;
}
static void destroy_ws(ws_s *ws) {
clear_subscriptions(ws);
if (ws->on_close)
ws->on_close(ws);
free_ws_buffer(ws, ws->buffer);
free(ws);
}
/*******************************************************************************
Writing to the Websocket
*/
#define WS_MAX_FRAME_SIZE 65532 // should be less then `unsigned short`
// clang-format off
#if !defined(__BIG_ENDIAN__) && !defined(__LITTLE_ENDIAN__)
# if defined(__has_include)
# if __has_include(<endian.h>)
# include <endian.h>
# elif __has_include(<sys/endian.h>)
# include <sys/endian.h>
# endif
# endif
# if !defined(__BIG_ENDIAN__) && !defined(__LITTLE_ENDIAN__) && \
__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
# define __BIG_ENDIAN__
# endif
#endif
// clang-format on
#ifdef __BIG_ENDIAN__
/** byte swap 64 bit integer */
#define bswap64(i) (i)
#else
// TODO: check for __builtin_bswap64
/** byte swap 64 bit integer */
#define bswap64(i) \
((((i)&0xFFULL) << 56) | (((i)&0xFF00ULL) << 40) | \
(((i)&0xFF0000ULL) << 24) | (((i)&0xFF000000ULL) << 8) | \
(((i)&0xFF00000000ULL) >> 8) | (((i)&0xFF0000000000ULL) >> 24) | \
(((i)&0xFF000000000000ULL) >> 40) | (((i)&0xFF00000000000000ULL) >> 56))
#endif
static void websocket_mask(void *dest, void *data, size_t len) {
/* a semi-random 4 byte mask */
uint32_t mask = ((rand() << 7) ^ ((uintptr_t)dest >> 13));
/* place mask at head of data */
dest = (uint8_t *)dest + 4;
memcpy(dest, &mask, 4);
/* TODO: optimize this */
for (size_t i = 0; i < len; i++) {
((uint8_t *)dest)[i] = ((uint8_t *)data)[i] ^ ((uint8_t *)(&mask))[i & 3];
}
}
static size_t websocket_encode(void *buff, void *data, size_t len, char text,
char first, char last, char client) {
if (len < 126) {
struct {
unsigned op_code : 4;
unsigned rsv3 : 1;
unsigned rsv2 : 1;
unsigned rsv1 : 1;
unsigned fin : 1;
unsigned size : 7;
unsigned masked : 1;
} head = {.op_code = (first ? (!text ? 2 : text) : 0),
.fin = last,
.size = len,
.masked = client};
memcpy(buff, &head, 2);
if (client) {
websocket_mask((uint8_t *)buff + 2, data, len);
len += 4;
} else
memcpy((uint8_t *)buff + 2, data, len);
return len + 2;
} else if (len < (1UL << 16)) {
/* head is 4 bytes */
struct {
unsigned op_code : 4;
unsigned rsv3 : 1;
unsigned rsv2 : 1;
unsigned rsv1 : 1;
unsigned fin : 1;
unsigned size : 7;
unsigned masked : 1;
unsigned length : 16;
} head = {.op_code = (first ? (text ? 1 : 2) : 0),
.fin = last,
.size = 126,
.masked = client,
.length = htons(len)};
memcpy(buff, &head, 4);
if (client) {
websocket_mask((uint8_t *)buff + 4, data, len);
len += 4;
} else
memcpy((uint8_t *)buff + 4, data, len);
return len + 4;
}
/* Really Long Message */
struct {
unsigned op_code : 4;
unsigned rsv3 : 1;
unsigned rsv2 : 1;
unsigned rsv1 : 1;
unsigned fin : 1;
unsigned size : 7;
unsigned masked : 1;
} head = {
.op_code = (first ? (text ? 1 : 2) : 0),
.fin = last,
.size = 127,
.masked = client,
};
memcpy(buff, &head, 2);
((size_t *)((uint8_t *)buff + 2))[0] = bswap64(len);
if (client) {
websocket_mask((uint8_t *)buff + 10, data, len);
len += 4;
} else
memcpy((uint8_t *)buff + 10, data, len);
return len + 10;
}
static void websocket_write_impl(intptr_t fd, void *data, size_t len,
char text, /* TODO: add client masking */
char first, char last, char client) {
if (len < 126) {
sock_buffer_s *sbuff = sock_buffer_checkout();
sbuff->len =
websocket_encode(sbuff->buf, data, len, text, first, last, client);
sock_buffer_send(fd, sbuff);
// // was:
// char buff[len + (client ? 6 : 2)];
// len = websocket_encode(buff, data, len, text, first, last, client);
// sock_write(fd, buff, len);
} else if (len <= WS_MAX_FRAME_SIZE) {
if (len >= BUFFER_PACKET_SIZE - 8) { // can't feet in sock buffer
/* head MUST be 4 bytes */
void *buff = malloc(len + 4);
len = websocket_encode(buff, data, len, text, first, last, client);
sock_write2(.uuid = fd, .buffer = buff, .length = len, .move = 1);
} else {
sock_buffer_s *sbuff = sock_buffer_checkout();
sbuff->len =
websocket_encode(sbuff->buf, data, len, text, first, last, client);
sock_buffer_send(fd, sbuff);
}
} else {
/* frame fragmentation is better for large data then large frames */
while (len > WS_MAX_FRAME_SIZE) {
websocket_write_impl(fd, data, WS_MAX_FRAME_SIZE, text, first, 0, client);
data = ((uint8_t *)data) + WS_MAX_FRAME_SIZE;
first = 0;
len -= WS_MAX_FRAME_SIZE;
}
websocket_write_impl(fd, data, len, text, first, 1, client);
}
return;
}
/* *****************************************************************************
UTF-8 testing. This part was practically copied from:
https://stackoverflow.com/a/22135005/4025095
and
http://bjoern.hoehrmann.de/utf-8/decoder/dfa
***************************************************************************** */
/* Copyright (c) 2008-2009 Bjoern Hoehrmann <bjoern@hoehrmann.de> */
/* See http://bjoern.hoehrmann.de/utf-8/decoder/dfa/ for details. */
#define UTF8_ACCEPT 0
#define UTF8_REJECT 1
static const uint8_t utf8d[] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 00..1f
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 20..3f
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 40..5f
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 60..7f
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9,
9, 9, 9, 9, 9, 9, 9, 9, 9, 9, // 80..9f
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // a0..bf
8, 8, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // c0..df
0xa, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3,
0x3, 0x3, 0x4, 0x3, 0x3, // e0..ef
0xb, 0x6, 0x6, 0x6, 0x5, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8,
0x8, 0x8, 0x8, 0x8, 0x8, // f0..ff
0x0, 0x1, 0x2, 0x3, 0x5, 0x8, 0x7, 0x1, 0x1, 0x1, 0x4,
0x6, 0x1, 0x1, 0x1, 0x1, // s0..s0
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,
1, 0, 1, 0, 1, 1, 1, 1, 1, 1, // s1..s2
1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 2, 1, 1, 1, 1, 1, 1, 1, 1, // s3..s4
1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 3, 1, 3, 1, 1, 1, 1, 1, 1, // s5..s6
1, 3, 1, 1, 1, 1, 1, 3, 1, 3, 1,
1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // s7..s8
};
static inline uint32_t validate_utf8(uint8_t *str, size_t len) {
uint32_t state = 0;
uint32_t type;
while (len) {
type = utf8d[*str];
state = utf8d[256 + state * 16 + type];
if (state == UTF8_REJECT)
return 0;
len--;
str++;
}
return state == 0;
}
/* *****************************************************************************
Subscription handling
***************************************************************************** */
static void websocket_on_pubsub_message_direct(pubsub_message_s *msg) {
protocol_s *pr =
facil_protocol_try_lock((intptr_t)msg->udata1, FIO_PR_LOCK_WRITE);
if (!pr) {
if (errno == EBADF)
return;
pubsub_defer(msg);
return;
}
websocket_write((ws_s *)pr, msg->msg.data, msg->msg.len,
msg->msg.len >= (2 << 14)
? 0
: validate_utf8((uint8_t *)msg->msg.data, msg->msg.len));
facil_protocol_unlock(pr, FIO_PR_LOCK_WRITE);
}
static void websocket_on_pubsub_message_direct_txt(pubsub_message_s *msg) {
protocol_s *pr =
facil_protocol_try_lock((intptr_t)msg->udata1, FIO_PR_LOCK_WRITE);
if (!pr) {
if (errno == EBADF)
return;
pubsub_defer(msg);
return;
}
websocket_write((ws_s *)pr, msg->msg.data, msg->msg.len, 1);
facil_protocol_unlock(pr, FIO_PR_LOCK_WRITE);
}
static void websocket_on_pubsub_message_direct_bin(pubsub_message_s *msg) {
protocol_s *pr =
facil_protocol_try_lock((intptr_t)msg->udata1, FIO_PR_LOCK_WRITE);
if (!pr) {
if (errno == EBADF)
return;
pubsub_defer(msg);
return;
}
websocket_write((ws_s *)pr, msg->msg.data, msg->msg.len, 0);
facil_protocol_unlock(pr, FIO_PR_LOCK_WRITE);
}
static void websocket_on_pubsub_message(pubsub_message_s *msg) {
protocol_s *pr =
facil_protocol_try_lock((intptr_t)msg->udata1, FIO_PR_LOCK_TASK);
if (!pr) {
if (errno == EBADF)
return;
pubsub_defer(msg);
return;
}
websocket_write((ws_s *)pr, msg->msg.data, msg->msg.len, 0);
facil_protocol_unlock(pr, FIO_PR_LOCK_TASK);
}
/**
* Returns a subscription ID on success and 0 on failure.
*/
#undef websocket_subscribe
uintptr_t websocket_subscribe(struct websocket_subscribe_s args) {
pubsub_sub_pt sub = pubsub_subscribe(
.engine = args.engine,
.channel =
{
.name = (char *)args.channel.name, .len = args.channel.len,
},
.use_pattern = args.use_pattern,
.on_message =
(args.on_message
? websocket_on_pubsub_message
: args.force_binary
? websocket_on_pubsub_message_direct_bin
: args.force_text
? websocket_on_pubsub_message_direct_txt
: websocket_on_pubsub_message_direct),
.udata1 = (void *)args.ws->fd, .udata2 = args.udata);
if (!sub)
return 0;
subscription_s *s = create_subscription(args.ws, sub);
return (uintptr_t)s;
}
/**
* Returns the existing subscription's ID (if exists) or 0 (no subscription).
*/
#undef websocket_find_sub
uintptr_t websocket_find_sub(struct websocket_subscribe_s args) {
pubsub_sub_pt sub = pubsub_find_sub(
.engine = args.engine,
.channel =
{
.name = (char *)args.channel.name, .len = args.channel.len,
},
.use_pattern = args.use_pattern,
.on_message =
(args.on_message
? websocket_on_pubsub_message
: args.force_binary
? websocket_on_pubsub_message_direct_bin
: args.force_text
? websocket_on_pubsub_message_direct_txt
: websocket_on_pubsub_message_direct),
.udata1 = (void *)args.ws->fd, .udata2 = args.udata);
if (!sub)
return 0;
subscription_s *s;
fio_list_for_each(subscription_s, node, s, args.ws->subscriptions) {
if (s->sub == sub)
return (uintptr_t)s;
}
return 0;
}
/**
* Unsubscribes from a channel.
*/
void websocket_unsubscribe(ws_s *ws, uintptr_t subscription_id) {
subscription_s *s;
fio_list_for_each(subscription_s, node, s, ws->subscriptions) {
if (s == (subscription_s *)subscription_id) {
pubsub_unsubscribe(s->sub);
free_subscription(s);
return;
}
}
}
/*******************************************************************************
The API implementation
*/
/** The upgrade */
#undef websocket_upgrade
ssize_t websocket_upgrade(websocket_settings_s settings) {
// A static data used for all websocket connections.
static char ws_key_accpt_str[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
// require either a request or a response.
if (((uintptr_t)settings.request | (uintptr_t)settings.response) ==
(uintptr_t)NULL)
return -1;
if (settings.max_msg_size == 0)
settings.max_msg_size = 262144; /** defaults to ~250KB */
if (settings.timeout == 0)
settings.timeout = 40; /* defaults to 40 seconds */
// make sure we have a response object.
http_response_s *response = settings.response;
if (response == NULL) {
/* initialize a default upgrade response */
response = http_response_create(settings.request);
} else
settings.request = response->request;
// allocate the protocol object (TODO: (maybe) pooling)
ws_s *ws = new_websocket();
if (!ws)
goto refuse;
// setup the socket-server data
ws->fd = response->request->fd;
// Setup ws callbacks
ws->on_close = settings.on_close;
ws->on_message = settings.on_message;
ws->on_ready = settings.on_ready;
ws->on_shutdown = settings.on_shutdown;
// setup any user data
ws->udata = settings.udata;
// buffer limits
ws->max_msg_size = settings.max_msg_size;
const char *recv_str;
recv_str =
http_request_header_find(settings.request, "sec-websocket-version", 21)
.value;
if (recv_str == NULL || recv_str[0] != '1' || recv_str[1] != '3')
goto refuse;
http_header_s sec_h =
http_request_header_find(settings.request, "sec-websocket-key", 17);
if (sec_h.value == NULL)
goto refuse;
// websocket extentions (none)
// the accept Base64 Hash - we need to compute this one and set it
// the client's unique string
// use the SHA1 methods provided to concat the client string and hash
sha1_s sha1;
sha1 = bscrypt_sha1_init();
bscrypt_sha1_write(&sha1, sec_h.value, sec_h.value_len);
bscrypt_sha1_write(&sha1, ws_key_accpt_str, sizeof(ws_key_accpt_str) - 1);
// base encode the data
char websockets_key[32];
int len =
bscrypt_base64_encode(websockets_key, bscrypt_sha1_result(&sha1), 20);
// websocket extentions (none)
// upgrade taking place, make sure the upgrade headers are valid for the
// response.
response->status = 101;
http_response_write_header(response, .name = "Connection", .name_len = 10,
.value = "Upgrade", .value_len = 7);
http_response_write_header(response, .name = "Upgrade", .name_len = 7,
.value = "websocket", .value_len = 9);
http_response_write_header(response, .name = "sec-websocket-version",
.name_len = 21, .value = "13", .value_len = 2);
// set the string's length and encoding
http_response_write_header(response, .name = "Sec-WebSocket-Accept",
.name_len = 20, .value = websockets_key,
.value_len = len);
// // inform about 0 extension support
// sec_h = http_request_header_find(settings.request,
// "sec-websocket-extensions", 24);
// if (recv_str != NULL)
// http_response_write_header(response, .name =
// "Sec-Websocket-Extensions",
// .name_length = 24);
goto cleanup;
refuse:
// set the negative response
response->status = 400;
cleanup:
if (response->status == 101) {
// send the response
http_response_finish(response);
// we have an active websocket connection - prep the connection buffer
ws->buffer = create_ws_buffer(ws);
// update the timeout
facil_set_timeout(ws->fd, settings.timeout);
// call the on_open callback
if (settings.on_open)
settings.on_open(ws);
// update the protocol object, cleanning up the old one
facil_attach(ws->fd, (protocol_s *)ws);
return 0;
}
http_response_finish(response);
destroy_ws(ws);
return -1;
}
#define websocket_upgrade(...) \
websocket_upgrade((websocket_settings_s){__VA_ARGS__})
/** Returns the opaque user data associated with the websocket. */
void *websocket_udata(ws_s *ws) { return ws->udata; }
/** Returns the the process specific connection's UUID (see `libsock`). */
intptr_t websocket_uuid(ws_s *ws) { return ws->fd; }
/** Sets the opaque user data associated with the websocket.
* Returns the old value, if any. */
void *websocket_udata_set(ws_s *ws, void *udata) {
void *old = ws->udata;
ws->udata = udata;
return old;
}
/** Writes data to the websocket. Returns -1 on failure (0 on success). */
int websocket_write(ws_s *ws, void *data, size_t size, uint8_t is_text) {
if (sock_isvalid(ws->fd)) {
websocket_write_impl(ws->fd, data, size, is_text, 1, 1, ws->parser.client);
return 0;
}
return -1;
}
/** Closes a websocket connection. */
void websocket_close(ws_s *ws) {
sock_write2(.uuid = ws->fd, .buffer = "\x88\x00", .length = 2, .move = 1,
.dealloc = SOCK_DEALLOC_NOOP);
sock_close(ws->fd);
return;
}
/**
Counts the number of websocket connections.
*/
size_t websocket_count(void) { return facil_count(WEBSOCKET_ID_STR); }
/*******************************************************************************
Each Implementation
*/
/** A task container. */
struct WSTask {
void (*task)(ws_s *, void *);
void (*on_finish)(ws_s *, void *);
void *arg;
};
/** Performs a task on each websocket connection that shares the same process
*/
static void perform_ws_task(intptr_t fd, protocol_s *ws_, void *tsk_) {
(void)(fd);
struct WSTask *tsk = tsk_;
tsk->task((ws_s *)(ws_), tsk->arg);
}
/** clears away a wesbocket task. */
static void finish_ws_task(intptr_t fd, void *arg) {
struct WSTask *tsk = arg;
if (tsk->on_finish) {
protocol_s *ws = facil_protocol_try_lock(fd, FIO_PR_LOCK_TASK);
if (!ws && errno != EBADF) {
defer((void (*)(void *, void *))finish_ws_task, (void *)fd, arg);
return;
}
tsk->on_finish((ws_s *)ws, tsk->arg);
if (ws)
facil_protocol_unlock(ws, FIO_PR_LOCK_TASK);
}
free(tsk);
}
/**
Performs a task on each websocket connection that shares the same process
(except the originating `ws_s` connection which is allowed to be NULL).
*/
#undef websocket_each
void websocket_each(struct websocket_each_args_s args) {
struct WSTask *tsk = malloc(sizeof(*tsk));
tsk->arg = args.arg;
tsk->on_finish = args.on_finish;
tsk->task = args.task;
facil_each(.origin = (args.origin ? args.origin->fd : -1),
.service = WEBSOCKET_ID_STR, .task = perform_ws_task, .arg = tsk,
.on_complete = finish_ws_task);
}
/*******************************************************************************
Multi-Write (direct broadcast) Implementation
*/
struct websocket_multi_write {
uint8_t (*if_callback)(ws_s *ws_to, void *arg);
void (*on_finished)(ws_s *ws_origin, void *arg);
intptr_t origin;
void *arg;
spn_lock_i lock;
/* ... we need to have padding for pointer arithmatics... */
uint8_t as_client;
/* ... we need to have padding for pointer arithmatics... */
size_t count;
size_t length;
uint8_t buffer[]; /* starts on border alignment */
};
static void ws_mw_defered_on_finish_fb(intptr_t fd, void *arg) {
(void)(fd);
struct websocket_multi_write *fin = arg;
if (fin->on_finished)
fin->on_finished(NULL, fin->arg);
free(fin);
}
static void ws_mw_defered_on_finish(intptr_t fd, protocol_s *ws, void *arg) {
(void)fd;
struct websocket_multi_write *fin = arg;
if (fin->on_finished) {
fin->on_finished((ws_s *)ws, fin->arg);
}
free(fin);
}
static void ws_reduce_or_free_multi_write(void *buff) {
struct websocket_multi_write *mw = (void *)((uintptr_t)buff - sizeof(*mw));
spn_lock(&mw->lock);
mw->count -= 1;
if (!mw->count) {
spn_unlock(&mw->lock);
if (mw->on_finished) {
facil_defer(.uuid = mw->origin, .task = ws_mw_defered_on_finish,
.arg = mw, .fallback = ws_mw_defered_on_finish_fb,
.task_type = FIO_PR_LOCK_WRITE);
} else
free(mw);
} else
spn_unlock(&mw->lock);
}
static void ws_finish_multi_write(intptr_t fd, void *arg) {
struct websocket_multi_write *multi = arg;
(void)(fd);
ws_reduce_or_free_multi_write(multi->buffer);
}
static void ws_direct_multi_write(intptr_t fd, protocol_s *_ws, void *arg) {
struct websocket_multi_write *multi = arg;
if (((ws_s *)(_ws))->parser.client != multi->as_client)
return;
spn_lock(&multi->lock);
multi->count += 1;
spn_unlock(&multi->lock);
sock_write2(.uuid = fd, .buffer = multi->buffer, .length = multi->length,
.dealloc = ws_reduce_or_free_multi_write, .move = 1);
}
static void ws_check_multi_write(intptr_t fd, protocol_s *_ws, void *arg) {
struct websocket_multi_write *multi = arg;
if (((ws_s *)(_ws))->parser.client != multi->as_client)
return;
if (multi->if_callback((void *)_ws, multi->arg))
ws_direct_multi_write(fd, _ws, arg);
}
#undef websocket_write_each
int websocket_write_each(struct websocket_write_each_args_s args) {
if (!args.data || !args.length)
return -1;
struct websocket_multi_write *multi =
malloc(sizeof(*multi) + args.length + 16 /* max head size + 2 */);
if (!multi) {
if (args.on_finished)
defer((void (*)(void *, void *))args.on_finished, NULL, args.arg);
return -1;
}
*multi = (struct websocket_multi_write){
.length = websocket_encode(multi->buffer, args.data, args.length,
args.is_text, 1, 1, args.as_client),
.if_callback = args.filter,
.on_finished = args.on_finished,
.arg = args.arg,
.origin = (args.origin ? args.origin->fd : -1),
.as_client = args.as_client,
.lock = SPN_LOCK_INIT,
.count = 1,
};
facil_each(.origin = multi->origin, .service = WEBSOCKET_ID_STR,
.task_type = FIO_PR_LOCK_WRITE,
.task =
(args.filter ? ws_check_multi_write : ws_direct_multi_write),
.arg = multi, .on_complete = ws_finish_multi_write);
return 0;
}