From 0911027c09cf3f734f39f0d3b1bfe4119b73b100 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Mon, 9 Oct 2017 02:48:33 +0200 Subject: routingtable: only use device's mutex, not a special rt one --- src/device.c | 2 +- src/netlink.c | 8 ++-- src/peer.c | 2 +- src/routingtable.c | 96 ++++++++------------------------------------- src/routingtable.h | 14 +++---- src/selftest/routingtable.h | 16 +++++--- 6 files changed, 38 insertions(+), 100 deletions(-) diff --git a/src/device.c b/src/device.c index 5102acc..0fb5dcd 100644 --- a/src/device.c +++ b/src/device.c @@ -212,7 +212,7 @@ static void destruct(struct net_device *dev) packet_queue_free(&wg->decrypt_queue, true); packet_queue_free(&wg->encrypt_queue, true); destroy_workqueue(wg->packet_crypt_wq); - routing_table_free(&wg->peer_routing_table); + routing_table_free(&wg->peer_routing_table, &wg->device_update_lock); ratelimiter_uninit(); memzero_explicit(&wg->static_identity, sizeof(struct noise_static_identity)); skb_queue_purge(&wg->incoming_handshakes); diff --git a/src/netlink.c b/src/netlink.c index b813508..fc27b7f 100644 --- a/src/netlink.c +++ b/src/netlink.c @@ -126,7 +126,7 @@ static int get_peer(struct wireguard_peer *peer, unsigned int index, unsigned in allowedips_nest = nla_nest_start(skb, WGPEER_A_ALLOWEDIPS); if (!allowedips_nest) goto err; - if (routing_table_walk_ips_by_peer_sleepable(&peer->device->peer_routing_table, &ctx, peer, get_allowedips)) { + if (routing_table_walk_ips_by_peer(&peer->device->peer_routing_table, &ctx, peer, get_allowedips, &peer->device->device_update_lock)) { *allowedips_idx_cursor = ctx.idx; nla_nest_end(skb, allowedips_nest); nla_nest_end(skb, peer_nest); @@ -274,9 +274,9 @@ static int set_allowedip(struct wireguard_peer *peer, struct nlattr **attrs) cidr = nla_get_u8(attrs[WGALLOWEDIP_A_CIDR_MASK]); if (family == AF_INET && cidr <= 32 && nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr)) - ret = routing_table_insert_v4(&peer->device->peer_routing_table, nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer); + ret = routing_table_insert_v4(&peer->device->peer_routing_table, nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer, &peer->device->device_update_lock); else if (family == AF_INET6 && cidr <= 128 && nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in6_addr)) - ret = routing_table_insert_v6(&peer->device->peer_routing_table, nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer); + ret = routing_table_insert_v6(&peer->device->peer_routing_table, nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer, &peer->device->device_update_lock); return ret; } @@ -343,7 +343,7 @@ static int set_peer(struct wireguard_device *wg, struct nlattr **attrs) } if (flags & WGPEER_F_REPLACE_ALLOWEDIPS) - routing_table_remove_by_peer(&wg->peer_routing_table, peer); + routing_table_remove_by_peer(&wg->peer_routing_table, peer, &wg->device_update_lock); if (attrs[WGPEER_A_ALLOWEDIPS]) { int rem; diff --git a/src/peer.c b/src/peer.c index 4408201..8cef1f9 100644 --- a/src/peer.c +++ b/src/peer.c @@ -79,7 +79,7 @@ void peer_remove(struct wireguard_peer *peer) if (unlikely(!peer)) return; lockdep_assert_held(&peer->device->device_update_lock); - routing_table_remove_by_peer(&peer->device->peer_routing_table, peer); + routing_table_remove_by_peer(&peer->device->peer_routing_table, peer, &peer->device->device_update_lock); pubkey_hashtable_remove(&peer->device->peer_hashtable, peer); skb_queue_purge(&peer->staged_packet_queue); noise_handshake_clear(&peer->handshake); diff --git a/src/routingtable.c b/src/routingtable.c index 781c758..e884383 100644 --- a/src/routingtable.c +++ b/src/routingtable.c @@ -45,18 +45,6 @@ static void free_root_node(struct routing_table_node __rcu *top, struct mutex *l call_rcu_bh(&node->rcu, node_free_rcu); } -static size_t count_nodes(struct routing_table_node __rcu *top) -{ - size_t ret = 0; - walk_prep; - - walk (top, NULL) { - if (node->peer) - ++ret; - } - return ret; -} - static int walk_ips_by_peer(struct routing_table_node __rcu *top, int family, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family), struct mutex *maybe_lock) { int ret; @@ -185,6 +173,9 @@ static int add(struct routing_table_node __rcu **trie, u8 bits, const u8 *key, u { struct routing_table_node *node, *parent, *down, *newnode; + if (unlikely(cidr > bits || !peer)) + return -EINVAL; + if (!rcu_access_pointer(*trie)) { node = kzalloc(sizeof(*node) + (bits + 7) / 8, GFP_KERNEL); if (!node) @@ -244,91 +235,36 @@ static int add(struct routing_table_node __rcu **trie, u8 bits, const u8 *key, u void routing_table_init(struct routing_table *table) { memset(table, 0, sizeof(struct routing_table)); - mutex_init(&table->table_update_lock); } -void routing_table_free(struct routing_table *table) +void routing_table_free(struct routing_table *table, struct mutex *mutex) { - mutex_lock(&table->table_update_lock); - free_root_node(table->root4, &table->table_update_lock); + free_root_node(table->root4, mutex); rcu_assign_pointer(table->root4, NULL); - free_root_node(table->root6, &table->table_update_lock); + free_root_node(table->root6, mutex); rcu_assign_pointer(table->root6, NULL); - mutex_unlock(&table->table_update_lock); -} - -int routing_table_insert_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer) -{ - int ret; - - if (unlikely(cidr > 32 || !peer)) - return -EINVAL; - mutex_lock(&table->table_update_lock); - ret = add(&table->root4, 32, (const u8 *)ip, cidr, peer, &table->table_update_lock); - mutex_unlock(&table->table_update_lock); - return ret; -} - -int routing_table_insert_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer) -{ - int ret; - - if (unlikely(cidr > 128 || !peer)) - return -EINVAL; - mutex_lock(&table->table_update_lock); - ret = add(&table->root6, 128, (const u8 *)ip, cidr, peer, &table->table_update_lock); - mutex_unlock(&table->table_update_lock); - return ret; } -void routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer) +int routing_table_insert_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *mutex) { - mutex_lock(&table->table_update_lock); - walk_remove_by_peer(&table->root4, peer, &table->table_update_lock); - walk_remove_by_peer(&table->root6, peer, &table->table_update_lock); - mutex_unlock(&table->table_update_lock); + return add(&table->root4, 32, (const u8 *)ip, cidr, peer, mutex); } -size_t routing_table_count_nodes(struct routing_table *table) +int routing_table_insert_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *mutex) { - size_t ret; - - rcu_read_lock_bh(); - ret = count_nodes(table->root4) + count_nodes(table->root6); - rcu_read_unlock_bh(); - return ret; + return add(&table->root6, 128, (const u8 *)ip, cidr, peer, mutex); } -int routing_table_walk_ips_by_peer(struct routing_table *table, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family)) +void routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer, struct mutex *mutex) { - int ret; - - rcu_read_lock_bh(); - ret = walk_ips_by_peer(table->root4, AF_INET, ctx, peer, func, NULL); - rcu_read_unlock_bh(); - if (ret) - return ret; - - rcu_read_lock_bh(); - ret = walk_ips_by_peer(table->root6, AF_INET6, ctx, peer, func, NULL); - rcu_read_unlock_bh(); - return ret; + walk_remove_by_peer(&table->root4, peer, mutex); + walk_remove_by_peer(&table->root6, peer, mutex); } -int routing_table_walk_ips_by_peer_sleepable(struct routing_table *table, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family)) +int routing_table_walk_ips_by_peer(struct routing_table *table, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family), struct mutex *mutex) { - int ret; - - mutex_lock(&table->table_update_lock); - ret = walk_ips_by_peer(table->root4, AF_INET, ctx, peer, func, &table->table_update_lock); - mutex_unlock(&table->table_update_lock); - if (ret) - return ret; - - mutex_lock(&table->table_update_lock); - ret = walk_ips_by_peer(table->root6, AF_INET6, ctx, peer, func, &table->table_update_lock); - mutex_unlock(&table->table_update_lock); - return ret; + return walk_ips_by_peer(table->root4, AF_INET, ctx, peer, func, mutex) ?: + walk_ips_by_peer(table->root6, AF_INET6, ctx, peer, func, mutex); } /* Returns a strong reference to a peer */ diff --git a/src/routingtable.h b/src/routingtable.h index c251354..815118c 100644 --- a/src/routingtable.h +++ b/src/routingtable.h @@ -13,17 +13,15 @@ struct routing_table_node; struct routing_table { struct routing_table_node __rcu *root4; struct routing_table_node __rcu *root6; - struct mutex table_update_lock; }; void routing_table_init(struct routing_table *table); -void routing_table_free(struct routing_table *table); -int routing_table_insert_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer); -int routing_table_insert_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer); -void routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer); -size_t routing_table_count_nodes(struct routing_table *table); -int routing_table_walk_ips_by_peer(struct routing_table *table, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family)); -int routing_table_walk_ips_by_peer_sleepable(struct routing_table *table, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family)); +void routing_table_free(struct routing_table *table, struct mutex *mutex); +int routing_table_insert_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *mutex); +int routing_table_insert_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *mutex); +void routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer, struct mutex *mutex); + +int routing_table_walk_ips_by_peer(struct routing_table *table, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family), struct mutex *mutex); /* These return a strong reference to a peer: */ struct wireguard_peer *routing_table_lookup_dst(struct routing_table *table, struct sk_buff *skb); diff --git a/src/selftest/routingtable.h b/src/selftest/routingtable.h index 951eb59..4e30b98 100644 --- a/src/selftest/routingtable.h +++ b/src/selftest/routingtable.h @@ -349,7 +349,7 @@ static __init inline struct in6_addr *ip6(u32 a, u32 b, u32 c, u32 d) } while (0) #define insert(version, mem, ipa, ipb, ipc, ipd, cidr) \ - routing_table_insert_v##version(&t, ip##version(ipa, ipb, ipc, ipd), cidr, mem) + routing_table_insert_v##version(&t, ip##version(ipa, ipb, ipc, ipd), cidr, mem, &mutex) #define maybe_fail \ ++i; \ @@ -370,6 +370,7 @@ static __init inline struct in6_addr *ip6(u32 a, u32 b, u32 c, u32 d) bool __init routing_table_selftest(void) { + DEFINE_MUTEX(mutex); struct routing_table t; struct wireguard_peer *a = NULL, *b = NULL, *c = NULL, *d = NULL, *e = NULL, *f = NULL, *g = NULL, *h = NULL; size_t i = 0; @@ -377,6 +378,8 @@ bool __init routing_table_selftest(void) struct in6_addr ip; __be64 part; + mutex_lock(&mutex); + routing_table_init(&t); init_peer(a); init_peer(b); @@ -452,18 +455,18 @@ bool __init routing_table_selftest(void) insert(4, a, 128, 0, 0, 0, 32); insert(4, a, 192, 0, 0, 0, 32); insert(4, a, 255, 0, 0, 0, 32); - routing_table_remove_by_peer(&t, a); + routing_table_remove_by_peer(&t, a, &mutex); test_negative(4, a, 1, 0, 0, 0); test_negative(4, a, 64, 0, 0, 0); test_negative(4, a, 128, 0, 0, 0); test_negative(4, a, 192, 0, 0, 0); test_negative(4, a, 255, 0, 0, 0); - routing_table_free(&t); + routing_table_free(&t, &mutex); routing_table_init(&t); insert(4, a, 192, 168, 0, 0, 16); insert(4, a, 192, 168, 0, 0, 24); - routing_table_remove_by_peer(&t, a); + routing_table_remove_by_peer(&t, a, &mutex); test_negative(4, a, 192, 168, 0, 1); /* These will hit the BUG_ON(len >= 128) in free_node if something goes wrong. */ @@ -471,7 +474,7 @@ bool __init routing_table_selftest(void) part = cpu_to_be64(~(1LLU << (i % 64))); memset(&ip, 0xff, 16); memcpy((u8 *)&ip + (i < 64) * 8, &part, 8); - routing_table_insert_v6(&t, &ip, 128, a); + routing_table_insert_v6(&t, &ip, 128, a, &mutex); } #ifdef DEBUG_RANDOM_TRIE @@ -483,7 +486,7 @@ bool __init routing_table_selftest(void) pr_info("routing table self-tests: pass\n"); free: - routing_table_free(&t); + routing_table_free(&t, &mutex); kfree(a); kfree(b); kfree(c); @@ -492,6 +495,7 @@ free: kfree(f); kfree(g); kfree(h); + mutex_unlock(&mutex); return success; } -- cgit v1.2.3-59-g8ed1b