diff options
Diffstat (limited to 'radix-trie.c')
-rw-r--r-- | radix-trie.c | 82 |
1 files changed, 58 insertions, 24 deletions
diff --git a/radix-trie.c b/radix-trie.c index 062490f..18f656c 100644 --- a/radix-trie.c +++ b/radix-trie.c @@ -32,6 +32,7 @@ struct radix_node { struct radix_pool { struct radix_node *node; struct radix_pool *next; + bool shadowed; }; static unsigned int fls64(uint64_t x) @@ -376,15 +377,55 @@ static int remove_node(struct radix_node *trie, const uint8_t *key, return 0; } -static int ipp_addpool(struct radix_pool **pool, struct radix_node **root, - uint8_t bits, const uint8_t *key, uint8_t cidr) +static void totalip_inc(struct ip_pool *ipp, uint8_t bits, uint8_t val) +{ + if (bits == 32) { + BUG_ON(val >= 32); + ipp->total_ipv4 += 1ULL << val; + } else if (bits == 128) { + uint64_t tmp = ipp->totall_ipv6; + BUG_ON(val > 64); + ipp->totall_ipv6 += (val == 64) ? 0 : 1ULL << val; + if (ipp->totall_ipv6 <= tmp) + ++ipp->totalh_ipv6; + } +} + +static void totalip_dec(struct ip_pool *ipp, uint8_t bits, uint8_t val) +{ + if (bits == 32) { + BUG_ON(val >= 32); + ipp->total_ipv4 -= 1ULL << val; + } else if (bits == 128) { + uint64_t tmp = ipp->totall_ipv6; + BUG_ON(val > 64); + ipp->totall_ipv6 -= (val == 64) ? 0 : 1ULL << val; + if (ipp->totall_ipv6 >= tmp) + --ipp->totalh_ipv6; + } +} + +static int ipp_addpool(struct ip_pool *ipp, struct radix_pool **pool, + struct radix_node **root, uint8_t bits, + const uint8_t *key, uint8_t cidr) { struct radix_pool *newpool; struct radix_node *node; + bool shadowed = false; while (*pool) { - if (common_bits((*pool)->node, key, bits) >= cidr) - return -1; + node = (*pool)->node; + + if (common_bits(node, key, bits) >= MIN(cidr, node->cidr)) { + if (cidr > node->cidr) { + shadowed = true; + } else if (cidr < node->cidr && !(*pool)->shadowed) { + (*pool)->shadowed = true; + totalip_dec(ipp, bits, bits - cidr); + } else { + return -1; + } + } pool = &(*pool)->next; } @@ -397,6 +438,9 @@ static int ipp_addpool(struct radix_pool **pool, struct radix_node **root, /* TODO: special case /31 ?, see RFC 3021 */ } + if (!shadowed) + totalip_inc(ipp, bits, bits - cidr); + newpool = malloc(sizeof *newpool); if (!newpool) fatal("malloc()"); @@ -408,6 +452,7 @@ static int ipp_addpool(struct radix_pool **pool, struct radix_node **root, BUG_ON(!node || !prefix_matches(node, key, bits)); } newpool->node = node; + newpool->shadowed = shadowed; newpool->next = NULL; *pool = newpool; @@ -547,41 +592,26 @@ int ipp_del_v6(struct ip_pool *pool, const struct in6_addr *ip, uint8_t cidr) return ret; } -int ipp_addpool_v4(struct ip_pool *pool, const struct in_addr *ip, uint8_t cidr) +int ipp_addpool_v4(struct ip_pool *ipp, const struct in_addr *ip, uint8_t cidr) { uint8_t key[4] __aligned(__alignof(uint32_t)); - int ret; if (cidr <= 0 || cidr >= 32) return -1; swap_endian(key, (const uint8_t *)ip, 32); - ret = ipp_addpool(&pool->ip4_pool, &pool->ip4_root, 32, key, cidr); - if (!ret) - pool->total_ipv4 += 1 << (32 - cidr); - - return ret; + return ipp_addpool(ipp, &ipp->ip4_pool, &ipp->ip4_root, 32, key, cidr); } -int ipp_addpool_v6(struct ip_pool *pool, const struct in6_addr *ip, - uint8_t cidr) +int ipp_addpool_v6(struct ip_pool *ipp, const struct in6_addr *ip, uint8_t cidr) { uint8_t key[16] __aligned(__alignof(uint64_t)); - int ret; if (cidr < 64 || cidr >= 128) return -1; swap_endian(key, (const uint8_t *)ip, 128); - ret = ipp_addpool(&pool->ip6_pool, &pool->ip6_root, 128, key, cidr); - if (!ret) { - uint64_t tmp = pool->totall_ipv6; - pool->totall_ipv6 += (cidr <= 64) ? 0 : 1 << (128 - cidr); - if (pool->totall_ipv6 <= tmp) - ++pool->totalh_ipv6; - } - - return ret; + return ipp_addpool(ipp, &ipp->ip6_pool, &ipp->ip6_root, 128, key, cidr); } void ipp_addnth_v4(struct ip_pool *pool, struct in_addr *dest, uint32_t index) @@ -589,6 +619,9 @@ void ipp_addnth_v4(struct ip_pool *pool, struct in_addr *dest, uint32_t index) struct radix_pool *current = pool->ip4_pool; for (current = pool->ip4_pool; current; current = current->next) { + if (current->shadowed) + continue; + if (index < current->node->left + current->node->right) break; @@ -608,7 +641,8 @@ void ipp_addnth_v6(struct ip_pool *pool, struct in6_addr *dest, uint64_t tmp; while (current) { - if (current->node->left == 0 && current->node->right == 0) { + if (current->shadowed || + (current->node->left == 0 && current->node->right == 0)) { current = current->next; continue; } |