aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src/routingtable.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/routingtable.c')
-rw-r--r--src/routingtable.c76
1 files changed, 46 insertions, 30 deletions
diff --git a/src/routingtable.c b/src/routingtable.c
index e884383..6cbea3d 100644
--- a/src/routingtable.c
+++ b/src/routingtable.c
@@ -25,40 +25,38 @@ static void node_free_rcu(struct rcu_head *rcu)
{
kfree(container_of(rcu, struct routing_table_node, rcu));
}
-#define push(p, lock) ({ \
+
+#define push(stack, p, len) ({ \
if (rcu_access_pointer(p)) { \
BUG_ON(len >= 128); \
- stack[len++] = lock ? rcu_dereference_protected(p, lockdep_is_held((struct mutex *)lock)) : rcu_dereference_bh(p); \
+ stack[len++] = rcu_dereference_protected(p, lockdep_is_held(lock)); \
} \
true; \
})
-#define walk_prep \
- struct routing_table_node *stack[128], *node; \
- unsigned int len;
-#define walk(top, lock) for (len = 0, push(top, lock); len > 0 && (node = stack[--len]) && push(node->bit[0], lock) && push(node->bit[1], lock);)
-
static void free_root_node(struct routing_table_node __rcu *top, struct mutex *lock)
{
- walk_prep;
+ struct routing_table_node *stack[128], *node;
+ unsigned int len;
- walk (top, lock)
+ for (len = 0, push(stack, top, len); len > 0 && (node = stack[--len]) && push(stack, node->bit[0], len) && push(stack, node->bit[1], len);)
call_rcu_bh(&node->rcu, node_free_rcu);
}
-static int walk_ips_by_peer(struct routing_table_node __rcu *top, int family, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family), struct mutex *maybe_lock)
+static int walk_by_peer(struct routing_table_node __rcu *top, int family, struct routing_table_cursor *cursor, struct wireguard_peer *peer, int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), void *ctx, struct mutex *lock)
{
+ struct routing_table_node *node;
int ret;
- union nf_inet_addr ip = { .all = { 0 } };
- walk_prep;
- if (unlikely(!peer))
+ if (!rcu_access_pointer(top))
return 0;
- walk (top, maybe_lock) {
+ if (!cursor->len)
+ push(cursor->stack, top, cursor->len);
+
+ for (; cursor->len > 0 && (node = cursor->stack[cursor->len - 1]); --cursor->len, push(cursor->stack, node->bit[0], cursor->len), push(cursor->stack, node->bit[1], cursor->len)) {
if (node->peer != peer)
continue;
- memcpy(ip.all, node->bits, family == AF_INET6 ? 16 : 4);
- ret = func(ctx, ip, node->cidr, family);
+ ret = func(ctx, node->bits, node->cidr, family);
if (ret)
return ret;
}
@@ -234,37 +232,55 @@ static int add(struct routing_table_node __rcu **trie, u8 bits, const u8 *key, u
void routing_table_init(struct routing_table *table)
{
- memset(table, 0, sizeof(struct routing_table));
+ table->root4 = table->root6 = NULL;
+ table->seq = 1;
}
-void routing_table_free(struct routing_table *table, struct mutex *mutex)
+void routing_table_free(struct routing_table *table, struct mutex *lock)
{
- free_root_node(table->root4, mutex);
+ ++table->seq;
+ free_root_node(table->root4, lock);
rcu_assign_pointer(table->root4, NULL);
- free_root_node(table->root6, mutex);
+ free_root_node(table->root6, lock);
rcu_assign_pointer(table->root6, NULL);
}
-int routing_table_insert_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *mutex)
+int routing_table_insert_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *lock)
{
- return add(&table->root4, 32, (const u8 *)ip, cidr, peer, mutex);
+ ++table->seq;
+ return add(&table->root4, 32, (const u8 *)ip, cidr, peer, lock);
}
-int routing_table_insert_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *mutex)
+int routing_table_insert_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *lock)
{
- return add(&table->root6, 128, (const u8 *)ip, cidr, peer, mutex);
+ ++table->seq;
+ return add(&table->root6, 128, (const u8 *)ip, cidr, peer, lock);
}
-void routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer, struct mutex *mutex)
+void routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer, struct mutex *lock)
{
- walk_remove_by_peer(&table->root4, peer, mutex);
- walk_remove_by_peer(&table->root6, peer, mutex);
+ ++table->seq;
+ walk_remove_by_peer(&table->root4, peer, lock);
+ walk_remove_by_peer(&table->root6, peer, lock);
}
-int routing_table_walk_ips_by_peer(struct routing_table *table, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family), struct mutex *mutex)
+int routing_table_walk_by_peer(struct routing_table *table, struct routing_table_cursor *cursor, struct wireguard_peer *peer, int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), void *ctx, struct mutex *lock)
{
- return walk_ips_by_peer(table->root4, AF_INET, ctx, peer, func, mutex) ?:
- walk_ips_by_peer(table->root6, AF_INET6, ctx, peer, func, mutex);
+ int ret;
+
+ if (!cursor->seq)
+ cursor->seq = table->seq;
+ else if (cursor->seq != table->seq)
+ return 0;
+
+ if (!cursor->second_half) {
+ ret = walk_by_peer(table->root4, AF_INET, cursor, peer, func, ctx, lock);
+ if (ret)
+ return ret;
+ cursor->len = 0;
+ cursor->second_half = true;
+ }
+ return walk_by_peer(table->root6, AF_INET6, cursor, peer, func, ctx, lock);
}
/* Returns a strong reference to a peer */