diff options
-rw-r--r-- | Makefile | 2 | ||||
-rw-r--r-- | common.h | 4 | ||||
-rw-r--r-- | radix-trie.c | 330 | ||||
-rw-r--r-- | radix-trie.h | 35 | ||||
-rw-r--r-- | wg-dynamic-server.c | 129 |
5 files changed, 480 insertions, 20 deletions
@@ -46,7 +46,7 @@ endif all: wg-dynamic-server wg-dynamic-client wg-dynamic-client: wg-dynamic-client.o netlink.o common.o -wg-dynamic-server: wg-dynamic-server.o netlink.o common.o +wg-dynamic-server: wg-dynamic-server.o netlink.o radix-trie.o common.o ifneq ($(V),1) clean: @@ -9,9 +9,10 @@ #include <stdbool.h> #include <stdint.h> #include <stdlib.h> - #include <netinet/in.h> +#include "netlink.h" + #define MAX_CONNECTIONS 16 #define MAX_LINESIZE 4096 @@ -53,6 +54,7 @@ struct wg_dynamic_attr { struct wg_dynamic_request { enum wg_dynamic_key cmd; uint32_t version; + wg_key pubkey; unsigned char *buf; size_t buflen; struct wg_dynamic_attr *first, *last; diff --git a/radix-trie.c b/radix-trie.c new file mode 100644 index 0000000..0e67a52 --- /dev/null +++ b/radix-trie.c @@ -0,0 +1,330 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +#define _DEFAULT_SOURCE +#include <endian.h> + +#include <arpa/inet.h> +#include <errno.h> +#include <stdbool.h> +#include <stdint.h> +#include <stdlib.h> +#include <string.h> + +#include "dbg.h" +#include "radix-trie.h" + +#define MIN(X, Y) (((X) < (Y)) ? (X) : (Y)) + +struct radix_node { + struct radix_node *bit[2]; + void *data; + uint8_t bits[16]; + uint8_t cidr, bit_at_a, bit_at_b; +}; + +// TODO: sort out #ifdef business to make this portable +static unsigned int fls64(uint64_t a) +{ + return __builtin_ctzl(a) + 1; +} + +static unsigned int fls(uint32_t a) +{ + return __builtin_ctz(a) + 1; +} + +static unsigned int fls128(uint64_t a, uint64_t b) +{ + return a ? fls64(a) + 64U : (b ? fls64(b) : 0); +} + +static void swap_endian(uint8_t *dst, const uint8_t *src, uint8_t bits) +{ + if (bits == 32) { + *(uint32_t *)dst = be32toh(*(const uint32_t *)src); + } else if (bits == 128) { + ((uint64_t *)dst)[0] = be64toh(((const uint64_t *)src)[0]); + ((uint64_t *)dst)[1] = be64toh(((const uint64_t *)src)[1]); + } +} + +static uint8_t common_bits(const struct radix_node *node, const uint8_t *key, + uint8_t bits) +{ + if (bits == 32) + return 32U - fls(*(const uint32_t *)node->bits ^ + *(const uint32_t *)key); + else if (bits == 128) + return 128U - fls128(*(const uint64_t *)&node->bits[0] ^ + *(const uint64_t *)&key[0], + *(const uint64_t *)&node->bits[8] ^ + *(const uint64_t *)&key[8]); + return 0; +} + +static struct radix_node *new_node(const uint8_t *key, uint8_t cidr, + uint8_t bits) +{ + struct radix_node *node; + + node = malloc(sizeof *node); + if (!node) + fatal("malloc()"); + + node->bit[0] = node->bit[1] = node->data = NULL; + node->cidr = cidr; + node->bit_at_a = cidr / 8U; +#ifdef __LITTLE_ENDIAN + node->bit_at_a ^= (bits / 8U - 1U) % 8U; +#endif + node->bit_at_b = 7U - (cidr % 8U); + memcpy(node->bits, key, bits / 8U); + + return node; +} + +static bool prefix_matches(const struct radix_node *node, const uint8_t *key, + uint8_t bits) +{ + return common_bits(node, key, bits) >= node->cidr; +} + +#define CHOOSE_NODE(parent, key) \ + parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1] + +static bool node_placement(struct radix_node *trie, const uint8_t *key, + uint8_t cidr, uint8_t bits, + struct radix_node **rnode) +{ + struct radix_node *node = trie, *parent = NULL; + bool exact = false; + + while (node && node->cidr <= cidr && prefix_matches(node, key, bits)) { + parent = node; + if (parent->cidr == cidr) { + exact = true; + break; + } + + node = CHOOSE_NODE(parent, key); + } + *rnode = parent; + return exact; +} + +static int add(struct radix_node **trie, uint8_t bits, const uint8_t *key, + uint8_t cidr, void *data, bool overwrite) +{ + struct radix_node *node, *newnode, *down, *parent; + + if (cidr > bits || !data) + return -EINVAL; + + if (!*trie) { + *trie = new_node(key, cidr, bits); + (*trie)->data = data; + return 0; + } + + if (node_placement(*trie, key, cidr, bits, &node)) { + // exact match, so use the existing node + if (!overwrite && node->data) + return 1; + + node->data = data; + return 0; + } + + if (!overwrite && node && node->data) + return 1; + + newnode = new_node(key, cidr, bits); + newnode->data = data; + + if (!node) { + down = *trie; + } else { + down = CHOOSE_NODE(node, key); + + if (!down) { + CHOOSE_NODE(node, key) = newnode; + return 0; + } + } + cidr = MIN(cidr, common_bits(down, key, bits)); + parent = node; + + if (newnode->cidr == cidr) { + CHOOSE_NODE(newnode, down->bits) = down; + if (!parent) + *trie = newnode; + else + CHOOSE_NODE(parent, newnode->bits) = newnode; + } else { + node = new_node(newnode->bits, cidr, bits); + + CHOOSE_NODE(node, down->bits) = down; + CHOOSE_NODE(node, newnode->bits) = newnode; + if (!parent) + *trie = node; + else + CHOOSE_NODE(parent, node->bits) = node; + } + + return 0; +} + +static void radix_free_nodes(struct radix_node *node) +{ + struct radix_node *old, *bottom = node; + + while (node) { + while (bottom->bit[0]) + bottom = bottom->bit[0]; + bottom->bit[0] = node->bit[1]; + + old = node; + node = node->bit[0]; + free(old); + } +} + +#ifndef __aligned +#define __aligned(x) __attribute__((aligned(x))) +#endif + +static int insert_v4(struct radix_node **root, const struct in_addr *ip, + uint8_t cidr, void *data, bool overwrite) +{ + /* Aligned so it can be passed to fls */ + uint8_t key[4] __aligned(__alignof(uint32_t)); + + swap_endian(key, (const uint8_t *)ip, 32); + return add(root, 32, key, cidr, data, overwrite); +} + +static int insert_v6(struct radix_node **root, const struct in6_addr *ip, + uint8_t cidr, void *data, bool overwrite) +{ + /* Aligned so it can be passed to fls64 */ + uint8_t key[16] __aligned(__alignof(uint64_t)); + + swap_endian(key, (const uint8_t *)ip, 128); + return add(root, 128, key, cidr, data, overwrite); +} + +static struct radix_node *find_node(struct radix_node *trie, uint8_t bits, + const uint8_t *key) +{ + struct radix_node *node = trie, *found = NULL; + + while (node && prefix_matches(node, key, bits)) { + if (node->data) + found = node; + if (node->cidr == bits) + break; + node = CHOOSE_NODE(node, key); + } + return found; +} + +static struct radix_node *lookup(struct radix_node *root, uint8_t bits, + const void *be_ip) +{ + /* Aligned so it can be passed to fls/fls64 */ + uint8_t ip[16] __aligned(__alignof(uint64_t)); + struct radix_node *node; + + swap_endian(ip, be_ip, bits); + node = find_node(root, bits, ip); + return node; +} + +#ifdef DEBUG +#include <stdio.h> +void node_to_str(struct radix_node *node, char *buf) +{ + struct in6_addr addr; + char out[INET6_ADDRSTRLEN]; + char cidr[5]; + + if (!node) { + strcpy(buf, "-"); + return; + } + + swap_endian(addr.s6_addr, node->bits, 128); + inet_ntop(AF_INET6, &addr, out, sizeof out); + snprintf(cidr, sizeof cidr, "/%u", node->cidr); + strcpy(buf, out); + strcat(buf, cidr); +} + +void debug_print_trie(struct radix_node *root) +{ + char parent[INET6_ADDRSTRLEN + 4], child1[INET6_ADDRSTRLEN + 4], + child2[INET6_ADDRSTRLEN + 4]; + + if (!root) + return; + + node_to_str(root, parent); + node_to_str(root->bit[0], child1); + node_to_str(root->bit[1], child2); + + debug("%s -> %s, %s\n", parent, child1, child2); + + debug_print_trie(root->bit[0]); + debug_print_trie(root->bit[1]); +} +#endif + +void radix_init(struct radix_trie *trie) +{ + trie->ip4_root = trie->ip6_root = NULL; +} + +void radix_free(struct radix_trie *trie) +{ + radix_free_nodes(trie->ip4_root); + radix_free_nodes(trie->ip6_root); +} + +void *radix_find_v4(struct radix_trie *trie, uint8_t bits, const void *be_ip) +{ + struct radix_node *found = lookup(trie->ip4_root, bits, be_ip); + return found ? found->data : NULL; +} + +void *radix_find_v6(struct radix_trie *trie, uint8_t bits, const void *be_ip) +{ + struct radix_node *found = lookup(trie->ip6_root, bits, be_ip); + return found ? found->data : NULL; +} + +int radix_insert_v4(struct radix_trie *root, const struct in_addr *ip, + uint8_t cidr, void *data) +{ + return insert_v4(&root->ip4_root, ip, cidr, data, true); +} + +int radix_insert_v6(struct radix_trie *root, const struct in6_addr *ip, + uint8_t cidr, void *data) +{ + return insert_v6(&root->ip6_root, ip, cidr, data, true); +} + +int radix_tryinsert_v4(struct radix_trie *root, const struct in_addr *ip, + uint8_t cidr, void *data) +{ + return insert_v4(&root->ip4_root, ip, cidr, data, false); +} + +int radix_tryinsert_v6(struct radix_trie *root, const struct in6_addr *ip, + uint8_t cidr, void *data) +{ + return insert_v6(&root->ip6_root, ip, cidr, data, false); +} diff --git a/radix-trie.h b/radix-trie.h new file mode 100644 index 0000000..3d0b00f --- /dev/null +++ b/radix-trie.h @@ -0,0 +1,35 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +#ifndef __RADIX_TRIE_H__ +#define __RADIX_TRIE_H__ + +#include <arpa/inet.h> +#include <stdbool.h> +#include <stdint.h> + +struct radix_trie { + struct radix_node *ip4_root, *ip6_root; +}; + +void radix_init(struct radix_trie *trie); +void radix_free(struct radix_trie *trie); +void *radix_find_v4(struct radix_trie *trie, uint8_t bits, const void *be_ip); +void *radix_find_v6(struct radix_trie *trie, uint8_t bits, const void *be_ip); +int radix_insert_v4(struct radix_trie *root, const struct in_addr *ip, + uint8_t cidr, void *data); +int radix_insert_v6(struct radix_trie *root, const struct in6_addr *ip, + uint8_t cidr, void *data); +int radix_tryinsert_v4(struct radix_trie *root, const struct in_addr *ip, + uint8_t cidr, void *data); +int radix_tryinsert_v6(struct radix_trie *root, const struct in6_addr *ip, + uint8_t cidr, void *data); + +#ifdef DEBUG +void node_to_str(struct radix_node *node, char *buf); +void debug_print_trie(struct radix_node *root); +#endif + +#endif diff --git a/wg-dynamic-server.c b/wg-dynamic-server.c index 89a2d86..861c8f3 100644 --- a/wg-dynamic-server.c +++ b/wg-dynamic-server.c @@ -9,6 +9,7 @@ #include <stdint.h> #include <stdio.h> #include <stdlib.h> +#include <string.h> #include <time.h> #include <arpa/inet.h> @@ -23,13 +24,16 @@ #include "common.h" #include "dbg.h" #include "netlink.h" +#include "radix-trie.h" #define MAX_RESPONSE_SIZE 8192 static const char *progname; +static const char *wg_interface; static struct in6_addr well_known; static wg_device *device = NULL; +static struct radix_trie allowedips_trie; static struct pollfd pollfds[MAX_CONNECTIONS + 1]; struct mnl_cb_data { @@ -189,25 +193,110 @@ static int get_avail_pollfds() } } -static int accept_connection(int sockfd) +static void rebuild_allowedips_trie() +{ + int ret; + wg_peer *peer; + wg_allowedip *allowedip; + + radix_free(&allowedips_trie); + + wg_free_device(device); + if (wg_get_device(&device, wg_interface)) + fatal("Unable to access interface %s", wg_interface); + + wg_for_each_peer (device, peer) { + wg_for_each_allowedip (peer, allowedip) { + if (allowedip->family == AF_INET) + ret = radix_insert_v4(&allowedips_trie, + &allowedip->ip4, + allowedip->cidr, peer); + else + ret = radix_insert_v6(&allowedips_trie, + &allowedip->ip6, + allowedip->cidr, peer); + if (ret) + die("Failed to rebuild allowedips trie\n"); + } + } +} + +static wg_key *addr_to_pubkey(struct sockaddr_storage *addr) +{ + wg_peer *peer; + + if (addr->ss_family == AF_INET) + peer = radix_find_v4(&allowedips_trie, 32, + &((struct sockaddr_in *)addr)->sin_addr); + else + peer = radix_find_v6(&allowedips_trie, 128, + &((struct sockaddr_in6 *)addr)->sin6_addr); + + if (!peer) + return NULL; + + return &peer->public_key; +} + +static int accept_connection(int sockfd, wg_key *dest) { int fd; + wg_key *pubkey; + struct sockaddr_storage addr; + socklen_t size = sizeof addr; #ifdef __linux__ - fd = accept4(sockfd, NULL, NULL, SOCK_NONBLOCK); + fd = accept4(sockfd, (struct sockaddr *)&addr, &size, SOCK_NONBLOCK); if (fd < 0) - fatal("Failed to accept connection"); + return -errno; #else - fd = accept(sockfd, NULL, NULL); + fd = accept(sockfd, (struct sockaddr *)&addr, &size); if (fd < 0) - fatal("Failed to accept connection"); + return -errno; int res = fcntl(fd, F_GETFL, 0); if (res < 0 || fcntl(fd, F_SETFL, res | O_NONBLOCK) < 0) fatal("Setting socket to nonblocking failed"); #endif + pubkey = addr_to_pubkey(&addr); + if (!pubkey) { + /* our copy of allowedips is outdated, refresh */ + rebuild_allowedips_trie(); + pubkey = addr_to_pubkey(&addr); + if (!pubkey) { + /* either we lost the race or something is very wrong */ + close(fd); + return -ENOENT; + } + } + memcpy(dest, pubkey, sizeof *pubkey); + + wg_key_b64_string key; + char out[INET6_ADDRSTRLEN]; + wg_key_to_base64(key, *pubkey); + inet_ntop(addr.ss_family, &((struct sockaddr_in6 *)&addr)->sin6_addr, + out, sizeof(out)); + debug("%s has pubkey: %s\n", out, key); + return fd; } +static void accept_incoming(int sockfd, struct wg_dynamic_request *reqs) +{ + int n, fd; + while ((n = get_avail_pollfds()) >= 0) { + fd = accept_connection(sockfd, &reqs[n - 1].pubkey); + if (fd < 0) { + if (fd == -ENOENT) + debug("Failed to match IP to pubkey\n"); + else if (fd != -EAGAIN && fd != -EWOULDBLOCK) + debug("Failed to accept connection: %s\n", + strerror(-fd)); + break; + } + pollfds[n].fd = fd; + } +} + static void close_connection(int *fd, struct wg_dynamic_request *req) { if (close(*fd)) @@ -360,7 +449,7 @@ static bool send_response(int fd, struct wg_dynamic_request *req) static void setup_socket(int *fd) { - int val = 1; + int val = 1, res; struct sockaddr_in6 addr = { .sin6_family = AF_INET6, .sin6_port = htons(WG_DYNAMIC_PORT), @@ -372,6 +461,10 @@ static void setup_socket(int *fd) if (*fd < 0) fatal("Creating a socket failed"); + res = fcntl(*fd, F_GETFL, 0); + if (res < 0 || fcntl(*fd, F_SETFL, res | O_NONBLOCK) < 0) + fatal("Setting socket to nonblocking failed"); + if (setsockopt(*fd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof val) == -1) fatal("Setting socket option failed"); @@ -384,6 +477,7 @@ static void setup_socket(int *fd) static void cleanup() { + radix_free(&allowedips_trie); wg_free_device(device); for (int i = 0; i < MAX_CONNECTIONS + 1; ++i) { @@ -398,8 +492,7 @@ static void cleanup() int main(int argc, char *argv[]) { struct wg_dynamic_request reqs[MAX_CONNECTIONS] = { 0 }; - int *sockfd = &pollfds[0].fd, n; - const char *iface; + int *sockfd = &pollfds[0].fd; progname = argv[0]; inet_pton(AF_INET6, WG_DYNAMIC_ADDR, &well_known); @@ -414,19 +507,22 @@ int main(int argc, char *argv[]) if (argc != 2) usage(); - iface = argv[1]; + radix_init(&allowedips_trie); + + wg_interface = argv[1]; if (atexit(cleanup)) die("Failed to set exit function\n"); - if (wg_get_device(&device, iface)) - fatal("Unable to access interface %s", iface); + rebuild_allowedips_trie(); if (!validate_link_local_ip(device->ifindex)) // TODO: assign IP instead? - die("%s needs to have %s assigned\n", iface, WG_DYNAMIC_ADDR); + die("%s needs to have %s assigned\n", wg_interface, + WG_DYNAMIC_ADDR); if (!valid_peer_found(device)) - die("%s has no peers with link-local allowedips\n", iface); + die("%s has no peers with link-local allowedips\n", + wg_interface); setup_socket(sockfd); @@ -435,11 +531,8 @@ int main(int argc, char *argv[]) fatal("Failed to poll() fds"); if (pollfds[0].revents & POLLIN) { - n = get_avail_pollfds(); - if (n >= 0) { - pollfds[0].revents = 0; - pollfds[n].fd = accept_connection(*sockfd); - } + pollfds[0].revents = 0; + accept_incoming(*sockfd, reqs); } for (int i = 1; i < MAX_CONNECTIONS + 1; ++i) { |