From eb3ec870ba48f4be7bbfe30a2a722e53e5f25e1f Mon Sep 17 00:00:00 2001 From: Thomas Gschwantner Date: Mon, 25 Feb 2019 09:41:19 +0100 Subject: Implement a radix-trie for storing ip=pubkey --- Makefile | 2 +- radix-trie.c | 330 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ radix-trie.h | 35 +++++++ 3 files changed, 366 insertions(+), 1 deletion(-) create mode 100644 radix-trie.c create mode 100644 radix-trie.h diff --git a/Makefile b/Makefile index d7dce6a..575fc38 100644 --- a/Makefile +++ b/Makefile @@ -46,7 +46,7 @@ endif all: wg-dynamic-server wg-dynamic-client wg-dynamic-client: wg-dynamic-client.o netlink.o common.o -wg-dynamic-server: wg-dynamic-server.o netlink.o common.o +wg-dynamic-server: wg-dynamic-server.o netlink.o radix-trie.o common.o ifneq ($(V),1) clean: diff --git a/radix-trie.c b/radix-trie.c new file mode 100644 index 0000000..0e67a52 --- /dev/null +++ b/radix-trie.c @@ -0,0 +1,330 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +#define _DEFAULT_SOURCE +#include + +#include +#include +#include +#include +#include +#include + +#include "dbg.h" +#include "radix-trie.h" + +#define MIN(X, Y) (((X) < (Y)) ? (X) : (Y)) + +struct radix_node { + struct radix_node *bit[2]; + void *data; + uint8_t bits[16]; + uint8_t cidr, bit_at_a, bit_at_b; +}; + +// TODO: sort out #ifdef business to make this portable +static unsigned int fls64(uint64_t a) +{ + return __builtin_ctzl(a) + 1; +} + +static unsigned int fls(uint32_t a) +{ + return __builtin_ctz(a) + 1; +} + +static unsigned int fls128(uint64_t a, uint64_t b) +{ + return a ? fls64(a) + 64U : (b ? fls64(b) : 0); +} + +static void swap_endian(uint8_t *dst, const uint8_t *src, uint8_t bits) +{ + if (bits == 32) { + *(uint32_t *)dst = be32toh(*(const uint32_t *)src); + } else if (bits == 128) { + ((uint64_t *)dst)[0] = be64toh(((const uint64_t *)src)[0]); + ((uint64_t *)dst)[1] = be64toh(((const uint64_t *)src)[1]); + } +} + +static uint8_t common_bits(const struct radix_node *node, const uint8_t *key, + uint8_t bits) +{ + if (bits == 32) + return 32U - fls(*(const uint32_t *)node->bits ^ + *(const uint32_t *)key); + else if (bits == 128) + return 128U - fls128(*(const uint64_t *)&node->bits[0] ^ + *(const uint64_t *)&key[0], + *(const uint64_t *)&node->bits[8] ^ + *(const uint64_t *)&key[8]); + return 0; +} + +static struct radix_node *new_node(const uint8_t *key, uint8_t cidr, + uint8_t bits) +{ + struct radix_node *node; + + node = malloc(sizeof *node); + if (!node) + fatal("malloc()"); + + node->bit[0] = node->bit[1] = node->data = NULL; + node->cidr = cidr; + node->bit_at_a = cidr / 8U; +#ifdef __LITTLE_ENDIAN + node->bit_at_a ^= (bits / 8U - 1U) % 8U; +#endif + node->bit_at_b = 7U - (cidr % 8U); + memcpy(node->bits, key, bits / 8U); + + return node; +} + +static bool prefix_matches(const struct radix_node *node, const uint8_t *key, + uint8_t bits) +{ + return common_bits(node, key, bits) >= node->cidr; +} + +#define CHOOSE_NODE(parent, key) \ + parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1] + +static bool node_placement(struct radix_node *trie, const uint8_t *key, + uint8_t cidr, uint8_t bits, + struct radix_node **rnode) +{ + struct radix_node *node = trie, *parent = NULL; + bool exact = false; + + while (node && node->cidr <= cidr && prefix_matches(node, key, bits)) { + parent = node; + if (parent->cidr == cidr) { + exact = true; + break; + } + + node = CHOOSE_NODE(parent, key); + } + *rnode = parent; + return exact; +} + +static int add(struct radix_node **trie, uint8_t bits, const uint8_t *key, + uint8_t cidr, void *data, bool overwrite) +{ + struct radix_node *node, *newnode, *down, *parent; + + if (cidr > bits || !data) + return -EINVAL; + + if (!*trie) { + *trie = new_node(key, cidr, bits); + (*trie)->data = data; + return 0; + } + + if (node_placement(*trie, key, cidr, bits, &node)) { + // exact match, so use the existing node + if (!overwrite && node->data) + return 1; + + node->data = data; + return 0; + } + + if (!overwrite && node && node->data) + return 1; + + newnode = new_node(key, cidr, bits); + newnode->data = data; + + if (!node) { + down = *trie; + } else { + down = CHOOSE_NODE(node, key); + + if (!down) { + CHOOSE_NODE(node, key) = newnode; + return 0; + } + } + cidr = MIN(cidr, common_bits(down, key, bits)); + parent = node; + + if (newnode->cidr == cidr) { + CHOOSE_NODE(newnode, down->bits) = down; + if (!parent) + *trie = newnode; + else + CHOOSE_NODE(parent, newnode->bits) = newnode; + } else { + node = new_node(newnode->bits, cidr, bits); + + CHOOSE_NODE(node, down->bits) = down; + CHOOSE_NODE(node, newnode->bits) = newnode; + if (!parent) + *trie = node; + else + CHOOSE_NODE(parent, node->bits) = node; + } + + return 0; +} + +static void radix_free_nodes(struct radix_node *node) +{ + struct radix_node *old, *bottom = node; + + while (node) { + while (bottom->bit[0]) + bottom = bottom->bit[0]; + bottom->bit[0] = node->bit[1]; + + old = node; + node = node->bit[0]; + free(old); + } +} + +#ifndef __aligned +#define __aligned(x) __attribute__((aligned(x))) +#endif + +static int insert_v4(struct radix_node **root, const struct in_addr *ip, + uint8_t cidr, void *data, bool overwrite) +{ + /* Aligned so it can be passed to fls */ + uint8_t key[4] __aligned(__alignof(uint32_t)); + + swap_endian(key, (const uint8_t *)ip, 32); + return add(root, 32, key, cidr, data, overwrite); +} + +static int insert_v6(struct radix_node **root, const struct in6_addr *ip, + uint8_t cidr, void *data, bool overwrite) +{ + /* Aligned so it can be passed to fls64 */ + uint8_t key[16] __aligned(__alignof(uint64_t)); + + swap_endian(key, (const uint8_t *)ip, 128); + return add(root, 128, key, cidr, data, overwrite); +} + +static struct radix_node *find_node(struct radix_node *trie, uint8_t bits, + const uint8_t *key) +{ + struct radix_node *node = trie, *found = NULL; + + while (node && prefix_matches(node, key, bits)) { + if (node->data) + found = node; + if (node->cidr == bits) + break; + node = CHOOSE_NODE(node, key); + } + return found; +} + +static struct radix_node *lookup(struct radix_node *root, uint8_t bits, + const void *be_ip) +{ + /* Aligned so it can be passed to fls/fls64 */ + uint8_t ip[16] __aligned(__alignof(uint64_t)); + struct radix_node *node; + + swap_endian(ip, be_ip, bits); + node = find_node(root, bits, ip); + return node; +} + +#ifdef DEBUG +#include +void node_to_str(struct radix_node *node, char *buf) +{ + struct in6_addr addr; + char out[INET6_ADDRSTRLEN]; + char cidr[5]; + + if (!node) { + strcpy(buf, "-"); + return; + } + + swap_endian(addr.s6_addr, node->bits, 128); + inet_ntop(AF_INET6, &addr, out, sizeof out); + snprintf(cidr, sizeof cidr, "/%u", node->cidr); + strcpy(buf, out); + strcat(buf, cidr); +} + +void debug_print_trie(struct radix_node *root) +{ + char parent[INET6_ADDRSTRLEN + 4], child1[INET6_ADDRSTRLEN + 4], + child2[INET6_ADDRSTRLEN + 4]; + + if (!root) + return; + + node_to_str(root, parent); + node_to_str(root->bit[0], child1); + node_to_str(root->bit[1], child2); + + debug("%s -> %s, %s\n", parent, child1, child2); + + debug_print_trie(root->bit[0]); + debug_print_trie(root->bit[1]); +} +#endif + +void radix_init(struct radix_trie *trie) +{ + trie->ip4_root = trie->ip6_root = NULL; +} + +void radix_free(struct radix_trie *trie) +{ + radix_free_nodes(trie->ip4_root); + radix_free_nodes(trie->ip6_root); +} + +void *radix_find_v4(struct radix_trie *trie, uint8_t bits, const void *be_ip) +{ + struct radix_node *found = lookup(trie->ip4_root, bits, be_ip); + return found ? found->data : NULL; +} + +void *radix_find_v6(struct radix_trie *trie, uint8_t bits, const void *be_ip) +{ + struct radix_node *found = lookup(trie->ip6_root, bits, be_ip); + return found ? found->data : NULL; +} + +int radix_insert_v4(struct radix_trie *root, const struct in_addr *ip, + uint8_t cidr, void *data) +{ + return insert_v4(&root->ip4_root, ip, cidr, data, true); +} + +int radix_insert_v6(struct radix_trie *root, const struct in6_addr *ip, + uint8_t cidr, void *data) +{ + return insert_v6(&root->ip6_root, ip, cidr, data, true); +} + +int radix_tryinsert_v4(struct radix_trie *root, const struct in_addr *ip, + uint8_t cidr, void *data) +{ + return insert_v4(&root->ip4_root, ip, cidr, data, false); +} + +int radix_tryinsert_v6(struct radix_trie *root, const struct in6_addr *ip, + uint8_t cidr, void *data) +{ + return insert_v6(&root->ip6_root, ip, cidr, data, false); +} diff --git a/radix-trie.h b/radix-trie.h new file mode 100644 index 0000000..3d0b00f --- /dev/null +++ b/radix-trie.h @@ -0,0 +1,35 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +#ifndef __RADIX_TRIE_H__ +#define __RADIX_TRIE_H__ + +#include +#include +#include + +struct radix_trie { + struct radix_node *ip4_root, *ip6_root; +}; + +void radix_init(struct radix_trie *trie); +void radix_free(struct radix_trie *trie); +void *radix_find_v4(struct radix_trie *trie, uint8_t bits, const void *be_ip); +void *radix_find_v6(struct radix_trie *trie, uint8_t bits, const void *be_ip); +int radix_insert_v4(struct radix_trie *root, const struct in_addr *ip, + uint8_t cidr, void *data); +int radix_insert_v6(struct radix_trie *root, const struct in6_addr *ip, + uint8_t cidr, void *data); +int radix_tryinsert_v4(struct radix_trie *root, const struct in_addr *ip, + uint8_t cidr, void *data); +int radix_tryinsert_v6(struct radix_trie *root, const struct in6_addr *ip, + uint8_t cidr, void *data); + +#ifdef DEBUG +void node_to_str(struct radix_node *node, char *buf); +void debug_print_trie(struct radix_node *root); +#endif + +#endif -- cgit v1.2.3-59-g8ed1b