From 5b14f8e69f5bfa85708b190c34479bd539297bfe Mon Sep 17 00:00:00 2001 From: Thomas Gschwantner Date: Fri, 14 Jun 2019 20:12:25 +0200 Subject: Use epoll() instead of poll() This enables us to later use the timeout parameter of epoll_wait() to timely remove expired leases. --- common.c | 51 ++++++-------- common.h | 5 +- wg-dynamic-client.c | 11 +-- wg-dynamic-server.c | 199 ++++++++++++++++++++++++++++++---------------------- 4 files changed, 143 insertions(+), 123 deletions(-) diff --git a/common.c b/common.c index 5be1226..09a02e2 100644 --- a/common.c +++ b/common.c @@ -202,25 +202,6 @@ static ssize_t parse_line(unsigned char *buf, size_t len, return line_len; } -void free_wg_dynamic_request(struct wg_dynamic_request *req) -{ - struct wg_dynamic_attr *prev, *cur = req->first; - - while (cur) { - prev = cur; - cur = cur->next; - free(prev); - } - - req->cmd = WGKEY_UNKNOWN; - req->version = 0; - free(req->buf); - req->buf = NULL; - req->buflen = 0; - req->first = NULL; - req->last = NULL; -} - static int parse_request(struct wg_dynamic_request *req, unsigned char *buf, size_t len) { @@ -273,7 +254,7 @@ static int parse_request(struct wg_dynamic_request *req, unsigned char *buf, return 1; } -bool handle_request(int fd, struct wg_dynamic_request *req, +bool handle_request(struct wg_dynamic_request *req, bool (*success)(int, struct wg_dynamic_request *), bool (*error)(int, int)) { @@ -282,14 +263,14 @@ bool handle_request(int fd, struct wg_dynamic_request *req, unsigned char buf[RECV_BUFSIZE + MAX_LINESIZE]; while (1) { - bytes = read(fd, buf, RECV_BUFSIZE); + bytes = read(req->fd, buf, RECV_BUFSIZE); if (bytes < 0) { if (errno == EWOULDBLOCK || errno == EAGAIN) break; // TODO: handle EINTR - debug("Reading from socket %d failed: %s\n", fd, + debug("Reading from socket %d failed: %s\n", req->fd, strerror(errno)); return true; } else if (bytes == 0) { @@ -299,9 +280,9 @@ bool handle_request(int fd, struct wg_dynamic_request *req, ret = parse_request(req, buf, bytes); if (ret < 0) - return error(fd, -ret); + return error(req->fd, -ret); else if (ret == 0) - return success(fd, req); + return success(req->fd, req); } return false; @@ -367,13 +348,27 @@ uint32_t current_time() return tp.tv_sec; } -void close_connection(int *fd, struct wg_dynamic_request *req) +void close_connection(struct wg_dynamic_request *req) { - if (close(*fd)) + struct wg_dynamic_attr *prev, *cur = req->first; + + if (close(req->fd)) debug("Failed to close socket\n"); - *fd = -1; - free_wg_dynamic_request(req); + while (cur) { + prev = cur; + cur = cur->next; + free(prev); + } + + req->cmd = WGKEY_UNKNOWN; + req->version = 0; + req->fd = -1; + free(req->buf); + req->buf = NULL; + req->buflen = 0; + req->first = NULL; + req->last = NULL; } bool is_link_local(unsigned char *addr) diff --git a/common.h b/common.h index d0f8ffd..b2bd054 100644 --- a/common.h +++ b/common.h @@ -59,6 +59,7 @@ struct wg_dynamic_attr { struct wg_dynamic_request { enum wg_dynamic_key cmd; uint32_t version; + int fd; wg_key pubkey; unsigned char *buf; size_t buflen; @@ -77,7 +78,7 @@ struct wg_combined_ip { #define ARRAY_SIZE(arr) (sizeof(arr) / sizeof((arr)[0])) void free_wg_dynamic_request(struct wg_dynamic_request *req); -bool handle_request(int fd, struct wg_dynamic_request *req, +bool handle_request(struct wg_dynamic_request *req, bool (*success)(int, struct wg_dynamic_request *), bool (*error)(int, int)); size_t send_message(int fd, unsigned char *buf, size_t *len); @@ -85,7 +86,7 @@ void send_later(struct wg_dynamic_request *req, unsigned char *const buf, size_t msglen); int print_to_buf(char *buf, size_t bufsize, size_t len, char *fmt, ...); uint32_t current_time(); -void close_connection(int *fd, struct wg_dynamic_request *req); +void close_connection(struct wg_dynamic_request *req); bool is_link_local(unsigned char *addr); void iface_get_all_addrs(uint8_t family, mnl_cb_t data_cb, void *cb_data); int data_attr_cb(const struct nlattr *attr, void *data); diff --git a/wg-dynamic-client.c b/wg-dynamic-client.c index 2224a22..8dcbd80 100644 --- a/wg-dynamic-client.c +++ b/wg-dynamic-client.c @@ -390,13 +390,6 @@ static bool handle_response(int fd, struct wg_dynamic_request *req) return true; } -static bool read_response(int fd, struct wg_dynamic_request *req, - bool (*success)(int, struct wg_dynamic_request *), - bool (*error)(int, int)) -{ - return handle_request(fd, req, success, error); -} - int main(int argc __attribute__((unused)), char *argv[] __attribute__((unused))) { int *fd = &our_fd; @@ -446,9 +439,9 @@ int main(int argc __attribute__((unused)), char *argv[] __attribute__((unused))) request_ip(*fd, &our_lease); - while (!read_response(*fd, &req, handle_response, handle_error)) + while (!handle_request(&req, handle_response, handle_error)) ; - close_connection(fd, &req); + close_connection(&req); } return 0; diff --git a/wg-dynamic-server.c b/wg-dynamic-server.c index 323ac76..8a552a8 100644 --- a/wg-dynamic-server.c +++ b/wg-dynamic-server.c @@ -15,7 +15,7 @@ #include #include #include -#include +#include #include #include @@ -33,7 +33,10 @@ static const char *wg_interface; static struct in6_addr well_known; static wg_device *device = NULL; -static struct pollfd pollfds[MAX_CONNECTIONS + 1]; +static struct wg_dynamic_request requests[MAX_CONNECTIONS] = { 0 }; + +static int sockfd = -1; +static int epollfd = -1; KHASH_MAP_INIT_INT64(allowedht, wg_key *) khash_t(allowedht) * allowedips_ht; @@ -117,17 +120,6 @@ static bool valid_peer_found(wg_device *device) return false; } -static int get_avail_pollfds() -{ - for (int nfds = 1;; ++nfds) { - if (nfds >= MAX_CONNECTIONS + 1) - return -1; - - if (pollfds[nfds].fd < 0) - return nfds; - } -} - static void rebuild_allowedips_ht() { wg_peer *peer; @@ -230,23 +222,6 @@ static int accept_connection(int sockfd, wg_key *dest) return fd; } -static void accept_incoming(int sockfd, struct wg_dynamic_request *reqs) -{ - int n, fd; - while ((n = get_avail_pollfds()) >= 0) { - fd = accept_connection(sockfd, &reqs[n - 1].pubkey); - if (fd < 0) { - if (fd == -ENOENT) - debug("Failed to match IP to pubkey\n"); - else if (fd != -EAGAIN && fd != -EWOULDBLOCK) - debug("Failed to accept connection: %s\n", - strerror(-fd)); - break; - } - pollfds[n].fd = fd; - } -} - static bool send_error(int fd, int ret) { UNUSED(fd); @@ -398,7 +373,7 @@ static bool send_response(int fd, struct wg_dynamic_request *req) return false; } -static void setup_socket(int *fd) +static void setup_socket() { int val = 1, res; struct sockaddr_in6 addr = { @@ -408,21 +383,21 @@ static void setup_socket(int *fd) .sin6_scope_id = device->ifindex, }; - *fd = socket(AF_INET6, SOCK_STREAM, 0); - if (*fd < 0) + sockfd = socket(AF_INET6, SOCK_STREAM, 0); + if (sockfd < 0) fatal("Creating a socket failed"); - res = fcntl(*fd, F_GETFL, 0); - if (res < 0 || fcntl(*fd, F_SETFL, res | O_NONBLOCK) < 0) + res = fcntl(sockfd, F_GETFL, 0); + if (res < 0 || fcntl(sockfd, F_SETFL, res | O_NONBLOCK) < 0) fatal("Setting socket to nonblocking failed"); - if (setsockopt(*fd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof val) == -1) + if (setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof val)) fatal("Setting socket option failed"); - if (bind(*fd, (struct sockaddr *)&addr, sizeof(addr)) == -1) + if (bind(sockfd, (struct sockaddr *)&addr, sizeof(addr)) == -1) fatal("Binding socket failed"); - if (listen(*fd, SOMAXCONN) == -1) + if (listen(sockfd, SOMAXCONN) == -1) fatal("Listening to socket failed"); } @@ -432,39 +407,31 @@ static void cleanup() kh_destroy(allowedht, allowedips_ht); wg_free_device(device); - for (int i = 0; i < MAX_CONNECTIONS + 1; ++i) { - if (pollfds[i].fd < 0) + if (sockfd >= 0) + close(sockfd); + + if (epollfd >= 0) + close(epollfd); + + for (int i = 0; i < MAX_CONNECTIONS; ++i) { + if (requests[i].fd < 0) continue; - if (close(pollfds[i].fd)) - debug("Failed to close fd %d\n", pollfds[i].fd); + close_connection(&requests[i]); } } -int main(int argc, char *argv[]) +static void setup() { - struct wg_dynamic_request reqs[MAX_CONNECTIONS] = { 0 }; - int *sockfd = &pollfds[0].fd; - - progname = argv[0]; if (inet_pton(AF_INET6, WG_DYNAMIC_ADDR, &well_known) != 1) fatal("inet_pton()"); - for (int i = 0; i < MAX_CONNECTIONS + 1; ++i) { - pollfds[i] = (struct pollfd){ - .fd = -1, - .events = POLLIN, - }; - } - - if (argc != 2) - usage(); - + leases_init("leases_file"); allowedips_ht = kh_init(allowedht); - leases_init("leases_file"); + for (int i = 0; i < MAX_CONNECTIONS; ++i) + requests[i].fd = -1; - wg_interface = argv[1]; if (atexit(cleanup)) die("Failed to set exit function\n"); @@ -479,43 +446,107 @@ int main(int argc, char *argv[]) die("%s has no peers with link-local allowedips\n", wg_interface); - setup_socket(sockfd); + setup_socket(&sockfd); +} - while (1) { - if (poll(pollfds, MAX_CONNECTIONS + 1, -1) == -1) - fatal("Failed to poll() fds"); +static int get_avail_request() +{ + for (int nfds = 0;; ++nfds) { + if (nfds >= MAX_CONNECTIONS) + return -1; - if (pollfds[0].revents & POLLIN) { - pollfds[0].revents = 0; - accept_incoming(*sockfd, reqs); - } + if (requests[nfds].fd < 0) + return nfds; + } +} - for (int i = 1; i < MAX_CONNECTIONS + 1; ++i) { - size_t off; +static void accept_incoming(int sockfd, int epollfd, + struct wg_dynamic_request *requests) +{ + int n, fd; + struct epoll_event ev; - if (!(pollfds[i].revents & POLLOUT)) + while ((n = get_avail_request()) >= 0) { + fd = accept_connection(sockfd, &requests[n].pubkey); + if (fd < 0) { + if (fd == -ENOENT) { + debug("Failed to match IP to pubkey\n"); continue; + } else if (fd != -EAGAIN && fd != -EWOULDBLOCK) { + debug("Failed to accept connection: %s\n", + strerror(-fd)); + continue; + } - off = send_message(pollfds[i].fd, reqs[i - 1].buf, - &reqs[i - 1].buflen); - if (reqs[i - 1].buflen) - memmove(reqs[i - 1].buf, reqs[i - 1].buf + off, - reqs[i - 1].buflen); - else - close_connection(&pollfds[i].fd, &reqs[i - 1]); + break; } - for (int i = 1; i < MAX_CONNECTIONS + 1; ++i) { - if (pollfds[i].fd < 0 || !pollfds[i].revents & POLLIN) + ev.events = EPOLLIN | EPOLLET; + ev.data.ptr = &requests[n]; + if (epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &ev) == -1) + fatal("epoll_ctl()"); + + requests[n].fd = fd; + } +} + +static void handle_event(struct wg_dynamic_request *req, uint32_t events) +{ + if (!req) { + accept_incoming(sockfd, epollfd, requests); + return; + } + + if (events & EPOLLIN) { + if (handle_request(req, send_response, send_error)) + close_connection(req); + } + + if (events & EPOLLOUT) { + size_t off = send_message(req->fd, req->buf, &req->buflen); + if (req->buflen) + memmove(req->buf, req->buf + off, req->buflen); + else + close_connection(req); + } +} + +static void poll_loop() +{ + struct epoll_event ev, events[MAX_CONNECTIONS]; + epollfd = epoll_create1(0); + if (epollfd == -1) + fatal("epoll_create1()"); + + ev.events = EPOLLIN | EPOLLET; + ev.data.ptr = NULL; + if (epoll_ctl(epollfd, EPOLL_CTL_ADD, sockfd, &ev) == -1) + fatal("epoll_ctl()"); + + while (1) { + int nfds = epoll_wait(epollfd, events, MAX_CONNECTIONS, -1); + if (nfds == -1) { + if (errno == EINTR) continue; - if (handle_request(pollfds[i].fd, &reqs[i - 1], - send_response, send_error)) - close_connection(&pollfds[i].fd, &reqs[i - 1]); - else if (reqs[i - 1].buf) - pollfds[i].events |= POLLOUT; + fatal("epoll_wait()"); } + + for (int i = 0; i < nfds; ++i) + handle_event(events[i].data.ptr, events[i].events); } +} + +int main(int argc, char *argv[]) +{ + progname = argv[0]; + if (argc != 2) + usage(); + + wg_interface = argv[1]; + setup(); + + poll_loop(); return 0; } -- cgit v1.2.3-59-g8ed1b