aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2017-04-03 05:20:25 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2017-04-04 03:44:35 +0200
commit293e9d604f949db9501d0ce01570350198e59c0b (patch)
treed9d02ae5be1ebc9eb7e76c90baa781e45e9189fe
parentqemu: new stable kernel (diff)
downloadwireguard-monolithic-historical-293e9d604f949db9501d0ce01570350198e59c0b.tar.xz
wireguard-monolithic-historical-293e9d604f949db9501d0ce01570350198e59c0b.zip
locking: always use _bh
All locks are potentially between user context and softirq, which means we need to take the _bh variant.
-rw-r--r--src/data.c12
-rw-r--r--src/hashtables.c38
-rw-r--r--src/noise.c14
-rw-r--r--src/peer.c8
-rw-r--r--src/receive.c6
-rw-r--r--src/routingtable.c48
-rw-r--r--src/send.c19
-rw-r--r--src/socket.c14
8 files changed, 81 insertions, 78 deletions
diff --git a/src/data.c b/src/data.c
index dcbbd10..4751eb8 100644
--- a/src/data.c
+++ b/src/data.c
@@ -282,11 +282,11 @@ int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer,
struct noise_keypair *keypair;
struct sk_buff *skb;
- rcu_read_lock();
- keypair = noise_keypair_get(rcu_dereference(peer->keypairs.current_keypair));
+ rcu_read_lock_bh();
+ keypair = noise_keypair_get(rcu_dereference_bh(peer->keypairs.current_keypair));
if (unlikely(!keypair))
goto err_rcu;
- rcu_read_unlock();
+ rcu_read_unlock_bh();
skb_queue_walk(queue, skb) {
if (unlikely(!get_encryption_nonce(&PACKET_CB(skb)->nonce, &keypair->sending)))
@@ -338,7 +338,7 @@ err:
noise_keypair_put(keypair);
return ret;
err_rcu:
- rcu_read_unlock();
+ rcu_read_unlock_bh();
return ret;
}
@@ -421,9 +421,9 @@ void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg, packe
__le32 idx = ((struct message_data *)skb->data)->key_idx;
ret = -EINVAL;
- rcu_read_lock();
+ rcu_read_lock_bh();
keypair = noise_keypair_get((struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx));
- rcu_read_unlock();
+ rcu_read_unlock_bh();
if (unlikely(!keypair))
goto err;
diff --git a/src/hashtables.c b/src/hashtables.c
index 4cb8441..efd7111 100644
--- a/src/hashtables.c
+++ b/src/hashtables.c
@@ -36,15 +36,15 @@ void pubkey_hashtable_remove(struct pubkey_hashtable *table, struct wireguard_pe
struct wireguard_peer *pubkey_hashtable_lookup(struct pubkey_hashtable *table, const u8 pubkey[NOISE_PUBLIC_KEY_LEN])
{
struct wireguard_peer *iter_peer, *peer = NULL;
- rcu_read_lock();
- hlist_for_each_entry_rcu(iter_peer, pubkey_bucket(table, pubkey), pubkey_hash) {
+ rcu_read_lock_bh();
+ hlist_for_each_entry_rcu_bh(iter_peer, pubkey_bucket(table, pubkey), pubkey_hash) {
if (!memcmp(pubkey, iter_peer->handshake.remote_static, NOISE_PUBLIC_KEY_LEN)) {
peer = iter_peer;
break;
}
}
peer = peer_get(peer);
- rcu_read_unlock();
+ rcu_read_unlock_bh();
return peer;
}
@@ -65,60 +65,60 @@ __le32 index_hashtable_insert(struct index_hashtable *table, struct index_hashta
{
struct index_hashtable_entry *existing_entry;
- spin_lock(&table->lock);
+ spin_lock_bh(&table->lock);
hlist_del_init_rcu(&entry->index_hash);
- spin_unlock(&table->lock);
+ spin_unlock_bh(&table->lock);
- rcu_read_lock();
+ rcu_read_lock_bh();
search_unused_slot:
/* First we try to find an unused slot, randomly, while unlocked. */
entry->index = (__force __le32)get_random_u32();
- hlist_for_each_entry_rcu(existing_entry, index_bucket(table, entry->index), index_hash) {
+ hlist_for_each_entry_rcu_bh(existing_entry, index_bucket(table, entry->index), index_hash) {
if (existing_entry->index == entry->index)
goto search_unused_slot; /* If it's already in use, we continue searching. */
}
/* Once we've found an unused slot, we lock it, and then double-check
* that nobody else stole it from us. */
- spin_lock(&table->lock);
- hlist_for_each_entry_rcu(existing_entry, index_bucket(table, entry->index), index_hash) {
+ spin_lock_bh(&table->lock);
+ hlist_for_each_entry_rcu_bh(existing_entry, index_bucket(table, entry->index), index_hash) {
if (existing_entry->index == entry->index) {
- spin_unlock(&table->lock);
+ spin_unlock_bh(&table->lock);
goto search_unused_slot; /* If it was stolen, we start over. */
}
}
/* Otherwise, we know we have it exclusively (since we're locked), so we insert. */
hlist_add_head_rcu(&entry->index_hash, index_bucket(table, entry->index));
- spin_unlock(&table->lock);
+ spin_unlock_bh(&table->lock);
- rcu_read_unlock();
+ rcu_read_unlock_bh();
return entry->index;
}
void index_hashtable_replace(struct index_hashtable *table, struct index_hashtable_entry *old, struct index_hashtable_entry *new)
{
- spin_lock(&table->lock);
+ spin_lock_bh(&table->lock);
new->index = old->index;
hlist_replace_rcu(&old->index_hash, &new->index_hash);
INIT_HLIST_NODE(&old->index_hash);
- spin_unlock(&table->lock);
+ spin_unlock_bh(&table->lock);
}
void index_hashtable_remove(struct index_hashtable *table, struct index_hashtable_entry *entry)
{
- spin_lock(&table->lock);
+ spin_lock_bh(&table->lock);
hlist_del_init_rcu(&entry->index_hash);
- spin_unlock(&table->lock);
+ spin_unlock_bh(&table->lock);
}
/* Returns a strong reference to a entry->peer */
struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *table, const enum index_hashtable_type type_mask, const __le32 index)
{
struct index_hashtable_entry *iter_entry, *entry = NULL;
- rcu_read_lock();
- hlist_for_each_entry_rcu(iter_entry, index_bucket(table, index), index_hash) {
+ rcu_read_lock_bh();
+ hlist_for_each_entry_rcu_bh(iter_entry, index_bucket(table, index), index_hash) {
if (iter_entry->index == index && (iter_entry->type & type_mask)) {
entry = iter_entry;
break;
@@ -129,6 +129,6 @@ struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *tab
if (unlikely(!entry->peer))
entry = NULL;
}
- rcu_read_unlock();
+ rcu_read_unlock_bh();
return entry;
}
diff --git a/src/noise.c b/src/noise.c
index 52a9be3..d6c6398 100644
--- a/src/noise.c
+++ b/src/noise.c
@@ -85,7 +85,7 @@ static void keypair_free_kref(struct kref *kref)
{
struct noise_keypair *keypair = container_of(kref, struct noise_keypair, refcount);
index_hashtable_remove(&keypair->entry.peer->device->index_hashtable, &keypair->entry);
- call_rcu(&keypair->rcu, keypair_free_rcu);
+ call_rcu_bh(&keypair->rcu, keypair_free_rcu);
}
void noise_keypair_put(struct noise_keypair *keypair)
@@ -97,7 +97,7 @@ void noise_keypair_put(struct noise_keypair *keypair)
struct noise_keypair *noise_keypair_get(struct noise_keypair *keypair)
{
- RCU_LOCKDEP_WARN(!rcu_read_lock_held(), "Calling noise_keypair_get without holding the RCU read lock.");
+ RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(), "Calling noise_keypair_get without holding the RCU BH read lock");
if (unlikely(!keypair || !kref_get_unless_zero(&keypair->refcount)))
return NULL;
return keypair;
@@ -167,19 +167,19 @@ bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_k
/* TODO: probably this needs the actual mutex, but we're in atomic context,
* so we can't take it here. Instead we just rely on RCU for the lookups. */
- rcu_read_lock();
- if (unlikely(received_keypair == rcu_dereference(keypairs->next_keypair))) {
+ rcu_read_lock_bh();
+ if (unlikely(received_keypair == rcu_dereference_bh(keypairs->next_keypair))) {
ret = true;
/* When we've finally received the confirmation, we slide the next
* into the current, the current into the previous, and get rid of
* the old previous. */
- old_keypair = rcu_dereference(keypairs->previous_keypair);
- rcu_assign_pointer(keypairs->previous_keypair, rcu_dereference(keypairs->current_keypair));
+ old_keypair = rcu_dereference_bh(keypairs->previous_keypair);
+ rcu_assign_pointer(keypairs->previous_keypair, rcu_dereference_bh(keypairs->current_keypair));
noise_keypair_put(old_keypair);
rcu_assign_pointer(keypairs->current_keypair, received_keypair);
rcu_assign_pointer(keypairs->next_keypair, NULL);
}
- rcu_read_unlock();
+ rcu_read_unlock_bh();
return ret;
}
diff --git a/src/peer.c b/src/peer.c
index 4264fa0..cd093b4 100644
--- a/src/peer.c
+++ b/src/peer.c
@@ -49,7 +49,7 @@ struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_
struct wireguard_peer *peer_get(struct wireguard_peer *peer)
{
- RCU_LOCKDEP_WARN(!rcu_read_lock_held(), "Calling peer_get without holding the RCU read lock.");
+ RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(), "Calling peer_get without holding the RCU read lock");
if (unlikely(!peer || !kref_get_unless_zero(&peer->refcount)))
return NULL;
return peer;
@@ -57,9 +57,9 @@ struct wireguard_peer *peer_get(struct wireguard_peer *peer)
struct wireguard_peer *peer_rcu_get(struct wireguard_peer *peer)
{
- rcu_read_lock();
+ rcu_read_lock_bh();
peer = peer_get(peer);
- rcu_read_unlock();
+ rcu_read_unlock_bh();
return peer;
}
@@ -95,7 +95,7 @@ static void rcu_release(struct rcu_head *rcu)
static void kref_release(struct kref *refcount)
{
struct wireguard_peer *peer = container_of(refcount, struct wireguard_peer, refcount);
- call_rcu(&peer->rcu, rcu_release);
+ call_rcu_bh(&peer->rcu, rcu_release);
}
void peer_put(struct wireguard_peer *peer)
diff --git a/src/receive.c b/src/receive.c
index f791a2e..3b375ae 100644
--- a/src/receive.c
+++ b/src/receive.c
@@ -192,12 +192,12 @@ static void keep_key_fresh(struct wireguard_peer *peer)
if (peer->sent_lastminute_handshake)
return;
- rcu_read_lock();
- keypair = rcu_dereference(peer->keypairs.current_keypair);
+ rcu_read_lock_bh();
+ keypair = rcu_dereference_bh(peer->keypairs.current_keypair);
if (likely(keypair && keypair->sending.is_valid) && keypair->i_am_the_initiator &&
unlikely(time_is_before_eq_jiffies64(keypair->sending.birthdate + REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT)))
send = true;
- rcu_read_unlock();
+ rcu_read_unlock_bh();
if (send) {
peer->sent_lastminute_handshake = true;
diff --git a/src/routingtable.c b/src/routingtable.c
index 36e0bc9..1de7727 100644
--- a/src/routingtable.c
+++ b/src/routingtable.c
@@ -34,9 +34,13 @@ static inline void copy_and_assign_cidr(struct routing_table_node *node, const u
* return;
* free_node(node->bit[0]);
* free_node(node->bit[1]);
- * kfree_rcu(node);
+ * kfree_rcu_bh(node);
* }
*/
+static void node_free_rcu(struct rcu_head *rcu)
+{
+ kfree(container_of(rcu, struct routing_table_node, rcu));
+}
#define ref(p) rcu_access_pointer(p)
#define push(p) do { BUG_ON(len >= 128); stack[len++] = rcu_dereference_protected(p, lockdep_is_held(lock)); } while (0)
static void free_node(struct routing_table_node *top, struct mutex *lock)
@@ -61,7 +65,7 @@ static void free_node(struct routing_table_node *top, struct mutex *lock)
if (ref(node->bit[1]))
push(node->bit[1]);
} else {
- kfree_rcu(node, rcu);
+ call_rcu_bh(&node->rcu, node_free_rcu);
--len;
}
prev = node;
@@ -185,7 +189,7 @@ static inline struct routing_table_node *find_node(struct routing_table_node *tr
found = node;
if (node->cidr == bits)
break;
- node = rcu_dereference(node->bit[bit_at(key, node->bit_at_a, node->bit_at_b)]);
+ node = rcu_dereference_bh(node->bit[bit_at(key, node->bit_at_a, node->bit_at_b)]);
}
return found;
}
@@ -276,7 +280,7 @@ static int add(struct routing_table_node __rcu **trie, u8 bits, const u8 *key, u
}
#define push(p) do { \
- struct routing_table_node *next = (maybe_lock ? rcu_dereference_protected(p, lockdep_is_held(maybe_lock)) : rcu_dereference(p)); \
+ struct routing_table_node *next = (maybe_lock ? rcu_dereference_protected(p, lockdep_is_held(maybe_lock)) : rcu_dereference_bh(p)); \
if (next) { \
BUG_ON(len >= 128); \
stack[len++] = next; \
@@ -385,11 +389,11 @@ inline struct wireguard_peer *routing_table_lookup_v4(struct routing_table *tabl
struct wireguard_peer *peer = NULL;
struct routing_table_node *node;
- rcu_read_lock();
- node = find_node(rcu_dereference(table->root4), 32, (const u8 *)ip);
+ rcu_read_lock_bh();
+ node = find_node(rcu_dereference_bh(table->root4), 32, (const u8 *)ip);
if (node)
peer = peer_get(node->peer);
- rcu_read_unlock();
+ rcu_read_unlock_bh();
return peer;
}
@@ -399,11 +403,11 @@ inline struct wireguard_peer *routing_table_lookup_v6(struct routing_table *tabl
struct wireguard_peer *peer = NULL;
struct routing_table_node *node;
- rcu_read_lock();
- node = find_node(rcu_dereference(table->root6), 128, (const u8 *)ip);
+ rcu_read_lock_bh();
+ node = find_node(rcu_dereference_bh(table->root6), 128, (const u8 *)ip);
if (node)
peer = peer_get(node->peer);
- rcu_read_unlock();
+ rcu_read_unlock_bh();
return peer;
}
@@ -439,28 +443,28 @@ int routing_table_remove_by_peer(struct routing_table *table, struct wireguard_p
int routing_table_walk_ips(struct routing_table *table, void *ctx, int (*func)(void *ctx, struct wireguard_peer *peer, union nf_inet_addr ip, u8 cidr, int family))
{
int ret;
- rcu_read_lock();
- ret = walk_ips(rcu_dereference(table->root4), AF_INET, ctx, func, NULL);
- rcu_read_unlock();
+ rcu_read_lock_bh();
+ ret = walk_ips(rcu_dereference_bh(table->root4), AF_INET, ctx, func, NULL);
+ rcu_read_unlock_bh();
if (ret)
return ret;
- rcu_read_lock();
- ret = walk_ips(rcu_dereference(table->root6), AF_INET6, ctx, func, NULL);
- rcu_read_unlock();
+ rcu_read_lock_bh();
+ ret = walk_ips(rcu_dereference_bh(table->root6), AF_INET6, ctx, func, NULL);
+ rcu_read_unlock_bh();
return ret;
}
int routing_table_walk_ips_by_peer(struct routing_table *table, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family))
{
int ret;
- rcu_read_lock();
- ret = walk_ips_by_peer(rcu_dereference(table->root4), AF_INET, ctx, peer, func, NULL);
- rcu_read_unlock();
+ rcu_read_lock_bh();
+ ret = walk_ips_by_peer(rcu_dereference_bh(table->root4), AF_INET, ctx, peer, func, NULL);
+ rcu_read_unlock_bh();
if (ret)
return ret;
- rcu_read_lock();
- ret = walk_ips_by_peer(rcu_dereference(table->root6), AF_INET6, ctx, peer, func, NULL);
- rcu_read_unlock();
+ rcu_read_lock_bh();
+ ret = walk_ips_by_peer(rcu_dereference_bh(table->root6), AF_INET6, ctx, peer, func, NULL);
+ rcu_read_unlock_bh();
return ret;
}
diff --git a/src/send.c b/src/send.c
index f5414e1..046b62e 100644
--- a/src/send.c
+++ b/src/send.c
@@ -91,13 +91,13 @@ static inline void keep_key_fresh(struct wireguard_peer *peer)
struct noise_keypair *keypair;
bool send = false;
- rcu_read_lock();
- keypair = rcu_dereference(peer->keypairs.current_keypair);
+ rcu_read_lock_bh();
+ keypair = rcu_dereference_bh(peer->keypairs.current_keypair);
if (likely(keypair && keypair->sending.is_valid) &&
(unlikely(atomic64_read(&keypair->sending.counter.counter) > REKEY_AFTER_MESSAGES) ||
(keypair->i_am_the_initiator && unlikely(time_is_before_eq_jiffies64(keypair->sending.birthdate + REKEY_AFTER_TIME)))))
send = true;
- rcu_read_unlock();
+ rcu_read_unlock_bh();
if (send)
packet_queue_handshake_initiation(peer);
@@ -144,15 +144,14 @@ static void message_create_data_done(struct sk_buff_head *queue, struct wireguar
void packet_send_queue(struct wireguard_peer *peer)
{
struct sk_buff_head queue;
- unsigned long flags;
peer->need_resend_queue = false;
/* Steal the current queue into our local one. */
skb_queue_head_init(&queue);
- spin_lock_irqsave(&peer->tx_packet_queue.lock, flags);
+ spin_lock_bh(&peer->tx_packet_queue.lock);
skb_queue_splice_init(&peer->tx_packet_queue, &queue);
- spin_unlock_irqrestore(&peer->tx_packet_queue.lock, flags);
+ spin_unlock_bh(&peer->tx_packet_queue.lock);
if (unlikely(!skb_queue_len(&queue)))
return;
@@ -172,17 +171,17 @@ void packet_send_queue(struct wireguard_peer *peer)
/* We stick the remaining skbs from local_queue at the top of the peer's
* queue again, setting the top of local_queue to be the skb that begins
* the requeueing. */
- spin_lock_irqsave(&peer->tx_packet_queue.lock, flags);
+ spin_lock_bh(&peer->tx_packet_queue.lock);
skb_queue_splice(&queue, &peer->tx_packet_queue);
- spin_unlock_irqrestore(&peer->tx_packet_queue.lock, flags);
+ spin_unlock_bh(&peer->tx_packet_queue.lock);
break;
case -ENOKEY:
/* ENOKEY means that we don't have a valid session for the peer, which
* means we should initiate a session, but after requeuing like above. */
- spin_lock_irqsave(&peer->tx_packet_queue.lock, flags);
+ spin_lock_bh(&peer->tx_packet_queue.lock);
skb_queue_splice(&queue, &peer->tx_packet_queue);
- spin_unlock_irqrestore(&peer->tx_packet_queue.lock, flags);
+ spin_unlock_bh(&peer->tx_packet_queue.lock);
packet_queue_handshake_initiation(peer);
break;
diff --git a/src/socket.c b/src/socket.c
index a2b64b3..54b1ba2 100644
--- a/src/socket.c
+++ b/src/socket.c
@@ -30,8 +30,8 @@ static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct
skb->next = skb->prev = NULL;
skb->dev = netdev_pub(wg);
- rcu_read_lock();
- sock = rcu_dereference(wg->sock4);
+ rcu_read_lock_bh();
+ sock = rcu_dereference_bh(wg->sock4);
if (unlikely(!sock)) {
ret = -ENONET;
@@ -73,7 +73,7 @@ static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct
err:
kfree_skb(skb);
out:
- rcu_read_unlock();
+ rcu_read_unlock_bh();
return ret;
}
@@ -97,8 +97,8 @@ static inline int send6(struct wireguard_device *wg, struct sk_buff *skb, struct
skb->next = skb->prev = NULL;
skb->dev = netdev_pub(wg);
- rcu_read_lock();
- sock = rcu_dereference(wg->sock6);
+ rcu_read_lock_bh();
+ sock = rcu_dereference_bh(wg->sock6);
if (unlikely(!sock)) {
ret = -ENONET;
@@ -139,7 +139,7 @@ static inline int send6(struct wireguard_device *wg, struct sk_buff *skb, struct
err:
kfree_skb(skb);
out:
- rcu_read_unlock();
+ rcu_read_unlock_bh();
return ret;
#else
return -EAFNOSUPPORT;
@@ -377,7 +377,7 @@ void socket_uninit(struct wireguard_device *wg)
rcu_assign_pointer(wg->sock4, NULL);
rcu_assign_pointer(wg->sock6, NULL);
mutex_unlock(&wg->socket_update_lock);
- synchronize_rcu();
+ synchronize_rcu_bh();
synchronize_net();
sock_free(old4);
sock_free(old6);