aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--src/config.c17
-rw-r--r--src/routingtable.c411
-rw-r--r--src/routingtable.h8
-rw-r--r--src/selftest/routing-table.h133
-rw-r--r--src/selftest/routingtable.h504
5 files changed, 634 insertions, 439 deletions
diff --git a/src/config.c b/src/config.c
index 2736377..a5a25c9 100644
--- a/src/config.c
+++ b/src/config.c
@@ -209,20 +209,6 @@ static inline int use_data(struct data_remaining *data, size_t size)
return 0;
}
-static int calculate_ipmasks_size(void *ctx, struct wireguard_peer *peer, union nf_inet_addr ip, u8 cidr, int family)
-{
- size_t *count = ctx;
- *count += sizeof(struct wgipmask);
- return 0;
-}
-
-static size_t calculate_peers_size(struct wireguard_device *wg)
-{
- size_t len = peer_total_count(wg) * sizeof(struct wgpeer);
- routing_table_walk_ips(&wg->peer_routing_table, &len, calculate_ipmasks_size);
- return len;
-}
-
static int populate_ipmask(void *ctx, union nf_inet_addr ip, u8 cidr, int family)
{
int ret;
@@ -305,7 +291,8 @@ int config_get_device(struct wireguard_device *wg, void __user *user_device)
mutex_lock(&wg->device_update_lock);
if (!user_device) {
- ret = calculate_peers_size(wg);
+ ret = peer_total_count(wg) * sizeof(struct wgpeer)
+ + routing_table_count_nodes(&wg->peer_routing_table) * sizeof(struct wgipmask);
goto out;
}
diff --git a/src/routingtable.c b/src/routingtable.c
index 1de7727..f9c3eff 100644
--- a/src/routingtable.c
+++ b/src/routingtable.c
@@ -7,16 +7,10 @@ struct routing_table_node {
struct routing_table_node __rcu *bit[2];
struct rcu_head rcu;
struct wireguard_peer *peer;
- u8 cidr;
- u8 bit_at_a, bit_at_b;
- bool incidental;
- u8 bits[];
+ u8 cidr, bit_at_a, bit_at_b;
+ u8 bits[] __aligned(__alignof__(u64));
};
-static inline u8 bit_at(const u8 *key, u8 a, u8 b)
-{
- return (key[a] >> b) & 1;
-}
static inline void copy_and_assign_cidr(struct routing_table_node *node, const u8 *src, u8 cidr)
{
memcpy(node->bits, src, (cidr + 7) / 8);
@@ -25,67 +19,77 @@ static inline void copy_and_assign_cidr(struct routing_table_node *node, const u
node->bit_at_a = cidr / 8;
node->bit_at_b = 7 - (cidr % 8);
}
+#define choose_node(parent, key) parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1]
-/* Non-recursive RCU expansion of:
- *
- * free_node(node)
- * {
- * if (!node)
- * return;
- * free_node(node->bit[0]);
- * free_node(node->bit[1]);
- * kfree_rcu_bh(node);
- * }
- */
static void node_free_rcu(struct rcu_head *rcu)
{
kfree(container_of(rcu, struct routing_table_node, rcu));
}
-#define ref(p) rcu_access_pointer(p)
-#define push(p) do { BUG_ON(len >= 128); stack[len++] = rcu_dereference_protected(p, lockdep_is_held(lock)); } while (0)
-static void free_node(struct routing_table_node *top, struct mutex *lock)
+#define push(p, lock) ({ \
+ if (rcu_access_pointer(p)) { \
+ BUG_ON(len >= 128); \
+ stack[len++] = lock ? rcu_dereference_protected(p, lockdep_is_held((struct mutex *)lock)) : rcu_dereference_bh(p); \
+ } \
+ true; \
+})
+#define walk_prep \
+ struct routing_table_node *stack[128], *node; \
+ unsigned int len;
+#define walk(top, lock) for (len = 0, push(top, lock); len > 0 && (node = stack[--len]) && push(node->bit[0], lock) && push(node->bit[1], lock);)
+
+static void free_root_node(struct routing_table_node __rcu *top, struct mutex *lock)
{
- struct routing_table_node *stack[128];
- struct routing_table_node *node = NULL;
- struct routing_table_node *prev = NULL;
- unsigned int len = 0;
+ walk_prep;
+ walk (top, lock)
+ call_rcu_bh(&node->rcu, node_free_rcu);
+}
- if (!top)
- return;
+static size_t count_nodes(struct routing_table_node __rcu *top)
+{
+ size_t ret = 0;
+ walk_prep;
+ walk (top, NULL) {
+ if (node->peer)
+ ++ret;
+ }
+ return ret;
+}
- stack[len++] = top;
- while (len > 0) {
- node = stack[len - 1];
- if (!prev || ref(prev->bit[0]) == node || ref(prev->bit[1]) == node) {
- if (ref(node->bit[0]))
- push(node->bit[0]);
- else if (ref(node->bit[1]))
- push(node->bit[1]);
- } else if (ref(node->bit[0]) == prev) {
- if (ref(node->bit[1]))
- push(node->bit[1]);
- } else {
- call_rcu_bh(&node->rcu, node_free_rcu);
- --len;
- }
- prev = node;
+static int walk_ips_by_peer(struct routing_table_node __rcu *top, int family, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family), struct mutex *maybe_lock)
+{
+ int ret;
+ union nf_inet_addr ip = { .all = { 0 } };
+ walk_prep;
+
+ if (unlikely(!peer))
+ return 0;
+
+ walk (top, maybe_lock) {
+ if (node->peer != peer)
+ continue;
+ memcpy(ip.all, node->bits, family == AF_INET6 ? 16 : 4);
+ ret = func(ctx, ip, node->cidr, family);
+ if (ret)
+ return ret;
}
+ return 0;
}
#undef push
-#define push(p) do { BUG_ON(len >= 128); stack[len++] = p; } while (0)
-static bool walk_remove_by_peer(struct routing_table_node __rcu **top, struct wireguard_peer *peer, struct mutex *lock)
+
+#define ref(p) rcu_access_pointer(p)
+#define deref(p) rcu_dereference_protected(*p, lockdep_is_held(lock))
+#define push(p) ({ BUG_ON(len >= 128); stack[len++] = p; })
+static void walk_remove_by_peer(struct routing_table_node __rcu **top, struct wireguard_peer *peer, struct mutex *lock)
{
- struct routing_table_node __rcu **stack[128];
- struct routing_table_node __rcu **nptr;
- struct routing_table_node *node = NULL;
- struct routing_table_node *prev = NULL;
- unsigned int len = 0;
- bool ret = false;
-
- stack[len++] = top;
- while (len > 0) {
+ struct routing_table_node __rcu **stack[128], **nptr, *node, *prev;
+ unsigned int len;
+
+ if (unlikely(!peer || !ref(*top)))
+ return;
+
+ for (prev = NULL, len = 0, push(top); len > 0; prev = node) {
nptr = stack[len - 1];
- node = rcu_dereference_protected(*nptr, lockdep_is_held(lock));
+ node = deref(nptr);
if (!node) {
--len;
continue;
@@ -100,111 +104,76 @@ static bool walk_remove_by_peer(struct routing_table_node __rcu **top, struct wi
push(&node->bit[1]);
} else {
if (node->peer == peer) {
- ret = true;
node->peer = NULL;
- node->incidental = true;
if (!node->bit[0] || !node->bit[1]) {
- /* collapse (even if both are null) */
- rcu_assign_pointer(*nptr, rcu_dereference_protected(node->bit[!node->bit[0]], lockdep_is_held(lock)));
- rcu_assign_pointer(node->bit[0], NULL);
- rcu_assign_pointer(node->bit[1], NULL);
- free_node(node, lock);
+ rcu_assign_pointer(*nptr, deref(&node->bit[!ref(node->bit[0])]));
+ call_rcu_bh(&node->rcu, node_free_rcu);
+ node = deref(nptr);
}
}
--len;
}
- prev = node;
}
-
- return ret;
}
#undef ref
+#undef deref
#undef push
-static inline bool match(const struct routing_table_node *node, const u8 *key, u8 match_len)
+static inline unsigned int fls128(u64 a, u64 b)
{
- u8 full_blocks_to_match = match_len / 8;
- u8 bits_leftover = match_len % 8;
- u8 mask;
- const u8 *a = node->bits, *b = key;
- if (memcmp(a, b, full_blocks_to_match))
- return false;
- if (!bits_leftover)
- return true;
- mask = ~(0xff >> bits_leftover);
- return (a[full_blocks_to_match] & mask) == (b[full_blocks_to_match] & mask);
+ return a ? fls64(a) + 64 : fls64(b);
}
-static inline u8 common_bits(const struct routing_table_node *node, const u8 *key, u8 match_len)
+static inline u8 common_bits(const struct routing_table_node *node, const u8 *key, u8 bits)
{
- u8 max = (((match_len > node->cidr) ? match_len : node->cidr) + 7) / 8;
- u8 bits = 0;
- u8 i, mask;
- const u8 *a = node->bits, *b = key;
- for (i = 0; i < max; ++i, bits += 8) {
- if (a[i] != b[i])
- break;
- }
- if (i == max)
- return bits;
- for (mask = 128; mask > 0; mask /= 2, ++bits) {
- if ((a[i] & mask) != (b[i] & mask))
- return bits;
- }
+ if (bits == 32)
+ return 32 - fls(be32_to_cpu(*(const __be32 *)node->bits ^ *(const __be32 *)key));
+ else if (bits == 128)
+ return 128 - fls128(be64_to_cpu(*(const __be64 *)&node->bits[0] ^ *(const __be64 *)&key[0]), be64_to_cpu(*(const __be64 *)&node->bits[8] ^ *(const __be64 *)&key[8]));
BUG();
- return bits;
-}
-
-static int remove(struct routing_table_node __rcu **trie, const u8 *key, u8 cidr, struct mutex *lock)
-{
- struct routing_table_node *parent = NULL, *node;
- node = rcu_dereference_protected(*trie, lockdep_is_held(lock));
- while (node && node->cidr <= cidr && match(node, key, node->cidr)) {
- if (node->cidr == cidr) {
- /* exact match */
- node->incidental = true;
- node->peer = NULL;
- if (!node->bit[0] || !node->bit[1]) {
- /* collapse (even if both are null) */
- if (parent)
- rcu_assign_pointer(parent->bit[bit_at(key, parent->bit_at_a, parent->bit_at_b)],
- rcu_dereference_protected(node->bit[(!node->bit[0]) ? 1 : 0], lockdep_is_held(lock)));
- rcu_assign_pointer(node->bit[0], NULL);
- rcu_assign_pointer(node->bit[1], NULL);
- free_node(node, lock);
- }
- return 0;
- }
- parent = node;
- node = rcu_dereference_protected(parent->bit[bit_at(key, parent->bit_at_a, parent->bit_at_b)], lockdep_is_held(lock));
- }
- return -ENOENT;
+ return 0;
}
static inline struct routing_table_node *find_node(struct routing_table_node *trie, u8 bits, const u8 *key)
{
struct routing_table_node *node = trie, *found = NULL;
- while (node && match(node, key, node->cidr)) {
- if (!node->incidental)
+
+ while (node && common_bits(node, key, bits) >= node->cidr) {
+ if (node->peer)
found = node;
if (node->cidr == bits)
break;
- node = rcu_dereference_bh(node->bit[bit_at(key, node->bit_at_a, node->bit_at_b)]);
+ node = rcu_dereference_bh(choose_node(node, key));
}
return found;
}
-static inline bool node_placement(struct routing_table_node __rcu *trie, const u8 *key, u8 cidr, struct routing_table_node **rnode, struct mutex *lock)
+/* Returns a strong reference to a peer */
+static inline struct wireguard_peer *lookup(struct routing_table_node __rcu *root, u8 bits, const void *ip)
+{
+ struct wireguard_peer *peer = NULL;
+ struct routing_table_node *node;
+
+ rcu_read_lock_bh();
+ node = find_node(rcu_dereference_bh(root), bits, ip);
+ if (node)
+ peer = peer_get(node->peer);
+ rcu_read_unlock_bh();
+ return peer;
+}
+
+static inline bool node_placement(struct routing_table_node __rcu *trie, const u8 *key, u8 cidr, u8 bits, struct routing_table_node **rnode, struct mutex *lock)
{
bool exact = false;
struct routing_table_node *parent = NULL, *node = rcu_dereference_protected(trie, lockdep_is_held(lock));
- while (node && node->cidr <= cidr && match(node, key, node->cidr)) {
+
+ while (node && node->cidr <= cidr && common_bits(node, key, bits) >= node->cidr) {
parent = node;
if (parent->cidr == cidr) {
exact = true;
break;
}
- node = rcu_dereference_protected(parent->bit[bit_at(key, parent->bit_at_a, parent->bit_at_b)], lockdep_is_held(lock));
+ node = rcu_dereference_protected(choose_node(parent, key), lockdep_is_held(lock));
}
if (rnode)
*rnode = parent;
@@ -224,9 +193,7 @@ static int add(struct routing_table_node __rcu **trie, u8 bits, const u8 *key, u
rcu_assign_pointer(*trie, node);
return 0;
}
- if (node_placement(*trie, key, cidr, &node, lock)) {
- /* exact match */
- node->incidental = false;
+ if (node_placement(*trie, key, cidr, bits, &node, lock)) {
node->peer = peer;
return 0;
}
@@ -239,112 +206,40 @@ static int add(struct routing_table_node __rcu **trie, u8 bits, const u8 *key, u
if (!node)
down = rcu_dereference_protected(*trie, lockdep_is_held(lock));
- else
- down = rcu_dereference_protected(node->bit[bit_at(key, node->bit_at_a, node->bit_at_b)], lockdep_is_held(lock));
- if (!down) {
- rcu_assign_pointer(node->bit[bit_at(key, node->bit_at_a, node->bit_at_b)], newnode);
- return 0;
+ else {
+ down = rcu_dereference_protected(choose_node(node, key), lockdep_is_held(lock));
+ if (!down) {
+ rcu_assign_pointer(choose_node(node, key), newnode);
+ return 0;
+ }
}
- /* here we must be inserting between node and down */
- cidr = min(cidr, common_bits(down, key, cidr));
+ cidr = min(cidr, common_bits(down, key, bits));
parent = node;
- /* we either need to make a new branch above down and newnode
- * or newnode can be the branch. newnode can be the branch if
- * its cidr == bits_in_common */
if (newnode->cidr == cidr) {
- /* newnode can be the branch */
- rcu_assign_pointer(newnode->bit[bit_at(down->bits, newnode->bit_at_a, newnode->bit_at_b)], down);
+ rcu_assign_pointer(choose_node(newnode, down->bits), down);
if (!parent)
rcu_assign_pointer(*trie, newnode);
else
- rcu_assign_pointer(parent->bit[bit_at(newnode->bits, parent->bit_at_a, parent->bit_at_b)], newnode);
+ rcu_assign_pointer(choose_node(parent, newnode->bits), newnode);
} else {
- /* reparent */
node = kzalloc(sizeof(*node) + (bits + 7) / 8, GFP_KERNEL);
if (!node) {
kfree(newnode);
return -ENOMEM;
}
- node->incidental = true;
copy_and_assign_cidr(node, newnode->bits, cidr);
- rcu_assign_pointer(node->bit[bit_at(down->bits, node->bit_at_a, node->bit_at_b)], down);
- rcu_assign_pointer(node->bit[bit_at(newnode->bits, node->bit_at_a, node->bit_at_b)], newnode);
+ rcu_assign_pointer(choose_node(node, down->bits), down);
+ rcu_assign_pointer(choose_node(node, newnode->bits), newnode);
if (!parent)
rcu_assign_pointer(*trie, node);
else
- rcu_assign_pointer(parent->bit[bit_at(node->bits, parent->bit_at_a, parent->bit_at_b)], node);
+ rcu_assign_pointer(choose_node(parent, node->bits), node);
}
return 0;
}
-#define push(p) do { \
- struct routing_table_node *next = (maybe_lock ? rcu_dereference_protected(p, lockdep_is_held(maybe_lock)) : rcu_dereference_bh(p)); \
- if (next) { \
- BUG_ON(len >= 128); \
- stack[len++] = next; \
- } \
-} while (0)
-static int walk_ips(struct routing_table_node *top, int family, void *ctx, int (*func)(void *ctx, struct wireguard_peer *peer, union nf_inet_addr ip, u8 cidr, int family), struct mutex *maybe_lock)
-{
- int ret;
- union nf_inet_addr ip = { .all = { 0 } };
- struct routing_table_node *stack[128];
- struct routing_table_node *node;
- unsigned int len = 0;
- struct wireguard_peer *peer;
-
- if (!top)
- return 0;
-
- stack[len++] = top;
- while (len > 0) {
- node = stack[--len];
-
- peer = peer_get(node->peer);
- if (peer) {
- memcpy(ip.all, node->bits, family == AF_INET6 ? 16 : 4);
- ret = func(ctx, peer, ip, node->cidr, family);
- peer_put(peer);
- if (ret)
- return ret;
- }
-
- push(node->bit[0]);
- push(node->bit[1]);
- }
- return 0;
-}
-static int walk_ips_by_peer(struct routing_table_node *top, int family, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family), struct mutex *maybe_lock)
-{
- int ret;
- union nf_inet_addr ip = { .all = { 0 } };
- struct routing_table_node *stack[128];
- struct routing_table_node *node;
- unsigned int len = 0;
-
- if (!top)
- return 0;
-
- stack[len++] = top;
- while (len > 0) {
- node = stack[--len];
-
- if (node->peer == peer) {
- memcpy(ip.all, node->bits, family == AF_INET6 ? 16 : 4);
- ret = func(ctx, ip, node->cidr, family);
- if (ret)
- return ret;
- }
-
- push(node->bit[0]);
- push(node->bit[1]);
- }
- return 0;
-}
-#undef push
-
void routing_table_init(struct routing_table *table)
{
memset(table, 0, sizeof(struct routing_table));
@@ -354,9 +249,9 @@ void routing_table_init(struct routing_table *table)
void routing_table_free(struct routing_table *table)
{
mutex_lock(&table->table_update_lock);
- free_node(rcu_dereference_protected(table->root4, lockdep_is_held(&table->table_update_lock)), &table->table_update_lock);
+ free_root_node(table->root4, &table->table_update_lock);
rcu_assign_pointer(table->root4, NULL);
- free_node(rcu_dereference_protected(table->root6, lockdep_is_held(&table->table_update_lock)), &table->table_update_lock);
+ free_root_node(table->root6, &table->table_update_lock);
rcu_assign_pointer(table->root6, NULL);
mutex_unlock(&table->table_update_lock);
}
@@ -364,7 +259,7 @@ void routing_table_free(struct routing_table *table)
int routing_table_insert_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer)
{
int ret;
- if (cidr > 32)
+ if (unlikely(cidr > 32 || !peer))
return -EINVAL;
mutex_lock(&table->table_update_lock);
ret = add(&table->root4, 32, (const u8 *)ip, cidr, peer, &table->table_update_lock);
@@ -375,7 +270,7 @@ int routing_table_insert_v4(struct routing_table *table, const struct in_addr *i
int routing_table_insert_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer)
{
int ret;
- if (cidr > 128)
+ if (unlikely(cidr > 128 || !peer))
return -EINVAL;
mutex_lock(&table->table_update_lock);
ret = add(&table->root6, 128, (const u8 *)ip, cidr, peer, &table->table_update_lock);
@@ -383,73 +278,19 @@ int routing_table_insert_v6(struct routing_table *table, const struct in6_addr *
return ret;
}
-/* Returns a strong reference to a peer */
-inline struct wireguard_peer *routing_table_lookup_v4(struct routing_table *table, const struct in_addr *ip)
+void routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer)
{
- struct wireguard_peer *peer = NULL;
- struct routing_table_node *node;
-
- rcu_read_lock_bh();
- node = find_node(rcu_dereference_bh(table->root4), 32, (const u8 *)ip);
- if (node)
- peer = peer_get(node->peer);
- rcu_read_unlock_bh();
- return peer;
-}
-
-/* Returns a strong reference to a peer */
-inline struct wireguard_peer *routing_table_lookup_v6(struct routing_table *table, const struct in6_addr *ip)
-{
- struct wireguard_peer *peer = NULL;
- struct routing_table_node *node;
-
- rcu_read_lock_bh();
- node = find_node(rcu_dereference_bh(table->root6), 128, (const u8 *)ip);
- if (node)
- peer = peer_get(node->peer);
- rcu_read_unlock_bh();
- return peer;
-}
-
-int routing_table_remove_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr)
-{
- int ret;
- mutex_lock(&table->table_update_lock);
- ret = remove(&table->root4, (const u8 *)ip, cidr, &table->table_update_lock);
- mutex_unlock(&table->table_update_lock);
- return ret;
-}
-
-int routing_table_remove_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr)
-{
- int ret;
mutex_lock(&table->table_update_lock);
- ret = remove(&table->root6, (const u8 *)ip, cidr, &table->table_update_lock);
+ walk_remove_by_peer(&table->root4, peer, &table->table_update_lock);
+ walk_remove_by_peer(&table->root6, peer, &table->table_update_lock);
mutex_unlock(&table->table_update_lock);
- return ret;
}
-int routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer)
+size_t routing_table_count_nodes(struct routing_table *table)
{
- bool found;
- mutex_lock(&table->table_update_lock);
- found = walk_remove_by_peer(&table->root4, peer, &table->table_update_lock) | walk_remove_by_peer(&table->root6, peer, &table->table_update_lock);
- mutex_unlock(&table->table_update_lock);
- return found ? 0 : -EINVAL;
-}
-
-/* Calls func with a strong reference to each peer, before putting it when the function has completed.
- * It's thus up to the caller to call peer_put on it if it's going to be used elsewhere after or stored. */
-int routing_table_walk_ips(struct routing_table *table, void *ctx, int (*func)(void *ctx, struct wireguard_peer *peer, union nf_inet_addr ip, u8 cidr, int family))
-{
- int ret;
- rcu_read_lock_bh();
- ret = walk_ips(rcu_dereference_bh(table->root4), AF_INET, ctx, func, NULL);
- rcu_read_unlock_bh();
- if (ret)
- return ret;
+ size_t ret;
rcu_read_lock_bh();
- ret = walk_ips(rcu_dereference_bh(table->root6), AF_INET6, ctx, func, NULL);
+ ret = count_nodes(table->root4) + count_nodes(table->root6);
rcu_read_unlock_bh();
return ret;
}
@@ -458,12 +299,12 @@ int routing_table_walk_ips_by_peer(struct routing_table *table, void *ctx, struc
{
int ret;
rcu_read_lock_bh();
- ret = walk_ips_by_peer(rcu_dereference_bh(table->root4), AF_INET, ctx, peer, func, NULL);
+ ret = walk_ips_by_peer(table->root4, AF_INET, ctx, peer, func, NULL);
rcu_read_unlock_bh();
if (ret)
return ret;
rcu_read_lock_bh();
- ret = walk_ips_by_peer(rcu_dereference_bh(table->root6), AF_INET6, ctx, peer, func, NULL);
+ ret = walk_ips_by_peer(table->root6, AF_INET6, ctx, peer, func, NULL);
rcu_read_unlock_bh();
return ret;
}
@@ -472,12 +313,12 @@ int routing_table_walk_ips_by_peer_sleepable(struct routing_table *table, void *
{
int ret;
mutex_lock(&table->table_update_lock);
- ret = walk_ips_by_peer(rcu_dereference_protected(table->root4, lockdep_is_held(&table->table_update_lock)), AF_INET, ctx, peer, func, &table->table_update_lock);
+ ret = walk_ips_by_peer(table->root4, AF_INET, ctx, peer, func, &table->table_update_lock);
mutex_unlock(&table->table_update_lock);
if (ret)
return ret;
mutex_lock(&table->table_update_lock);
- ret = walk_ips_by_peer(rcu_dereference_protected(table->root6, lockdep_is_held(&table->table_update_lock)), AF_INET6, ctx, peer, func, &table->table_update_lock);
+ ret = walk_ips_by_peer(table->root6, AF_INET6, ctx, peer, func, &table->table_update_lock);
mutex_unlock(&table->table_update_lock);
return ret;
}
@@ -499,9 +340,9 @@ struct wireguard_peer *routing_table_lookup_dst(struct routing_table *table, str
if (unlikely(!has_valid_ip_header(skb)))
return NULL;
if (ip_hdr(skb)->version == 4)
- return routing_table_lookup_v4(table, (struct in_addr *)&ip_hdr(skb)->daddr);
+ return lookup(table->root4, 32, &ip_hdr(skb)->daddr);
else if (ip_hdr(skb)->version == 6)
- return routing_table_lookup_v6(table, &ipv6_hdr(skb)->daddr);
+ return lookup(table->root6, 128, &ipv6_hdr(skb)->daddr);
return NULL;
}
@@ -511,10 +352,10 @@ struct wireguard_peer *routing_table_lookup_src(struct routing_table *table, str
if (unlikely(!has_valid_ip_header(skb)))
return NULL;
if (ip_hdr(skb)->version == 4)
- return routing_table_lookup_v4(table, (struct in_addr *)&ip_hdr(skb)->saddr);
+ return lookup(table->root4, 32, &ip_hdr(skb)->saddr);
else if (ip_hdr(skb)->version == 6)
- return routing_table_lookup_v6(table, &ipv6_hdr(skb)->saddr);
+ return lookup(table->root6, 128, &ipv6_hdr(skb)->saddr);
return NULL;
}
-#include "selftest/routing-table.h"
+#include "selftest/routingtable.h"
diff --git a/src/routingtable.h b/src/routingtable.h
index adcc632..4fdf410 100644
--- a/src/routingtable.h
+++ b/src/routingtable.h
@@ -20,16 +20,12 @@ void routing_table_init(struct routing_table *table);
void routing_table_free(struct routing_table *table);
int routing_table_insert_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer);
int routing_table_insert_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer);
-int routing_table_remove_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr);
-int routing_table_remove_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr);
-int routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer);
-int routing_table_walk_ips(struct routing_table *table, void *ctx, int (*func)(void *ctx, struct wireguard_peer *peer, union nf_inet_addr ip, u8 cidr, int family));
+void routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer);
+size_t routing_table_count_nodes(struct routing_table *table);
int routing_table_walk_ips_by_peer(struct routing_table *table, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family));
int routing_table_walk_ips_by_peer_sleepable(struct routing_table *table, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family));
/* These return a strong reference to a peer: */
-struct wireguard_peer *routing_table_lookup_v4(struct routing_table *table, const struct in_addr *ip);
-struct wireguard_peer *routing_table_lookup_v6(struct routing_table *table, const struct in6_addr *ip);
struct wireguard_peer *routing_table_lookup_dst(struct routing_table *table, struct sk_buff *skb);
struct wireguard_peer *routing_table_lookup_src(struct routing_table *table, struct sk_buff *skb);
diff --git a/src/selftest/routing-table.h b/src/selftest/routing-table.h
deleted file mode 100644
index a603401..0000000
--- a/src/selftest/routing-table.h
+++ /dev/null
@@ -1,133 +0,0 @@
-/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */
-
-#ifdef DEBUG
-static inline struct in_addr *ip4(u8 a, u8 b, u8 c, u8 d)
-{
- static struct in_addr ip;
- u8 *split = (u8 *)&ip;
- split[0] = a;
- split[1] = b;
- split[2] = c;
- split[3] = d;
- return &ip;
-}
-static inline struct in6_addr *ip6(u32 a, u32 b, u32 c, u32 d)
-{
- static struct in6_addr ip;
- __be32 *split = (__be32 *)&ip;
- split[0] = cpu_to_be32(a);
- split[1] = cpu_to_be32(b);
- split[2] = cpu_to_be32(c);
- split[3] = cpu_to_be32(d);
- return &ip;
-}
-
-bool routing_table_selftest(void)
-{
- struct routing_table t;
- struct wireguard_peer *a = NULL, *b = NULL, *c = NULL, *d = NULL, *e = NULL, *f = NULL, *g = NULL, *h = NULL;
- size_t i = 0;
- bool success = false;
- struct in6_addr ip;
- __be64 part;
-
- routing_table_init(&t);
-#define init_peer(name) do { name = kzalloc(sizeof(struct wireguard_peer), GFP_KERNEL); if (!name) goto free; kref_init(&name->refcount); } while (0)
- init_peer(a);
- init_peer(b);
- init_peer(c);
- init_peer(d);
- init_peer(e);
- init_peer(f);
- init_peer(g);
- init_peer(h);
-#undef init_peer
-
-#define insert(version, mem, ipa, ipb, ipc, ipd, cidr) routing_table_insert_v##version(&t, ip##version(ipa, ipb, ipc, ipd), cidr, mem)
- insert(4, a, 192, 168, 4, 0, 24);
- insert(4, b, 192, 168, 4, 4, 32);
- insert(4, c, 192, 168, 0, 0, 16);
- insert(4, d, 192, 95, 5, 64, 27);
- insert(4, c, 192, 95, 5, 65, 27); /* replaces previous entry, and maskself is required */
- insert(6, d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128);
- insert(6, c, 0x26075300, 0x60006b00, 0, 0, 64);
- insert(4, e, 0, 0, 0, 0, 0);
- insert(6, e, 0, 0, 0, 0, 0);
- insert(6, f, 0, 0, 0, 0, 0); /* replaces previous entry */
- insert(6, g, 0x24046800, 0, 0, 0, 32);
- insert(6, h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64); /* maskself is required */
- insert(6, a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128);
- insert(4, g, 64, 15, 112, 0, 20);
- insert(4, h, 64, 15, 123, 211, 25); /* maskself is required */
- insert(4, a, 10, 0, 0, 0, 25);
- insert(4, b, 10, 0, 0, 128, 25);
- insert(4, a, 10, 1, 0, 0, 30);
- insert(4, b, 10, 1, 0, 4, 30);
- insert(4, c, 10, 1, 0, 8, 29);
- insert(4, d, 10, 1, 0, 16, 29);
-#undef insert
-
- success = true;
-#define test(version, mem, ipa, ipb, ipc, ipd) do { \
- bool _s = routing_table_lookup_v##version(&t, ip##version(ipa, ipb, ipc, ipd)) == mem; \
- ++i; \
- if (!_s) { \
- pr_info("routing table self-test %zu: FAIL\n", i); \
- success = false; \
- } \
-} while (0)
- test(4, a, 192, 168, 4, 20);
- test(4, a, 192, 168, 4, 0);
- test(4, b, 192, 168, 4, 4);
- test(4, c, 192, 168, 200, 182);
- test(4, c, 192, 95, 5, 68);
- test(4, e, 192, 95, 5, 96);
- test(6, d, 0x26075300, 0x60006b00, 0, 0xc05f0543);
- test(6, c, 0x26075300, 0x60006b00, 0, 0xc02e01ee);
- test(6, f, 0x26075300, 0x60006b01, 0, 0);
- test(6, g, 0x24046800, 0x40040806, 0, 0x1006);
- test(6, g, 0x24046800, 0x40040806, 0x1234, 0x5678);
- test(6, f, 0x240467ff, 0x40040806, 0x1234, 0x5678);
- test(6, f, 0x24046801, 0x40040806, 0x1234, 0x5678);
- test(6, h, 0x24046800, 0x40040800, 0x1234, 0x5678);
- test(6, h, 0x24046800, 0x40040800, 0, 0);
- test(6, h, 0x24046800, 0x40040800, 0x10101010, 0x10101010);
- test(6, a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef);
- test(4, g, 64, 15, 116, 26);
- test(4, g, 64, 15, 127, 3);
- test(4, g, 64, 15, 123, 1);
- test(4, h, 64, 15, 123, 128);
- test(4, h, 64, 15, 123, 129);
- test(4, a, 10, 0, 0, 52);
- test(4, b, 10, 0, 0, 220);
- test(4, a, 10, 1, 0, 2);
- test(4, b, 10, 1, 0, 6);
- test(4, c, 10, 1, 0, 10);
- test(4, d, 10, 1, 0, 20);
-#undef test
-
- /* These will hit the BUG_ON(len >= 128) in free_node if something goes wrong. */
- for (i = 0; i < 128; ++i) {
- part = cpu_to_be64(~(1LLU << (i % 64)));
- memset(&ip, 0xff, 16);
- memcpy((u8 *)&ip + (i < 64) * 8, &part, 8);
- routing_table_insert_v6(&t, &ip, 128, a);
- }
-
- if (success)
- pr_info("routing table self-tests: pass\n");
-
-free:
- routing_table_free(&t);
- kfree(a);
- kfree(b);
- kfree(c);
- kfree(d);
- kfree(e);
- kfree(f);
- kfree(g);
- kfree(h);
-
- return success;
-}
-#endif
diff --git a/src/selftest/routingtable.h b/src/selftest/routingtable.h
new file mode 100644
index 0000000..0915e65
--- /dev/null
+++ b/src/selftest/routingtable.h
@@ -0,0 +1,504 @@
+/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */
+
+#ifdef DEBUG
+
+#ifdef DEBUG_PRINT_TRIE_GRAPHVIZ
+#include <linux/siphash.h>
+static void print_node(struct routing_table_node *node, u8 bits)
+{
+ u32 color = 0;
+ char *style = "dotted";
+ char *fmt_connection = KERN_DEBUG "\t\"%p/%d\" -> \"%p/%d\";\n";
+ char *fmt_declaration = KERN_DEBUG "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n";
+ if (bits == 32) {
+ fmt_connection = KERN_DEBUG "\t\"%pI4/%d\" -> \"%pI4/%d\";\n";
+ fmt_declaration = KERN_DEBUG "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n";
+ } else if (bits == 128) {
+ fmt_connection = KERN_DEBUG "\t\"%pI6/%d\" -> \"%pI6/%d\";\n";
+ fmt_declaration = KERN_DEBUG "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n";
+ }
+ if (node->peer) {
+ hsiphash_key_t key = { 0 };
+ memcpy(&key, &node->peer, sizeof(node->peer));
+ color = hsiphash_1u32(0xdeadbeef, &key) % 200 << 16 | hsiphash_1u32(0xbabecafe, &key) % 200 << 8 | hsiphash_1u32(0xabad1dea, &key) % 200;
+ style = "bold";
+ }
+ printk(fmt_declaration, node->bits, node->cidr, style, color);
+ if (node->bit[0]) {
+ printk(fmt_connection, node->bits, node->cidr, node->bit[0]->bits, node->bit[0]->cidr);
+ print_node(node->bit[0], bits);
+ }
+ if (node->bit[1]) {
+ printk(fmt_connection, node->bits, node->cidr, node->bit[1]->bits, node->bit[1]->cidr);
+ print_node(node->bit[1], bits);
+ }
+}
+static void print_tree(struct routing_table_node *top, u8 bits)
+{
+ printk(KERN_DEBUG "digraph trie {\n");
+ print_node(top, bits);
+ printk(KERN_DEBUG "}\n");
+}
+#endif
+
+#ifdef DEBUG_RANDOM_TRIE
+#define NUM_PEERS 2000
+#define NUM_RAND_ROUTES 400
+#define NUM_MUTATED_ROUTES 100
+#define NUM_QUERIES (NUM_RAND_ROUTES * NUM_MUTATED_ROUTES * 30)
+#include <linux/random.h>
+struct horrible_routing_table {
+ struct hlist_head head;
+};
+struct horrible_routing_table_node {
+ struct hlist_node table;
+ union nf_inet_addr ip;
+ union nf_inet_addr mask;
+ uint8_t ip_version;
+ void *value;
+};
+static void horrible_routing_table_init(struct horrible_routing_table *table)
+{
+ INIT_HLIST_HEAD(&table->head);
+}
+static void horrible_routing_table_free(struct horrible_routing_table *table)
+{
+ struct hlist_node *h;
+ struct horrible_routing_table_node *node;
+ hlist_for_each_entry_safe(node, h, &table->head, table) {
+ hlist_del(&node->table);
+ kfree(node);
+ };
+}
+static inline union nf_inet_addr horrible_cidr_to_mask(uint8_t cidr)
+{
+ union nf_inet_addr mask;
+ memset(&mask, 0x00, 128 / 8);
+ memset(&mask, 0xff, cidr / 8);
+ if (cidr % 32)
+ mask.all[cidr / 32] = htonl((0xFFFFFFFFUL << (32 - (cidr % 32))) & 0xFFFFFFFFUL);
+ return mask;
+}
+static inline uint8_t horrible_mask_to_cidr(union nf_inet_addr subnet)
+{
+ return hweight32(subnet.all[0])
+ + hweight32(subnet.all[1])
+ + hweight32(subnet.all[2])
+ + hweight32(subnet.all[3]);
+}
+static inline void horrible_mask_self(struct horrible_routing_table_node *node)
+{
+ if (node->ip_version == 4)
+ node->ip.ip &= node->mask.ip;
+ else if (node->ip_version == 6) {
+ node->ip.ip6[0] &= node->mask.ip6[0];
+ node->ip.ip6[1] &= node->mask.ip6[1];
+ node->ip.ip6[2] &= node->mask.ip6[2];
+ node->ip.ip6[3] &= node->mask.ip6[3];
+ }
+}
+static inline bool horrible_match_v4(const struct horrible_routing_table_node *node, struct in_addr *ip)
+{
+ return (ip->s_addr & node->mask.ip) == node->ip.ip;
+}
+static inline bool horrible_match_v6(const struct horrible_routing_table_node *node, struct in6_addr *ip)
+{
+ return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) == node->ip.ip6[0] &&
+ (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) == node->ip.ip6[1] &&
+ (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) == node->ip.ip6[2] &&
+ (ip->in6_u.u6_addr32[3] & node->mask.ip6[3]) == node->ip.ip6[3];
+}
+static void horrible_insert_ordered(struct horrible_routing_table *table, struct horrible_routing_table_node *node)
+{
+ struct horrible_routing_table_node *other = NULL, *where = NULL;
+ uint8_t my_cidr = horrible_mask_to_cidr(node->mask);
+ hlist_for_each_entry(other, &table->head, table) {
+ if (!memcmp(&other->mask, &node->mask, sizeof(union nf_inet_addr)) &&
+ !memcmp(&other->ip, &node->ip, sizeof(union nf_inet_addr)) &&
+ other->ip_version == node->ip_version) {
+ other->value = node->value;
+ kfree(node);
+ return;
+ }
+ where = other;
+ if (horrible_mask_to_cidr(other->mask) <= my_cidr)
+ break;
+ }
+ if (!other && !where)
+ hlist_add_head(&node->table, &table->head);
+ else if (!other)
+ hlist_add_behind(&node->table, &where->table);
+ else
+ hlist_add_before(&node->table, &where->table);
+}
+static int horrible_routing_table_insert_v4(struct horrible_routing_table *table, struct in_addr *ip, uint8_t cidr, void *value)
+{
+ struct horrible_routing_table_node *node = kzalloc(sizeof(struct horrible_routing_table_node), GFP_KERNEL);
+ if (!node)
+ return -ENOMEM;
+ node->ip.in = *ip;
+ node->mask = horrible_cidr_to_mask(cidr);
+ node->ip_version = 4;
+ node->value = value;
+ horrible_mask_self(node);
+ horrible_insert_ordered(table, node);
+ return 0;
+}
+static int horrible_routing_table_insert_v6(struct horrible_routing_table *table, struct in6_addr *ip, uint8_t cidr, void *value)
+{
+ struct horrible_routing_table_node *node = kzalloc(sizeof(struct horrible_routing_table_node), GFP_KERNEL);
+ if (!node)
+ return -ENOMEM;
+ node->ip.in6 = *ip;
+ node->mask = horrible_cidr_to_mask(cidr);
+ node->ip_version = 6;
+ node->value = value;
+ horrible_mask_self(node);
+ horrible_insert_ordered(table, node);
+ return 0;
+}
+static void *horrible_routing_table_lookup_v4(struct horrible_routing_table *table, struct in_addr *ip)
+{
+ struct horrible_routing_table_node *node;
+ void *ret = NULL;
+ hlist_for_each_entry(node, &table->head, table) {
+ if (node->ip_version != 4)
+ continue;
+ if (horrible_match_v4(node, ip)) {
+ ret = node->value;
+ break;
+ }
+ };
+ return ret;
+}
+static void *horrible_routing_table_lookup_v6(struct horrible_routing_table *table, struct in6_addr *ip)
+{
+ struct horrible_routing_table_node *node;
+ void *ret = NULL;
+ hlist_for_each_entry(node, &table->head, table) {
+ if (node->ip_version != 6)
+ continue;
+ if (horrible_match_v6(node, ip)) {
+ ret = node->value;
+ break;
+ }
+ };
+ return ret;
+}
+
+static bool randomized_test(void)
+{
+ bool ret = false;
+ unsigned int i, j, k, mutate_amount, cidr;
+ struct wireguard_peer **peers, *peer;
+ struct routing_table t;
+ struct horrible_routing_table h;
+ u8 ip[16], mutate_mask[16], mutated[16];
+
+ routing_table_init(&t);
+ horrible_routing_table_init(&h);
+
+ peers = kcalloc(NUM_PEERS, sizeof(struct wireguard_peer *), GFP_KERNEL);
+ if (!peers) {
+ pr_info("routing table random self-test: out of memory\n");
+ goto free;
+ }
+ for (i = 0; i < NUM_PEERS; ++i) {
+ peers[i] = kzalloc(sizeof(struct wireguard_peer), GFP_KERNEL);
+ if (!peers[i]) {
+ pr_info("routing table random self-test: out of memory\n");
+ goto free;
+ }
+ kref_init(&peers[i]->refcount);
+ }
+
+ for (i = 0; i < NUM_RAND_ROUTES; ++i) {
+ prandom_bytes(ip, 4);
+ cidr = prandom_u32_max(32) + 1;
+ peer = peers[prandom_u32_max(NUM_PEERS)];
+ if (routing_table_insert_v4(&t, (struct in_addr *)ip, cidr, peer) < 0) {
+ pr_info("routing table random self-test: out of memory\n");
+ goto free;
+ }
+ if (horrible_routing_table_insert_v4(&h, (struct in_addr *)ip, cidr, peer) < 0) {
+ pr_info("routing table random self-test: out of memory\n");
+ goto free;
+ }
+ for (j = 0; j < NUM_MUTATED_ROUTES; ++j) {
+ memcpy(mutated, ip, 4);
+ prandom_bytes(mutate_mask, 4);
+ mutate_amount = prandom_u32_max(32);
+ for (k = 0; k < mutate_amount / 8; ++k)
+ mutate_mask[k] = 0xff;
+ mutate_mask[k] = 0xff << ((8 - (mutate_amount % 8)) % 8);
+ for (; k < 4; ++k)
+ mutate_mask[k] = 0;
+ for (k = 0; k < 4; ++k)
+ mutated[k] = (mutated[k] & mutate_mask[k]) | (~mutate_mask[k] & prandom_u32_max(256));
+ cidr = prandom_u32_max(32) + 1;
+ peer = peers[prandom_u32_max(NUM_PEERS)];
+ if (routing_table_insert_v4(&t, (struct in_addr *)mutated, cidr, peer) < 0) {
+ pr_info("routing table random self-test: out of memory\n");
+ goto free;
+ }
+ if (horrible_routing_table_insert_v4(&h, (struct in_addr *)mutated, cidr, peer)) {
+ pr_info("routing table random self-test: out of memory\n");
+ goto free;
+ }
+ }
+ }
+
+ for (i = 0; i < NUM_RAND_ROUTES; ++i) {
+ prandom_bytes(ip, 16);
+ cidr = prandom_u32_max(128) + 1;
+ peer = peers[prandom_u32_max(NUM_PEERS)];
+ if (routing_table_insert_v6(&t, (struct in6_addr *)ip, cidr, peer) < 0) {
+ pr_info("routing table random self-test: out of memory\n");
+ goto free;
+ }
+ if (horrible_routing_table_insert_v6(&h, (struct in6_addr *)ip, cidr, peer) < 0) {
+ pr_info("routing table random self-test: out of memory\n");
+ goto free;
+ }
+ for (j = 0; j < NUM_MUTATED_ROUTES; ++j) {
+ memcpy(mutated, ip, 16);
+ prandom_bytes(mutate_mask, 16);
+ mutate_amount = prandom_u32_max(128);
+ for (k = 0; k < mutate_amount / 8; ++k)
+ mutate_mask[k] = 0xff;
+ mutate_mask[k] = 0xff << ((8 - (mutate_amount % 8)) % 8);
+ for (; k < 4; ++k)
+ mutate_mask[k] = 0;
+ for (k = 0; k < 4; ++k)
+ mutated[k] = (mutated[k] & mutate_mask[k]) | (~mutate_mask[k] & prandom_u32_max(256));
+ cidr = prandom_u32_max(128) + 1;
+ peer = peers[prandom_u32_max(NUM_PEERS)];
+ if (routing_table_insert_v6(&t, (struct in6_addr *)mutated, cidr, peer) < 0) {
+ pr_info("routing table random self-test: out of memory\n");
+ goto free;
+ }
+ if (horrible_routing_table_insert_v6(&h, (struct in6_addr *)mutated, cidr, peer)) {
+ pr_info("routing table random self-test: out of memory\n");
+ goto free;
+ }
+ }
+ }
+
+#ifdef DEBUG_PRINT_TRIE_GRAPHVIZ
+ print_tree(t.root4, 32);
+ print_tree(t.root6, 128);
+#endif
+
+ for (i = 0; i < NUM_QUERIES; ++i) {
+ prandom_bytes(ip, 4);
+ if (lookup(t.root4, 32, ip) != horrible_routing_table_lookup_v4(&h, (struct in_addr *)ip)) {
+ pr_info("routing table random self-test: FAIL\n");
+ goto free;
+ }
+ }
+
+ for (i = 0; i < NUM_QUERIES; ++i) {
+ prandom_bytes(ip, 16);
+ if (lookup(t.root6, 128, ip) != horrible_routing_table_lookup_v6(&h, (struct in6_addr *)ip)) {
+ pr_info("routing table random self-test: FAIL\n");
+ goto free;
+ }
+ }
+ ret = true;
+
+free:
+ routing_table_free(&t);
+ horrible_routing_table_free(&h);
+ if (peers) {
+ for (i = 0; i < NUM_PEERS; ++i)
+ kfree(peers[i]);
+ }
+ kfree(peers);
+ return ret;
+}
+#endif
+
+static inline struct in_addr *ip4(u8 a, u8 b, u8 c, u8 d)
+{
+ static struct in_addr ip;
+ u8 *split = (u8 *)&ip;
+ split[0] = a;
+ split[1] = b;
+ split[2] = c;
+ split[3] = d;
+ return &ip;
+}
+static inline struct in6_addr *ip6(u32 a, u32 b, u32 c, u32 d)
+{
+ static struct in6_addr ip;
+ __be32 *split = (__be32 *)&ip;
+ split[0] = cpu_to_be32(a);
+ split[1] = cpu_to_be32(b);
+ split[2] = cpu_to_be32(c);
+ split[3] = cpu_to_be32(d);
+ return &ip;
+}
+
+#define init_peer(name) do { \
+ name = kzalloc(sizeof(struct wireguard_peer), GFP_KERNEL); \
+ if (!name) { \
+ pr_info("routing table self-test: out of memory\n"); \
+ goto free; \
+ } \
+ kref_init(&name->refcount); \
+} while (0)
+
+#define insert(version, mem, ipa, ipb, ipc, ipd, cidr) \
+ routing_table_insert_v##version(&t, ip##version(ipa, ipb, ipc, ipd), cidr, mem)
+
+#define maybe_fail \
+ ++i; \
+ if (!_s) { \
+ pr_info("routing table self-test %zu: FAIL\n", i); \
+ success = false; \
+ }
+
+#define test(version, mem, ipa, ipb, ipc, ipd) do { \
+ bool _s = lookup(t.root##version, version == 4 ? 32 : 128, ip##version(ipa, ipb, ipc, ipd)) == mem; \
+ maybe_fail \
+} while (0)
+
+#define test_negative(version, mem, ipa, ipb, ipc, ipd) do { \
+ bool _s = lookup(t.root##version, version == 4 ? 32 : 128, ip##version(ipa, ipb, ipc, ipd)) != mem; \
+ maybe_fail \
+} while (0)
+
+bool routing_table_selftest(void)
+{
+ struct routing_table t;
+ struct wireguard_peer *a = NULL, *b = NULL, *c = NULL, *d = NULL, *e = NULL, *f = NULL, *g = NULL, *h = NULL;
+ size_t i = 0;
+ bool success = false;
+ struct in6_addr ip;
+ __be64 part;
+
+ routing_table_init(&t);
+ init_peer(a);
+ init_peer(b);
+ init_peer(c);
+ init_peer(d);
+ init_peer(e);
+ init_peer(f);
+ init_peer(g);
+ init_peer(h);
+
+ insert(4, a, 192, 168, 4, 0, 24);
+ insert(4, b, 192, 168, 4, 4, 32);
+ insert(4, c, 192, 168, 0, 0, 16);
+ insert(4, d, 192, 95, 5, 64, 27);
+ insert(4, c, 192, 95, 5, 65, 27); /* replaces previous entry, and maskself is required */
+ insert(6, d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128);
+ insert(6, c, 0x26075300, 0x60006b00, 0, 0, 64);
+ insert(4, e, 0, 0, 0, 0, 0);
+ insert(6, e, 0, 0, 0, 0, 0);
+ insert(6, f, 0, 0, 0, 0, 0); /* replaces previous entry */
+ insert(6, g, 0x24046800, 0, 0, 0, 32);
+ insert(6, h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64); /* maskself is required */
+ insert(6, a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128);
+ insert(6, c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128);
+ insert(6, b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98);
+ insert(4, g, 64, 15, 112, 0, 20);
+ insert(4, h, 64, 15, 123, 211, 25); /* maskself is required */
+ insert(4, a, 10, 0, 0, 0, 25);
+ insert(4, b, 10, 0, 0, 128, 25);
+ insert(4, a, 10, 1, 0, 0, 30);
+ insert(4, b, 10, 1, 0, 4, 30);
+ insert(4, c, 10, 1, 0, 8, 29);
+ insert(4, d, 10, 1, 0, 16, 29);
+
+#ifdef DEBUG_PRINT_TRIE_GRAPHVIZ
+ print_tree(t.root4, 32);
+ print_tree(t.root6, 128);
+#endif
+
+ success = true;
+
+ test(4, a, 192, 168, 4, 20);
+ test(4, a, 192, 168, 4, 0);
+ test(4, b, 192, 168, 4, 4);
+ test(4, c, 192, 168, 200, 182);
+ test(4, c, 192, 95, 5, 68);
+ test(4, e, 192, 95, 5, 96);
+ test(6, d, 0x26075300, 0x60006b00, 0, 0xc05f0543);
+ test(6, c, 0x26075300, 0x60006b00, 0, 0xc02e01ee);
+ test(6, f, 0x26075300, 0x60006b01, 0, 0);
+ test(6, g, 0x24046800, 0x40040806, 0, 0x1006);
+ test(6, g, 0x24046800, 0x40040806, 0x1234, 0x5678);
+ test(6, f, 0x240467ff, 0x40040806, 0x1234, 0x5678);
+ test(6, f, 0x24046801, 0x40040806, 0x1234, 0x5678);
+ test(6, h, 0x24046800, 0x40040800, 0x1234, 0x5678);
+ test(6, h, 0x24046800, 0x40040800, 0, 0);
+ test(6, h, 0x24046800, 0x40040800, 0x10101010, 0x10101010);
+ test(6, a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef);
+ test(4, g, 64, 15, 116, 26);
+ test(4, g, 64, 15, 127, 3);
+ test(4, g, 64, 15, 123, 1);
+ test(4, h, 64, 15, 123, 128);
+ test(4, h, 64, 15, 123, 129);
+ test(4, a, 10, 0, 0, 52);
+ test(4, b, 10, 0, 0, 220);
+ test(4, a, 10, 1, 0, 2);
+ test(4, b, 10, 1, 0, 6);
+ test(4, c, 10, 1, 0, 10);
+ test(4, d, 10, 1, 0, 20);
+
+ insert(4, a, 1, 0, 0, 0, 32);
+ insert(4, a, 64, 0, 0, 0, 32);
+ insert(4, a, 128, 0, 0, 0, 32);
+ insert(4, a, 192, 0, 0, 0, 32);
+ insert(4, a, 255, 0, 0, 0, 32);
+ routing_table_remove_by_peer(&t, a);
+ test_negative(4, a, 1, 0, 0, 0);
+ test_negative(4, a, 64, 0, 0, 0);
+ test_negative(4, a, 128, 0, 0, 0);
+ test_negative(4, a, 192, 0, 0, 0);
+ test_negative(4, a, 255, 0, 0, 0);
+
+ routing_table_free(&t);
+ routing_table_init(&t);
+ insert(4, a, 192, 168, 0, 0, 16);
+ insert(4, a, 192, 168, 0, 0, 24);
+ routing_table_remove_by_peer(&t, a);
+ test_negative(4, a, 192, 168, 0, 1);
+
+ /* These will hit the BUG_ON(len >= 128) in free_node if something goes wrong. */
+ for (i = 0; i < 128; ++i) {
+ part = cpu_to_be64(~(1LLU << (i % 64)));
+ memset(&ip, 0xff, 16);
+ memcpy((u8 *)&ip + (i < 64) * 8, &part, 8);
+ routing_table_insert_v6(&t, &ip, 128, a);
+ }
+
+#ifdef DEBUG_RANDOM_TRIE
+ if (success)
+ success = randomized_test();
+#endif
+
+ if (success)
+ pr_info("routing table self-tests: pass\n");
+
+free:
+ routing_table_free(&t);
+ kfree(a);
+ kfree(b);
+ kfree(c);
+ kfree(d);
+ kfree(e);
+ kfree(f);
+ kfree(g);
+ kfree(h);
+
+ return success;
+}
+#undef test_negative
+#undef test
+#undef remove
+#undef insert
+#undef init_peer
+
+#endif