From 571ef898b92639052314cb5be63609bd97c19206 Mon Sep 17 00:00:00 2001 From: Thomas Gschwantner Date: Sat, 13 Apr 2019 22:17:18 +0200 Subject: radix-trie: implement ipp_addrnd_* and related --- radix-trie.c | 359 ++++++++++++++++++++++++++++++++++++++++++++++++++--------- radix-trie.h | 39 ++++--- 2 files changed, 328 insertions(+), 70 deletions(-) diff --git a/radix-trie.c b/radix-trie.c index f314913..fb39047 100644 --- a/radix-trie.c +++ b/radix-trie.c @@ -18,11 +18,22 @@ #define MIN(X, Y) (((X) < (Y)) ? (X) : (Y)) +#ifndef __aligned +#define __aligned(x) __attribute__((aligned(x))) +#endif + struct radix_node { struct radix_node *bit[2]; - void *data; + uint64_t left; + uint64_t right; uint8_t bits[16]; uint8_t cidr, bit_at_a, bit_at_b; + bool is_leaf; +}; + +struct radix_pool { + struct radix_node *node; + struct radix_pool *next; }; static unsigned int fls64(uint64_t x) @@ -69,19 +80,31 @@ static struct radix_node *new_node(const uint8_t *key, uint8_t cidr, uint8_t bits) { struct radix_node *node; + uint64_t mask; node = malloc(sizeof *node); if (!node) fatal("malloc()"); - node->bit[0] = node->bit[1] = node->data = NULL; + node->bit[0] = node->bit[1] = NULL; node->cidr = cidr; node->bit_at_a = cidr / 8U; #ifdef __LITTLE_ENDIAN node->bit_at_a ^= (bits / 8U - 1U) % 8U; #endif node->bit_at_b = 7U - (cidr % 8U); + node->is_leaf = false; + if (bits - cidr > 0 && bits - cidr - 1 < 64) + node->left = node->right = 1ULL << (bits - cidr - 1); + else + node->left = node->right = 0; + memcpy(node->bits, key, bits / 8U); + mask = (bits - cidr) >= 64 ? 0 : 0xFFFFFFFFFFFFFFFF << (bits - cidr); + if (bits == 32) + *(uint32_t *)node->bits &= mask; + else + *(uint64_t *)&node->bits[8] &= mask; return node; } @@ -115,34 +138,120 @@ static bool node_placement(struct radix_node *trie, const uint8_t *key, return exact; } +static uint64_t subnet_diff(uint8_t *ip1, uint8_t *ip2, uint8_t bits) +{ + if (bits == 32) + return *(const uint32_t *)ip1 - *(const uint32_t *)ip2; + else + return *(const uint64_t *)&ip1[8] - *(const uint64_t *)&ip2[8]; +} + +static void add_nth(struct radix_node *start, uint8_t bits, uint64_t n, + uint8_t *dest) +{ + struct radix_node *target = start, *parent, *newnode, *between; + uint8_t ip[16] __aligned(__alignof(uint64_t)); + uint8_t cidr = bits; + uint64_t result, free_ips, diff; + + BUG_ON(n > target->left + target->right - 1); + + do { + parent = target; + + if (n >= parent->left) { + target = parent->bit[1]; + BUG_ON(!parent->right); + --(parent->right); + + n += (1ULL << (bits - parent->cidr - 1)) - parent->left; + } else { + target = parent->bit[0]; + BUG_ON(!parent->left); + --(parent->left); + } + + if (!target) + break; + + /* check if target has a suitable ip range */ + free_ips = target->left + target->right; + diff = subnet_diff(target->bits, parent->bits, bits); + if (n < diff) { + /* can't go down, or we'd skip too many ips */ + break; + } else if (n >= diff + free_ips) { + /* can't go down, we want a higher ip */ + n += (1ULL << (bits - target->cidr)) - free_ips; + break; + } else { + /* match; subtract skipped ips */ + n -= diff; + } + } while (1); + + if (bits == 32) { + result = *(const uint32_t *)parent->bits + n; + BUG_ON(result > UINT32_MAX); + + memcpy(ip, &result, 4); + } else { + result = *(const uint64_t *)&parent->bits[8] + n; + memcpy(ip, &parent->bits, 8); + memcpy(ip + 8, &result, 8); + } + + newnode = new_node(ip, cidr, bits); + newnode->is_leaf = true; + swap_endian(dest, (const uint8_t *)ip, bits); + + if (!target) { + CHOOSE_NODE(parent, newnode->bits) = newnode; + } else { + cidr = MIN(cidr, common_bits(target, ip, bits)); + between = new_node(newnode->bits, cidr, bits); + + CHOOSE_NODE(between, target->bits) = target; + CHOOSE_NODE(between, newnode->bits) = newnode; + CHOOSE_NODE(parent, between->bits) = between; + + between->left -= + (1ULL << (bits - between->bit[0]->cidr)) - + (between->bit[0]->left + between->bit[0]->right); + between->right -= + (1ULL << (bits - between->bit[1]->cidr)) - + (between->bit[1]->left + between->bit[1]->right); + } +} + static int add(struct radix_node **trie, uint8_t bits, const uint8_t *key, - uint8_t cidr, void *data, bool overwrite) + uint8_t cidr, bool is_leaf) { struct radix_node *node, *newnode, *down, *parent; - if (cidr > bits || !data) + if (cidr > bits) return -EINVAL; if (!*trie) { *trie = new_node(key, cidr, bits); - (*trie)->data = data; + (*trie)->is_leaf = is_leaf; return 0; } if (node_placement(*trie, key, cidr, bits, &node)) { - // exact match, so use the existing node - if (!overwrite && node->data) + /* exact match, so use the existing node */ + if (node->is_leaf) return 1; - node->data = data; + node->is_leaf = is_leaf; return 0; } - if (!overwrite && node && node->data) + if (node && node->is_leaf) return 1; newnode = new_node(key, cidr, bits); - newnode->data = data; + newnode->is_leaf = is_leaf; if (!node) { down = *trie; @@ -192,28 +301,54 @@ static void radix_free_nodes(struct radix_node *node) } } -#ifndef __aligned -#define __aligned(x) __attribute__((aligned(x))) -#endif +static void decrement_radix(struct radix_node *trie, uint8_t bits, + const uint8_t *key) +{ + struct radix_node *node = trie; + + while (node && prefix_matches(node, key, bits)) { + if (node->cidr == bits) + break; + + if (CHOOSE_NODE(node, key) == node->bit[0]) + --(node->left); + else + --(node->right); + + node = CHOOSE_NODE(node, key); + } +} static int insert_v4(struct radix_node **root, const struct in_addr *ip, - uint8_t cidr, void *data, bool overwrite) + uint8_t cidr) { /* Aligned so it can be passed to fls */ uint8_t key[4] __aligned(__alignof(uint32_t)); + int ret; swap_endian(key, (const uint8_t *)ip, 32); - return add(root, 32, key, cidr, data, overwrite); + + ret = add(root, 32, key, cidr, true); + if (!ret) + decrement_radix(*root, 32, (uint8_t *)key); + + return ret; } static int insert_v6(struct radix_node **root, const struct in6_addr *ip, - uint8_t cidr, void *data, bool overwrite) + uint8_t cidr) { /* Aligned so it can be passed to fls64 */ uint8_t key[16] __aligned(__alignof(uint64_t)); + int ret; swap_endian(key, (const uint8_t *)ip, 128); - return add(root, 128, key, cidr, data, overwrite); + + ret = add(root, 128, key, cidr, true); + if (!ret) + decrement_radix(*root, 128, (uint8_t *)key); + + return ret; } static struct radix_node *find_node(struct radix_node *trie, uint8_t bits, @@ -222,8 +357,7 @@ static struct radix_node *find_node(struct radix_node *trie, uint8_t bits, struct radix_node *node = trie, *found = NULL; while (node && prefix_matches(node, key, bits)) { - if (node->data) - found = node; + found = node; if (node->cidr == bits) break; node = CHOOSE_NODE(node, key); @@ -231,16 +365,36 @@ static struct radix_node *find_node(struct radix_node *trie, uint8_t bits, return found; } -static struct radix_node *lookup(struct radix_node *root, uint8_t bits, - const void *be_ip) +static int ipp_addpool(struct radix_pool **pool, struct radix_node **root, + uint8_t bits, const uint8_t *key, uint8_t cidr) { - /* Aligned so it can be passed to fls/fls64 */ - uint8_t ip[16] __aligned(__alignof(uint64_t)); - struct radix_node *node; + struct radix_pool *newpool; - swap_endian(ip, be_ip, bits); - node = find_node(root, bits, ip); - return node; + while (*pool) { + if (common_bits((*pool)->node, key, bits) >= cidr) + return -1; + + pool = &(*pool)->next; + } + + BUG_ON(add(root, bits, key, cidr, false)); + + if (bits == 32) { + /* TODO: insert network address (0) and broadcast address (255) + * into the pool, so they can't be used */ + /* TODO: special case /31 ?, see RFC 3021 */ + } + + newpool = malloc(sizeof *newpool); + if (!newpool) + fatal("malloc()"); + + newpool->node = find_node(*root, bits, key); + BUG_ON(!newpool->node); + newpool->next = NULL; + *pool = newpool; + + return 0; } #ifdef DEBUG @@ -282,66 +436,161 @@ static void debug_print_trie(struct radix_node *root, uint8_t bits) node_to_str(root->bit[0], child1, bits); node_to_str(root->bit[1], child2, bits); - debug("%s -> %s, %s\n", parent, child1, child2); + debug("%s (%zu, %zu) -> %s, %s\n", parent, root->left, root->right, + child1, child2); debug_print_trie(root->bit[0], bits); debug_print_trie(root->bit[1], bits); } -void debug_print_trie_v4(struct radix_trie *trie) +void debug_print_trie_v4(struct ip_pool *pool) { - debug_print_trie(trie->ip4_root, 32); + debug_print_trie(pool->ip4_root, 32); } -void debug_print_trie_v6(struct radix_trie *trie) +void debug_print_trie_v6(struct ip_pool *pool) { - debug_print_trie(trie->ip6_root, 128); + debug_print_trie(pool->ip6_root, 128); } #endif -void radix_init(struct radix_trie *trie) +void ipp_init(struct ip_pool *pool) { - trie->ip4_root = trie->ip6_root = NULL; + pool->ip4_root = pool->ip6_root = NULL; + pool->ip4_pool = pool->ip6_pool = NULL; } -void radix_free(struct radix_trie *trie) +void ipp_free(struct ip_pool *pool) { - radix_free_nodes(trie->ip4_root); - radix_free_nodes(trie->ip6_root); + struct radix_pool *next; + + radix_free_nodes(pool->ip4_root); + radix_free_nodes(pool->ip6_root); + + for (struct radix_pool *cur = pool->ip4_pool; cur; cur = cur->next) { + next = cur->next; + free(cur); + cur = next; + } + + for (struct radix_pool *cur = pool->ip6_pool; cur; cur = cur->next) { + next = cur->next; + free(cur); + cur = next; + } } -void *radix_find_v4(struct radix_trie *trie, const void *be_ip) +int ipp_add_v4(struct ip_pool *pool, const struct in_addr *ip, uint8_t cidr) { - struct radix_node *found = lookup(trie->ip4_root, 32, be_ip); - return found ? found->data : NULL; + return insert_v4(&pool->ip4_root, ip, cidr); } -void *radix_find_v6(struct radix_trie *trie, const void *be_ip) +int ipp_add_v6(struct ip_pool *pool, const struct in6_addr *ip, uint8_t cidr) { - struct radix_node *found = lookup(trie->ip6_root, 128, be_ip); - return found ? found->data : NULL; + return insert_v6(&pool->ip6_root, ip, cidr); } -int radix_insert_v4(struct radix_trie *root, const struct in_addr *ip, - uint8_t cidr, void *data) +int ipp_addpool_v4(struct ip_pool *pool, const struct in_addr *ip, uint8_t cidr) { - return insert_v4(&root->ip4_root, ip, cidr, data, true); + uint8_t key[4] __aligned(__alignof(uint32_t)); + + if (cidr <= 0 || cidr >= 32) + return -1; + + swap_endian(key, (const uint8_t *)ip, 32); + return ipp_addpool(&pool->ip4_pool, &pool->ip4_root, 32, key, cidr); } -int radix_insert_v6(struct radix_trie *root, const struct in6_addr *ip, - uint8_t cidr, void *data) +int ipp_addpool_v6(struct ip_pool *pool, const struct in6_addr *ip, + uint8_t cidr) { - return insert_v6(&root->ip6_root, ip, cidr, data, true); + uint8_t key[16] __aligned(__alignof(uint64_t)); + + if (cidr <= 0 || cidr < 64 || cidr >= 128) + return -1; + + swap_endian(key, (const uint8_t *)ip, 128); + return ipp_addpool(&pool->ip6_pool, &pool->ip6_root, 128, key, cidr); } -int radix_tryinsert_v4(struct radix_trie *root, const struct in_addr *ip, - uint8_t cidr, void *data) +uint32_t ipp_gettotal_v4(struct ip_pool *pool) { - return insert_v4(&root->ip4_root, ip, cidr, data, false); + struct radix_pool *current = pool->ip4_pool; + uint32_t total = 0; + + for (current = pool->ip4_pool; current; current = current->next) + total += current->node->left + current->node->right; + + return total; } -int radix_tryinsert_v6(struct radix_trie *root, const struct in6_addr *ip, - uint8_t cidr, void *data) +uint64_t ipp_gettotal_v6(struct ip_pool *pool, uint32_t *high) { - return insert_v6(&root->ip6_root, ip, cidr, data, false); + struct radix_pool *current = pool->ip6_pool; + uint64_t t_low = 0, tmp; + uint32_t t_high = 0; + + while (current) { + if (current->node->left == 0 && current->node->right == 0) { + current = current->next; + continue; + } + + tmp = t_low + current->node->left + current->node->right; + if (tmp <= t_low) + ++t_high; + + t_low = tmp; + current = current->next; + } + + *high = t_high; + return t_low; +} + +void ipp_addnth_v4(struct ip_pool *pool, struct in_addr *dest, uint32_t index) +{ + struct radix_pool *current = pool->ip4_pool; + + for (current = pool->ip4_pool; current; current = current->next) { + if (index < current->node->left + current->node->right) + break; + + index -= current->node->left + current->node->right; + } + + BUG_ON(!current); + + add_nth(current->node, 32, index, (uint8_t *)&dest->s_addr); +} + +void ipp_addnth_v6(struct ip_pool *pool, struct in6_addr *dest, + uint32_t index_low, uint64_t index_high) +{ + struct radix_pool *current = pool->ip6_pool; + uint64_t tmp; + + while (current) { + if (current->node->left == 0 && current->node->right == 0) { + current = current->next; + continue; + } + + if (index_high == 0 && + index_low < (current->node->left + current->node->right)) + break; + + tmp = index_low - (current->node->left + current->node->right); + if (tmp >= index_low) { + BUG_ON(index_high == 0); + --index_high; + } + index_low = tmp; + + current = current->next; + } + + BUG_ON(!pool || index_high); + + add_nth(current->node, 128, index_low, (uint8_t *)&dest->s6_addr); } diff --git a/radix-trie.h b/radix-trie.h index eafda95..6ffaccf 100644 --- a/radix-trie.h +++ b/radix-trie.h @@ -10,27 +10,36 @@ #include #include -struct radix_trie { +struct ip_pool { struct radix_node *ip4_root, *ip6_root; + struct radix_pool *ip4_pool, *ip6_pool; }; -void radix_init(struct radix_trie *trie); -void radix_free(struct radix_trie *trie); -void *radix_find_v4(struct radix_trie *trie, const void *be_ip); -void *radix_find_v6(struct radix_trie *trie, const void *be_ip); -int radix_insert_v4(struct radix_trie *root, const struct in_addr *ip, - uint8_t cidr, void *data); -int radix_insert_v6(struct radix_trie *root, const struct in6_addr *ip, - uint8_t cidr, void *data); -int radix_tryinsert_v4(struct radix_trie *root, const struct in_addr *ip, - uint8_t cidr, void *data); -int radix_tryinsert_v6(struct radix_trie *root, const struct in6_addr *ip, - uint8_t cidr, void *data); +void ipp_init(struct ip_pool *pool); +void ipp_free(struct ip_pool *pool); + +int ipp_add_v4(struct ip_pool *pool, const struct in_addr *ip, uint8_t cidr); +int ipp_add_v6(struct ip_pool *pool, const struct in6_addr *ip, uint8_t cidr); + +uint32_t ipp_gettotal_v4(struct ip_pool *pool); +uint64_t ipp_gettotal_v6(struct ip_pool *pool, uint32_t *high); + +void ipp_addnth_v4(struct ip_pool *pool, struct in_addr *dest, uint32_t index); +void ipp_addnth_v6(struct ip_pool *pool, struct in6_addr *dest, + uint32_t index_low, uint64_t index_high); + +int ipp_addpool_v4(struct ip_pool *pool, const struct in_addr *ip, + uint8_t cidr); +int ipp_addpool_v6(struct ip_pool *pool, const struct in6_addr *ip, + uint8_t cidr); + +int ipp_removepool_v4(struct ip_pool *pool, const struct in_addr *ip); +int ipp_removepool_v6(struct ip_pool *pool, const struct in6_addr *ip); #ifdef DEBUG void node_to_str(struct radix_node *node, char *buf, uint8_t bits); -void debug_print_trie_v4(struct radix_trie *trie); -void debug_print_trie_v6(struct radix_trie *trie); +void debug_print_trie_v4(struct ip_pool *pool); +void debug_print_trie_v6(struct ip_pool *pool); #endif #endif -- cgit v1.2.3-59-g8ed1b