aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Makefile2
-rw-r--r--common.h4
-rw-r--r--radix-trie.c330
-rw-r--r--radix-trie.h35
-rw-r--r--wg-dynamic-server.c129
5 files changed, 480 insertions, 20 deletions
diff --git a/Makefile b/Makefile
index d7dce6a..575fc38 100644
--- a/Makefile
+++ b/Makefile
@@ -46,7 +46,7 @@ endif
all: wg-dynamic-server wg-dynamic-client
wg-dynamic-client: wg-dynamic-client.o netlink.o common.o
-wg-dynamic-server: wg-dynamic-server.o netlink.o common.o
+wg-dynamic-server: wg-dynamic-server.o netlink.o radix-trie.o common.o
ifneq ($(V),1)
clean:
diff --git a/common.h b/common.h
index afb6beb..3a8411d 100644
--- a/common.h
+++ b/common.h
@@ -9,9 +9,10 @@
#include <stdbool.h>
#include <stdint.h>
#include <stdlib.h>
-
#include <netinet/in.h>
+#include "netlink.h"
+
#define MAX_CONNECTIONS 16
#define MAX_LINESIZE 4096
@@ -53,6 +54,7 @@ struct wg_dynamic_attr {
struct wg_dynamic_request {
enum wg_dynamic_key cmd;
uint32_t version;
+ wg_key pubkey;
unsigned char *buf;
size_t buflen;
struct wg_dynamic_attr *first, *last;
diff --git a/radix-trie.c b/radix-trie.c
new file mode 100644
index 0000000..0e67a52
--- /dev/null
+++ b/radix-trie.c
@@ -0,0 +1,330 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+#define _DEFAULT_SOURCE
+#include <endian.h>
+
+#include <arpa/inet.h>
+#include <errno.h>
+#include <stdbool.h>
+#include <stdint.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include "dbg.h"
+#include "radix-trie.h"
+
+#define MIN(X, Y) (((X) < (Y)) ? (X) : (Y))
+
+struct radix_node {
+ struct radix_node *bit[2];
+ void *data;
+ uint8_t bits[16];
+ uint8_t cidr, bit_at_a, bit_at_b;
+};
+
+// TODO: sort out #ifdef business to make this portable
+static unsigned int fls64(uint64_t a)
+{
+ return __builtin_ctzl(a) + 1;
+}
+
+static unsigned int fls(uint32_t a)
+{
+ return __builtin_ctz(a) + 1;
+}
+
+static unsigned int fls128(uint64_t a, uint64_t b)
+{
+ return a ? fls64(a) + 64U : (b ? fls64(b) : 0);
+}
+
+static void swap_endian(uint8_t *dst, const uint8_t *src, uint8_t bits)
+{
+ if (bits == 32) {
+ *(uint32_t *)dst = be32toh(*(const uint32_t *)src);
+ } else if (bits == 128) {
+ ((uint64_t *)dst)[0] = be64toh(((const uint64_t *)src)[0]);
+ ((uint64_t *)dst)[1] = be64toh(((const uint64_t *)src)[1]);
+ }
+}
+
+static uint8_t common_bits(const struct radix_node *node, const uint8_t *key,
+ uint8_t bits)
+{
+ if (bits == 32)
+ return 32U - fls(*(const uint32_t *)node->bits ^
+ *(const uint32_t *)key);
+ else if (bits == 128)
+ return 128U - fls128(*(const uint64_t *)&node->bits[0] ^
+ *(const uint64_t *)&key[0],
+ *(const uint64_t *)&node->bits[8] ^
+ *(const uint64_t *)&key[8]);
+ return 0;
+}
+
+static struct radix_node *new_node(const uint8_t *key, uint8_t cidr,
+ uint8_t bits)
+{
+ struct radix_node *node;
+
+ node = malloc(sizeof *node);
+ if (!node)
+ fatal("malloc()");
+
+ node->bit[0] = node->bit[1] = node->data = NULL;
+ node->cidr = cidr;
+ node->bit_at_a = cidr / 8U;
+#ifdef __LITTLE_ENDIAN
+ node->bit_at_a ^= (bits / 8U - 1U) % 8U;
+#endif
+ node->bit_at_b = 7U - (cidr % 8U);
+ memcpy(node->bits, key, bits / 8U);
+
+ return node;
+}
+
+static bool prefix_matches(const struct radix_node *node, const uint8_t *key,
+ uint8_t bits)
+{
+ return common_bits(node, key, bits) >= node->cidr;
+}
+
+#define CHOOSE_NODE(parent, key) \
+ parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1]
+
+static bool node_placement(struct radix_node *trie, const uint8_t *key,
+ uint8_t cidr, uint8_t bits,
+ struct radix_node **rnode)
+{
+ struct radix_node *node = trie, *parent = NULL;
+ bool exact = false;
+
+ while (node && node->cidr <= cidr && prefix_matches(node, key, bits)) {
+ parent = node;
+ if (parent->cidr == cidr) {
+ exact = true;
+ break;
+ }
+
+ node = CHOOSE_NODE(parent, key);
+ }
+ *rnode = parent;
+ return exact;
+}
+
+static int add(struct radix_node **trie, uint8_t bits, const uint8_t *key,
+ uint8_t cidr, void *data, bool overwrite)
+{
+ struct radix_node *node, *newnode, *down, *parent;
+
+ if (cidr > bits || !data)
+ return -EINVAL;
+
+ if (!*trie) {
+ *trie = new_node(key, cidr, bits);
+ (*trie)->data = data;
+ return 0;
+ }
+
+ if (node_placement(*trie, key, cidr, bits, &node)) {
+ // exact match, so use the existing node
+ if (!overwrite && node->data)
+ return 1;
+
+ node->data = data;
+ return 0;
+ }
+
+ if (!overwrite && node && node->data)
+ return 1;
+
+ newnode = new_node(key, cidr, bits);
+ newnode->data = data;
+
+ if (!node) {
+ down = *trie;
+ } else {
+ down = CHOOSE_NODE(node, key);
+
+ if (!down) {
+ CHOOSE_NODE(node, key) = newnode;
+ return 0;
+ }
+ }
+ cidr = MIN(cidr, common_bits(down, key, bits));
+ parent = node;
+
+ if (newnode->cidr == cidr) {
+ CHOOSE_NODE(newnode, down->bits) = down;
+ if (!parent)
+ *trie = newnode;
+ else
+ CHOOSE_NODE(parent, newnode->bits) = newnode;
+ } else {
+ node = new_node(newnode->bits, cidr, bits);
+
+ CHOOSE_NODE(node, down->bits) = down;
+ CHOOSE_NODE(node, newnode->bits) = newnode;
+ if (!parent)
+ *trie = node;
+ else
+ CHOOSE_NODE(parent, node->bits) = node;
+ }
+
+ return 0;
+}
+
+static void radix_free_nodes(struct radix_node *node)
+{
+ struct radix_node *old, *bottom = node;
+
+ while (node) {
+ while (bottom->bit[0])
+ bottom = bottom->bit[0];
+ bottom->bit[0] = node->bit[1];
+
+ old = node;
+ node = node->bit[0];
+ free(old);
+ }
+}
+
+#ifndef __aligned
+#define __aligned(x) __attribute__((aligned(x)))
+#endif
+
+static int insert_v4(struct radix_node **root, const struct in_addr *ip,
+ uint8_t cidr, void *data, bool overwrite)
+{
+ /* Aligned so it can be passed to fls */
+ uint8_t key[4] __aligned(__alignof(uint32_t));
+
+ swap_endian(key, (const uint8_t *)ip, 32);
+ return add(root, 32, key, cidr, data, overwrite);
+}
+
+static int insert_v6(struct radix_node **root, const struct in6_addr *ip,
+ uint8_t cidr, void *data, bool overwrite)
+{
+ /* Aligned so it can be passed to fls64 */
+ uint8_t key[16] __aligned(__alignof(uint64_t));
+
+ swap_endian(key, (const uint8_t *)ip, 128);
+ return add(root, 128, key, cidr, data, overwrite);
+}
+
+static struct radix_node *find_node(struct radix_node *trie, uint8_t bits,
+ const uint8_t *key)
+{
+ struct radix_node *node = trie, *found = NULL;
+
+ while (node && prefix_matches(node, key, bits)) {
+ if (node->data)
+ found = node;
+ if (node->cidr == bits)
+ break;
+ node = CHOOSE_NODE(node, key);
+ }
+ return found;
+}
+
+static struct radix_node *lookup(struct radix_node *root, uint8_t bits,
+ const void *be_ip)
+{
+ /* Aligned so it can be passed to fls/fls64 */
+ uint8_t ip[16] __aligned(__alignof(uint64_t));
+ struct radix_node *node;
+
+ swap_endian(ip, be_ip, bits);
+ node = find_node(root, bits, ip);
+ return node;
+}
+
+#ifdef DEBUG
+#include <stdio.h>
+void node_to_str(struct radix_node *node, char *buf)
+{
+ struct in6_addr addr;
+ char out[INET6_ADDRSTRLEN];
+ char cidr[5];
+
+ if (!node) {
+ strcpy(buf, "-");
+ return;
+ }
+
+ swap_endian(addr.s6_addr, node->bits, 128);
+ inet_ntop(AF_INET6, &addr, out, sizeof out);
+ snprintf(cidr, sizeof cidr, "/%u", node->cidr);
+ strcpy(buf, out);
+ strcat(buf, cidr);
+}
+
+void debug_print_trie(struct radix_node *root)
+{
+ char parent[INET6_ADDRSTRLEN + 4], child1[INET6_ADDRSTRLEN + 4],
+ child2[INET6_ADDRSTRLEN + 4];
+
+ if (!root)
+ return;
+
+ node_to_str(root, parent);
+ node_to_str(root->bit[0], child1);
+ node_to_str(root->bit[1], child2);
+
+ debug("%s -> %s, %s\n", parent, child1, child2);
+
+ debug_print_trie(root->bit[0]);
+ debug_print_trie(root->bit[1]);
+}
+#endif
+
+void radix_init(struct radix_trie *trie)
+{
+ trie->ip4_root = trie->ip6_root = NULL;
+}
+
+void radix_free(struct radix_trie *trie)
+{
+ radix_free_nodes(trie->ip4_root);
+ radix_free_nodes(trie->ip6_root);
+}
+
+void *radix_find_v4(struct radix_trie *trie, uint8_t bits, const void *be_ip)
+{
+ struct radix_node *found = lookup(trie->ip4_root, bits, be_ip);
+ return found ? found->data : NULL;
+}
+
+void *radix_find_v6(struct radix_trie *trie, uint8_t bits, const void *be_ip)
+{
+ struct radix_node *found = lookup(trie->ip6_root, bits, be_ip);
+ return found ? found->data : NULL;
+}
+
+int radix_insert_v4(struct radix_trie *root, const struct in_addr *ip,
+ uint8_t cidr, void *data)
+{
+ return insert_v4(&root->ip4_root, ip, cidr, data, true);
+}
+
+int radix_insert_v6(struct radix_trie *root, const struct in6_addr *ip,
+ uint8_t cidr, void *data)
+{
+ return insert_v6(&root->ip6_root, ip, cidr, data, true);
+}
+
+int radix_tryinsert_v4(struct radix_trie *root, const struct in_addr *ip,
+ uint8_t cidr, void *data)
+{
+ return insert_v4(&root->ip4_root, ip, cidr, data, false);
+}
+
+int radix_tryinsert_v6(struct radix_trie *root, const struct in6_addr *ip,
+ uint8_t cidr, void *data)
+{
+ return insert_v6(&root->ip6_root, ip, cidr, data, false);
+}
diff --git a/radix-trie.h b/radix-trie.h
new file mode 100644
index 0000000..3d0b00f
--- /dev/null
+++ b/radix-trie.h
@@ -0,0 +1,35 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+#ifndef __RADIX_TRIE_H__
+#define __RADIX_TRIE_H__
+
+#include <arpa/inet.h>
+#include <stdbool.h>
+#include <stdint.h>
+
+struct radix_trie {
+ struct radix_node *ip4_root, *ip6_root;
+};
+
+void radix_init(struct radix_trie *trie);
+void radix_free(struct radix_trie *trie);
+void *radix_find_v4(struct radix_trie *trie, uint8_t bits, const void *be_ip);
+void *radix_find_v6(struct radix_trie *trie, uint8_t bits, const void *be_ip);
+int radix_insert_v4(struct radix_trie *root, const struct in_addr *ip,
+ uint8_t cidr, void *data);
+int radix_insert_v6(struct radix_trie *root, const struct in6_addr *ip,
+ uint8_t cidr, void *data);
+int radix_tryinsert_v4(struct radix_trie *root, const struct in_addr *ip,
+ uint8_t cidr, void *data);
+int radix_tryinsert_v6(struct radix_trie *root, const struct in6_addr *ip,
+ uint8_t cidr, void *data);
+
+#ifdef DEBUG
+void node_to_str(struct radix_node *node, char *buf);
+void debug_print_trie(struct radix_node *root);
+#endif
+
+#endif
diff --git a/wg-dynamic-server.c b/wg-dynamic-server.c
index 89a2d86..861c8f3 100644
--- a/wg-dynamic-server.c
+++ b/wg-dynamic-server.c
@@ -9,6 +9,7 @@
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
+#include <string.h>
#include <time.h>
#include <arpa/inet.h>
@@ -23,13 +24,16 @@
#include "common.h"
#include "dbg.h"
#include "netlink.h"
+#include "radix-trie.h"
#define MAX_RESPONSE_SIZE 8192
static const char *progname;
+static const char *wg_interface;
static struct in6_addr well_known;
static wg_device *device = NULL;
+static struct radix_trie allowedips_trie;
static struct pollfd pollfds[MAX_CONNECTIONS + 1];
struct mnl_cb_data {
@@ -189,25 +193,110 @@ static int get_avail_pollfds()
}
}
-static int accept_connection(int sockfd)
+static void rebuild_allowedips_trie()
+{
+ int ret;
+ wg_peer *peer;
+ wg_allowedip *allowedip;
+
+ radix_free(&allowedips_trie);
+
+ wg_free_device(device);
+ if (wg_get_device(&device, wg_interface))
+ fatal("Unable to access interface %s", wg_interface);
+
+ wg_for_each_peer (device, peer) {
+ wg_for_each_allowedip (peer, allowedip) {
+ if (allowedip->family == AF_INET)
+ ret = radix_insert_v4(&allowedips_trie,
+ &allowedip->ip4,
+ allowedip->cidr, peer);
+ else
+ ret = radix_insert_v6(&allowedips_trie,
+ &allowedip->ip6,
+ allowedip->cidr, peer);
+ if (ret)
+ die("Failed to rebuild allowedips trie\n");
+ }
+ }
+}
+
+static wg_key *addr_to_pubkey(struct sockaddr_storage *addr)
+{
+ wg_peer *peer;
+
+ if (addr->ss_family == AF_INET)
+ peer = radix_find_v4(&allowedips_trie, 32,
+ &((struct sockaddr_in *)addr)->sin_addr);
+ else
+ peer = radix_find_v6(&allowedips_trie, 128,
+ &((struct sockaddr_in6 *)addr)->sin6_addr);
+
+ if (!peer)
+ return NULL;
+
+ return &peer->public_key;
+}
+
+static int accept_connection(int sockfd, wg_key *dest)
{
int fd;
+ wg_key *pubkey;
+ struct sockaddr_storage addr;
+ socklen_t size = sizeof addr;
#ifdef __linux__
- fd = accept4(sockfd, NULL, NULL, SOCK_NONBLOCK);
+ fd = accept4(sockfd, (struct sockaddr *)&addr, &size, SOCK_NONBLOCK);
if (fd < 0)
- fatal("Failed to accept connection");
+ return -errno;
#else
- fd = accept(sockfd, NULL, NULL);
+ fd = accept(sockfd, (struct sockaddr *)&addr, &size);
if (fd < 0)
- fatal("Failed to accept connection");
+ return -errno;
int res = fcntl(fd, F_GETFL, 0);
if (res < 0 || fcntl(fd, F_SETFL, res | O_NONBLOCK) < 0)
fatal("Setting socket to nonblocking failed");
#endif
+ pubkey = addr_to_pubkey(&addr);
+ if (!pubkey) {
+ /* our copy of allowedips is outdated, refresh */
+ rebuild_allowedips_trie();
+ pubkey = addr_to_pubkey(&addr);
+ if (!pubkey) {
+ /* either we lost the race or something is very wrong */
+ close(fd);
+ return -ENOENT;
+ }
+ }
+ memcpy(dest, pubkey, sizeof *pubkey);
+
+ wg_key_b64_string key;
+ char out[INET6_ADDRSTRLEN];
+ wg_key_to_base64(key, *pubkey);
+ inet_ntop(addr.ss_family, &((struct sockaddr_in6 *)&addr)->sin6_addr,
+ out, sizeof(out));
+ debug("%s has pubkey: %s\n", out, key);
+
return fd;
}
+static void accept_incoming(int sockfd, struct wg_dynamic_request *reqs)
+{
+ int n, fd;
+ while ((n = get_avail_pollfds()) >= 0) {
+ fd = accept_connection(sockfd, &reqs[n - 1].pubkey);
+ if (fd < 0) {
+ if (fd == -ENOENT)
+ debug("Failed to match IP to pubkey\n");
+ else if (fd != -EAGAIN && fd != -EWOULDBLOCK)
+ debug("Failed to accept connection: %s\n",
+ strerror(-fd));
+ break;
+ }
+ pollfds[n].fd = fd;
+ }
+}
+
static void close_connection(int *fd, struct wg_dynamic_request *req)
{
if (close(*fd))
@@ -360,7 +449,7 @@ static bool send_response(int fd, struct wg_dynamic_request *req)
static void setup_socket(int *fd)
{
- int val = 1;
+ int val = 1, res;
struct sockaddr_in6 addr = {
.sin6_family = AF_INET6,
.sin6_port = htons(WG_DYNAMIC_PORT),
@@ -372,6 +461,10 @@ static void setup_socket(int *fd)
if (*fd < 0)
fatal("Creating a socket failed");
+ res = fcntl(*fd, F_GETFL, 0);
+ if (res < 0 || fcntl(*fd, F_SETFL, res | O_NONBLOCK) < 0)
+ fatal("Setting socket to nonblocking failed");
+
if (setsockopt(*fd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof val) == -1)
fatal("Setting socket option failed");
@@ -384,6 +477,7 @@ static void setup_socket(int *fd)
static void cleanup()
{
+ radix_free(&allowedips_trie);
wg_free_device(device);
for (int i = 0; i < MAX_CONNECTIONS + 1; ++i) {
@@ -398,8 +492,7 @@ static void cleanup()
int main(int argc, char *argv[])
{
struct wg_dynamic_request reqs[MAX_CONNECTIONS] = { 0 };
- int *sockfd = &pollfds[0].fd, n;
- const char *iface;
+ int *sockfd = &pollfds[0].fd;
progname = argv[0];
inet_pton(AF_INET6, WG_DYNAMIC_ADDR, &well_known);
@@ -414,19 +507,22 @@ int main(int argc, char *argv[])
if (argc != 2)
usage();
- iface = argv[1];
+ radix_init(&allowedips_trie);
+
+ wg_interface = argv[1];
if (atexit(cleanup))
die("Failed to set exit function\n");
- if (wg_get_device(&device, iface))
- fatal("Unable to access interface %s", iface);
+ rebuild_allowedips_trie();
if (!validate_link_local_ip(device->ifindex))
// TODO: assign IP instead?
- die("%s needs to have %s assigned\n", iface, WG_DYNAMIC_ADDR);
+ die("%s needs to have %s assigned\n", wg_interface,
+ WG_DYNAMIC_ADDR);
if (!valid_peer_found(device))
- die("%s has no peers with link-local allowedips\n", iface);
+ die("%s has no peers with link-local allowedips\n",
+ wg_interface);
setup_socket(sockfd);
@@ -435,11 +531,8 @@ int main(int argc, char *argv[])
fatal("Failed to poll() fds");
if (pollfds[0].revents & POLLIN) {
- n = get_avail_pollfds();
- if (n >= 0) {
- pollfds[0].revents = 0;
- pollfds[n].fd = accept_connection(*sockfd);
- }
+ pollfds[0].revents = 0;
+ accept_incoming(*sockfd, reqs);
}
for (int i = 1; i < MAX_CONNECTIONS + 1; ++i) {