/* SPDX-License-Identifier: Unlicense */ #include "debug.h" #include "proto_io.h" #include "proto.h" #include "timespec/timespec.h" #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-function" #define OPTPARSE_IMPLEMENTATION #define OPTPARSE_API static #include "optparse/optparse.h" #pragma GCC diagnostic pop #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #define TICK_PERIOD_MIN_MS 1u #define TICK_PERIOD_MAX_MS 10000u #define CONNECTIONS_MAX 256u struct connection { bool exist; bool prev_is_lost; uint16_t port; uint32_t connection_id; uint32_t addr; struct timespec last_syn; struct timespec tick_period; uint32_t tick; uint32_t server_roundtrip_prev; uint32_t client_roundtrip_prev; uint32_t repeat; }; static struct connection g_connections[CONNECTIONS_MAX] = {{0}, }; static volatile bool g_should_exit = false; static void sigintHandler(int a) { (void) a; g_should_exit = true; } static void ConnectionHandleFrame( struct connection *self, struct frame const *frame, struct timespec now, int fd) { if (frame->type == FT_SYNACK) { if (g_log_level >= LOG_LEVEL_TRACE) { fprintf(stderr, "FT_SYNACK(tick=%u, lost=%d, rt_prev=%u, repeat=%u)\n", frame->syn.tick, frame->syn.prev_is_lost, frame->syn.roundtrip_prev, frame->syn.repeat); } self->client_roundtrip_prev = frame->syn.roundtrip_prev; struct frame outgoing = *frame; outgoing.type = FT_ACK, outgoing.ack = (struct frame_data_ack){ .tick = self->tick, }; SendFrame(fd, &outgoing); long const srt_ms = timespec_to_ms(timespec_sub(now, self->last_syn)); assert(srt_ms >= 0); self->server_roundtrip_prev = srt_ms; if (g_log_level >= LOG_LEVEL_DEBUG) { fprintf(stderr, "%5" PRIu32 " server_rt %5lu ms, client_rt %5" PRIu32 " ms\n", self->tick, (unsigned long)srt_ms, frame->syn.roundtrip_prev); } self->tick++; } } static void ConnectionHandleSynExpiration(struct connection *self, int fd) { struct frame const outgoing = { .addr = self->addr, .port = self->port, .connection_id = self->connection_id, .type = FT_SYN, .syn = { .tick = self->tick, .prev_is_lost = self->prev_is_lost, .roundtrip_prev = self->server_roundtrip_prev, }, }; SendFrame(fd, &outgoing); } static void HandleFrame(struct frame const *frame, struct timespec now, int fd) { uint32_t const connection_id = frame->connection_id; if (connection_id >= CONNECTIONS_MAX) { return; } if (frame->type == FT_HANDSHAKE) { if (g_log_level >= LOG_LEVEL_DEBUG) { fprintf(stderr, "FT_HANDSHAKE(tick_period=%" PRIu32 ")\n", frame->handshake.tick_period); } struct frame outgoing = *frame; bool const is_tick_period_valid = frame->handshake.tick_period >= TICK_PERIOD_MIN_MS && frame->handshake.tick_period <= TICK_PERIOD_MAX_MS; if (!is_tick_period_valid) { outgoing.type = FT_RST; SendFrame(fd, &outgoing); return; } for (size_t i = 0; i < CONNECTIONS_MAX; i++) { if (!g_connections[i].exist) { g_connections[i] = (struct connection) { .exist = true, .tick_period = timespec_from_ms(frame->handshake.tick_period), .addr = frame->addr, .port = frame->port, .connection_id = i, .last_syn = now, }; outgoing.type = FT_HANDSHAKE_RESP; outgoing.connection_id = i; SendFrame(fd, &outgoing); if (g_log_level >= LOG_LEVEL_DEBUG) { fprintf(stderr, "New connection id=%" PRIu32 "\n", connection_id); } return; } } outgoing.type = FT_RST; SendFrame(fd, &outgoing); return; } if (g_connections[connection_id].exist) { ConnectionHandleFrame(&g_connections[connection_id], frame, now, fd); } if (frame->type == FT_RST) { g_connections[connection_id].exist = false; if (g_log_level >= LOG_LEVEL_DEBUG) { fprintf(stderr, "Connection closed id=%" PRIu32 "\n", connection_id); } } } static void PrintUsage(FILE *s, const char *argv0) { // Please, keep all lines in 80 columns range when printed. fprintf(s, "Usage: %s [options]\n" "Options:\n" " -h, --help, Show this message.\n" " -d, --debug-level=,\n" " Set id of debug log level: 0 means no logs, \n" " 1 means info, 2 means some debug,\n" " 3 is for maximum trace\n" , argv0); } int main(int argc, char *argv[]) { (void) argc; struct optparse_long longopts[] = { {"help", 'h', OPTPARSE_NONE}, {"debug-level", 'd', OPTPARSE_REQUIRED}, {0}, }; struct optparse options; optparse_init(&options, argv); // Parse opts int option; while ((option = optparse_long(&options, longopts, NULL)) != -1) { switch (option) { case 'h': PrintUsage(stdout, argv[0]); return EXIT_SUCCESS; break; case 'd': g_log_level = atoi(options.optarg); break; case '?': fprintf(stderr, "main: optparse_long: Error: \"%s\"\n", options.errmsg); return EXIT_FAILURE; } } int const fd = socket(AF_INET, SOCK_DGRAM, 0); if (fd < 0) { perror("socket failed"); return 1; } fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK); uint32_t const addr = INADDR_ANY; uint16_t const port = 50037u; struct sockaddr_in const serveraddr = { .sin_family = AF_INET, .sin_port = htons(port), .sin_addr.s_addr = htonl(addr), }; struct sigaction sa = { .sa_handler = sigintHandler, }; sigaction(SIGINT, &sa, NULL); if (bind(fd, (struct sockaddr *)&serveraddr, sizeof(serveraddr)) < 0) { perror("bind failed"); return 1; } { uint8_t const a0 = addr >> 24; uint8_t const a1 = addr >> 16; uint8_t const a2 = addr >> 8; uint8_t const a3 = (uint8_t)addr; if (g_log_level >= LOG_LEVEL_INFO) { fprintf(stderr, "Server is listening on %u.%u.%u.%u:%u\n", a0, a1, a2, a3, port); } } struct pollfd fds[1] = { { .fd = fd, .events = POLLIN, }, }; int timeout_ms = 10000u; while (!g_should_exit) { int const pollret = poll(fds, sizeof(fds)/sizeof(*fds), timeout_ms); int const poll_err = errno; struct timespec now; if (-1 == clock_gettime(CLOCK_MONOTONIC, &now)) { int const err = errno; fprintf(stderr, "clock_gettime(CLOCK_MONOTONIC, &now) failed (%d): \"%s\"", err, strerror(err)); } if (pollret > 0) { if (fds[0].revents & POLLIN) { while (1) { struct frame frame = { .type = FT_NONE, }; int const ret = ReceiveFrame(fds[0].fd, &frame); if (ret) { break; } HandleFrame(&frame, now, fds[0].fd); } } if (fds[0].revents & POLLRDNORM) { fprintf(stderr, "POLLRDNORM, ignoring...\n"); } if (fds[0].revents & POLLRDBAND) { fprintf(stderr, "POLLRDBAND, ignoring...\n"); } if (fds[0].revents & POLLPRI) { fprintf(stderr, "POLLPRI, ignoring...\n"); } if (fds[0].revents & POLLPRI) { fprintf(stderr, "POLLPRI, ignoring...\n"); } if (fds[0].revents & POLLHUP) { fprintf(stderr, "POLLHUP, ignoring...\n"); } if (fds[0].revents & POLLERR) { fprintf(stderr, "POLLERR, ignoring...\n"); } if (fds[0].revents & POLLNVAL) { fprintf(stderr, "POLLNVAL, exiting...\n"); return 1; } } else if (pollret < 0) { if (poll_err == EINTR) { fprintf(stderr, "SIGINT received\n"); break; } fprintf(stderr, "poll() failed (%d): \"%s\"", poll_err, strerror(poll_err)); } timeout_ms = 10000u; for (size_t i = 0; i < CONNECTIONS_MAX; i++) { struct connection *const conn = g_connections + i; if (!conn->exist) { continue; } struct timespec const expiration = timespec_add( conn->last_syn, conn->tick_period); if (timespec_ge(now, expiration)) { ConnectionHandleSynExpiration(conn, fd); // Using `expiration` instead of `now` time here to keep SYN // rate consistent, although it might have a little jitter. conn->last_syn = expiration; } timeout_ms = timespec_to_ms( timespec_min( timespec_from_ms(timeout_ms), timespec_sub(expiration, now))); } } close(fd); }