aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src/allowedips.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/allowedips.c')
-rw-r--r--src/allowedips.c148
1 files changed, 99 insertions, 49 deletions
diff --git a/src/allowedips.c b/src/allowedips.c
index 1442bf4..4616645 100644
--- a/src/allowedips.c
+++ b/src/allowedips.c
@@ -28,7 +28,8 @@ static __always_inline void swap_endian(u8 *dst, const u8 *src, u8 bits)
}
}
-static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, u8 cidr, u8 bits)
+static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src,
+ u8 cidr, u8 bits)
{
node->cidr = cidr;
node->bit_at_a = cidr / 8U;
@@ -39,34 +40,43 @@ static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, u8
memcpy(node->bits, src, bits / 8U);
}
-#define choose_node(parent, key) parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1]
+#define choose_node(parent, key) \
+ parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1]
static void node_free_rcu(struct rcu_head *rcu)
{
kfree(container_of(rcu, struct allowedips_node, rcu));
}
-#define push_rcu(stack, p, len) ({ \
- if (rcu_access_pointer(p)) { \
- BUG_ON(len >= 128); \
- stack[len++] = rcu_dereference_raw(p); \
- } \
- true; \
-})
+#define push_rcu(stack, p, len) ({ \
+ if (rcu_access_pointer(p)) { \
+ BUG_ON(len >= 128); \
+ stack[len++] = rcu_dereference_raw(p); \
+ } \
+ true; \
+ })
static void root_free_rcu(struct rcu_head *rcu)
{
- struct allowedips_node *node, *stack[128] = { container_of(rcu, struct allowedips_node, rcu) };
+ struct allowedips_node *node, *stack[128] =
+ { container_of(rcu, struct allowedips_node, rcu) };
unsigned int len = 1;
- while (len > 0 && (node = stack[--len]) && push_rcu(stack, node->bit[0], len) && push_rcu(stack, node->bit[1], len))
+ while (len > 0 && (node = stack[--len]) &&
+ push_rcu(stack, node->bit[0], len) &&
+ push_rcu(stack, node->bit[1], len))
kfree(node);
}
-static int walk_by_peer(struct allowedips_node __rcu *top, u8 bits, struct allowedips_cursor *cursor, struct wireguard_peer *peer, int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), void *ctx, struct mutex *lock)
+static int
+walk_by_peer(struct allowedips_node __rcu *top, u8 bits,
+ struct allowedips_cursor *cursor, struct wireguard_peer *peer,
+ int (*func)(void *ctx, const u8 *ip, u8 cidr, int family),
+ void *ctx, struct mutex *lock)
{
+ const int address_family = bits == 32 ? AF_INET : AF_INET6;
+ u8 ip[16] __aligned(__alignof(u64));
struct allowedips_node *node;
int ret;
- u8 ip[16] __aligned(__alignof(u64));
if (!rcu_access_pointer(top))
return 0;
@@ -74,16 +84,21 @@ static int walk_by_peer(struct allowedips_node __rcu *top, u8 bits, struct allow
if (!cursor->len)
push_rcu(cursor->stack, top, cursor->len);
- for (; cursor->len > 0 && (node = cursor->stack[cursor->len - 1]); --cursor->len, push_rcu(cursor->stack, node->bit[0], cursor->len), push_rcu(cursor->stack, node->bit[1], cursor->len)) {
- if (rcu_dereference_protected(node->peer, lockdep_is_held(lock)) != peer)
+ for (; cursor->len > 0 && (node = cursor->stack[cursor->len - 1]);
+ --cursor->len, push_rcu(cursor->stack, node->bit[0], cursor->len),
+ push_rcu(cursor->stack, node->bit[1], cursor->len)) {
+ const unsigned int cidr_bytes = DIV_ROUND_UP(node->cidr, 8U);
+
+ if (rcu_dereference_protected(node->peer,
+ lockdep_is_held(lock)) != peer)
continue;
swap_endian(ip, node->bits, bits);
- memset(ip + (node->cidr + 7U) / 8U, 0, (bits / 8U) - ((node->cidr + 7U) / 8U));
+ memset(ip + cidr_bytes, 0, bits / 8U - cidr_bytes);
if (node->cidr)
- ip[(node->cidr + 7U) / 8U - 1U] &= ~0U << (-node->cidr % 8U);
+ ip[cidr_bytes - 1U] &= ~0U << (-node->cidr % 8U);
- ret = func(ctx, ip, node->cidr, bits == 32 ? AF_INET : AF_INET6);
+ ret = func(ctx, ip, node->cidr, address_family);
if (ret)
return ret;
}
@@ -93,8 +108,12 @@ static int walk_by_peer(struct allowedips_node __rcu *top, u8 bits, struct allow
#define ref(p) rcu_access_pointer(p)
#define deref(p) rcu_dereference_protected(*p, lockdep_is_held(lock))
-#define push(p) ({ BUG_ON(len >= 128); stack[len++] = p; })
-static void walk_remove_by_peer(struct allowedips_node __rcu **top, struct wireguard_peer *peer, struct mutex *lock)
+#define push(p) ({ \
+ BUG_ON(len >= 128); \
+ stack[len++] = p; \
+ })
+static void walk_remove_by_peer(struct allowedips_node __rcu **top,
+ struct wireguard_peer *peer, struct mutex *lock)
{
struct allowedips_node __rcu **stack[128], **nptr;
struct allowedips_node *node, *prev;
@@ -110,7 +129,8 @@ static void walk_remove_by_peer(struct allowedips_node __rcu **top, struct wireg
--len;
continue;
}
- if (!prev || ref(prev->bit[0]) == node || ref(prev->bit[1]) == node) {
+ if (!prev || ref(prev->bit[0]) == node ||
+ ref(prev->bit[1]) == node) {
if (ref(node->bit[0]))
push(&node->bit[0]);
else if (ref(node->bit[1]))
@@ -119,10 +139,12 @@ static void walk_remove_by_peer(struct allowedips_node __rcu **top, struct wireg
if (ref(node->bit[1]))
push(&node->bit[1]);
} else {
- if (rcu_dereference_protected(node->peer, lockdep_is_held(lock)) == peer) {
+ if (rcu_dereference_protected(node->peer,
+ lockdep_is_held(lock)) == peer) {
RCU_INIT_POINTER(node->peer, NULL);
if (!node->bit[0] || !node->bit[1]) {
- rcu_assign_pointer(*nptr, deref(&node->bit[!ref(node->bit[0])]));
+ rcu_assign_pointer(*nptr,
+ deref(&node->bit[!ref(node->bit[0])]));
call_rcu_bh(&node->rcu, node_free_rcu);
node = deref(nptr);
}
@@ -140,23 +162,29 @@ static __always_inline unsigned int fls128(u64 a, u64 b)
return a ? fls64(a) + 64U : fls64(b);
}
-static __always_inline u8 common_bits(const struct allowedips_node *node, const u8 *key, u8 bits)
+static __always_inline u8 common_bits(const struct allowedips_node *node,
+ const u8 *key, u8 bits)
{
if (bits == 32)
return 32U - fls(*(const u32 *)node->bits ^ *(const u32 *)key);
else if (bits == 128)
- return 128U - fls128(*(const u64 *)&node->bits[0] ^ *(const u64 *)&key[0], *(const u64 *)&node->bits[8] ^ *(const u64 *)&key[8]);
+ return 128U - fls128(
+ *(const u64 *)&node->bits[0] ^ *(const u64 *)&key[0],
+ *(const u64 *)&node->bits[8] ^ *(const u64 *)&key[8]);
return 0;
}
-/* This could be much faster if it actually just compared the common bits properly,
- * by precomputing a mask bswap(~0 << (32 - cidr)), and the rest, but it turns out that
- * common_bits is already super fast on modern processors, even taking into account
- * the unfortunate bswap. So, we just inline it like this instead.
+/* This could be much faster if it actually just compared the common bits
+ * properly, by precomputing a mask bswap(~0 << (32 - cidr)), and the rest, but
+ * it turns out that common_bits is already super fast on modern processors,
+ * even taking into account the unfortunate bswap. So, we just inline it like
+ * this instead.
*/
-#define prefix_matches(node, key, bits) (common_bits(node, key, bits) >= node->cidr)
+#define prefix_matches(node, key, bits) \
+ (common_bits(node, key, bits) >= node->cidr)
-static __always_inline struct allowedips_node *find_node(struct allowedips_node *trie, u8 bits, const u8 *key)
+static __always_inline struct allowedips_node *
+find_node(struct allowedips_node *trie, u8 bits, const u8 *key)
{
struct allowedips_node *node = trie, *found = NULL;
@@ -171,11 +199,12 @@ static __always_inline struct allowedips_node *find_node(struct allowedips_node
}
/* Returns a strong reference to a peer */
-static __always_inline struct wireguard_peer *lookup(struct allowedips_node __rcu *root, u8 bits, const void *be_ip)
+static __always_inline struct wireguard_peer *
+lookup(struct allowedips_node __rcu *root, u8 bits, const void *be_ip)
{
+ u8 ip[16] __aligned(__alignof(u64));
struct wireguard_peer *peer = NULL;
struct allowedips_node *node;
- u8 ip[16] __aligned(__alignof(u64));
swap_endian(ip, be_ip, bits);
@@ -191,11 +220,14 @@ retry:
return peer;
}
-__attribute__((nonnull(1)))
-static inline bool node_placement(struct allowedips_node __rcu *trie, const u8 *key, u8 cidr, u8 bits, struct allowedips_node **rnode, struct mutex *lock)
+__attribute__((nonnull(1))) static inline bool
+node_placement(struct allowedips_node __rcu *trie, const u8 *key, u8 cidr,
+ u8 bits, struct allowedips_node **rnode, struct mutex *lock)
{
+ struct allowedips_node *node = rcu_dereference_protected(trie,
+ lockdep_is_held(lock));
+ struct allowedips_node *parent = NULL;
bool exact = false;
- struct allowedips_node *parent = NULL, *node = rcu_dereference_protected(trie, lockdep_is_held(lock));
while (node && node->cidr <= cidr && prefix_matches(node, key, bits)) {
parent = node;
@@ -203,13 +235,15 @@ static inline bool node_placement(struct allowedips_node __rcu *trie, const u8 *
exact = true;
break;
}
- node = rcu_dereference_protected(choose_node(parent, key), lockdep_is_held(lock));
+ node = rcu_dereference_protected(choose_node(parent, key),
+ lockdep_is_held(lock));
}
*rnode = parent;
return exact;
}
-static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *be_key, u8 cidr, struct wireguard_peer *peer, struct mutex *lock)
+static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *be_key,
+ u8 cidr, struct wireguard_peer *peer, struct mutex *lock)
{
struct allowedips_node *node, *parent, *down, *newnode;
u8 key[16] __aligned(__alignof(u64));
@@ -242,7 +276,8 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *be_key, u
if (!node)
down = rcu_dereference_protected(*trie, lockdep_is_held(lock));
else {
- down = rcu_dereference_protected(choose_node(node, key), lockdep_is_held(lock));
+ down = rcu_dereference_protected(choose_node(node, key),
+ lockdep_is_held(lock));
if (!down) {
rcu_assign_pointer(choose_node(node, key), newnode);
return 0;
@@ -256,7 +291,8 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *be_key, u
if (!parent)
rcu_assign_pointer(*trie, newnode);
else
- rcu_assign_pointer(choose_node(parent, newnode->bits), newnode);
+ rcu_assign_pointer(choose_node(parent, newnode->bits),
+ newnode);
} else {
node = kzalloc(sizeof(*node), GFP_KERNEL);
if (!node) {
@@ -270,7 +306,8 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *be_key, u
if (!parent)
rcu_assign_pointer(*trie, node);
else
- rcu_assign_pointer(choose_node(parent, node->bits), node);
+ rcu_assign_pointer(choose_node(parent, node->bits),
+ node);
}
return 0;
}
@@ -288,31 +325,42 @@ void allowedips_free(struct allowedips *table, struct mutex *lock)
RCU_INIT_POINTER(table->root4, NULL);
RCU_INIT_POINTER(table->root6, NULL);
if (rcu_access_pointer(old4))
- call_rcu_bh(&rcu_dereference_protected(old4, lockdep_is_held(lock))->rcu, root_free_rcu);
+ call_rcu_bh(&rcu_dereference_protected(old4,
+ lockdep_is_held(lock))->rcu, root_free_rcu);
if (rcu_access_pointer(old6))
- call_rcu_bh(&rcu_dereference_protected(old6, lockdep_is_held(lock))->rcu, root_free_rcu);
+ call_rcu_bh(&rcu_dereference_protected(old6,
+ lockdep_is_held(lock))->rcu, root_free_rcu);
}
-int allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *lock)
+int allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip,
+ u8 cidr, struct wireguard_peer *peer,
+ struct mutex *lock)
{
++table->seq;
return add(&table->root4, 32, (const u8 *)ip, cidr, peer, lock);
}
-int allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *lock)
+int allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip,
+ u8 cidr, struct wireguard_peer *peer,
+ struct mutex *lock)
{
++table->seq;
return add(&table->root6, 128, (const u8 *)ip, cidr, peer, lock);
}
-void allowedips_remove_by_peer(struct allowedips *table, struct wireguard_peer *peer, struct mutex *lock)
+void allowedips_remove_by_peer(struct allowedips *table,
+ struct wireguard_peer *peer, struct mutex *lock)
{
++table->seq;
walk_remove_by_peer(&table->root4, peer, lock);
walk_remove_by_peer(&table->root6, peer, lock);
}
-int allowedips_walk_by_peer(struct allowedips *table, struct allowedips_cursor *cursor, struct wireguard_peer *peer, int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), void *ctx, struct mutex *lock)
+int allowedips_walk_by_peer(struct allowedips *table,
+ struct allowedips_cursor *cursor,
+ struct wireguard_peer *peer,
+ int (*func)(void *ctx, const u8 *ip, u8 cidr, int family),
+ void *ctx, struct mutex *lock)
{
int ret;
@@ -332,7 +380,8 @@ int allowedips_walk_by_peer(struct allowedips *table, struct allowedips_cursor *
}
/* Returns a strong reference to a peer */
-struct wireguard_peer *allowedips_lookup_dst(struct allowedips *table, struct sk_buff *skb)
+struct wireguard_peer *allowedips_lookup_dst(struct allowedips *table,
+ struct sk_buff *skb)
{
if (skb->protocol == htons(ETH_P_IP))
return lookup(table->root4, 32, &ip_hdr(skb)->daddr);
@@ -342,7 +391,8 @@ struct wireguard_peer *allowedips_lookup_dst(struct allowedips *table, struct sk
}
/* Returns a strong reference to a peer */
-struct wireguard_peer *allowedips_lookup_src(struct allowedips *table, struct sk_buff *skb)
+struct wireguard_peer *allowedips_lookup_src(struct allowedips *table,
+ struct sk_buff *skb)
{
if (skb->protocol == htons(ETH_P_IP))
return lookup(table->root4, 32, &ip_hdr(skb)->saddr);