aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--radix-trie.c359
-rw-r--r--radix-trie.h39
2 files changed, 328 insertions, 70 deletions
diff --git a/radix-trie.c b/radix-trie.c
index f314913..fb39047 100644
--- a/radix-trie.c
+++ b/radix-trie.c
@@ -18,11 +18,22 @@
#define MIN(X, Y) (((X) < (Y)) ? (X) : (Y))
+#ifndef __aligned
+#define __aligned(x) __attribute__((aligned(x)))
+#endif
+
struct radix_node {
struct radix_node *bit[2];
- void *data;
+ uint64_t left;
+ uint64_t right;
uint8_t bits[16];
uint8_t cidr, bit_at_a, bit_at_b;
+ bool is_leaf;
+};
+
+struct radix_pool {
+ struct radix_node *node;
+ struct radix_pool *next;
};
static unsigned int fls64(uint64_t x)
@@ -69,19 +80,31 @@ static struct radix_node *new_node(const uint8_t *key, uint8_t cidr,
uint8_t bits)
{
struct radix_node *node;
+ uint64_t mask;
node = malloc(sizeof *node);
if (!node)
fatal("malloc()");
- node->bit[0] = node->bit[1] = node->data = NULL;
+ node->bit[0] = node->bit[1] = NULL;
node->cidr = cidr;
node->bit_at_a = cidr / 8U;
#ifdef __LITTLE_ENDIAN
node->bit_at_a ^= (bits / 8U - 1U) % 8U;
#endif
node->bit_at_b = 7U - (cidr % 8U);
+ node->is_leaf = false;
+ if (bits - cidr > 0 && bits - cidr - 1 < 64)
+ node->left = node->right = 1ULL << (bits - cidr - 1);
+ else
+ node->left = node->right = 0;
+
memcpy(node->bits, key, bits / 8U);
+ mask = (bits - cidr) >= 64 ? 0 : 0xFFFFFFFFFFFFFFFF << (bits - cidr);
+ if (bits == 32)
+ *(uint32_t *)node->bits &= mask;
+ else
+ *(uint64_t *)&node->bits[8] &= mask;
return node;
}
@@ -115,34 +138,120 @@ static bool node_placement(struct radix_node *trie, const uint8_t *key,
return exact;
}
+static uint64_t subnet_diff(uint8_t *ip1, uint8_t *ip2, uint8_t bits)
+{
+ if (bits == 32)
+ return *(const uint32_t *)ip1 - *(const uint32_t *)ip2;
+ else
+ return *(const uint64_t *)&ip1[8] - *(const uint64_t *)&ip2[8];
+}
+
+static void add_nth(struct radix_node *start, uint8_t bits, uint64_t n,
+ uint8_t *dest)
+{
+ struct radix_node *target = start, *parent, *newnode, *between;
+ uint8_t ip[16] __aligned(__alignof(uint64_t));
+ uint8_t cidr = bits;
+ uint64_t result, free_ips, diff;
+
+ BUG_ON(n > target->left + target->right - 1);
+
+ do {
+ parent = target;
+
+ if (n >= parent->left) {
+ target = parent->bit[1];
+ BUG_ON(!parent->right);
+ --(parent->right);
+
+ n += (1ULL << (bits - parent->cidr - 1)) - parent->left;
+ } else {
+ target = parent->bit[0];
+ BUG_ON(!parent->left);
+ --(parent->left);
+ }
+
+ if (!target)
+ break;
+
+ /* check if target has a suitable ip range */
+ free_ips = target->left + target->right;
+ diff = subnet_diff(target->bits, parent->bits, bits);
+ if (n < diff) {
+ /* can't go down, or we'd skip too many ips */
+ break;
+ } else if (n >= diff + free_ips) {
+ /* can't go down, we want a higher ip */
+ n += (1ULL << (bits - target->cidr)) - free_ips;
+ break;
+ } else {
+ /* match; subtract skipped ips */
+ n -= diff;
+ }
+ } while (1);
+
+ if (bits == 32) {
+ result = *(const uint32_t *)parent->bits + n;
+ BUG_ON(result > UINT32_MAX);
+
+ memcpy(ip, &result, 4);
+ } else {
+ result = *(const uint64_t *)&parent->bits[8] + n;
+ memcpy(ip, &parent->bits, 8);
+ memcpy(ip + 8, &result, 8);
+ }
+
+ newnode = new_node(ip, cidr, bits);
+ newnode->is_leaf = true;
+ swap_endian(dest, (const uint8_t *)ip, bits);
+
+ if (!target) {
+ CHOOSE_NODE(parent, newnode->bits) = newnode;
+ } else {
+ cidr = MIN(cidr, common_bits(target, ip, bits));
+ between = new_node(newnode->bits, cidr, bits);
+
+ CHOOSE_NODE(between, target->bits) = target;
+ CHOOSE_NODE(between, newnode->bits) = newnode;
+ CHOOSE_NODE(parent, between->bits) = between;
+
+ between->left -=
+ (1ULL << (bits - between->bit[0]->cidr)) -
+ (between->bit[0]->left + between->bit[0]->right);
+ between->right -=
+ (1ULL << (bits - between->bit[1]->cidr)) -
+ (between->bit[1]->left + between->bit[1]->right);
+ }
+}
+
static int add(struct radix_node **trie, uint8_t bits, const uint8_t *key,
- uint8_t cidr, void *data, bool overwrite)
+ uint8_t cidr, bool is_leaf)
{
struct radix_node *node, *newnode, *down, *parent;
- if (cidr > bits || !data)
+ if (cidr > bits)
return -EINVAL;
if (!*trie) {
*trie = new_node(key, cidr, bits);
- (*trie)->data = data;
+ (*trie)->is_leaf = is_leaf;
return 0;
}
if (node_placement(*trie, key, cidr, bits, &node)) {
- // exact match, so use the existing node
- if (!overwrite && node->data)
+ /* exact match, so use the existing node */
+ if (node->is_leaf)
return 1;
- node->data = data;
+ node->is_leaf = is_leaf;
return 0;
}
- if (!overwrite && node && node->data)
+ if (node && node->is_leaf)
return 1;
newnode = new_node(key, cidr, bits);
- newnode->data = data;
+ newnode->is_leaf = is_leaf;
if (!node) {
down = *trie;
@@ -192,28 +301,54 @@ static void radix_free_nodes(struct radix_node *node)
}
}
-#ifndef __aligned
-#define __aligned(x) __attribute__((aligned(x)))
-#endif
+static void decrement_radix(struct radix_node *trie, uint8_t bits,
+ const uint8_t *key)
+{
+ struct radix_node *node = trie;
+
+ while (node && prefix_matches(node, key, bits)) {
+ if (node->cidr == bits)
+ break;
+
+ if (CHOOSE_NODE(node, key) == node->bit[0])
+ --(node->left);
+ else
+ --(node->right);
+
+ node = CHOOSE_NODE(node, key);
+ }
+}
static int insert_v4(struct radix_node **root, const struct in_addr *ip,
- uint8_t cidr, void *data, bool overwrite)
+ uint8_t cidr)
{
/* Aligned so it can be passed to fls */
uint8_t key[4] __aligned(__alignof(uint32_t));
+ int ret;
swap_endian(key, (const uint8_t *)ip, 32);
- return add(root, 32, key, cidr, data, overwrite);
+
+ ret = add(root, 32, key, cidr, true);
+ if (!ret)
+ decrement_radix(*root, 32, (uint8_t *)key);
+
+ return ret;
}
static int insert_v6(struct radix_node **root, const struct in6_addr *ip,
- uint8_t cidr, void *data, bool overwrite)
+ uint8_t cidr)
{
/* Aligned so it can be passed to fls64 */
uint8_t key[16] __aligned(__alignof(uint64_t));
+ int ret;
swap_endian(key, (const uint8_t *)ip, 128);
- return add(root, 128, key, cidr, data, overwrite);
+
+ ret = add(root, 128, key, cidr, true);
+ if (!ret)
+ decrement_radix(*root, 128, (uint8_t *)key);
+
+ return ret;
}
static struct radix_node *find_node(struct radix_node *trie, uint8_t bits,
@@ -222,8 +357,7 @@ static struct radix_node *find_node(struct radix_node *trie, uint8_t bits,
struct radix_node *node = trie, *found = NULL;
while (node && prefix_matches(node, key, bits)) {
- if (node->data)
- found = node;
+ found = node;
if (node->cidr == bits)
break;
node = CHOOSE_NODE(node, key);
@@ -231,16 +365,36 @@ static struct radix_node *find_node(struct radix_node *trie, uint8_t bits,
return found;
}
-static struct radix_node *lookup(struct radix_node *root, uint8_t bits,
- const void *be_ip)
+static int ipp_addpool(struct radix_pool **pool, struct radix_node **root,
+ uint8_t bits, const uint8_t *key, uint8_t cidr)
{
- /* Aligned so it can be passed to fls/fls64 */
- uint8_t ip[16] __aligned(__alignof(uint64_t));
- struct radix_node *node;
+ struct radix_pool *newpool;
- swap_endian(ip, be_ip, bits);
- node = find_node(root, bits, ip);
- return node;
+ while (*pool) {
+ if (common_bits((*pool)->node, key, bits) >= cidr)
+ return -1;
+
+ pool = &(*pool)->next;
+ }
+
+ BUG_ON(add(root, bits, key, cidr, false));
+
+ if (bits == 32) {
+ /* TODO: insert network address (0) and broadcast address (255)
+ * into the pool, so they can't be used */
+ /* TODO: special case /31 ?, see RFC 3021 */
+ }
+
+ newpool = malloc(sizeof *newpool);
+ if (!newpool)
+ fatal("malloc()");
+
+ newpool->node = find_node(*root, bits, key);
+ BUG_ON(!newpool->node);
+ newpool->next = NULL;
+ *pool = newpool;
+
+ return 0;
}
#ifdef DEBUG
@@ -282,66 +436,161 @@ static void debug_print_trie(struct radix_node *root, uint8_t bits)
node_to_str(root->bit[0], child1, bits);
node_to_str(root->bit[1], child2, bits);
- debug("%s -> %s, %s\n", parent, child1, child2);
+ debug("%s (%zu, %zu) -> %s, %s\n", parent, root->left, root->right,
+ child1, child2);
debug_print_trie(root->bit[0], bits);
debug_print_trie(root->bit[1], bits);
}
-void debug_print_trie_v4(struct radix_trie *trie)
+void debug_print_trie_v4(struct ip_pool *pool)
{
- debug_print_trie(trie->ip4_root, 32);
+ debug_print_trie(pool->ip4_root, 32);
}
-void debug_print_trie_v6(struct radix_trie *trie)
+void debug_print_trie_v6(struct ip_pool *pool)
{
- debug_print_trie(trie->ip6_root, 128);
+ debug_print_trie(pool->ip6_root, 128);
}
#endif
-void radix_init(struct radix_trie *trie)
+void ipp_init(struct ip_pool *pool)
{
- trie->ip4_root = trie->ip6_root = NULL;
+ pool->ip4_root = pool->ip6_root = NULL;
+ pool->ip4_pool = pool->ip6_pool = NULL;
}
-void radix_free(struct radix_trie *trie)
+void ipp_free(struct ip_pool *pool)
{
- radix_free_nodes(trie->ip4_root);
- radix_free_nodes(trie->ip6_root);
+ struct radix_pool *next;
+
+ radix_free_nodes(pool->ip4_root);
+ radix_free_nodes(pool->ip6_root);
+
+ for (struct radix_pool *cur = pool->ip4_pool; cur; cur = cur->next) {
+ next = cur->next;
+ free(cur);
+ cur = next;
+ }
+
+ for (struct radix_pool *cur = pool->ip6_pool; cur; cur = cur->next) {
+ next = cur->next;
+ free(cur);
+ cur = next;
+ }
}
-void *radix_find_v4(struct radix_trie *trie, const void *be_ip)
+int ipp_add_v4(struct ip_pool *pool, const struct in_addr *ip, uint8_t cidr)
{
- struct radix_node *found = lookup(trie->ip4_root, 32, be_ip);
- return found ? found->data : NULL;
+ return insert_v4(&pool->ip4_root, ip, cidr);
}
-void *radix_find_v6(struct radix_trie *trie, const void *be_ip)
+int ipp_add_v6(struct ip_pool *pool, const struct in6_addr *ip, uint8_t cidr)
{
- struct radix_node *found = lookup(trie->ip6_root, 128, be_ip);
- return found ? found->data : NULL;
+ return insert_v6(&pool->ip6_root, ip, cidr);
}
-int radix_insert_v4(struct radix_trie *root, const struct in_addr *ip,
- uint8_t cidr, void *data)
+int ipp_addpool_v4(struct ip_pool *pool, const struct in_addr *ip, uint8_t cidr)
{
- return insert_v4(&root->ip4_root, ip, cidr, data, true);
+ uint8_t key[4] __aligned(__alignof(uint32_t));
+
+ if (cidr <= 0 || cidr >= 32)
+ return -1;
+
+ swap_endian(key, (const uint8_t *)ip, 32);
+ return ipp_addpool(&pool->ip4_pool, &pool->ip4_root, 32, key, cidr);
}
-int radix_insert_v6(struct radix_trie *root, const struct in6_addr *ip,
- uint8_t cidr, void *data)
+int ipp_addpool_v6(struct ip_pool *pool, const struct in6_addr *ip,
+ uint8_t cidr)
{
- return insert_v6(&root->ip6_root, ip, cidr, data, true);
+ uint8_t key[16] __aligned(__alignof(uint64_t));
+
+ if (cidr <= 0 || cidr < 64 || cidr >= 128)
+ return -1;
+
+ swap_endian(key, (const uint8_t *)ip, 128);
+ return ipp_addpool(&pool->ip6_pool, &pool->ip6_root, 128, key, cidr);
}
-int radix_tryinsert_v4(struct radix_trie *root, const struct in_addr *ip,
- uint8_t cidr, void *data)
+uint32_t ipp_gettotal_v4(struct ip_pool *pool)
{
- return insert_v4(&root->ip4_root, ip, cidr, data, false);
+ struct radix_pool *current = pool->ip4_pool;
+ uint32_t total = 0;
+
+ for (current = pool->ip4_pool; current; current = current->next)
+ total += current->node->left + current->node->right;
+
+ return total;
}
-int radix_tryinsert_v6(struct radix_trie *root, const struct in6_addr *ip,
- uint8_t cidr, void *data)
+uint64_t ipp_gettotal_v6(struct ip_pool *pool, uint32_t *high)
{
- return insert_v6(&root->ip6_root, ip, cidr, data, false);
+ struct radix_pool *current = pool->ip6_pool;
+ uint64_t t_low = 0, tmp;
+ uint32_t t_high = 0;
+
+ while (current) {
+ if (current->node->left == 0 && current->node->right == 0) {
+ current = current->next;
+ continue;
+ }
+
+ tmp = t_low + current->node->left + current->node->right;
+ if (tmp <= t_low)
+ ++t_high;
+
+ t_low = tmp;
+ current = current->next;
+ }
+
+ *high = t_high;
+ return t_low;
+}
+
+void ipp_addnth_v4(struct ip_pool *pool, struct in_addr *dest, uint32_t index)
+{
+ struct radix_pool *current = pool->ip4_pool;
+
+ for (current = pool->ip4_pool; current; current = current->next) {
+ if (index < current->node->left + current->node->right)
+ break;
+
+ index -= current->node->left + current->node->right;
+ }
+
+ BUG_ON(!current);
+
+ add_nth(current->node, 32, index, (uint8_t *)&dest->s_addr);
+}
+
+void ipp_addnth_v6(struct ip_pool *pool, struct in6_addr *dest,
+ uint32_t index_low, uint64_t index_high)
+{
+ struct radix_pool *current = pool->ip6_pool;
+ uint64_t tmp;
+
+ while (current) {
+ if (current->node->left == 0 && current->node->right == 0) {
+ current = current->next;
+ continue;
+ }
+
+ if (index_high == 0 &&
+ index_low < (current->node->left + current->node->right))
+ break;
+
+ tmp = index_low - (current->node->left + current->node->right);
+ if (tmp >= index_low) {
+ BUG_ON(index_high == 0);
+ --index_high;
+ }
+ index_low = tmp;
+
+ current = current->next;
+ }
+
+ BUG_ON(!pool || index_high);
+
+ add_nth(current->node, 128, index_low, (uint8_t *)&dest->s6_addr);
}
diff --git a/radix-trie.h b/radix-trie.h
index eafda95..6ffaccf 100644
--- a/radix-trie.h
+++ b/radix-trie.h
@@ -10,27 +10,36 @@
#include <stdbool.h>
#include <stdint.h>
-struct radix_trie {
+struct ip_pool {
struct radix_node *ip4_root, *ip6_root;
+ struct radix_pool *ip4_pool, *ip6_pool;
};
-void radix_init(struct radix_trie *trie);
-void radix_free(struct radix_trie *trie);
-void *radix_find_v4(struct radix_trie *trie, const void *be_ip);
-void *radix_find_v6(struct radix_trie *trie, const void *be_ip);
-int radix_insert_v4(struct radix_trie *root, const struct in_addr *ip,
- uint8_t cidr, void *data);
-int radix_insert_v6(struct radix_trie *root, const struct in6_addr *ip,
- uint8_t cidr, void *data);
-int radix_tryinsert_v4(struct radix_trie *root, const struct in_addr *ip,
- uint8_t cidr, void *data);
-int radix_tryinsert_v6(struct radix_trie *root, const struct in6_addr *ip,
- uint8_t cidr, void *data);
+void ipp_init(struct ip_pool *pool);
+void ipp_free(struct ip_pool *pool);
+
+int ipp_add_v4(struct ip_pool *pool, const struct in_addr *ip, uint8_t cidr);
+int ipp_add_v6(struct ip_pool *pool, const struct in6_addr *ip, uint8_t cidr);
+
+uint32_t ipp_gettotal_v4(struct ip_pool *pool);
+uint64_t ipp_gettotal_v6(struct ip_pool *pool, uint32_t *high);
+
+void ipp_addnth_v4(struct ip_pool *pool, struct in_addr *dest, uint32_t index);
+void ipp_addnth_v6(struct ip_pool *pool, struct in6_addr *dest,
+ uint32_t index_low, uint64_t index_high);
+
+int ipp_addpool_v4(struct ip_pool *pool, const struct in_addr *ip,
+ uint8_t cidr);
+int ipp_addpool_v6(struct ip_pool *pool, const struct in6_addr *ip,
+ uint8_t cidr);
+
+int ipp_removepool_v4(struct ip_pool *pool, const struct in_addr *ip);
+int ipp_removepool_v6(struct ip_pool *pool, const struct in6_addr *ip);
#ifdef DEBUG
void node_to_str(struct radix_node *node, char *buf, uint8_t bits);
-void debug_print_trie_v4(struct radix_trie *trie);
-void debug_print_trie_v6(struct radix_trie *trie);
+void debug_print_trie_v4(struct ip_pool *pool);
+void debug_print_trie_v6(struct ip_pool *pool);
#endif
#endif