aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--src/allowedips.c100
-rw-r--r--src/allowedips.h33
-rw-r--r--src/netlink.c61
-rw-r--r--src/peer.c1
-rw-r--r--src/peer.h1
-rw-r--r--src/selftest/allowedips.c93
6 files changed, 129 insertions, 160 deletions
diff --git a/src/allowedips.c b/src/allowedips.c
index 30b66f4..bfb6020 100644
--- a/src/allowedips.c
+++ b/src/allowedips.c
@@ -6,18 +6,6 @@
#include "allowedips.h"
#include "peer.h"
-struct allowedips_node {
- struct wg_peer __rcu *peer;
- struct rcu_head rcu;
- struct allowedips_node __rcu *bit[2];
- /* While it may seem scandalous that we waste space for v4,
- * we're alloc'ing to the nearest power of 2 anyway, so this
- * doesn't actually make a difference.
- */
- u8 bits[16] __aligned(__alignof(u64));
- u8 cidr, bit_at_a, bit_at_b;
-};
-
static __always_inline void swap_endian(u8 *dst, const u8 *src, u8 bits)
{
if (bits == 32) {
@@ -37,6 +25,7 @@ static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src,
node->bit_at_a ^= (bits / 8U - 1U) % 8U;
#endif
node->bit_at_b = 7U - (cidr % 8U);
+ node->bitlen = bits;
memcpy(node->bits, src, bits / 8U);
}
#define CHOOSE_NODE(parent, key) \
@@ -69,43 +58,17 @@ static void root_free_rcu(struct rcu_head *rcu)
}
}
-static int
-walk_by_peer(struct allowedips_node __rcu *top, u8 bits,
- struct allowedips_cursor *cursor, struct wg_peer *peer,
- int (*func)(void *ctx, const u8 *ip, u8 cidr, int family),
- void *ctx, struct mutex *lock)
+static void root_remove_peer_lists(struct allowedips_node *root)
{
- const int address_family = bits == 32 ? AF_INET : AF_INET6;
- /* Aligned so it can be treated as u64 */
- u8 ip[16] __aligned(__alignof(u64));
- struct allowedips_node *node;
- int ret;
-
- if (!rcu_access_pointer(top))
- return 0;
-
- if (!cursor->len)
- push_rcu(cursor->stack, top, &cursor->len);
-
- for (; cursor->len > 0 && (node = cursor->stack[cursor->len - 1]);
- --cursor->len, push_rcu(cursor->stack, node->bit[0], &cursor->len),
- push_rcu(cursor->stack, node->bit[1], &cursor->len)) {
- const unsigned int cidr_bytes = DIV_ROUND_UP(node->cidr, 8U);
-
- if (rcu_dereference_protected(node->peer,
- lockdep_is_held(lock)) != peer)
- continue;
-
- swap_endian(ip, node->bits, bits);
- memset(ip + cidr_bytes, 0, bits / 8U - cidr_bytes);
- if (node->cidr)
- ip[cidr_bytes - 1U] &= ~0U << (-node->cidr % 8U);
+ struct allowedips_node *node, *stack[128] = { root };
+ unsigned int len = 1;
- ret = func(ctx, ip, node->cidr, address_family);
- if (ret)
- return ret;
+ while (len > 0 && (node = stack[--len])) {
+ push_rcu(stack, node->bit[0], &len);
+ push_rcu(stack, node->bit[1], &len);
+ if (rcu_access_pointer(node->peer))
+ list_del(&node->peer_list);
}
- return 0;
}
static void walk_remove_by_peer(struct allowedips_node __rcu **top,
@@ -145,6 +108,7 @@ static void walk_remove_by_peer(struct allowedips_node __rcu **top,
if (rcu_dereference_protected(node->peer,
lockdep_is_held(lock)) == peer) {
RCU_INIT_POINTER(node->peer, NULL);
+ list_del(&node->peer_list);
if (!node->bit[0] || !node->bit[1]) {
rcu_assign_pointer(*nptr, DEREF(
&node->bit[!REF(node->bit[0])]));
@@ -263,12 +227,14 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
if (unlikely(!node))
return -ENOMEM;
RCU_INIT_POINTER(node->peer, peer);
+ list_add_tail(&node->peer_list, &peer->allowedips_list);
copy_and_assign_cidr(node, key, cidr, bits);
rcu_assign_pointer(*trie, node);
return 0;
}
if (node_placement(*trie, key, cidr, bits, &node, lock)) {
rcu_assign_pointer(node->peer, peer);
+ list_move_tail(&node->peer_list, &peer->allowedips_list);
return 0;
}
@@ -276,6 +242,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
if (unlikely(!newnode))
return -ENOMEM;
RCU_INIT_POINTER(newnode->peer, peer);
+ list_add_tail(&newnode->peer_list, &peer->allowedips_list);
copy_and_assign_cidr(newnode, key, cidr, bits);
if (!node) {
@@ -304,6 +271,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
kfree(newnode);
return -ENOMEM;
}
+ INIT_LIST_HEAD(&node->peer_list);
copy_and_assign_cidr(node, newnode->bits, cidr, bits);
rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down);
@@ -326,15 +294,20 @@ void wg_allowedips_init(struct allowedips *table)
void wg_allowedips_free(struct allowedips *table, struct mutex *lock)
{
struct allowedips_node __rcu *old4 = table->root4, *old6 = table->root6;
+
++table->seq;
RCU_INIT_POINTER(table->root4, NULL);
RCU_INIT_POINTER(table->root6, NULL);
- if (rcu_access_pointer(old4))
+ if (rcu_access_pointer(old4)) {
+ root_remove_peer_lists(old4);
call_rcu_bh(&rcu_dereference_protected(old4,
lockdep_is_held(lock))->rcu, root_free_rcu);
- if (rcu_access_pointer(old6))
+ }
+ if (rcu_access_pointer(old6)) {
+ root_remove_peer_lists(old6);
call_rcu_bh(&rcu_dereference_protected(old6,
lockdep_is_held(lock))->rcu, root_free_rcu);
+ }
}
int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip,
@@ -367,29 +340,16 @@ void wg_allowedips_remove_by_peer(struct allowedips *table,
walk_remove_by_peer(&table->root6, peer, lock);
}
-int wg_allowedips_walk_by_peer(struct allowedips *table,
- struct allowedips_cursor *cursor,
- struct wg_peer *peer,
- int (*func)(void *ctx, const u8 *ip, u8 cidr,
- int family),
- void *ctx, struct mutex *lock)
+int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr)
{
- int ret;
-
- if (!cursor->seq)
- cursor->seq = table->seq;
- else if (cursor->seq != table->seq)
- return 0;
-
- if (!cursor->second_half) {
- ret = walk_by_peer(table->root4, 32, cursor, peer, func, ctx,
- lock);
- if (ret)
- return ret;
- cursor->len = 0;
- cursor->second_half = true;
- }
- return walk_by_peer(table->root6, 128, cursor, peer, func, ctx, lock);
+ const unsigned int cidr_bytes = DIV_ROUND_UP(node->cidr, 8U);
+ swap_endian(ip, node->bits, node->bitlen);
+ memset(ip + cidr_bytes, 0, node->bitlen / 8U - cidr_bytes);
+ if (node->cidr)
+ ip[cidr_bytes - 1U] &= ~0U << (-node->cidr % 8U);
+
+ *cidr = node->cidr;
+ return node->bitlen == 32 ? AF_INET : AF_INET6;
}
/* Returns a strong reference to a peer */
diff --git a/src/allowedips.h b/src/allowedips.h
index 29e15a2..e5c83ca 100644
--- a/src/allowedips.h
+++ b/src/allowedips.h
@@ -11,7 +11,23 @@
#include <linux/ipv6.h>
struct wg_peer;
-struct allowedips_node;
+
+struct allowedips_node {
+ struct wg_peer __rcu *peer;
+ struct allowedips_node __rcu *bit[2];
+ /* While it may seem scandalous that we waste space for v4,
+ * we're alloc'ing to the nearest power of 2 anyway, so this
+ * doesn't actually make a difference.
+ */
+ u8 bits[16] __aligned(__alignof(u64));
+ u8 cidr, bit_at_a, bit_at_b, bitlen;
+
+ /* Keep rarely used list at bottom to be beyond cache line. */
+ union {
+ struct list_head peer_list;
+ struct rcu_head rcu;
+ };
+};
struct allowedips {
struct allowedips_node __rcu *root4;
@@ -19,13 +35,6 @@ struct allowedips {
u64 seq;
};
-struct allowedips_cursor {
- u64 seq;
- struct allowedips_node *stack[128];
- unsigned int len;
- bool second_half;
-};
-
void wg_allowedips_init(struct allowedips *table);
void wg_allowedips_free(struct allowedips *table, struct mutex *mutex);
int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip,
@@ -34,12 +43,8 @@ int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip,
u8 cidr, struct wg_peer *peer, struct mutex *lock);
void wg_allowedips_remove_by_peer(struct allowedips *table,
struct wg_peer *peer, struct mutex *lock);
-int wg_allowedips_walk_by_peer(struct allowedips *table,
- struct allowedips_cursor *cursor,
- struct wg_peer *peer,
- int (*func)(void *ctx, const u8 *ip, u8 cidr,
- int family),
- void *ctx, struct mutex *lock);
+/* The ip input pointer should be __aligned(__alignof(u64))) */
+int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr);
/* These return a strong reference to a peer: */
struct wg_peer *wg_allowedips_lookup_dst(struct allowedips *table,
diff --git a/src/netlink.c b/src/netlink.c
index f44f211..b179b31 100644
--- a/src/netlink.c
+++ b/src/netlink.c
@@ -69,9 +69,9 @@ static struct wg_device *lookup_interface(struct nlattr **attrs,
return netdev_priv(dev);
}
-static int get_allowedips(void *ctx, const u8 *ip, u8 cidr, int family)
+static int get_allowedips(struct sk_buff *skb, const u8 *ip, u8 cidr,
+ int family)
{
- struct sk_buff *skb = ctx;
struct nlattr *allowedip_nest;
allowedip_nest = nla_nest_start(skb, 0);
@@ -90,10 +90,12 @@ static int get_allowedips(void *ctx, const u8 *ip, u8 cidr, int family)
return 0;
}
-static int get_peer(struct wg_peer *peer, struct allowedips_cursor *rt_cursor,
- struct sk_buff *skb)
+static int
+get_peer(struct wg_peer *peer, struct allowedips_node **next_allowedips_node,
+ u64 *allowedips_seq, struct sk_buff *skb)
{
struct nlattr *allowedips_nest, *peer_nest = nla_nest_start(skb, 0);
+ struct allowedips_node *allowedips_node = *next_allowedips_node;
bool fail;
if (!peer_nest)
@@ -106,7 +108,7 @@ static int get_peer(struct wg_peer *peer, struct allowedips_cursor *rt_cursor,
if (fail)
goto err;
- if (!rt_cursor->seq) {
+ if (!allowedips_node) {
const struct __kernel_timespec last_handshake = {
.tv_sec = peer->walltime_last_handshake.tv_sec,
.tv_nsec = peer->walltime_last_handshake.tv_nsec
@@ -143,21 +145,39 @@ static int get_peer(struct wg_peer *peer, struct allowedips_cursor *rt_cursor,
read_unlock_bh(&peer->endpoint_lock);
if (fail)
goto err;
+ allowedips_node =
+ list_first_entry_or_null(&peer->allowedips_list,
+ struct allowedips_node, peer_list);
}
+ if (!allowedips_node)
+ goto no_allowedips;
+ if (!*allowedips_seq)
+ *allowedips_seq = peer->device->peer_allowedips.seq;
+ else if (*allowedips_seq != peer->device->peer_allowedips.seq)
+ goto no_allowedips;
allowedips_nest = nla_nest_start(skb, WGPEER_A_ALLOWEDIPS);
if (!allowedips_nest)
goto err;
- if (wg_allowedips_walk_by_peer(&peer->device->peer_allowedips,
- rt_cursor, peer, get_allowedips, skb,
- &peer->device->device_update_lock)) {
- nla_nest_end(skb, allowedips_nest);
- nla_nest_end(skb, peer_nest);
- return -EMSGSIZE;
+
+ list_for_each_entry_from(allowedips_node, &peer->allowedips_list,
+ peer_list) {
+ u8 cidr, ip[16] __aligned(__alignof(u64));
+ int family;
+
+ family = wg_allowedips_read_node(allowedips_node, ip, &cidr);
+ if (get_allowedips(skb, ip, cidr, family)) {
+ nla_nest_end(skb, allowedips_nest);
+ nla_nest_end(skb, peer_nest);
+ *next_allowedips_node = allowedips_node;
+ return -EMSGSIZE;
+ }
}
- memset(rt_cursor, 0, sizeof(*rt_cursor));
nla_nest_end(skb, allowedips_nest);
+no_allowedips:
nla_nest_end(skb, peer_nest);
+ *next_allowedips_node = NULL;
+ *allowedips_seq = 0;
return 0;
err:
nla_nest_cancel(skb, peer_nest);
@@ -174,16 +194,9 @@ static int wg_get_device_start(struct netlink_callback *cb)
genl_family.maxattr, device_policy, NULL);
if (ret < 0)
return ret;
- cb->args[2] = (long)kzalloc(sizeof(struct allowedips_cursor),
- GFP_KERNEL);
- if (unlikely(!cb->args[2]))
- return -ENOMEM;
wg = lookup_interface(attrs, cb->skb);
- if (IS_ERR(wg)) {
- kfree((void *)cb->args[2]);
- cb->args[2] = 0;
+ if (IS_ERR(wg))
return PTR_ERR(wg);
- }
cb->args[0] = (long)wg;
return 0;
}
@@ -191,7 +204,6 @@ static int wg_get_device_start(struct netlink_callback *cb)
static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb)
{
struct wg_peer *peer, *next_peer_cursor, *last_peer_cursor;
- struct allowedips_cursor *rt_cursor;
struct nlattr *peers_nest;
struct wg_device *wg;
int ret = -EMSGSIZE;
@@ -201,7 +213,6 @@ static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb)
wg = (struct wg_device *)cb->args[0];
next_peer_cursor = (struct wg_peer *)cb->args[1];
last_peer_cursor = (struct wg_peer *)cb->args[1];
- rt_cursor = (struct allowedips_cursor *)cb->args[2];
rtnl_lock();
mutex_lock(&wg->device_update_lock);
@@ -253,7 +264,8 @@ static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb)
lockdep_assert_held(&wg->device_update_lock);
peer = list_prepare_entry(last_peer_cursor, &wg->peer_list, peer_list);
list_for_each_entry_continue(peer, &wg->peer_list, peer_list) {
- if (get_peer(peer, rt_cursor, skb)) {
+ if (get_peer(peer, (struct allowedips_node **)&cb->args[2],
+ (u64 *)&cb->args[4] /* and args[5] */, skb)) {
done = false;
break;
}
@@ -290,12 +302,9 @@ static int wg_get_device_done(struct netlink_callback *cb)
{
struct wg_device *wg = (struct wg_device *)cb->args[0];
struct wg_peer *peer = (struct wg_peer *)cb->args[1];
- struct allowedips_cursor *rt_cursor =
- (struct allowedips_cursor *)cb->args[2];
if (wg)
dev_put(wg->dev);
- kfree(rt_cursor);
wg_peer_put(peer);
return 0;
}
diff --git a/src/peer.c b/src/peer.c
index 6a33df0..996f40b 100644
--- a/src/peer.c
+++ b/src/peer.c
@@ -64,6 +64,7 @@ struct wg_peer *wg_peer_create(struct wg_device *wg,
NAPI_POLL_WEIGHT);
napi_enable(&peer->napi);
list_add_tail(&peer->peer_list, &wg->peer_list);
+ INIT_LIST_HEAD(&peer->allowedips_list);
wg_pubkey_hashtable_add(wg->peer_hashtable, peer);
++wg->num_peers;
pr_debug("%s: Peer %llu created\n", wg->dev->name, peer->internal_id);
diff --git a/src/peer.h b/src/peer.h
index 2e04262..23af409 100644
--- a/src/peer.h
+++ b/src/peer.h
@@ -60,6 +60,7 @@ struct wg_peer {
struct kref refcount;
struct rcu_head rcu;
struct list_head peer_list;
+ struct list_head allowedips_list;
u64 internal_id;
struct napi_struct napi;
bool is_dead;
diff --git a/src/selftest/allowedips.c b/src/selftest/allowedips.c
index 379ac31..6e244a9 100644
--- a/src/selftest/allowedips.c
+++ b/src/selftest/allowedips.c
@@ -452,47 +452,14 @@ static __init inline struct in6_addr *ip6(u32 a, u32 b, u32 c, u32 d)
return &ip;
}
-struct walk_ctx {
- int count;
- bool found_a, found_b, found_c, found_d, found_e;
- bool found_other;
-};
-
-static __init int walk_callback(void *ctx, const u8 *ip, u8 cidr, int family)
-{
- struct walk_ctx *wctx = ctx;
-
- wctx->count++;
-
- if (cidr == 27 &&
- !memcmp(ip, ip4(192, 95, 5, 64), sizeof(struct in_addr)))
- wctx->found_a = true;
- else if (cidr == 128 &&
- !memcmp(ip, ip6(0x26075300, 0x60006b00, 0, 0xc05f0543),
- sizeof(struct in6_addr)))
- wctx->found_b = true;
- else if (cidr == 29 &&
- !memcmp(ip, ip4(10, 1, 0, 16), sizeof(struct in_addr)))
- wctx->found_c = true;
- else if (cidr == 83 &&
- !memcmp(ip, ip6(0x26075300, 0x6d8a6bf8, 0xdab1e000, 0),
- sizeof(struct in6_addr)))
- wctx->found_d = true;
- else if (cidr == 21 &&
- !memcmp(ip, ip6(0x26075000, 0, 0, 0), sizeof(struct in6_addr)))
- wctx->found_e = true;
- else
- wctx->found_other = true;
-
- return 0;
-}
-
static __init struct wg_peer *init_peer(void)
{
struct wg_peer *peer = kzalloc(sizeof(*peer), GFP_KERNEL);
- if (peer)
- kref_init(&peer->refcount);
+ if (!peer)
+ return NULL;
+ kref_init(&peer->refcount);
+ INIT_LIST_HEAD(&peer->allowedips_list);
return peer;
}
@@ -527,23 +494,24 @@ static __init struct wg_peer *init_peer(void)
bool __init wg_allowedips_selftest(void)
{
- struct allowedips_cursor *cursor = kzalloc(sizeof(*cursor), GFP_KERNEL);
+ bool found_a = false, found_b = false, found_c = false, found_d = false,
+ found_e = false, found_other = false;
struct wg_peer *a = init_peer(), *b = init_peer(), *c = init_peer(),
*d = init_peer(), *e = init_peer(), *f = init_peer(),
*g = init_peer(), *h = init_peer();
- struct walk_ctx wctx = { 0 };
+ struct allowedips_node *iter_node;
bool success = false;
struct allowedips t;
DEFINE_MUTEX(mutex);
struct in6_addr ip;
- size_t i = 0;
+ size_t i = 0, count = 0;
__be64 part;
mutex_init(&mutex);
mutex_lock(&mutex);
wg_allowedips_init(&t);
- if (!cursor || !a || !b || !c || !d || !e || !f || !g || !h) {
+ if (!a || !b || !c || !d || !e || !f || !g || !h) {
pr_err("allowedips self-test malloc: FAIL\n");
goto free;
}
@@ -649,14 +617,40 @@ bool __init wg_allowedips_selftest(void)
insert(4, a, 10, 1, 0, 20, 29);
insert(6, a, 0x26075300, 0x6d8a6bf8, 0xdab1f1df, 0xc05f1523, 83);
insert(6, a, 0x26075300, 0x6d8a6bf8, 0xdab1f1df, 0xc05f1523, 21);
- wg_allowedips_walk_by_peer(&t, cursor, a, walk_callback, &wctx, &mutex);
- test_boolean(wctx.count == 5);
- test_boolean(wctx.found_a);
- test_boolean(wctx.found_b);
- test_boolean(wctx.found_c);
- test_boolean(wctx.found_d);
- test_boolean(wctx.found_e);
- test_boolean(!wctx.found_other);
+ list_for_each_entry(iter_node, &a->allowedips_list, peer_list) {
+ u8 cidr, ip[16] __aligned(__alignof(u64));
+ int family = wg_allowedips_read_node(iter_node, ip, &cidr);
+
+ count++;
+
+ if (cidr == 27 && family == AF_INET &&
+ !memcmp(ip, ip4(192, 95, 5, 64), sizeof(struct in_addr)))
+ found_a = true;
+ else if (cidr == 128 && family == AF_INET6 &&
+ !memcmp(ip, ip6(0x26075300, 0x60006b00, 0, 0xc05f0543),
+ sizeof(struct in6_addr)))
+ found_b = true;
+ else if (cidr == 29 && family == AF_INET &&
+ !memcmp(ip, ip4(10, 1, 0, 16), sizeof(struct in_addr)))
+ found_c = true;
+ else if (cidr == 83 && family == AF_INET6 &&
+ !memcmp(ip, ip6(0x26075300, 0x6d8a6bf8, 0xdab1e000, 0),
+ sizeof(struct in6_addr)))
+ found_d = true;
+ else if (cidr == 21 && family == AF_INET6 &&
+ !memcmp(ip, ip6(0x26075000, 0, 0, 0),
+ sizeof(struct in6_addr)))
+ found_e = true;
+ else
+ found_other = true;
+ }
+ test_boolean(count == 5);
+ test_boolean(found_a);
+ test_boolean(found_b);
+ test_boolean(found_c);
+ test_boolean(found_d);
+ test_boolean(found_e);
+ test_boolean(!found_other);
if (IS_ENABLED(DEBUG_RANDOM_TRIE) && success)
success = randomized_test();
@@ -675,7 +669,6 @@ free:
kfree(g);
kfree(h);
mutex_unlock(&mutex);
- kfree(cursor);
return success;
}