aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorThomas Gschwantner <tharre3@gmail.com>2019-12-03 01:41:41 +0100
committerThomas Gschwantner <tharre3@gmail.com>2019-12-11 06:22:17 +0100
commit7a05c232f8eb66565001022708159127b78f3ce2 (patch)
tree9557db7f399aa733e48e41d38b9ee66aa5b44594
parentradix-trie: fix bug where /64 pools would overflow (diff)
downloadwg-dynamic-7a05c232f8eb66565001022708159127b78f3ce2.tar.xz
wg-dynamic-7a05c232f8eb66565001022708159127b78f3ce2.zip
radix-trie: fix issues related to shadow-/counting
Previously it was possible that pools would not be correctly shadowed and/or the left or right counters were not updated properly. To fix that, every node now has flags indicating what type of node it is, and if it's shadowed. Furthermore, We determine if a poolnode needs to be (un)shadowed by walking the trie now. remove_node() now also only modifies the left right counters if the target node was actually found.
-rw-r--r--radix-trie.c250
1 files changed, 144 insertions, 106 deletions
diff --git a/radix-trie.c b/radix-trie.c
index af9ae3c..237b8b6 100644
--- a/radix-trie.c
+++ b/radix-trie.c
@@ -20,19 +20,23 @@
#define __aligned(x) __attribute__((aligned(x)))
#endif
+enum radix_node_flags {
+ RNODE_IS_LEAF = 1U << 0,
+ RNODE_IS_POOLNODE = 1U << 1,
+ RNODE_IS_SHADOWED = 1U << 2,
+};
+
struct radix_node {
struct radix_node *bit[2];
uint64_t left;
uint64_t right;
uint8_t bits[16];
- uint8_t cidr, bit_at_a, bit_at_b;
- bool is_leaf;
+ uint8_t cidr, bit_at_a, bit_at_b, flags;
};
struct radix_pool {
struct radix_node *node;
struct radix_pool *next;
- bool shadowed;
};
static unsigned int fls64(uint64_t x)
@@ -92,7 +96,7 @@ static struct radix_node *new_node(const uint8_t *key, uint8_t cidr,
node->bit_at_a ^= (bits / 8U - 1U) % 8U;
#endif
node->bit_at_b = 7U - (cidr % 8U);
- node->is_leaf = false;
+ node->flags = 0;
if (bits - cidr > 0 && bits - cidr - 1 < 64)
node->left = node->right = 1ULL << (bits - cidr - 1);
else
@@ -117,26 +121,6 @@ static bool prefix_matches(const struct radix_node *node, const uint8_t *key,
#define CHOOSE_NODE(parent, key) \
(parent)->bit[(key[(parent)->bit_at_a] >> (parent)->bit_at_b) & 1]
-static bool node_placement(struct radix_node *trie, const uint8_t *key,
- uint8_t cidr, uint8_t bits,
- struct radix_node **rnode)
-{
- struct radix_node *node = trie, *parent = NULL;
- bool exact = false;
-
- while (node && node->cidr <= cidr && prefix_matches(node, key, bits)) {
- parent = node;
- if (parent->cidr == cidr) {
- exact = true;
- break;
- }
-
- node = CHOOSE_NODE(parent, key);
- }
- *rnode = parent;
- return exact;
-}
-
static uint64_t subnet_diff(uint8_t *ip1, uint8_t *ip2, uint8_t bits)
{
if (bits == 32)
@@ -145,6 +129,14 @@ static uint64_t subnet_diff(uint8_t *ip1, uint8_t *ip2, uint8_t bits)
return *(const uint64_t *)&ip1[8] - *(const uint64_t *)&ip2[8];
}
+static uint64_t taken_ips(struct radix_node *node, uint8_t bits)
+{
+ if ((bits - node->cidr) >= 64)
+ return 0;
+
+ return (1ULL << (bits - node->cidr)) - (node->left + node->right);
+}
+
static void add_nth(struct radix_node *start, uint8_t bits, uint64_t n,
uint8_t *dest)
{
@@ -201,7 +193,7 @@ static void add_nth(struct radix_node *start, uint8_t bits, uint64_t n,
}
newnode = new_node(ip, cidr, bits);
- newnode->is_leaf = true;
+ newnode->flags |= RNODE_IS_LEAF;
swap_endian(dest, (const uint8_t *)ip, bits);
if (!target) {
@@ -214,43 +206,59 @@ static void add_nth(struct radix_node *start, uint8_t bits, uint64_t n,
CHOOSE_NODE(between, newnode->bits) = newnode;
CHOOSE_NODE(parent, between->bits) = between;
- between->left -=
- (1ULL << (bits - between->bit[0]->cidr)) -
- (between->bit[0]->left + between->bit[0]->right);
- between->right -=
- (1ULL << (bits - between->bit[1]->cidr)) -
- (between->bit[1]->left + between->bit[1]->right);
+ between->left -= taken_ips(between->bit[0], bits);
+ between->right -= taken_ips(between->bit[1], bits);
}
}
-static int add(struct radix_node **trie, uint8_t bits, const uint8_t *key,
- uint8_t cidr, bool is_leaf)
+static struct radix_node *add(struct radix_node **trie, uint8_t bits,
+ const uint8_t *key, uint8_t cidr, uint8_t type)
{
- struct radix_node *node, *newnode, *down, *parent;
+ struct radix_node *node = NULL, *newnode, *down, *parent, *tmp = *trie;
+ bool exact = false, in_pool = false;
- if (cidr > bits)
- return -EINVAL;
+ if (cidr > bits) {
+ errno = EINVAL;
+ return NULL;
+ }
if (!*trie) {
*trie = new_node(key, cidr, bits);
- (*trie)->is_leaf = is_leaf;
- return 0;
+ (*trie)->flags = type;
+ return *trie;
}
- if (node_placement(*trie, key, cidr, bits, &node)) {
- /* exact match, so use the existing node */
- if (node->is_leaf)
- return 1;
+ while (tmp && tmp->cidr <= cidr && prefix_matches(tmp, key, bits)) {
+ node = tmp;
+ if (tmp->flags & RNODE_IS_POOLNODE)
+ in_pool = true;
- node->is_leaf = is_leaf;
- return 0;
+ if (node->cidr == cidr) {
+ exact = true;
+ break;
+ }
+
+ tmp = CHOOSE_NODE(node, key);
}
- if (node && node->is_leaf)
- return 1;
+ if (!in_pool && (type & RNODE_IS_LEAF)) {
+ errno = ENOENT;
+ return NULL;
+ }
+
+ if (exact) {
+ /* exact match, so use the existing node */
+ if (node->flags & type) {
+ errno = EEXIST;
+ return NULL;
+ }
+
+ node->flags = type;
+ return node;
+ }
newnode = new_node(key, cidr, bits);
- newnode->is_leaf = is_leaf;
+ newnode->flags = type;
if (!node) {
down = *trie;
@@ -259,7 +267,7 @@ static int add(struct radix_node **trie, uint8_t bits, const uint8_t *key,
if (!down) {
CHOOSE_NODE(node, key) = newnode;
- return 0;
+ return newnode;
}
}
cidr = MIN(cidr, common_bits(down, key, bits));
@@ -276,13 +284,19 @@ static int add(struct radix_node **trie, uint8_t bits, const uint8_t *key,
CHOOSE_NODE(node, down->bits) = down;
CHOOSE_NODE(node, newnode->bits) = newnode;
+
+ if (CHOOSE_NODE(node, down->bits) == node->bit[0])
+ node->left -= taken_ips(down, bits);
+ else
+ node->right -= taken_ips(down, bits);
+
if (!parent)
*trie = node;
else
CHOOSE_NODE(parent, node->bits) = node;
}
- return 0;
+ return newnode;
}
static void radix_free_nodes(struct radix_node *node)
@@ -322,15 +336,15 @@ static int insert_v4(struct radix_node **root, const struct in_addr *ip,
{
/* Aligned so it can be passed to fls */
uint8_t key[4] __aligned(__alignof(uint32_t));
- int ret;
swap_endian(key, (const uint8_t *)ip, 32);
- ret = add(root, 32, key, cidr, true);
- if (!ret)
+ if (add(root, 32, key, cidr, RNODE_IS_LEAF)) {
decrement_radix(*root, 32, (uint8_t *)key);
+ return 0;
+ }
- return ret;
+ return -1;
}
static int insert_v6(struct radix_node **root, const struct in6_addr *ip,
@@ -338,41 +352,47 @@ static int insert_v6(struct radix_node **root, const struct in6_addr *ip,
{
/* Aligned so it can be passed to fls64 */
uint8_t key[16] __aligned(__alignof(uint64_t));
- int ret;
swap_endian(key, (const uint8_t *)ip, 128);
- ret = add(root, 128, key, cidr, true);
- if (!ret)
+ if (add(root, 128, key, cidr, RNODE_IS_LEAF)) {
decrement_radix(*root, 128, (uint8_t *)key);
+ return 0;
+ }
- return ret;
+ return -1;
}
-static int remove_node(struct radix_node *trie, const uint8_t *key,
+static int remove_node(struct radix_node **trie, const uint8_t *key,
uint8_t bits)
{
- struct radix_node **node = &trie, **target = NULL;
+ struct radix_node **node = trie, **target = NULL;
+ uint64_t *pnodes[127];
+ int i = 0;
while (*node && prefix_matches(*node, key, bits)) {
- if ((*node)->is_leaf) {
+ if ((*node)->flags & RNODE_IS_LEAF) {
target = node;
break;
}
if (CHOOSE_NODE(*node, key) == (*node)->bit[0])
- ++((*node)->left);
+ pnodes[i++] = &((*node)->left);
else
- ++((*node)->right);
+ pnodes[i++] = &((*node)->right);
+ BUG_ON(i >= 127);
node = &CHOOSE_NODE(*node, key);
}
if (!target)
return 1; /* key not found in trie */
+ for (int j = 0; j < i; ++j)
+ ++(*(pnodes[j]));
+
+ free(*node);
*target = NULL;
- radix_free_nodes(*node);
return 0;
}
@@ -391,46 +411,69 @@ static void totalip_inc(struct ipns *ns, uint8_t bits, uint8_t val)
}
}
-static void totalip_dec(struct ipns *ns, uint8_t bits, uint8_t val)
+static void shadow_nodes(struct radix_node *node)
{
- if (bits == 32) {
- BUG_ON(val > 32);
- ns->total_ipv4 -= 1ULL << val;
- } else if (bits == 128) {
- uint64_t tmp = ns->totall_ipv6;
- BUG_ON(val > 64);
- ns->totall_ipv6 -= (val == 64) ? 0 : 1ULL << val;
- if (ns->totall_ipv6 >= tmp)
- --ns->totalh_ipv6;
+ if (!node)
+ return;
+
+ if (node->flags & RNODE_IS_POOLNODE) {
+ BUG_ON(node->flags & RNODE_IS_SHADOWED);
+ node->flags |= RNODE_IS_SHADOWED;
+ return;
}
+
+ if (node->flags & RNODE_IS_LEAF)
+ return;
+
+ shadow_nodes(node->bit[0]);
+ shadow_nodes(node->bit[1]);
}
static int ipp_addpool(struct ipns *ns, struct radix_pool **pool,
struct radix_node **root, uint8_t bits,
const uint8_t *key, uint8_t cidr)
{
+ struct radix_node **node = root, *newnode;
struct radix_pool *newpool;
- struct radix_node *node;
- bool shadowed = false;
-
- while (*pool) {
- node = (*pool)->node;
-
- if (common_bits(node, key, bits) >= MIN(cidr, node->cidr)) {
- if (cidr > node->cidr) {
- shadowed = true;
- } else if (cidr < node->cidr && !(*pool)->shadowed) {
- (*pool)->shadowed = true;
- totalip_dec(ns, bits, bits - cidr);
- } else {
- return -1;
- }
+ bool shadow = false, good_match = false;
+ uint8_t flags;
+
+ while (*node && (*node)->cidr <= cidr &&
+ prefix_matches(*node, key, bits)) {
+ if ((*node)->cidr == cidr) {
+ good_match = true;
+ break;
}
- pool = &(*pool)->next;
+ if ((*node)->flags & RNODE_IS_POOLNODE)
+ shadow = true;
+
+ node = &CHOOSE_NODE(*node, key);
+ }
+
+ flags = RNODE_IS_POOLNODE | (shadow ? RNODE_IS_SHADOWED : 0);
+
+ if (good_match) {
+ if ((*node)->flags & RNODE_IS_POOLNODE)
+ return -1; /* already exists */
+
+ BUG_ON((*node)->flags & RNODE_IS_SHADOWED);
+ (*node)->flags |= flags;
+
+ newnode = *node;
+ } else {
+ newnode = add(node, bits, key, cidr, flags);
+ if (newnode->bit[0])
+ newnode->left -= taken_ips(newnode->bit[0], bits);
+
+ if (newnode->bit[1])
+ newnode->right -= taken_ips(newnode->bit[1], bits);
}
- BUG_ON(add(root, bits, key, cidr, false));
+ if (!shadow) {
+ shadow_nodes(newnode->bit[0]);
+ shadow_nodes(newnode->bit[1]);
+ }
if (bits == 32) {
/* TODO: insert network address (0) and broadcast address (255)
@@ -438,22 +481,15 @@ static int ipp_addpool(struct ipns *ns, struct radix_pool **pool,
/* TODO: special case /31 ?, see RFC 3021 */
}
- if (!shadowed)
+ if (!shadow)
totalip_inc(ns, bits, bits - cidr);
newpool = malloc(sizeof *newpool);
if (!newpool)
fatal("malloc()");
- node = *root;
- while (node->cidr != cidr) {
- node = CHOOSE_NODE(node, key);
-
- BUG_ON(!node || !prefix_matches(node, key, bits));
- }
- newpool->node = node;
- newpool->shadowed = shadowed;
- newpool->next = NULL;
+ newpool->node = newnode;
+ newpool->next = *pool;
*pool = newpool;
return 0;
@@ -498,8 +534,10 @@ static void debug_print_trie(struct radix_node *root, uint8_t bits)
node_to_str(root->bit[0], child1, bits);
node_to_str(root->bit[1], child2, bits);
- debug("%s (%zu, %zu) -> %s, %s\n", parent, root->left, root->right,
- child1, child2);
+ debug("%s (%zu, %zu, %c%c%c) -> %s, %s\n", parent, root->left,
+ root->right, root->flags & RNODE_IS_LEAF ? 'l' : '-',
+ root->flags & RNODE_IS_POOLNODE ? 'p' : '-',
+ root->flags & RNODE_IS_SHADOWED ? 's' : '-', child1, child2);
debug_print_trie(root->bit[0], bits);
debug_print_trie(root->bit[1], bits);
@@ -569,7 +607,7 @@ int ipp_del_v4(struct ipns *ns, const struct in_addr *ip, uint8_t cidr)
int ret;
swap_endian(key, (const uint8_t *)ip, 32);
- ret = remove_node(ns->ip4_root, key, cidr);
+ ret = remove_node(&ns->ip4_root, key, cidr);
if (!ret)
++ns->total_ipv4;
@@ -582,7 +620,7 @@ int ipp_del_v6(struct ipns *ns, const struct in6_addr *ip, uint8_t cidr)
int ret;
swap_endian(key, (const uint8_t *)ip, 128);
- ret = remove_node(ns->ip6_root, key, cidr);
+ ret = remove_node(&ns->ip6_root, key, cidr);
if (!ret) {
++ns->totall_ipv6;
if (ns->totall_ipv6 == 0)
@@ -631,7 +669,7 @@ void ipp_addnth_v4(struct ipns *ns, struct in_addr *dest, uint32_t index)
struct radix_pool *current;
for (current = ns->ip4_pools; current; current = current->next) {
- if (current->shadowed)
+ if (current->node->flags & RNODE_IS_SHADOWED)
continue;
if (index < current->node->left + current->node->right)
@@ -653,7 +691,7 @@ void ipp_addnth_v6(struct ipns *ns, struct in6_addr *dest, uint32_t index_low,
uint64_t tmp;
for (current = ns->ip6_pools; current; current = current->next) {
- if (current->shadowed ||
+ if (current->node->flags & RNODE_IS_SHADOWED ||
(current->node->left == 0 && current->node->right == 0))
continue;