diff options
Diffstat (limited to '')
-rw-r--r-- | radix-trie.c | 250 |
1 files changed, 144 insertions, 106 deletions
diff --git a/radix-trie.c b/radix-trie.c index af9ae3c..237b8b6 100644 --- a/radix-trie.c +++ b/radix-trie.c @@ -20,19 +20,23 @@ #define __aligned(x) __attribute__((aligned(x))) #endif +enum radix_node_flags { + RNODE_IS_LEAF = 1U << 0, + RNODE_IS_POOLNODE = 1U << 1, + RNODE_IS_SHADOWED = 1U << 2, +}; + struct radix_node { struct radix_node *bit[2]; uint64_t left; uint64_t right; uint8_t bits[16]; - uint8_t cidr, bit_at_a, bit_at_b; - bool is_leaf; + uint8_t cidr, bit_at_a, bit_at_b, flags; }; struct radix_pool { struct radix_node *node; struct radix_pool *next; - bool shadowed; }; static unsigned int fls64(uint64_t x) @@ -92,7 +96,7 @@ static struct radix_node *new_node(const uint8_t *key, uint8_t cidr, node->bit_at_a ^= (bits / 8U - 1U) % 8U; #endif node->bit_at_b = 7U - (cidr % 8U); - node->is_leaf = false; + node->flags = 0; if (bits - cidr > 0 && bits - cidr - 1 < 64) node->left = node->right = 1ULL << (bits - cidr - 1); else @@ -117,26 +121,6 @@ static bool prefix_matches(const struct radix_node *node, const uint8_t *key, #define CHOOSE_NODE(parent, key) \ (parent)->bit[(key[(parent)->bit_at_a] >> (parent)->bit_at_b) & 1] -static bool node_placement(struct radix_node *trie, const uint8_t *key, - uint8_t cidr, uint8_t bits, - struct radix_node **rnode) -{ - struct radix_node *node = trie, *parent = NULL; - bool exact = false; - - while (node && node->cidr <= cidr && prefix_matches(node, key, bits)) { - parent = node; - if (parent->cidr == cidr) { - exact = true; - break; - } - - node = CHOOSE_NODE(parent, key); - } - *rnode = parent; - return exact; -} - static uint64_t subnet_diff(uint8_t *ip1, uint8_t *ip2, uint8_t bits) { if (bits == 32) @@ -145,6 +129,14 @@ static uint64_t subnet_diff(uint8_t *ip1, uint8_t *ip2, uint8_t bits) return *(const uint64_t *)&ip1[8] - *(const uint64_t *)&ip2[8]; } +static uint64_t taken_ips(struct radix_node *node, uint8_t bits) +{ + if ((bits - node->cidr) >= 64) + return 0; + + return (1ULL << (bits - node->cidr)) - (node->left + node->right); +} + static void add_nth(struct radix_node *start, uint8_t bits, uint64_t n, uint8_t *dest) { @@ -201,7 +193,7 @@ static void add_nth(struct radix_node *start, uint8_t bits, uint64_t n, } newnode = new_node(ip, cidr, bits); - newnode->is_leaf = true; + newnode->flags |= RNODE_IS_LEAF; swap_endian(dest, (const uint8_t *)ip, bits); if (!target) { @@ -214,43 +206,59 @@ static void add_nth(struct radix_node *start, uint8_t bits, uint64_t n, CHOOSE_NODE(between, newnode->bits) = newnode; CHOOSE_NODE(parent, between->bits) = between; - between->left -= - (1ULL << (bits - between->bit[0]->cidr)) - - (between->bit[0]->left + between->bit[0]->right); - between->right -= - (1ULL << (bits - between->bit[1]->cidr)) - - (between->bit[1]->left + between->bit[1]->right); + between->left -= taken_ips(between->bit[0], bits); + between->right -= taken_ips(between->bit[1], bits); } } -static int add(struct radix_node **trie, uint8_t bits, const uint8_t *key, - uint8_t cidr, bool is_leaf) +static struct radix_node *add(struct radix_node **trie, uint8_t bits, + const uint8_t *key, uint8_t cidr, uint8_t type) { - struct radix_node *node, *newnode, *down, *parent; + struct radix_node *node = NULL, *newnode, *down, *parent, *tmp = *trie; + bool exact = false, in_pool = false; - if (cidr > bits) - return -EINVAL; + if (cidr > bits) { + errno = EINVAL; + return NULL; + } if (!*trie) { *trie = new_node(key, cidr, bits); - (*trie)->is_leaf = is_leaf; - return 0; + (*trie)->flags = type; + return *trie; } - if (node_placement(*trie, key, cidr, bits, &node)) { - /* exact match, so use the existing node */ - if (node->is_leaf) - return 1; + while (tmp && tmp->cidr <= cidr && prefix_matches(tmp, key, bits)) { + node = tmp; + if (tmp->flags & RNODE_IS_POOLNODE) + in_pool = true; - node->is_leaf = is_leaf; - return 0; + if (node->cidr == cidr) { + exact = true; + break; + } + + tmp = CHOOSE_NODE(node, key); } - if (node && node->is_leaf) - return 1; + if (!in_pool && (type & RNODE_IS_LEAF)) { + errno = ENOENT; + return NULL; + } + + if (exact) { + /* exact match, so use the existing node */ + if (node->flags & type) { + errno = EEXIST; + return NULL; + } + + node->flags = type; + return node; + } newnode = new_node(key, cidr, bits); - newnode->is_leaf = is_leaf; + newnode->flags = type; if (!node) { down = *trie; @@ -259,7 +267,7 @@ static int add(struct radix_node **trie, uint8_t bits, const uint8_t *key, if (!down) { CHOOSE_NODE(node, key) = newnode; - return 0; + return newnode; } } cidr = MIN(cidr, common_bits(down, key, bits)); @@ -276,13 +284,19 @@ static int add(struct radix_node **trie, uint8_t bits, const uint8_t *key, CHOOSE_NODE(node, down->bits) = down; CHOOSE_NODE(node, newnode->bits) = newnode; + + if (CHOOSE_NODE(node, down->bits) == node->bit[0]) + node->left -= taken_ips(down, bits); + else + node->right -= taken_ips(down, bits); + if (!parent) *trie = node; else CHOOSE_NODE(parent, node->bits) = node; } - return 0; + return newnode; } static void radix_free_nodes(struct radix_node *node) @@ -322,15 +336,15 @@ static int insert_v4(struct radix_node **root, const struct in_addr *ip, { /* Aligned so it can be passed to fls */ uint8_t key[4] __aligned(__alignof(uint32_t)); - int ret; swap_endian(key, (const uint8_t *)ip, 32); - ret = add(root, 32, key, cidr, true); - if (!ret) + if (add(root, 32, key, cidr, RNODE_IS_LEAF)) { decrement_radix(*root, 32, (uint8_t *)key); + return 0; + } - return ret; + return -1; } static int insert_v6(struct radix_node **root, const struct in6_addr *ip, @@ -338,41 +352,47 @@ static int insert_v6(struct radix_node **root, const struct in6_addr *ip, { /* Aligned so it can be passed to fls64 */ uint8_t key[16] __aligned(__alignof(uint64_t)); - int ret; swap_endian(key, (const uint8_t *)ip, 128); - ret = add(root, 128, key, cidr, true); - if (!ret) + if (add(root, 128, key, cidr, RNODE_IS_LEAF)) { decrement_radix(*root, 128, (uint8_t *)key); + return 0; + } - return ret; + return -1; } -static int remove_node(struct radix_node *trie, const uint8_t *key, +static int remove_node(struct radix_node **trie, const uint8_t *key, uint8_t bits) { - struct radix_node **node = &trie, **target = NULL; + struct radix_node **node = trie, **target = NULL; + uint64_t *pnodes[127]; + int i = 0; while (*node && prefix_matches(*node, key, bits)) { - if ((*node)->is_leaf) { + if ((*node)->flags & RNODE_IS_LEAF) { target = node; break; } if (CHOOSE_NODE(*node, key) == (*node)->bit[0]) - ++((*node)->left); + pnodes[i++] = &((*node)->left); else - ++((*node)->right); + pnodes[i++] = &((*node)->right); + BUG_ON(i >= 127); node = &CHOOSE_NODE(*node, key); } if (!target) return 1; /* key not found in trie */ + for (int j = 0; j < i; ++j) + ++(*(pnodes[j])); + + free(*node); *target = NULL; - radix_free_nodes(*node); return 0; } @@ -391,46 +411,69 @@ static void totalip_inc(struct ipns *ns, uint8_t bits, uint8_t val) } } -static void totalip_dec(struct ipns *ns, uint8_t bits, uint8_t val) +static void shadow_nodes(struct radix_node *node) { - if (bits == 32) { - BUG_ON(val > 32); - ns->total_ipv4 -= 1ULL << val; - } else if (bits == 128) { - uint64_t tmp = ns->totall_ipv6; - BUG_ON(val > 64); - ns->totall_ipv6 -= (val == 64) ? 0 : 1ULL << val; - if (ns->totall_ipv6 >= tmp) - --ns->totalh_ipv6; + if (!node) + return; + + if (node->flags & RNODE_IS_POOLNODE) { + BUG_ON(node->flags & RNODE_IS_SHADOWED); + node->flags |= RNODE_IS_SHADOWED; + return; } + + if (node->flags & RNODE_IS_LEAF) + return; + + shadow_nodes(node->bit[0]); + shadow_nodes(node->bit[1]); } static int ipp_addpool(struct ipns *ns, struct radix_pool **pool, struct radix_node **root, uint8_t bits, const uint8_t *key, uint8_t cidr) { + struct radix_node **node = root, *newnode; struct radix_pool *newpool; - struct radix_node *node; - bool shadowed = false; - - while (*pool) { - node = (*pool)->node; - - if (common_bits(node, key, bits) >= MIN(cidr, node->cidr)) { - if (cidr > node->cidr) { - shadowed = true; - } else if (cidr < node->cidr && !(*pool)->shadowed) { - (*pool)->shadowed = true; - totalip_dec(ns, bits, bits - cidr); - } else { - return -1; - } + bool shadow = false, good_match = false; + uint8_t flags; + + while (*node && (*node)->cidr <= cidr && + prefix_matches(*node, key, bits)) { + if ((*node)->cidr == cidr) { + good_match = true; + break; } - pool = &(*pool)->next; + if ((*node)->flags & RNODE_IS_POOLNODE) + shadow = true; + + node = &CHOOSE_NODE(*node, key); + } + + flags = RNODE_IS_POOLNODE | (shadow ? RNODE_IS_SHADOWED : 0); + + if (good_match) { + if ((*node)->flags & RNODE_IS_POOLNODE) + return -1; /* already exists */ + + BUG_ON((*node)->flags & RNODE_IS_SHADOWED); + (*node)->flags |= flags; + + newnode = *node; + } else { + newnode = add(node, bits, key, cidr, flags); + if (newnode->bit[0]) + newnode->left -= taken_ips(newnode->bit[0], bits); + + if (newnode->bit[1]) + newnode->right -= taken_ips(newnode->bit[1], bits); } - BUG_ON(add(root, bits, key, cidr, false)); + if (!shadow) { + shadow_nodes(newnode->bit[0]); + shadow_nodes(newnode->bit[1]); + } if (bits == 32) { /* TODO: insert network address (0) and broadcast address (255) @@ -438,22 +481,15 @@ static int ipp_addpool(struct ipns *ns, struct radix_pool **pool, /* TODO: special case /31 ?, see RFC 3021 */ } - if (!shadowed) + if (!shadow) totalip_inc(ns, bits, bits - cidr); newpool = malloc(sizeof *newpool); if (!newpool) fatal("malloc()"); - node = *root; - while (node->cidr != cidr) { - node = CHOOSE_NODE(node, key); - - BUG_ON(!node || !prefix_matches(node, key, bits)); - } - newpool->node = node; - newpool->shadowed = shadowed; - newpool->next = NULL; + newpool->node = newnode; + newpool->next = *pool; *pool = newpool; return 0; @@ -498,8 +534,10 @@ static void debug_print_trie(struct radix_node *root, uint8_t bits) node_to_str(root->bit[0], child1, bits); node_to_str(root->bit[1], child2, bits); - debug("%s (%zu, %zu) -> %s, %s\n", parent, root->left, root->right, - child1, child2); + debug("%s (%zu, %zu, %c%c%c) -> %s, %s\n", parent, root->left, + root->right, root->flags & RNODE_IS_LEAF ? 'l' : '-', + root->flags & RNODE_IS_POOLNODE ? 'p' : '-', + root->flags & RNODE_IS_SHADOWED ? 's' : '-', child1, child2); debug_print_trie(root->bit[0], bits); debug_print_trie(root->bit[1], bits); @@ -569,7 +607,7 @@ int ipp_del_v4(struct ipns *ns, const struct in_addr *ip, uint8_t cidr) int ret; swap_endian(key, (const uint8_t *)ip, 32); - ret = remove_node(ns->ip4_root, key, cidr); + ret = remove_node(&ns->ip4_root, key, cidr); if (!ret) ++ns->total_ipv4; @@ -582,7 +620,7 @@ int ipp_del_v6(struct ipns *ns, const struct in6_addr *ip, uint8_t cidr) int ret; swap_endian(key, (const uint8_t *)ip, 128); - ret = remove_node(ns->ip6_root, key, cidr); + ret = remove_node(&ns->ip6_root, key, cidr); if (!ret) { ++ns->totall_ipv6; if (ns->totall_ipv6 == 0) @@ -631,7 +669,7 @@ void ipp_addnth_v4(struct ipns *ns, struct in_addr *dest, uint32_t index) struct radix_pool *current; for (current = ns->ip4_pools; current; current = current->next) { - if (current->shadowed) + if (current->node->flags & RNODE_IS_SHADOWED) continue; if (index < current->node->left + current->node->right) @@ -653,7 +691,7 @@ void ipp_addnth_v6(struct ipns *ns, struct in6_addr *dest, uint32_t index_low, uint64_t tmp; for (current = ns->ip6_pools; current; current = current->next) { - if (current->shadowed || + if (current->node->flags & RNODE_IS_SHADOWED || (current->node->left == 0 && current->node->right == 0)) continue; |