aboutsummaryrefslogtreecommitdiffstats
path: root/radix-trie.c
diff options
context:
space:
mode:
Diffstat (limited to 'radix-trie.c')
-rw-r--r--radix-trie.c250
1 files changed, 144 insertions, 106 deletions
diff --git a/radix-trie.c b/radix-trie.c
index af9ae3c..237b8b6 100644
--- a/radix-trie.c
+++ b/radix-trie.c
@@ -20,19 +20,23 @@
#define __aligned(x) __attribute__((aligned(x)))
#endif
+enum radix_node_flags {
+ RNODE_IS_LEAF = 1U << 0,
+ RNODE_IS_POOLNODE = 1U << 1,
+ RNODE_IS_SHADOWED = 1U << 2,
+};
+
struct radix_node {
struct radix_node *bit[2];
uint64_t left;
uint64_t right;
uint8_t bits[16];
- uint8_t cidr, bit_at_a, bit_at_b;
- bool is_leaf;
+ uint8_t cidr, bit_at_a, bit_at_b, flags;
};
struct radix_pool {
struct radix_node *node;
struct radix_pool *next;
- bool shadowed;
};
static unsigned int fls64(uint64_t x)
@@ -92,7 +96,7 @@ static struct radix_node *new_node(const uint8_t *key, uint8_t cidr,
node->bit_at_a ^= (bits / 8U - 1U) % 8U;
#endif
node->bit_at_b = 7U - (cidr % 8U);
- node->is_leaf = false;
+ node->flags = 0;
if (bits - cidr > 0 && bits - cidr - 1 < 64)
node->left = node->right = 1ULL << (bits - cidr - 1);
else
@@ -117,26 +121,6 @@ static bool prefix_matches(const struct radix_node *node, const uint8_t *key,
#define CHOOSE_NODE(parent, key) \
(parent)->bit[(key[(parent)->bit_at_a] >> (parent)->bit_at_b) & 1]
-static bool node_placement(struct radix_node *trie, const uint8_t *key,
- uint8_t cidr, uint8_t bits,
- struct radix_node **rnode)
-{
- struct radix_node *node = trie, *parent = NULL;
- bool exact = false;
-
- while (node && node->cidr <= cidr && prefix_matches(node, key, bits)) {
- parent = node;
- if (parent->cidr == cidr) {
- exact = true;
- break;
- }
-
- node = CHOOSE_NODE(parent, key);
- }
- *rnode = parent;
- return exact;
-}
-
static uint64_t subnet_diff(uint8_t *ip1, uint8_t *ip2, uint8_t bits)
{
if (bits == 32)
@@ -145,6 +129,14 @@ static uint64_t subnet_diff(uint8_t *ip1, uint8_t *ip2, uint8_t bits)
return *(const uint64_t *)&ip1[8] - *(const uint64_t *)&ip2[8];
}
+static uint64_t taken_ips(struct radix_node *node, uint8_t bits)
+{
+ if ((bits - node->cidr) >= 64)
+ return 0;
+
+ return (1ULL << (bits - node->cidr)) - (node->left + node->right);
+}
+
static void add_nth(struct radix_node *start, uint8_t bits, uint64_t n,
uint8_t *dest)
{
@@ -201,7 +193,7 @@ static void add_nth(struct radix_node *start, uint8_t bits, uint64_t n,
}
newnode = new_node(ip, cidr, bits);
- newnode->is_leaf = true;
+ newnode->flags |= RNODE_IS_LEAF;
swap_endian(dest, (const uint8_t *)ip, bits);
if (!target) {
@@ -214,43 +206,59 @@ static void add_nth(struct radix_node *start, uint8_t bits, uint64_t n,
CHOOSE_NODE(between, newnode->bits) = newnode;
CHOOSE_NODE(parent, between->bits) = between;
- between->left -=
- (1ULL << (bits - between->bit[0]->cidr)) -
- (between->bit[0]->left + between->bit[0]->right);
- between->right -=
- (1ULL << (bits - between->bit[1]->cidr)) -
- (between->bit[1]->left + between->bit[1]->right);
+ between->left -= taken_ips(between->bit[0], bits);
+ between->right -= taken_ips(between->bit[1], bits);
}
}
-static int add(struct radix_node **trie, uint8_t bits, const uint8_t *key,
- uint8_t cidr, bool is_leaf)
+static struct radix_node *add(struct radix_node **trie, uint8_t bits,
+ const uint8_t *key, uint8_t cidr, uint8_t type)
{
- struct radix_node *node, *newnode, *down, *parent;
+ struct radix_node *node = NULL, *newnode, *down, *parent, *tmp = *trie;
+ bool exact = false, in_pool = false;
- if (cidr > bits)
- return -EINVAL;
+ if (cidr > bits) {
+ errno = EINVAL;
+ return NULL;
+ }
if (!*trie) {
*trie = new_node(key, cidr, bits);
- (*trie)->is_leaf = is_leaf;
- return 0;
+ (*trie)->flags = type;
+ return *trie;
}
- if (node_placement(*trie, key, cidr, bits, &node)) {
- /* exact match, so use the existing node */
- if (node->is_leaf)
- return 1;
+ while (tmp && tmp->cidr <= cidr && prefix_matches(tmp, key, bits)) {
+ node = tmp;
+ if (tmp->flags & RNODE_IS_POOLNODE)
+ in_pool = true;
- node->is_leaf = is_leaf;
- return 0;
+ if (node->cidr == cidr) {
+ exact = true;
+ break;
+ }
+
+ tmp = CHOOSE_NODE(node, key);
}
- if (node && node->is_leaf)
- return 1;
+ if (!in_pool && (type & RNODE_IS_LEAF)) {
+ errno = ENOENT;
+ return NULL;
+ }
+
+ if (exact) {
+ /* exact match, so use the existing node */
+ if (node->flags & type) {
+ errno = EEXIST;
+ return NULL;
+ }
+
+ node->flags = type;
+ return node;
+ }
newnode = new_node(key, cidr, bits);
- newnode->is_leaf = is_leaf;
+ newnode->flags = type;
if (!node) {
down = *trie;
@@ -259,7 +267,7 @@ static int add(struct radix_node **trie, uint8_t bits, const uint8_t *key,
if (!down) {
CHOOSE_NODE(node, key) = newnode;
- return 0;
+ return newnode;
}
}
cidr = MIN(cidr, common_bits(down, key, bits));
@@ -276,13 +284,19 @@ static int add(struct radix_node **trie, uint8_t bits, const uint8_t *key,
CHOOSE_NODE(node, down->bits) = down;
CHOOSE_NODE(node, newnode->bits) = newnode;
+
+ if (CHOOSE_NODE(node, down->bits) == node->bit[0])
+ node->left -= taken_ips(down, bits);
+ else
+ node->right -= taken_ips(down, bits);
+
if (!parent)
*trie = node;
else
CHOOSE_NODE(parent, node->bits) = node;
}
- return 0;
+ return newnode;
}
static void radix_free_nodes(struct radix_node *node)
@@ -322,15 +336,15 @@ static int insert_v4(struct radix_node **root, const struct in_addr *ip,
{
/* Aligned so it can be passed to fls */
uint8_t key[4] __aligned(__alignof(uint32_t));
- int ret;
swap_endian(key, (const uint8_t *)ip, 32);
- ret = add(root, 32, key, cidr, true);
- if (!ret)
+ if (add(root, 32, key, cidr, RNODE_IS_LEAF)) {
decrement_radix(*root, 32, (uint8_t *)key);
+ return 0;
+ }
- return ret;
+ return -1;
}
static int insert_v6(struct radix_node **root, const struct in6_addr *ip,
@@ -338,41 +352,47 @@ static int insert_v6(struct radix_node **root, const struct in6_addr *ip,
{
/* Aligned so it can be passed to fls64 */
uint8_t key[16] __aligned(__alignof(uint64_t));
- int ret;
swap_endian(key, (const uint8_t *)ip, 128);
- ret = add(root, 128, key, cidr, true);
- if (!ret)
+ if (add(root, 128, key, cidr, RNODE_IS_LEAF)) {
decrement_radix(*root, 128, (uint8_t *)key);
+ return 0;
+ }
- return ret;
+ return -1;
}
-static int remove_node(struct radix_node *trie, const uint8_t *key,
+static int remove_node(struct radix_node **trie, const uint8_t *key,
uint8_t bits)
{
- struct radix_node **node = &trie, **target = NULL;
+ struct radix_node **node = trie, **target = NULL;
+ uint64_t *pnodes[127];
+ int i = 0;
while (*node && prefix_matches(*node, key, bits)) {
- if ((*node)->is_leaf) {
+ if ((*node)->flags & RNODE_IS_LEAF) {
target = node;
break;
}
if (CHOOSE_NODE(*node, key) == (*node)->bit[0])
- ++((*node)->left);
+ pnodes[i++] = &((*node)->left);
else
- ++((*node)->right);
+ pnodes[i++] = &((*node)->right);
+ BUG_ON(i >= 127);
node = &CHOOSE_NODE(*node, key);
}
if (!target)
return 1; /* key not found in trie */
+ for (int j = 0; j < i; ++j)
+ ++(*(pnodes[j]));
+
+ free(*node);
*target = NULL;
- radix_free_nodes(*node);
return 0;
}
@@ -391,46 +411,69 @@ static void totalip_inc(struct ipns *ns, uint8_t bits, uint8_t val)
}
}
-static void totalip_dec(struct ipns *ns, uint8_t bits, uint8_t val)
+static void shadow_nodes(struct radix_node *node)
{
- if (bits == 32) {
- BUG_ON(val > 32);
- ns->total_ipv4 -= 1ULL << val;
- } else if (bits == 128) {
- uint64_t tmp = ns->totall_ipv6;
- BUG_ON(val > 64);
- ns->totall_ipv6 -= (val == 64) ? 0 : 1ULL << val;
- if (ns->totall_ipv6 >= tmp)
- --ns->totalh_ipv6;
+ if (!node)
+ return;
+
+ if (node->flags & RNODE_IS_POOLNODE) {
+ BUG_ON(node->flags & RNODE_IS_SHADOWED);
+ node->flags |= RNODE_IS_SHADOWED;
+ return;
}
+
+ if (node->flags & RNODE_IS_LEAF)
+ return;
+
+ shadow_nodes(node->bit[0]);
+ shadow_nodes(node->bit[1]);
}
static int ipp_addpool(struct ipns *ns, struct radix_pool **pool,
struct radix_node **root, uint8_t bits,
const uint8_t *key, uint8_t cidr)
{
+ struct radix_node **node = root, *newnode;
struct radix_pool *newpool;
- struct radix_node *node;
- bool shadowed = false;
-
- while (*pool) {
- node = (*pool)->node;
-
- if (common_bits(node, key, bits) >= MIN(cidr, node->cidr)) {
- if (cidr > node->cidr) {
- shadowed = true;
- } else if (cidr < node->cidr && !(*pool)->shadowed) {
- (*pool)->shadowed = true;
- totalip_dec(ns, bits, bits - cidr);
- } else {
- return -1;
- }
+ bool shadow = false, good_match = false;
+ uint8_t flags;
+
+ while (*node && (*node)->cidr <= cidr &&
+ prefix_matches(*node, key, bits)) {
+ if ((*node)->cidr == cidr) {
+ good_match = true;
+ break;
}
- pool = &(*pool)->next;
+ if ((*node)->flags & RNODE_IS_POOLNODE)
+ shadow = true;
+
+ node = &CHOOSE_NODE(*node, key);
+ }
+
+ flags = RNODE_IS_POOLNODE | (shadow ? RNODE_IS_SHADOWED : 0);
+
+ if (good_match) {
+ if ((*node)->flags & RNODE_IS_POOLNODE)
+ return -1; /* already exists */
+
+ BUG_ON((*node)->flags & RNODE_IS_SHADOWED);
+ (*node)->flags |= flags;
+
+ newnode = *node;
+ } else {
+ newnode = add(node, bits, key, cidr, flags);
+ if (newnode->bit[0])
+ newnode->left -= taken_ips(newnode->bit[0], bits);
+
+ if (newnode->bit[1])
+ newnode->right -= taken_ips(newnode->bit[1], bits);
}
- BUG_ON(add(root, bits, key, cidr, false));
+ if (!shadow) {
+ shadow_nodes(newnode->bit[0]);
+ shadow_nodes(newnode->bit[1]);
+ }
if (bits == 32) {
/* TODO: insert network address (0) and broadcast address (255)
@@ -438,22 +481,15 @@ static int ipp_addpool(struct ipns *ns, struct radix_pool **pool,
/* TODO: special case /31 ?, see RFC 3021 */
}
- if (!shadowed)
+ if (!shadow)
totalip_inc(ns, bits, bits - cidr);
newpool = malloc(sizeof *newpool);
if (!newpool)
fatal("malloc()");
- node = *root;
- while (node->cidr != cidr) {
- node = CHOOSE_NODE(node, key);
-
- BUG_ON(!node || !prefix_matches(node, key, bits));
- }
- newpool->node = node;
- newpool->shadowed = shadowed;
- newpool->next = NULL;
+ newpool->node = newnode;
+ newpool->next = *pool;
*pool = newpool;
return 0;
@@ -498,8 +534,10 @@ static void debug_print_trie(struct radix_node *root, uint8_t bits)
node_to_str(root->bit[0], child1, bits);
node_to_str(root->bit[1], child2, bits);
- debug("%s (%zu, %zu) -> %s, %s\n", parent, root->left, root->right,
- child1, child2);
+ debug("%s (%zu, %zu, %c%c%c) -> %s, %s\n", parent, root->left,
+ root->right, root->flags & RNODE_IS_LEAF ? 'l' : '-',
+ root->flags & RNODE_IS_POOLNODE ? 'p' : '-',
+ root->flags & RNODE_IS_SHADOWED ? 's' : '-', child1, child2);
debug_print_trie(root->bit[0], bits);
debug_print_trie(root->bit[1], bits);
@@ -569,7 +607,7 @@ int ipp_del_v4(struct ipns *ns, const struct in_addr *ip, uint8_t cidr)
int ret;
swap_endian(key, (const uint8_t *)ip, 32);
- ret = remove_node(ns->ip4_root, key, cidr);
+ ret = remove_node(&ns->ip4_root, key, cidr);
if (!ret)
++ns->total_ipv4;
@@ -582,7 +620,7 @@ int ipp_del_v6(struct ipns *ns, const struct in6_addr *ip, uint8_t cidr)
int ret;
swap_endian(key, (const uint8_t *)ip, 128);
- ret = remove_node(ns->ip6_root, key, cidr);
+ ret = remove_node(&ns->ip6_root, key, cidr);
if (!ret) {
++ns->totall_ipv6;
if (ns->totall_ipv6 == 0)
@@ -631,7 +669,7 @@ void ipp_addnth_v4(struct ipns *ns, struct in_addr *dest, uint32_t index)
struct radix_pool *current;
for (current = ns->ip4_pools; current; current = current->next) {
- if (current->shadowed)
+ if (current->node->flags & RNODE_IS_SHADOWED)
continue;
if (index < current->node->left + current->node->right)
@@ -653,7 +691,7 @@ void ipp_addnth_v6(struct ipns *ns, struct in6_addr *dest, uint32_t index_low,
uint64_t tmp;
for (current = ns->ip6_pools; current; current = current->next) {
- if (current->shadowed ||
+ if (current->node->flags & RNODE_IS_SHADOWED ||
(current->node->left == 0 && current->node->right == 0))
continue;