diff options
author | 2015-10-07 06:17:17 +0200 | |
---|---|---|
committer | 2015-10-07 06:43:49 +0200 | |
commit | 95be258957a104eecca671c3de83daa74dc28779 (patch) | |
tree | aefaa2a40f9931a07e213a9fe01f94389fac52e8 | |
parent | Patricia trie implementation (diff) | |
download | kernel-routing-table-patricia.tar.xz kernel-routing-table-patricia.zip |
Simplify and rewritepatricia
-rw-r--r-- | routing-table.c | 167 | ||||
-rw-r--r-- | routing-table.h | 2 |
2 files changed, 64 insertions, 105 deletions
diff --git a/routing-table.c b/routing-table.c index f589552..69c9b0e 100644 --- a/routing-table.c +++ b/routing-table.c @@ -1,15 +1,5 @@ /* Copyright 2011 OmniTI Computer Consulting, Inc. * Copyright 2015 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. - * - * TODO: Replace this PATRICIA-trie with an LC-trie. This implementation here is - * based on some random code I took the pains to clean-up. It's certainly less - * awful than the original now, but I can't really vouch for its performance or - * correctness. Something needs to be built from scratch. If anybody would like - * to do this for me, I'd be very pleased. Just keep the API from the .h file. - * This also uses a rwlock, when it should use RCU, which is quite unfortunate. - * I don't like having recursive algorithms in the kernel either. - * - * In short: this is awful; somebody please fix it. */ #include "routing-table.h" @@ -17,12 +7,12 @@ struct routing_table_node { struct routing_table_node *bit[2]; void *value; - uint32_t bits[4]; uint8_t cidr; bool incidental; + uint8_t bits[]; }; -#define BIT_AT(k, b) ((k[(b - 1) / 32] >> (31 - ((b - 1) % 32))) & 0x1) +#define BIT_AT(k, b) ((k[(b - 1) / 8] >> (7 - ((b - 1) % 8))) & 1) static void drop_node(struct routing_table_node *node) { @@ -35,46 +25,41 @@ static void drop_node(struct routing_table_node *node) kfree(node); } -static inline bool match(struct routing_table_node *node, uint32_t *key, uint8_t match_len) +static inline bool match(struct routing_table_node *node, uint8_t *key, uint8_t match_len) { - register unsigned int i, m = (match_len - 1) / 32; - if (!match_len) + uint8_t full_blocks_to_match = match_len / 8; + uint8_t bits_leftover = match_len % 8; + uint8_t mask; + uint8_t *a = node->bits, *b = key; + if (memcmp(a, b, full_blocks_to_match)) + return false; + if (!bits_leftover) return true; - for (i = 0; i <= m; i++) { - if (i < m) { /* we're matching a whole word */ - if (node->bits[i] != key[i]) - return false; - } else { - uint32_t mask = ((match_len % 32) == 0) ? 0xffffffff : ~(0xffffffff >> (match_len % 32)); - if ((node->bits[i] & mask) != (key[i] & mask)) - return false; - } - } - return true; + mask = ~(0xff >> bits_leftover); + return (a[full_blocks_to_match] & mask) == (b[full_blocks_to_match] & mask); } -static inline int common_bits(struct routing_table_node *node, uint32_t *key, uint8_t match_len) +static inline uint8_t common_bits(struct routing_table_node *node, uint8_t *key, uint8_t match_len) { - /* Largest common mask */ - unsigned int i; - uint32_t cidr = 0; - const unsigned int max_cidr = (match_len > node->cidr) ? match_len : node->cidr; - for (i = 0; i < 4; i++) { - uint32_t mask = 0, trymask; - while (mask != 0xffffffff && cidr < max_cidr) { - trymask = (mask >> 1) | 0x80000000; - if ((node->bits[i] & trymask) != (key[i] & trymask)) - break; - mask = trymask; - ++cidr; - } - if (mask != 0xffffffff || cidr >= max_cidr) + uint8_t max = (((match_len > node->cidr) ? match_len : node->cidr) + 7) / 8; + uint8_t bits = 0; + uint8_t i, mask; + uint8_t *a = node->bits, *b = key; + for (i = 0; i < max; ++i, bits += 8) { + if (a[i] != b[i]) break; } - return cidr; + if (i == max) + return bits; + for (mask = 128; mask > 0; mask /= 2, ++bits) { + if ((a[i] & mask) != (b[i] & mask)) + return bits; + } + BUG(); + return bits; } -static int remove(struct routing_table_node **trie, uint32_t *key, uint8_t cidr) +static int remove(struct routing_table_node **trie, uint8_t *key, uint8_t cidr) { struct routing_table_node *parent = NULL, *node; node = *trie; @@ -120,15 +105,25 @@ static bool walk_remove_by_value(struct routing_table_node **nptr, void *value) return ret; } -static bool find_node(struct routing_table_node **trie, uint32_t *key, uint8_t cidr, struct routing_table_node **rnode, struct routing_table_node **explicit_container) +static struct routing_table_node *find_node(struct routing_table_node *trie, uint8_t bits, uint8_t *key) +{ + struct routing_table_node *node = trie, *found = NULL; + while (node && match(node, key, node->cidr)) { + if (!node->incidental) + found = node; + if (node->cidr == bits) + break; + node = node->bit[BIT_AT(key, node->cidr + 1)]; + } + return found; +} + +static bool node_placement(struct routing_table_node *trie, uint8_t *key, uint8_t cidr, struct routing_table_node **rnode) { bool exact = false; - struct routing_table_node *parent = NULL, *node, *explicit_node = NULL; - node = *trie; + struct routing_table_node *parent = NULL, *node = trie; while (node && node->cidr <= cidr && match(node, key, node->cidr)) { parent = node; - if (!node->incidental) - explicit_node = node; if (parent->cidr == cidr) { exact = true; break; @@ -137,38 +132,36 @@ static bool find_node(struct routing_table_node **trie, uint32_t *key, uint8_t c } if (rnode) *rnode = parent; - if (explicit_container) - *explicit_container = explicit_node; return exact; } -static int add(struct routing_table_node **trie, uint32_t *key, uint8_t cidr, void *value) +static int add(struct routing_table_node **trie, uint8_t bits, uint8_t *key, uint8_t cidr, void *value) { struct routing_table_node *node, *parent, *down, *newnode; int bits_in_common; if (!*trie) { - node = kzalloc(sizeof(*node), GFP_KERNEL); + node = kzalloc(sizeof(*node) + (bits + 7) / 8, GFP_KERNEL); if (!node) return -ENOMEM; node->value = value; - memcpy(node->bits, key, 4 * ((cidr + 31) / 32)); + memcpy(node->bits, key, (bits + 7) / 8); node->cidr = cidr; *trie = node; return 0; } - if (find_node(trie, key, cidr, &node, NULL)) { + if (node_placement(*trie, key, cidr, &node)) { /* exact match */ node->incidental = false; node->value = value; return 0; } - newnode = kzalloc(sizeof(*node), GFP_KERNEL); + newnode = kzalloc(sizeof(*node) + (bits + 7) / 8, GFP_KERNEL); if (!newnode) return -ENOMEM; newnode->value = value; - memcpy(newnode->bits, key, 4 * ((cidr + 31) / 32)); + memcpy(newnode->bits, key, (bits + 7) / 8); newnode->cidr = cidr; if (!node) @@ -198,14 +191,14 @@ static int add(struct routing_table_node **trie, uint32_t *key, uint8_t cidr, vo parent->bit[BIT_AT(newnode->bits, plen)] = newnode; } else { /* reparent */ - node = kzalloc(sizeof(*node), GFP_KERNEL); + node = kzalloc(sizeof(*node) + (bits + 7) / 8, GFP_KERNEL); if (!node) { kfree(newnode); return -ENOMEM; } node->cidr = bits_in_common; node->incidental = true; - memcpy(node->bits, newnode->bits, sizeof(node->bits)); + memcpy(node->bits, newnode->bits, (bits + 7) / 8); node->bit[BIT_AT(down->bits, node->cidr + 1)] = down; node->bit[BIT_AT(newnode->bits, node->cidr + 1)] = newnode; if (!parent) @@ -222,10 +215,7 @@ static int walk_ips(struct routing_table_node *node, uint8_t ip_version, void *c union nf_inet_addr ip; if (!node) return 0; - ip.all[0] = htonl(node->bits[0]); - ip.all[1] = htonl(node->bits[1]); - ip.all[2] = htonl(node->bits[2]); - ip.all[3] = htonl(node->bits[3]); + memcpy(ip.all, node->bits, sizeof(ip.all)); ret = func(ctx, node->value, ip, node->cidr, ip_version); if (ret < 0) return ret; @@ -245,10 +235,7 @@ static int walk_ips_by_value(struct routing_table_node *node, uint8_t ip_version if (!node) return 0; if (node->value == value) { - ip.all[0] = htonl(node->bits[0]); - ip.all[1] = htonl(node->bits[1]); - ip.all[2] = htonl(node->bits[2]); - ip.all[3] = htonl(node->bits[3]); + memcpy(ip.all, node->bits, sizeof(ip.all)); ret = func(ctx, ip, node->cidr, ip_version); if (ret < 0) return ret; @@ -278,38 +265,22 @@ void routing_table_free(struct routing_table *table) int routing_table_insert_v4(struct routing_table *table, struct in_addr ip, uint8_t cidr, void *value) { - uint32_t *ia = (uint32_t *)&ip; - uint32_t mask; int ret; if (cidr > 32) return -EINVAL; - *ia = ntohl(*ia); - mask = (cidr == 32) ? 0xffffffff : ~(0xffffffff >> cidr); - *ia &= mask; write_lock_bh(&table->lock); - ret = add(&table->root4, ia, cidr, value); + ret = add(&table->root4, 32, (uint8_t *)&ip, cidr, value); write_unlock_bh(&table->lock); return ret; } int routing_table_insert_v6(struct routing_table *table, struct in6_addr ip, uint8_t cidr, void *value) { - uint32_t *ia = (uint32_t *)&ip; - uint32_t mask; - int splen; - unsigned int i; int ret; if (cidr > 128) return -EINVAL; - for (i = 0; i < 4; i++) { - splen = cidr - (i * 32); - mask = 0; - if (splen >= 0) - mask = (splen >= 32) ? 0xffffffff : ~(0xffffffff >> splen); - ia[i] = ntohl(ia[i]) & mask; - } write_lock_bh(&table->lock); - ret = add(&table->root6, ia, cidr, value); + ret = add(&table->root6, 128, (uint8_t *)&ip, cidr, value); write_unlock_bh(&table->lock); return ret; } @@ -317,11 +288,10 @@ int routing_table_insert_v6(struct routing_table *table, struct in6_addr ip, uin void *routing_table_lookup_v4(struct routing_table *table, struct in_addr ip) { void *value = NULL; - struct routing_table_node *node = NULL; - uint32_t *ia = (uint32_t *)&ip; - *ia = ntohl(*ia); + struct routing_table_node *node; + read_lock_bh(&table->lock); - find_node(&table->root4, ia, 32, NULL, &node); + node = find_node(table->root4, 32, (uint8_t *)&ip); if (node) value = node->value; read_unlock_bh(&table->lock); @@ -331,14 +301,10 @@ void *routing_table_lookup_v4(struct routing_table *table, struct in_addr ip) void *routing_table_lookup_v6(struct routing_table *table, struct in6_addr ip) { void *value = NULL; - struct routing_table_node *node = NULL; - uint32_t *ia = (uint32_t *)&ip; - ia[0] = ntohl(ia[0]); - ia[1] = ntohl(ia[1]); - ia[2] = ntohl(ia[2]); - ia[3] = ntohl(ia[3]); + struct routing_table_node *node; + read_lock_bh(&table->lock); - find_node(&table->root6, ia, 128, NULL, &node); + node = find_node(table->root6, 128, (uint8_t *)&ip); if (node) value = node->value; read_unlock_bh(&table->lock); @@ -348,10 +314,8 @@ void *routing_table_lookup_v6(struct routing_table *table, struct in6_addr ip) int routing_table_remove_v4(struct routing_table *table, struct in_addr ip, uint8_t cidr) { int ret; - uint32_t *ia = (uint32_t *)&ip; - *ia = ntohl(*ia); write_lock_bh(&table->lock); - ret = remove(&table->root4, ia, cidr); + ret = remove(&table->root4, (uint8_t *)&ip, cidr); write_unlock_bh(&table->lock); return ret; } @@ -359,13 +323,8 @@ int routing_table_remove_v4(struct routing_table *table, struct in_addr ip, uint int routing_table_remove_v6(struct routing_table *table, struct in6_addr ip, uint8_t cidr) { int ret; - uint32_t *ia = (uint32_t *)&ip; - ia[0] = ntohl(ia[0]); - ia[1] = ntohl(ia[1]); - ia[2] = ntohl(ia[2]); - ia[3] = ntohl(ia[3]); write_lock_bh(&table->lock); - ret = remove(&table->root6, ia, cidr); + ret = remove(&table->root6, (uint8_t *)&ip, cidr); write_unlock_bh(&table->lock); return ret; } diff --git a/routing-table.h b/routing-table.h index 2025985..4b624d3 100644 --- a/routing-table.h +++ b/routing-table.h @@ -3,7 +3,7 @@ #ifndef ROUTINGTABLE_H #define ROUTINGTABLE_H -#include <linux/list.h> +#include <linux/spinlock.h> #include <linux/ip.h> #include <linux/ipv6.h> |