diff options
Diffstat (limited to '')
-rw-r--r-- | wg-dynamic-server.c | 199 |
1 files changed, 115 insertions, 84 deletions
diff --git a/wg-dynamic-server.c b/wg-dynamic-server.c index 323ac76..8a552a8 100644 --- a/wg-dynamic-server.c +++ b/wg-dynamic-server.c @@ -15,7 +15,7 @@ #include <arpa/inet.h> #include <fcntl.h> #include <netdb.h> -#include <poll.h> +#include <sys/epoll.h> #include <sys/socket.h> #include <sys/types.h> @@ -33,7 +33,10 @@ static const char *wg_interface; static struct in6_addr well_known; static wg_device *device = NULL; -static struct pollfd pollfds[MAX_CONNECTIONS + 1]; +static struct wg_dynamic_request requests[MAX_CONNECTIONS] = { 0 }; + +static int sockfd = -1; +static int epollfd = -1; KHASH_MAP_INIT_INT64(allowedht, wg_key *) khash_t(allowedht) * allowedips_ht; @@ -117,17 +120,6 @@ static bool valid_peer_found(wg_device *device) return false; } -static int get_avail_pollfds() -{ - for (int nfds = 1;; ++nfds) { - if (nfds >= MAX_CONNECTIONS + 1) - return -1; - - if (pollfds[nfds].fd < 0) - return nfds; - } -} - static void rebuild_allowedips_ht() { wg_peer *peer; @@ -230,23 +222,6 @@ static int accept_connection(int sockfd, wg_key *dest) return fd; } -static void accept_incoming(int sockfd, struct wg_dynamic_request *reqs) -{ - int n, fd; - while ((n = get_avail_pollfds()) >= 0) { - fd = accept_connection(sockfd, &reqs[n - 1].pubkey); - if (fd < 0) { - if (fd == -ENOENT) - debug("Failed to match IP to pubkey\n"); - else if (fd != -EAGAIN && fd != -EWOULDBLOCK) - debug("Failed to accept connection: %s\n", - strerror(-fd)); - break; - } - pollfds[n].fd = fd; - } -} - static bool send_error(int fd, int ret) { UNUSED(fd); @@ -398,7 +373,7 @@ static bool send_response(int fd, struct wg_dynamic_request *req) return false; } -static void setup_socket(int *fd) +static void setup_socket() { int val = 1, res; struct sockaddr_in6 addr = { @@ -408,21 +383,21 @@ static void setup_socket(int *fd) .sin6_scope_id = device->ifindex, }; - *fd = socket(AF_INET6, SOCK_STREAM, 0); - if (*fd < 0) + sockfd = socket(AF_INET6, SOCK_STREAM, 0); + if (sockfd < 0) fatal("Creating a socket failed"); - res = fcntl(*fd, F_GETFL, 0); - if (res < 0 || fcntl(*fd, F_SETFL, res | O_NONBLOCK) < 0) + res = fcntl(sockfd, F_GETFL, 0); + if (res < 0 || fcntl(sockfd, F_SETFL, res | O_NONBLOCK) < 0) fatal("Setting socket to nonblocking failed"); - if (setsockopt(*fd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof val) == -1) + if (setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof val)) fatal("Setting socket option failed"); - if (bind(*fd, (struct sockaddr *)&addr, sizeof(addr)) == -1) + if (bind(sockfd, (struct sockaddr *)&addr, sizeof(addr)) == -1) fatal("Binding socket failed"); - if (listen(*fd, SOMAXCONN) == -1) + if (listen(sockfd, SOMAXCONN) == -1) fatal("Listening to socket failed"); } @@ -432,39 +407,31 @@ static void cleanup() kh_destroy(allowedht, allowedips_ht); wg_free_device(device); - for (int i = 0; i < MAX_CONNECTIONS + 1; ++i) { - if (pollfds[i].fd < 0) + if (sockfd >= 0) + close(sockfd); + + if (epollfd >= 0) + close(epollfd); + + for (int i = 0; i < MAX_CONNECTIONS; ++i) { + if (requests[i].fd < 0) continue; - if (close(pollfds[i].fd)) - debug("Failed to close fd %d\n", pollfds[i].fd); + close_connection(&requests[i]); } } -int main(int argc, char *argv[]) +static void setup() { - struct wg_dynamic_request reqs[MAX_CONNECTIONS] = { 0 }; - int *sockfd = &pollfds[0].fd; - - progname = argv[0]; if (inet_pton(AF_INET6, WG_DYNAMIC_ADDR, &well_known) != 1) fatal("inet_pton()"); - for (int i = 0; i < MAX_CONNECTIONS + 1; ++i) { - pollfds[i] = (struct pollfd){ - .fd = -1, - .events = POLLIN, - }; - } - - if (argc != 2) - usage(); - + leases_init("leases_file"); allowedips_ht = kh_init(allowedht); - leases_init("leases_file"); + for (int i = 0; i < MAX_CONNECTIONS; ++i) + requests[i].fd = -1; - wg_interface = argv[1]; if (atexit(cleanup)) die("Failed to set exit function\n"); @@ -479,43 +446,107 @@ int main(int argc, char *argv[]) die("%s has no peers with link-local allowedips\n", wg_interface); - setup_socket(sockfd); + setup_socket(&sockfd); +} - while (1) { - if (poll(pollfds, MAX_CONNECTIONS + 1, -1) == -1) - fatal("Failed to poll() fds"); +static int get_avail_request() +{ + for (int nfds = 0;; ++nfds) { + if (nfds >= MAX_CONNECTIONS) + return -1; - if (pollfds[0].revents & POLLIN) { - pollfds[0].revents = 0; - accept_incoming(*sockfd, reqs); - } + if (requests[nfds].fd < 0) + return nfds; + } +} - for (int i = 1; i < MAX_CONNECTIONS + 1; ++i) { - size_t off; +static void accept_incoming(int sockfd, int epollfd, + struct wg_dynamic_request *requests) +{ + int n, fd; + struct epoll_event ev; - if (!(pollfds[i].revents & POLLOUT)) + while ((n = get_avail_request()) >= 0) { + fd = accept_connection(sockfd, &requests[n].pubkey); + if (fd < 0) { + if (fd == -ENOENT) { + debug("Failed to match IP to pubkey\n"); continue; + } else if (fd != -EAGAIN && fd != -EWOULDBLOCK) { + debug("Failed to accept connection: %s\n", + strerror(-fd)); + continue; + } - off = send_message(pollfds[i].fd, reqs[i - 1].buf, - &reqs[i - 1].buflen); - if (reqs[i - 1].buflen) - memmove(reqs[i - 1].buf, reqs[i - 1].buf + off, - reqs[i - 1].buflen); - else - close_connection(&pollfds[i].fd, &reqs[i - 1]); + break; } - for (int i = 1; i < MAX_CONNECTIONS + 1; ++i) { - if (pollfds[i].fd < 0 || !pollfds[i].revents & POLLIN) + ev.events = EPOLLIN | EPOLLET; + ev.data.ptr = &requests[n]; + if (epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &ev) == -1) + fatal("epoll_ctl()"); + + requests[n].fd = fd; + } +} + +static void handle_event(struct wg_dynamic_request *req, uint32_t events) +{ + if (!req) { + accept_incoming(sockfd, epollfd, requests); + return; + } + + if (events & EPOLLIN) { + if (handle_request(req, send_response, send_error)) + close_connection(req); + } + + if (events & EPOLLOUT) { + size_t off = send_message(req->fd, req->buf, &req->buflen); + if (req->buflen) + memmove(req->buf, req->buf + off, req->buflen); + else + close_connection(req); + } +} + +static void poll_loop() +{ + struct epoll_event ev, events[MAX_CONNECTIONS]; + epollfd = epoll_create1(0); + if (epollfd == -1) + fatal("epoll_create1()"); + + ev.events = EPOLLIN | EPOLLET; + ev.data.ptr = NULL; + if (epoll_ctl(epollfd, EPOLL_CTL_ADD, sockfd, &ev) == -1) + fatal("epoll_ctl()"); + + while (1) { + int nfds = epoll_wait(epollfd, events, MAX_CONNECTIONS, -1); + if (nfds == -1) { + if (errno == EINTR) continue; - if (handle_request(pollfds[i].fd, &reqs[i - 1], - send_response, send_error)) - close_connection(&pollfds[i].fd, &reqs[i - 1]); - else if (reqs[i - 1].buf) - pollfds[i].events |= POLLOUT; + fatal("epoll_wait()"); } + + for (int i = 0; i < nfds; ++i) + handle_event(events[i].data.ptr, events[i].events); } +} + +int main(int argc, char *argv[]) +{ + progname = argv[0]; + if (argc != 2) + usage(); + + wg_interface = argv[1]; + setup(); + + poll_loop(); return 0; } |