summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2015-10-07 06:17:17 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2015-10-07 06:43:49 +0200
commit95be258957a104eecca671c3de83daa74dc28779 (patch)
treeaefaa2a40f9931a07e213a9fe01f94389fac52e8
parentPatricia trie implementation (diff)
downloadkernel-routing-table-patricia.tar.xz
kernel-routing-table-patricia.zip
Simplify and rewritepatricia
-rw-r--r--routing-table.c167
-rw-r--r--routing-table.h2
2 files changed, 64 insertions, 105 deletions
diff --git a/routing-table.c b/routing-table.c
index f589552..69c9b0e 100644
--- a/routing-table.c
+++ b/routing-table.c
@@ -1,15 +1,5 @@
/* Copyright 2011 OmniTI Computer Consulting, Inc.
* Copyright 2015 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
- *
- * TODO: Replace this PATRICIA-trie with an LC-trie. This implementation here is
- * based on some random code I took the pains to clean-up. It's certainly less
- * awful than the original now, but I can't really vouch for its performance or
- * correctness. Something needs to be built from scratch. If anybody would like
- * to do this for me, I'd be very pleased. Just keep the API from the .h file.
- * This also uses a rwlock, when it should use RCU, which is quite unfortunate.
- * I don't like having recursive algorithms in the kernel either.
- *
- * In short: this is awful; somebody please fix it.
*/
#include "routing-table.h"
@@ -17,12 +7,12 @@
struct routing_table_node {
struct routing_table_node *bit[2];
void *value;
- uint32_t bits[4];
uint8_t cidr;
bool incidental;
+ uint8_t bits[];
};
-#define BIT_AT(k, b) ((k[(b - 1) / 32] >> (31 - ((b - 1) % 32))) & 0x1)
+#define BIT_AT(k, b) ((k[(b - 1) / 8] >> (7 - ((b - 1) % 8))) & 1)
static void drop_node(struct routing_table_node *node)
{
@@ -35,46 +25,41 @@ static void drop_node(struct routing_table_node *node)
kfree(node);
}
-static inline bool match(struct routing_table_node *node, uint32_t *key, uint8_t match_len)
+static inline bool match(struct routing_table_node *node, uint8_t *key, uint8_t match_len)
{
- register unsigned int i, m = (match_len - 1) / 32;
- if (!match_len)
+ uint8_t full_blocks_to_match = match_len / 8;
+ uint8_t bits_leftover = match_len % 8;
+ uint8_t mask;
+ uint8_t *a = node->bits, *b = key;
+ if (memcmp(a, b, full_blocks_to_match))
+ return false;
+ if (!bits_leftover)
return true;
- for (i = 0; i <= m; i++) {
- if (i < m) { /* we're matching a whole word */
- if (node->bits[i] != key[i])
- return false;
- } else {
- uint32_t mask = ((match_len % 32) == 0) ? 0xffffffff : ~(0xffffffff >> (match_len % 32));
- if ((node->bits[i] & mask) != (key[i] & mask))
- return false;
- }
- }
- return true;
+ mask = ~(0xff >> bits_leftover);
+ return (a[full_blocks_to_match] & mask) == (b[full_blocks_to_match] & mask);
}
-static inline int common_bits(struct routing_table_node *node, uint32_t *key, uint8_t match_len)
+static inline uint8_t common_bits(struct routing_table_node *node, uint8_t *key, uint8_t match_len)
{
- /* Largest common mask */
- unsigned int i;
- uint32_t cidr = 0;
- const unsigned int max_cidr = (match_len > node->cidr) ? match_len : node->cidr;
- for (i = 0; i < 4; i++) {
- uint32_t mask = 0, trymask;
- while (mask != 0xffffffff && cidr < max_cidr) {
- trymask = (mask >> 1) | 0x80000000;
- if ((node->bits[i] & trymask) != (key[i] & trymask))
- break;
- mask = trymask;
- ++cidr;
- }
- if (mask != 0xffffffff || cidr >= max_cidr)
+ uint8_t max = (((match_len > node->cidr) ? match_len : node->cidr) + 7) / 8;
+ uint8_t bits = 0;
+ uint8_t i, mask;
+ uint8_t *a = node->bits, *b = key;
+ for (i = 0; i < max; ++i, bits += 8) {
+ if (a[i] != b[i])
break;
}
- return cidr;
+ if (i == max)
+ return bits;
+ for (mask = 128; mask > 0; mask /= 2, ++bits) {
+ if ((a[i] & mask) != (b[i] & mask))
+ return bits;
+ }
+ BUG();
+ return bits;
}
-static int remove(struct routing_table_node **trie, uint32_t *key, uint8_t cidr)
+static int remove(struct routing_table_node **trie, uint8_t *key, uint8_t cidr)
{
struct routing_table_node *parent = NULL, *node;
node = *trie;
@@ -120,15 +105,25 @@ static bool walk_remove_by_value(struct routing_table_node **nptr, void *value)
return ret;
}
-static bool find_node(struct routing_table_node **trie, uint32_t *key, uint8_t cidr, struct routing_table_node **rnode, struct routing_table_node **explicit_container)
+static struct routing_table_node *find_node(struct routing_table_node *trie, uint8_t bits, uint8_t *key)
+{
+ struct routing_table_node *node = trie, *found = NULL;
+ while (node && match(node, key, node->cidr)) {
+ if (!node->incidental)
+ found = node;
+ if (node->cidr == bits)
+ break;
+ node = node->bit[BIT_AT(key, node->cidr + 1)];
+ }
+ return found;
+}
+
+static bool node_placement(struct routing_table_node *trie, uint8_t *key, uint8_t cidr, struct routing_table_node **rnode)
{
bool exact = false;
- struct routing_table_node *parent = NULL, *node, *explicit_node = NULL;
- node = *trie;
+ struct routing_table_node *parent = NULL, *node = trie;
while (node && node->cidr <= cidr && match(node, key, node->cidr)) {
parent = node;
- if (!node->incidental)
- explicit_node = node;
if (parent->cidr == cidr) {
exact = true;
break;
@@ -137,38 +132,36 @@ static bool find_node(struct routing_table_node **trie, uint32_t *key, uint8_t c
}
if (rnode)
*rnode = parent;
- if (explicit_container)
- *explicit_container = explicit_node;
return exact;
}
-static int add(struct routing_table_node **trie, uint32_t *key, uint8_t cidr, void *value)
+static int add(struct routing_table_node **trie, uint8_t bits, uint8_t *key, uint8_t cidr, void *value)
{
struct routing_table_node *node, *parent, *down, *newnode;
int bits_in_common;
if (!*trie) {
- node = kzalloc(sizeof(*node), GFP_KERNEL);
+ node = kzalloc(sizeof(*node) + (bits + 7) / 8, GFP_KERNEL);
if (!node)
return -ENOMEM;
node->value = value;
- memcpy(node->bits, key, 4 * ((cidr + 31) / 32));
+ memcpy(node->bits, key, (bits + 7) / 8);
node->cidr = cidr;
*trie = node;
return 0;
}
- if (find_node(trie, key, cidr, &node, NULL)) {
+ if (node_placement(*trie, key, cidr, &node)) {
/* exact match */
node->incidental = false;
node->value = value;
return 0;
}
- newnode = kzalloc(sizeof(*node), GFP_KERNEL);
+ newnode = kzalloc(sizeof(*node) + (bits + 7) / 8, GFP_KERNEL);
if (!newnode)
return -ENOMEM;
newnode->value = value;
- memcpy(newnode->bits, key, 4 * ((cidr + 31) / 32));
+ memcpy(newnode->bits, key, (bits + 7) / 8);
newnode->cidr = cidr;
if (!node)
@@ -198,14 +191,14 @@ static int add(struct routing_table_node **trie, uint32_t *key, uint8_t cidr, vo
parent->bit[BIT_AT(newnode->bits, plen)] = newnode;
} else {
/* reparent */
- node = kzalloc(sizeof(*node), GFP_KERNEL);
+ node = kzalloc(sizeof(*node) + (bits + 7) / 8, GFP_KERNEL);
if (!node) {
kfree(newnode);
return -ENOMEM;
}
node->cidr = bits_in_common;
node->incidental = true;
- memcpy(node->bits, newnode->bits, sizeof(node->bits));
+ memcpy(node->bits, newnode->bits, (bits + 7) / 8);
node->bit[BIT_AT(down->bits, node->cidr + 1)] = down;
node->bit[BIT_AT(newnode->bits, node->cidr + 1)] = newnode;
if (!parent)
@@ -222,10 +215,7 @@ static int walk_ips(struct routing_table_node *node, uint8_t ip_version, void *c
union nf_inet_addr ip;
if (!node)
return 0;
- ip.all[0] = htonl(node->bits[0]);
- ip.all[1] = htonl(node->bits[1]);
- ip.all[2] = htonl(node->bits[2]);
- ip.all[3] = htonl(node->bits[3]);
+ memcpy(ip.all, node->bits, sizeof(ip.all));
ret = func(ctx, node->value, ip, node->cidr, ip_version);
if (ret < 0)
return ret;
@@ -245,10 +235,7 @@ static int walk_ips_by_value(struct routing_table_node *node, uint8_t ip_version
if (!node)
return 0;
if (node->value == value) {
- ip.all[0] = htonl(node->bits[0]);
- ip.all[1] = htonl(node->bits[1]);
- ip.all[2] = htonl(node->bits[2]);
- ip.all[3] = htonl(node->bits[3]);
+ memcpy(ip.all, node->bits, sizeof(ip.all));
ret = func(ctx, ip, node->cidr, ip_version);
if (ret < 0)
return ret;
@@ -278,38 +265,22 @@ void routing_table_free(struct routing_table *table)
int routing_table_insert_v4(struct routing_table *table, struct in_addr ip, uint8_t cidr, void *value)
{
- uint32_t *ia = (uint32_t *)&ip;
- uint32_t mask;
int ret;
if (cidr > 32)
return -EINVAL;
- *ia = ntohl(*ia);
- mask = (cidr == 32) ? 0xffffffff : ~(0xffffffff >> cidr);
- *ia &= mask;
write_lock_bh(&table->lock);
- ret = add(&table->root4, ia, cidr, value);
+ ret = add(&table->root4, 32, (uint8_t *)&ip, cidr, value);
write_unlock_bh(&table->lock);
return ret;
}
int routing_table_insert_v6(struct routing_table *table, struct in6_addr ip, uint8_t cidr, void *value)
{
- uint32_t *ia = (uint32_t *)&ip;
- uint32_t mask;
- int splen;
- unsigned int i;
int ret;
if (cidr > 128)
return -EINVAL;
- for (i = 0; i < 4; i++) {
- splen = cidr - (i * 32);
- mask = 0;
- if (splen >= 0)
- mask = (splen >= 32) ? 0xffffffff : ~(0xffffffff >> splen);
- ia[i] = ntohl(ia[i]) & mask;
- }
write_lock_bh(&table->lock);
- ret = add(&table->root6, ia, cidr, value);
+ ret = add(&table->root6, 128, (uint8_t *)&ip, cidr, value);
write_unlock_bh(&table->lock);
return ret;
}
@@ -317,11 +288,10 @@ int routing_table_insert_v6(struct routing_table *table, struct in6_addr ip, uin
void *routing_table_lookup_v4(struct routing_table *table, struct in_addr ip)
{
void *value = NULL;
- struct routing_table_node *node = NULL;
- uint32_t *ia = (uint32_t *)&ip;
- *ia = ntohl(*ia);
+ struct routing_table_node *node;
+
read_lock_bh(&table->lock);
- find_node(&table->root4, ia, 32, NULL, &node);
+ node = find_node(table->root4, 32, (uint8_t *)&ip);
if (node)
value = node->value;
read_unlock_bh(&table->lock);
@@ -331,14 +301,10 @@ void *routing_table_lookup_v4(struct routing_table *table, struct in_addr ip)
void *routing_table_lookup_v6(struct routing_table *table, struct in6_addr ip)
{
void *value = NULL;
- struct routing_table_node *node = NULL;
- uint32_t *ia = (uint32_t *)&ip;
- ia[0] = ntohl(ia[0]);
- ia[1] = ntohl(ia[1]);
- ia[2] = ntohl(ia[2]);
- ia[3] = ntohl(ia[3]);
+ struct routing_table_node *node;
+
read_lock_bh(&table->lock);
- find_node(&table->root6, ia, 128, NULL, &node);
+ node = find_node(table->root6, 128, (uint8_t *)&ip);
if (node)
value = node->value;
read_unlock_bh(&table->lock);
@@ -348,10 +314,8 @@ void *routing_table_lookup_v6(struct routing_table *table, struct in6_addr ip)
int routing_table_remove_v4(struct routing_table *table, struct in_addr ip, uint8_t cidr)
{
int ret;
- uint32_t *ia = (uint32_t *)&ip;
- *ia = ntohl(*ia);
write_lock_bh(&table->lock);
- ret = remove(&table->root4, ia, cidr);
+ ret = remove(&table->root4, (uint8_t *)&ip, cidr);
write_unlock_bh(&table->lock);
return ret;
}
@@ -359,13 +323,8 @@ int routing_table_remove_v4(struct routing_table *table, struct in_addr ip, uint
int routing_table_remove_v6(struct routing_table *table, struct in6_addr ip, uint8_t cidr)
{
int ret;
- uint32_t *ia = (uint32_t *)&ip;
- ia[0] = ntohl(ia[0]);
- ia[1] = ntohl(ia[1]);
- ia[2] = ntohl(ia[2]);
- ia[3] = ntohl(ia[3]);
write_lock_bh(&table->lock);
- ret = remove(&table->root6, ia, cidr);
+ ret = remove(&table->root6, (uint8_t *)&ip, cidr);
write_unlock_bh(&table->lock);
return ret;
}
diff --git a/routing-table.h b/routing-table.h
index 2025985..4b624d3 100644
--- a/routing-table.h
+++ b/routing-table.h
@@ -3,7 +3,7 @@
#ifndef ROUTINGTABLE_H
#define ROUTINGTABLE_H
-#include <linux/list.h>
+#include <linux/spinlock.h>
#include <linux/ip.h>
#include <linux/ipv6.h>