aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src/allowedips.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/allowedips.c')
-rw-r--r--src/allowedips.c100
1 files changed, 30 insertions, 70 deletions
diff --git a/src/allowedips.c b/src/allowedips.c
index 30b66f4..bfb6020 100644
--- a/src/allowedips.c
+++ b/src/allowedips.c
@@ -6,18 +6,6 @@
#include "allowedips.h"
#include "peer.h"
-struct allowedips_node {
- struct wg_peer __rcu *peer;
- struct rcu_head rcu;
- struct allowedips_node __rcu *bit[2];
- /* While it may seem scandalous that we waste space for v4,
- * we're alloc'ing to the nearest power of 2 anyway, so this
- * doesn't actually make a difference.
- */
- u8 bits[16] __aligned(__alignof(u64));
- u8 cidr, bit_at_a, bit_at_b;
-};
-
static __always_inline void swap_endian(u8 *dst, const u8 *src, u8 bits)
{
if (bits == 32) {
@@ -37,6 +25,7 @@ static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src,
node->bit_at_a ^= (bits / 8U - 1U) % 8U;
#endif
node->bit_at_b = 7U - (cidr % 8U);
+ node->bitlen = bits;
memcpy(node->bits, src, bits / 8U);
}
#define CHOOSE_NODE(parent, key) \
@@ -69,43 +58,17 @@ static void root_free_rcu(struct rcu_head *rcu)
}
}
-static int
-walk_by_peer(struct allowedips_node __rcu *top, u8 bits,
- struct allowedips_cursor *cursor, struct wg_peer *peer,
- int (*func)(void *ctx, const u8 *ip, u8 cidr, int family),
- void *ctx, struct mutex *lock)
+static void root_remove_peer_lists(struct allowedips_node *root)
{
- const int address_family = bits == 32 ? AF_INET : AF_INET6;
- /* Aligned so it can be treated as u64 */
- u8 ip[16] __aligned(__alignof(u64));
- struct allowedips_node *node;
- int ret;
-
- if (!rcu_access_pointer(top))
- return 0;
-
- if (!cursor->len)
- push_rcu(cursor->stack, top, &cursor->len);
-
- for (; cursor->len > 0 && (node = cursor->stack[cursor->len - 1]);
- --cursor->len, push_rcu(cursor->stack, node->bit[0], &cursor->len),
- push_rcu(cursor->stack, node->bit[1], &cursor->len)) {
- const unsigned int cidr_bytes = DIV_ROUND_UP(node->cidr, 8U);
-
- if (rcu_dereference_protected(node->peer,
- lockdep_is_held(lock)) != peer)
- continue;
-
- swap_endian(ip, node->bits, bits);
- memset(ip + cidr_bytes, 0, bits / 8U - cidr_bytes);
- if (node->cidr)
- ip[cidr_bytes - 1U] &= ~0U << (-node->cidr % 8U);
+ struct allowedips_node *node, *stack[128] = { root };
+ unsigned int len = 1;
- ret = func(ctx, ip, node->cidr, address_family);
- if (ret)
- return ret;
+ while (len > 0 && (node = stack[--len])) {
+ push_rcu(stack, node->bit[0], &len);
+ push_rcu(stack, node->bit[1], &len);
+ if (rcu_access_pointer(node->peer))
+ list_del(&node->peer_list);
}
- return 0;
}
static void walk_remove_by_peer(struct allowedips_node __rcu **top,
@@ -145,6 +108,7 @@ static void walk_remove_by_peer(struct allowedips_node __rcu **top,
if (rcu_dereference_protected(node->peer,
lockdep_is_held(lock)) == peer) {
RCU_INIT_POINTER(node->peer, NULL);
+ list_del(&node->peer_list);
if (!node->bit[0] || !node->bit[1]) {
rcu_assign_pointer(*nptr, DEREF(
&node->bit[!REF(node->bit[0])]));
@@ -263,12 +227,14 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
if (unlikely(!node))
return -ENOMEM;
RCU_INIT_POINTER(node->peer, peer);
+ list_add_tail(&node->peer_list, &peer->allowedips_list);
copy_and_assign_cidr(node, key, cidr, bits);
rcu_assign_pointer(*trie, node);
return 0;
}
if (node_placement(*trie, key, cidr, bits, &node, lock)) {
rcu_assign_pointer(node->peer, peer);
+ list_move_tail(&node->peer_list, &peer->allowedips_list);
return 0;
}
@@ -276,6 +242,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
if (unlikely(!newnode))
return -ENOMEM;
RCU_INIT_POINTER(newnode->peer, peer);
+ list_add_tail(&newnode->peer_list, &peer->allowedips_list);
copy_and_assign_cidr(newnode, key, cidr, bits);
if (!node) {
@@ -304,6 +271,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
kfree(newnode);
return -ENOMEM;
}
+ INIT_LIST_HEAD(&node->peer_list);
copy_and_assign_cidr(node, newnode->bits, cidr, bits);
rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down);
@@ -326,15 +294,20 @@ void wg_allowedips_init(struct allowedips *table)
void wg_allowedips_free(struct allowedips *table, struct mutex *lock)
{
struct allowedips_node __rcu *old4 = table->root4, *old6 = table->root6;
+
++table->seq;
RCU_INIT_POINTER(table->root4, NULL);
RCU_INIT_POINTER(table->root6, NULL);
- if (rcu_access_pointer(old4))
+ if (rcu_access_pointer(old4)) {
+ root_remove_peer_lists(old4);
call_rcu_bh(&rcu_dereference_protected(old4,
lockdep_is_held(lock))->rcu, root_free_rcu);
- if (rcu_access_pointer(old6))
+ }
+ if (rcu_access_pointer(old6)) {
+ root_remove_peer_lists(old6);
call_rcu_bh(&rcu_dereference_protected(old6,
lockdep_is_held(lock))->rcu, root_free_rcu);
+ }
}
int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip,
@@ -367,29 +340,16 @@ void wg_allowedips_remove_by_peer(struct allowedips *table,
walk_remove_by_peer(&table->root6, peer, lock);
}
-int wg_allowedips_walk_by_peer(struct allowedips *table,
- struct allowedips_cursor *cursor,
- struct wg_peer *peer,
- int (*func)(void *ctx, const u8 *ip, u8 cidr,
- int family),
- void *ctx, struct mutex *lock)
+int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr)
{
- int ret;
-
- if (!cursor->seq)
- cursor->seq = table->seq;
- else if (cursor->seq != table->seq)
- return 0;
-
- if (!cursor->second_half) {
- ret = walk_by_peer(table->root4, 32, cursor, peer, func, ctx,
- lock);
- if (ret)
- return ret;
- cursor->len = 0;
- cursor->second_half = true;
- }
- return walk_by_peer(table->root6, 128, cursor, peer, func, ctx, lock);
+ const unsigned int cidr_bytes = DIV_ROUND_UP(node->cidr, 8U);
+ swap_endian(ip, node->bits, node->bitlen);
+ memset(ip + cidr_bytes, 0, node->bitlen / 8U - cidr_bytes);
+ if (node->cidr)
+ ip[cidr_bytes - 1U] &= ~0U << (-node->cidr % 8U);
+
+ *cidr = node->cidr;
+ return node->bitlen == 32 ? AF_INET : AF_INET6;
}
/* Returns a strong reference to a peer */