aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--common.c51
-rw-r--r--common.h5
-rw-r--r--wg-dynamic-client.c11
-rw-r--r--wg-dynamic-server.c199
4 files changed, 143 insertions, 123 deletions
diff --git a/common.c b/common.c
index 5be1226..09a02e2 100644
--- a/common.c
+++ b/common.c
@@ -202,25 +202,6 @@ static ssize_t parse_line(unsigned char *buf, size_t len,
return line_len;
}
-void free_wg_dynamic_request(struct wg_dynamic_request *req)
-{
- struct wg_dynamic_attr *prev, *cur = req->first;
-
- while (cur) {
- prev = cur;
- cur = cur->next;
- free(prev);
- }
-
- req->cmd = WGKEY_UNKNOWN;
- req->version = 0;
- free(req->buf);
- req->buf = NULL;
- req->buflen = 0;
- req->first = NULL;
- req->last = NULL;
-}
-
static int parse_request(struct wg_dynamic_request *req, unsigned char *buf,
size_t len)
{
@@ -273,7 +254,7 @@ static int parse_request(struct wg_dynamic_request *req, unsigned char *buf,
return 1;
}
-bool handle_request(int fd, struct wg_dynamic_request *req,
+bool handle_request(struct wg_dynamic_request *req,
bool (*success)(int, struct wg_dynamic_request *),
bool (*error)(int, int))
{
@@ -282,14 +263,14 @@ bool handle_request(int fd, struct wg_dynamic_request *req,
unsigned char buf[RECV_BUFSIZE + MAX_LINESIZE];
while (1) {
- bytes = read(fd, buf, RECV_BUFSIZE);
+ bytes = read(req->fd, buf, RECV_BUFSIZE);
if (bytes < 0) {
if (errno == EWOULDBLOCK || errno == EAGAIN)
break;
// TODO: handle EINTR
- debug("Reading from socket %d failed: %s\n", fd,
+ debug("Reading from socket %d failed: %s\n", req->fd,
strerror(errno));
return true;
} else if (bytes == 0) {
@@ -299,9 +280,9 @@ bool handle_request(int fd, struct wg_dynamic_request *req,
ret = parse_request(req, buf, bytes);
if (ret < 0)
- return error(fd, -ret);
+ return error(req->fd, -ret);
else if (ret == 0)
- return success(fd, req);
+ return success(req->fd, req);
}
return false;
@@ -367,13 +348,27 @@ uint32_t current_time()
return tp.tv_sec;
}
-void close_connection(int *fd, struct wg_dynamic_request *req)
+void close_connection(struct wg_dynamic_request *req)
{
- if (close(*fd))
+ struct wg_dynamic_attr *prev, *cur = req->first;
+
+ if (close(req->fd))
debug("Failed to close socket\n");
- *fd = -1;
- free_wg_dynamic_request(req);
+ while (cur) {
+ prev = cur;
+ cur = cur->next;
+ free(prev);
+ }
+
+ req->cmd = WGKEY_UNKNOWN;
+ req->version = 0;
+ req->fd = -1;
+ free(req->buf);
+ req->buf = NULL;
+ req->buflen = 0;
+ req->first = NULL;
+ req->last = NULL;
}
bool is_link_local(unsigned char *addr)
diff --git a/common.h b/common.h
index d0f8ffd..b2bd054 100644
--- a/common.h
+++ b/common.h
@@ -59,6 +59,7 @@ struct wg_dynamic_attr {
struct wg_dynamic_request {
enum wg_dynamic_key cmd;
uint32_t version;
+ int fd;
wg_key pubkey;
unsigned char *buf;
size_t buflen;
@@ -77,7 +78,7 @@ struct wg_combined_ip {
#define ARRAY_SIZE(arr) (sizeof(arr) / sizeof((arr)[0]))
void free_wg_dynamic_request(struct wg_dynamic_request *req);
-bool handle_request(int fd, struct wg_dynamic_request *req,
+bool handle_request(struct wg_dynamic_request *req,
bool (*success)(int, struct wg_dynamic_request *),
bool (*error)(int, int));
size_t send_message(int fd, unsigned char *buf, size_t *len);
@@ -85,7 +86,7 @@ void send_later(struct wg_dynamic_request *req, unsigned char *const buf,
size_t msglen);
int print_to_buf(char *buf, size_t bufsize, size_t len, char *fmt, ...);
uint32_t current_time();
-void close_connection(int *fd, struct wg_dynamic_request *req);
+void close_connection(struct wg_dynamic_request *req);
bool is_link_local(unsigned char *addr);
void iface_get_all_addrs(uint8_t family, mnl_cb_t data_cb, void *cb_data);
int data_attr_cb(const struct nlattr *attr, void *data);
diff --git a/wg-dynamic-client.c b/wg-dynamic-client.c
index 2224a22..8dcbd80 100644
--- a/wg-dynamic-client.c
+++ b/wg-dynamic-client.c
@@ -390,13 +390,6 @@ static bool handle_response(int fd, struct wg_dynamic_request *req)
return true;
}
-static bool read_response(int fd, struct wg_dynamic_request *req,
- bool (*success)(int, struct wg_dynamic_request *),
- bool (*error)(int, int))
-{
- return handle_request(fd, req, success, error);
-}
-
int main(int argc __attribute__((unused)), char *argv[] __attribute__((unused)))
{
int *fd = &our_fd;
@@ -446,9 +439,9 @@ int main(int argc __attribute__((unused)), char *argv[] __attribute__((unused)))
request_ip(*fd, &our_lease);
- while (!read_response(*fd, &req, handle_response, handle_error))
+ while (!handle_request(&req, handle_response, handle_error))
;
- close_connection(fd, &req);
+ close_connection(&req);
}
return 0;
diff --git a/wg-dynamic-server.c b/wg-dynamic-server.c
index 323ac76..8a552a8 100644
--- a/wg-dynamic-server.c
+++ b/wg-dynamic-server.c
@@ -15,7 +15,7 @@
#include <arpa/inet.h>
#include <fcntl.h>
#include <netdb.h>
-#include <poll.h>
+#include <sys/epoll.h>
#include <sys/socket.h>
#include <sys/types.h>
@@ -33,7 +33,10 @@ static const char *wg_interface;
static struct in6_addr well_known;
static wg_device *device = NULL;
-static struct pollfd pollfds[MAX_CONNECTIONS + 1];
+static struct wg_dynamic_request requests[MAX_CONNECTIONS] = { 0 };
+
+static int sockfd = -1;
+static int epollfd = -1;
KHASH_MAP_INIT_INT64(allowedht, wg_key *)
khash_t(allowedht) * allowedips_ht;
@@ -117,17 +120,6 @@ static bool valid_peer_found(wg_device *device)
return false;
}
-static int get_avail_pollfds()
-{
- for (int nfds = 1;; ++nfds) {
- if (nfds >= MAX_CONNECTIONS + 1)
- return -1;
-
- if (pollfds[nfds].fd < 0)
- return nfds;
- }
-}
-
static void rebuild_allowedips_ht()
{
wg_peer *peer;
@@ -230,23 +222,6 @@ static int accept_connection(int sockfd, wg_key *dest)
return fd;
}
-static void accept_incoming(int sockfd, struct wg_dynamic_request *reqs)
-{
- int n, fd;
- while ((n = get_avail_pollfds()) >= 0) {
- fd = accept_connection(sockfd, &reqs[n - 1].pubkey);
- if (fd < 0) {
- if (fd == -ENOENT)
- debug("Failed to match IP to pubkey\n");
- else if (fd != -EAGAIN && fd != -EWOULDBLOCK)
- debug("Failed to accept connection: %s\n",
- strerror(-fd));
- break;
- }
- pollfds[n].fd = fd;
- }
-}
-
static bool send_error(int fd, int ret)
{
UNUSED(fd);
@@ -398,7 +373,7 @@ static bool send_response(int fd, struct wg_dynamic_request *req)
return false;
}
-static void setup_socket(int *fd)
+static void setup_socket()
{
int val = 1, res;
struct sockaddr_in6 addr = {
@@ -408,21 +383,21 @@ static void setup_socket(int *fd)
.sin6_scope_id = device->ifindex,
};
- *fd = socket(AF_INET6, SOCK_STREAM, 0);
- if (*fd < 0)
+ sockfd = socket(AF_INET6, SOCK_STREAM, 0);
+ if (sockfd < 0)
fatal("Creating a socket failed");
- res = fcntl(*fd, F_GETFL, 0);
- if (res < 0 || fcntl(*fd, F_SETFL, res | O_NONBLOCK) < 0)
+ res = fcntl(sockfd, F_GETFL, 0);
+ if (res < 0 || fcntl(sockfd, F_SETFL, res | O_NONBLOCK) < 0)
fatal("Setting socket to nonblocking failed");
- if (setsockopt(*fd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof val) == -1)
+ if (setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof val))
fatal("Setting socket option failed");
- if (bind(*fd, (struct sockaddr *)&addr, sizeof(addr)) == -1)
+ if (bind(sockfd, (struct sockaddr *)&addr, sizeof(addr)) == -1)
fatal("Binding socket failed");
- if (listen(*fd, SOMAXCONN) == -1)
+ if (listen(sockfd, SOMAXCONN) == -1)
fatal("Listening to socket failed");
}
@@ -432,39 +407,31 @@ static void cleanup()
kh_destroy(allowedht, allowedips_ht);
wg_free_device(device);
- for (int i = 0; i < MAX_CONNECTIONS + 1; ++i) {
- if (pollfds[i].fd < 0)
+ if (sockfd >= 0)
+ close(sockfd);
+
+ if (epollfd >= 0)
+ close(epollfd);
+
+ for (int i = 0; i < MAX_CONNECTIONS; ++i) {
+ if (requests[i].fd < 0)
continue;
- if (close(pollfds[i].fd))
- debug("Failed to close fd %d\n", pollfds[i].fd);
+ close_connection(&requests[i]);
}
}
-int main(int argc, char *argv[])
+static void setup()
{
- struct wg_dynamic_request reqs[MAX_CONNECTIONS] = { 0 };
- int *sockfd = &pollfds[0].fd;
-
- progname = argv[0];
if (inet_pton(AF_INET6, WG_DYNAMIC_ADDR, &well_known) != 1)
fatal("inet_pton()");
- for (int i = 0; i < MAX_CONNECTIONS + 1; ++i) {
- pollfds[i] = (struct pollfd){
- .fd = -1,
- .events = POLLIN,
- };
- }
-
- if (argc != 2)
- usage();
-
+ leases_init("leases_file");
allowedips_ht = kh_init(allowedht);
- leases_init("leases_file");
+ for (int i = 0; i < MAX_CONNECTIONS; ++i)
+ requests[i].fd = -1;
- wg_interface = argv[1];
if (atexit(cleanup))
die("Failed to set exit function\n");
@@ -479,43 +446,107 @@ int main(int argc, char *argv[])
die("%s has no peers with link-local allowedips\n",
wg_interface);
- setup_socket(sockfd);
+ setup_socket(&sockfd);
+}
- while (1) {
- if (poll(pollfds, MAX_CONNECTIONS + 1, -1) == -1)
- fatal("Failed to poll() fds");
+static int get_avail_request()
+{
+ for (int nfds = 0;; ++nfds) {
+ if (nfds >= MAX_CONNECTIONS)
+ return -1;
- if (pollfds[0].revents & POLLIN) {
- pollfds[0].revents = 0;
- accept_incoming(*sockfd, reqs);
- }
+ if (requests[nfds].fd < 0)
+ return nfds;
+ }
+}
- for (int i = 1; i < MAX_CONNECTIONS + 1; ++i) {
- size_t off;
+static void accept_incoming(int sockfd, int epollfd,
+ struct wg_dynamic_request *requests)
+{
+ int n, fd;
+ struct epoll_event ev;
- if (!(pollfds[i].revents & POLLOUT))
+ while ((n = get_avail_request()) >= 0) {
+ fd = accept_connection(sockfd, &requests[n].pubkey);
+ if (fd < 0) {
+ if (fd == -ENOENT) {
+ debug("Failed to match IP to pubkey\n");
continue;
+ } else if (fd != -EAGAIN && fd != -EWOULDBLOCK) {
+ debug("Failed to accept connection: %s\n",
+ strerror(-fd));
+ continue;
+ }
- off = send_message(pollfds[i].fd, reqs[i - 1].buf,
- &reqs[i - 1].buflen);
- if (reqs[i - 1].buflen)
- memmove(reqs[i - 1].buf, reqs[i - 1].buf + off,
- reqs[i - 1].buflen);
- else
- close_connection(&pollfds[i].fd, &reqs[i - 1]);
+ break;
}
- for (int i = 1; i < MAX_CONNECTIONS + 1; ++i) {
- if (pollfds[i].fd < 0 || !pollfds[i].revents & POLLIN)
+ ev.events = EPOLLIN | EPOLLET;
+ ev.data.ptr = &requests[n];
+ if (epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &ev) == -1)
+ fatal("epoll_ctl()");
+
+ requests[n].fd = fd;
+ }
+}
+
+static void handle_event(struct wg_dynamic_request *req, uint32_t events)
+{
+ if (!req) {
+ accept_incoming(sockfd, epollfd, requests);
+ return;
+ }
+
+ if (events & EPOLLIN) {
+ if (handle_request(req, send_response, send_error))
+ close_connection(req);
+ }
+
+ if (events & EPOLLOUT) {
+ size_t off = send_message(req->fd, req->buf, &req->buflen);
+ if (req->buflen)
+ memmove(req->buf, req->buf + off, req->buflen);
+ else
+ close_connection(req);
+ }
+}
+
+static void poll_loop()
+{
+ struct epoll_event ev, events[MAX_CONNECTIONS];
+ epollfd = epoll_create1(0);
+ if (epollfd == -1)
+ fatal("epoll_create1()");
+
+ ev.events = EPOLLIN | EPOLLET;
+ ev.data.ptr = NULL;
+ if (epoll_ctl(epollfd, EPOLL_CTL_ADD, sockfd, &ev) == -1)
+ fatal("epoll_ctl()");
+
+ while (1) {
+ int nfds = epoll_wait(epollfd, events, MAX_CONNECTIONS, -1);
+ if (nfds == -1) {
+ if (errno == EINTR)
continue;
- if (handle_request(pollfds[i].fd, &reqs[i - 1],
- send_response, send_error))
- close_connection(&pollfds[i].fd, &reqs[i - 1]);
- else if (reqs[i - 1].buf)
- pollfds[i].events |= POLLOUT;
+ fatal("epoll_wait()");
}
+
+ for (int i = 0; i < nfds; ++i)
+ handle_event(events[i].data.ptr, events[i].events);
}
+}
+
+int main(int argc, char *argv[])
+{
+ progname = argv[0];
+ if (argc != 2)
+ usage();
+
+ wg_interface = argv[1];
+ setup();
+
+ poll_loop();
return 0;
}