From 293e9d604f949db9501d0ce01570350198e59c0b Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Mon, 3 Apr 2017 05:20:25 +0200 Subject: locking: always use _bh All locks are potentially between user context and softirq, which means we need to take the _bh variant. --- src/data.c | 12 ++++++------ src/hashtables.c | 38 +++++++++++++++++++------------------- src/noise.c | 14 +++++++------- src/peer.c | 8 ++++---- src/receive.c | 6 +++--- src/routingtable.c | 48 ++++++++++++++++++++++++++---------------------- src/send.c | 19 +++++++++---------- src/socket.c | 14 +++++++------- 8 files changed, 81 insertions(+), 78 deletions(-) diff --git a/src/data.c b/src/data.c index dcbbd10..4751eb8 100644 --- a/src/data.c +++ b/src/data.c @@ -282,11 +282,11 @@ int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer, struct noise_keypair *keypair; struct sk_buff *skb; - rcu_read_lock(); - keypair = noise_keypair_get(rcu_dereference(peer->keypairs.current_keypair)); + rcu_read_lock_bh(); + keypair = noise_keypair_get(rcu_dereference_bh(peer->keypairs.current_keypair)); if (unlikely(!keypair)) goto err_rcu; - rcu_read_unlock(); + rcu_read_unlock_bh(); skb_queue_walk(queue, skb) { if (unlikely(!get_encryption_nonce(&PACKET_CB(skb)->nonce, &keypair->sending))) @@ -338,7 +338,7 @@ err: noise_keypair_put(keypair); return ret; err_rcu: - rcu_read_unlock(); + rcu_read_unlock_bh(); return ret; } @@ -421,9 +421,9 @@ void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg, packe __le32 idx = ((struct message_data *)skb->data)->key_idx; ret = -EINVAL; - rcu_read_lock(); + rcu_read_lock_bh(); keypair = noise_keypair_get((struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx)); - rcu_read_unlock(); + rcu_read_unlock_bh(); if (unlikely(!keypair)) goto err; diff --git a/src/hashtables.c b/src/hashtables.c index 4cb8441..efd7111 100644 --- a/src/hashtables.c +++ b/src/hashtables.c @@ -36,15 +36,15 @@ void pubkey_hashtable_remove(struct pubkey_hashtable *table, struct wireguard_pe struct wireguard_peer *pubkey_hashtable_lookup(struct pubkey_hashtable *table, const u8 pubkey[NOISE_PUBLIC_KEY_LEN]) { struct wireguard_peer *iter_peer, *peer = NULL; - rcu_read_lock(); - hlist_for_each_entry_rcu(iter_peer, pubkey_bucket(table, pubkey), pubkey_hash) { + rcu_read_lock_bh(); + hlist_for_each_entry_rcu_bh(iter_peer, pubkey_bucket(table, pubkey), pubkey_hash) { if (!memcmp(pubkey, iter_peer->handshake.remote_static, NOISE_PUBLIC_KEY_LEN)) { peer = iter_peer; break; } } peer = peer_get(peer); - rcu_read_unlock(); + rcu_read_unlock_bh(); return peer; } @@ -65,60 +65,60 @@ __le32 index_hashtable_insert(struct index_hashtable *table, struct index_hashta { struct index_hashtable_entry *existing_entry; - spin_lock(&table->lock); + spin_lock_bh(&table->lock); hlist_del_init_rcu(&entry->index_hash); - spin_unlock(&table->lock); + spin_unlock_bh(&table->lock); - rcu_read_lock(); + rcu_read_lock_bh(); search_unused_slot: /* First we try to find an unused slot, randomly, while unlocked. */ entry->index = (__force __le32)get_random_u32(); - hlist_for_each_entry_rcu(existing_entry, index_bucket(table, entry->index), index_hash) { + hlist_for_each_entry_rcu_bh(existing_entry, index_bucket(table, entry->index), index_hash) { if (existing_entry->index == entry->index) goto search_unused_slot; /* If it's already in use, we continue searching. */ } /* Once we've found an unused slot, we lock it, and then double-check * that nobody else stole it from us. */ - spin_lock(&table->lock); - hlist_for_each_entry_rcu(existing_entry, index_bucket(table, entry->index), index_hash) { + spin_lock_bh(&table->lock); + hlist_for_each_entry_rcu_bh(existing_entry, index_bucket(table, entry->index), index_hash) { if (existing_entry->index == entry->index) { - spin_unlock(&table->lock); + spin_unlock_bh(&table->lock); goto search_unused_slot; /* If it was stolen, we start over. */ } } /* Otherwise, we know we have it exclusively (since we're locked), so we insert. */ hlist_add_head_rcu(&entry->index_hash, index_bucket(table, entry->index)); - spin_unlock(&table->lock); + spin_unlock_bh(&table->lock); - rcu_read_unlock(); + rcu_read_unlock_bh(); return entry->index; } void index_hashtable_replace(struct index_hashtable *table, struct index_hashtable_entry *old, struct index_hashtable_entry *new) { - spin_lock(&table->lock); + spin_lock_bh(&table->lock); new->index = old->index; hlist_replace_rcu(&old->index_hash, &new->index_hash); INIT_HLIST_NODE(&old->index_hash); - spin_unlock(&table->lock); + spin_unlock_bh(&table->lock); } void index_hashtable_remove(struct index_hashtable *table, struct index_hashtable_entry *entry) { - spin_lock(&table->lock); + spin_lock_bh(&table->lock); hlist_del_init_rcu(&entry->index_hash); - spin_unlock(&table->lock); + spin_unlock_bh(&table->lock); } /* Returns a strong reference to a entry->peer */ struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *table, const enum index_hashtable_type type_mask, const __le32 index) { struct index_hashtable_entry *iter_entry, *entry = NULL; - rcu_read_lock(); - hlist_for_each_entry_rcu(iter_entry, index_bucket(table, index), index_hash) { + rcu_read_lock_bh(); + hlist_for_each_entry_rcu_bh(iter_entry, index_bucket(table, index), index_hash) { if (iter_entry->index == index && (iter_entry->type & type_mask)) { entry = iter_entry; break; @@ -129,6 +129,6 @@ struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *tab if (unlikely(!entry->peer)) entry = NULL; } - rcu_read_unlock(); + rcu_read_unlock_bh(); return entry; } diff --git a/src/noise.c b/src/noise.c index 52a9be3..d6c6398 100644 --- a/src/noise.c +++ b/src/noise.c @@ -85,7 +85,7 @@ static void keypair_free_kref(struct kref *kref) { struct noise_keypair *keypair = container_of(kref, struct noise_keypair, refcount); index_hashtable_remove(&keypair->entry.peer->device->index_hashtable, &keypair->entry); - call_rcu(&keypair->rcu, keypair_free_rcu); + call_rcu_bh(&keypair->rcu, keypair_free_rcu); } void noise_keypair_put(struct noise_keypair *keypair) @@ -97,7 +97,7 @@ void noise_keypair_put(struct noise_keypair *keypair) struct noise_keypair *noise_keypair_get(struct noise_keypair *keypair) { - RCU_LOCKDEP_WARN(!rcu_read_lock_held(), "Calling noise_keypair_get without holding the RCU read lock."); + RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(), "Calling noise_keypair_get without holding the RCU BH read lock"); if (unlikely(!keypair || !kref_get_unless_zero(&keypair->refcount))) return NULL; return keypair; @@ -167,19 +167,19 @@ bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_k /* TODO: probably this needs the actual mutex, but we're in atomic context, * so we can't take it here. Instead we just rely on RCU for the lookups. */ - rcu_read_lock(); - if (unlikely(received_keypair == rcu_dereference(keypairs->next_keypair))) { + rcu_read_lock_bh(); + if (unlikely(received_keypair == rcu_dereference_bh(keypairs->next_keypair))) { ret = true; /* When we've finally received the confirmation, we slide the next * into the current, the current into the previous, and get rid of * the old previous. */ - old_keypair = rcu_dereference(keypairs->previous_keypair); - rcu_assign_pointer(keypairs->previous_keypair, rcu_dereference(keypairs->current_keypair)); + old_keypair = rcu_dereference_bh(keypairs->previous_keypair); + rcu_assign_pointer(keypairs->previous_keypair, rcu_dereference_bh(keypairs->current_keypair)); noise_keypair_put(old_keypair); rcu_assign_pointer(keypairs->current_keypair, received_keypair); rcu_assign_pointer(keypairs->next_keypair, NULL); } - rcu_read_unlock(); + rcu_read_unlock_bh(); return ret; } diff --git a/src/peer.c b/src/peer.c index 4264fa0..cd093b4 100644 --- a/src/peer.c +++ b/src/peer.c @@ -49,7 +49,7 @@ struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_ struct wireguard_peer *peer_get(struct wireguard_peer *peer) { - RCU_LOCKDEP_WARN(!rcu_read_lock_held(), "Calling peer_get without holding the RCU read lock."); + RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(), "Calling peer_get without holding the RCU read lock"); if (unlikely(!peer || !kref_get_unless_zero(&peer->refcount))) return NULL; return peer; @@ -57,9 +57,9 @@ struct wireguard_peer *peer_get(struct wireguard_peer *peer) struct wireguard_peer *peer_rcu_get(struct wireguard_peer *peer) { - rcu_read_lock(); + rcu_read_lock_bh(); peer = peer_get(peer); - rcu_read_unlock(); + rcu_read_unlock_bh(); return peer; } @@ -95,7 +95,7 @@ static void rcu_release(struct rcu_head *rcu) static void kref_release(struct kref *refcount) { struct wireguard_peer *peer = container_of(refcount, struct wireguard_peer, refcount); - call_rcu(&peer->rcu, rcu_release); + call_rcu_bh(&peer->rcu, rcu_release); } void peer_put(struct wireguard_peer *peer) diff --git a/src/receive.c b/src/receive.c index f791a2e..3b375ae 100644 --- a/src/receive.c +++ b/src/receive.c @@ -192,12 +192,12 @@ static void keep_key_fresh(struct wireguard_peer *peer) if (peer->sent_lastminute_handshake) return; - rcu_read_lock(); - keypair = rcu_dereference(peer->keypairs.current_keypair); + rcu_read_lock_bh(); + keypair = rcu_dereference_bh(peer->keypairs.current_keypair); if (likely(keypair && keypair->sending.is_valid) && keypair->i_am_the_initiator && unlikely(time_is_before_eq_jiffies64(keypair->sending.birthdate + REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT))) send = true; - rcu_read_unlock(); + rcu_read_unlock_bh(); if (send) { peer->sent_lastminute_handshake = true; diff --git a/src/routingtable.c b/src/routingtable.c index 36e0bc9..1de7727 100644 --- a/src/routingtable.c +++ b/src/routingtable.c @@ -34,9 +34,13 @@ static inline void copy_and_assign_cidr(struct routing_table_node *node, const u * return; * free_node(node->bit[0]); * free_node(node->bit[1]); - * kfree_rcu(node); + * kfree_rcu_bh(node); * } */ +static void node_free_rcu(struct rcu_head *rcu) +{ + kfree(container_of(rcu, struct routing_table_node, rcu)); +} #define ref(p) rcu_access_pointer(p) #define push(p) do { BUG_ON(len >= 128); stack[len++] = rcu_dereference_protected(p, lockdep_is_held(lock)); } while (0) static void free_node(struct routing_table_node *top, struct mutex *lock) @@ -61,7 +65,7 @@ static void free_node(struct routing_table_node *top, struct mutex *lock) if (ref(node->bit[1])) push(node->bit[1]); } else { - kfree_rcu(node, rcu); + call_rcu_bh(&node->rcu, node_free_rcu); --len; } prev = node; @@ -185,7 +189,7 @@ static inline struct routing_table_node *find_node(struct routing_table_node *tr found = node; if (node->cidr == bits) break; - node = rcu_dereference(node->bit[bit_at(key, node->bit_at_a, node->bit_at_b)]); + node = rcu_dereference_bh(node->bit[bit_at(key, node->bit_at_a, node->bit_at_b)]); } return found; } @@ -276,7 +280,7 @@ static int add(struct routing_table_node __rcu **trie, u8 bits, const u8 *key, u } #define push(p) do { \ - struct routing_table_node *next = (maybe_lock ? rcu_dereference_protected(p, lockdep_is_held(maybe_lock)) : rcu_dereference(p)); \ + struct routing_table_node *next = (maybe_lock ? rcu_dereference_protected(p, lockdep_is_held(maybe_lock)) : rcu_dereference_bh(p)); \ if (next) { \ BUG_ON(len >= 128); \ stack[len++] = next; \ @@ -385,11 +389,11 @@ inline struct wireguard_peer *routing_table_lookup_v4(struct routing_table *tabl struct wireguard_peer *peer = NULL; struct routing_table_node *node; - rcu_read_lock(); - node = find_node(rcu_dereference(table->root4), 32, (const u8 *)ip); + rcu_read_lock_bh(); + node = find_node(rcu_dereference_bh(table->root4), 32, (const u8 *)ip); if (node) peer = peer_get(node->peer); - rcu_read_unlock(); + rcu_read_unlock_bh(); return peer; } @@ -399,11 +403,11 @@ inline struct wireguard_peer *routing_table_lookup_v6(struct routing_table *tabl struct wireguard_peer *peer = NULL; struct routing_table_node *node; - rcu_read_lock(); - node = find_node(rcu_dereference(table->root6), 128, (const u8 *)ip); + rcu_read_lock_bh(); + node = find_node(rcu_dereference_bh(table->root6), 128, (const u8 *)ip); if (node) peer = peer_get(node->peer); - rcu_read_unlock(); + rcu_read_unlock_bh(); return peer; } @@ -439,28 +443,28 @@ int routing_table_remove_by_peer(struct routing_table *table, struct wireguard_p int routing_table_walk_ips(struct routing_table *table, void *ctx, int (*func)(void *ctx, struct wireguard_peer *peer, union nf_inet_addr ip, u8 cidr, int family)) { int ret; - rcu_read_lock(); - ret = walk_ips(rcu_dereference(table->root4), AF_INET, ctx, func, NULL); - rcu_read_unlock(); + rcu_read_lock_bh(); + ret = walk_ips(rcu_dereference_bh(table->root4), AF_INET, ctx, func, NULL); + rcu_read_unlock_bh(); if (ret) return ret; - rcu_read_lock(); - ret = walk_ips(rcu_dereference(table->root6), AF_INET6, ctx, func, NULL); - rcu_read_unlock(); + rcu_read_lock_bh(); + ret = walk_ips(rcu_dereference_bh(table->root6), AF_INET6, ctx, func, NULL); + rcu_read_unlock_bh(); return ret; } int routing_table_walk_ips_by_peer(struct routing_table *table, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family)) { int ret; - rcu_read_lock(); - ret = walk_ips_by_peer(rcu_dereference(table->root4), AF_INET, ctx, peer, func, NULL); - rcu_read_unlock(); + rcu_read_lock_bh(); + ret = walk_ips_by_peer(rcu_dereference_bh(table->root4), AF_INET, ctx, peer, func, NULL); + rcu_read_unlock_bh(); if (ret) return ret; - rcu_read_lock(); - ret = walk_ips_by_peer(rcu_dereference(table->root6), AF_INET6, ctx, peer, func, NULL); - rcu_read_unlock(); + rcu_read_lock_bh(); + ret = walk_ips_by_peer(rcu_dereference_bh(table->root6), AF_INET6, ctx, peer, func, NULL); + rcu_read_unlock_bh(); return ret; } diff --git a/src/send.c b/src/send.c index f5414e1..046b62e 100644 --- a/src/send.c +++ b/src/send.c @@ -91,13 +91,13 @@ static inline void keep_key_fresh(struct wireguard_peer *peer) struct noise_keypair *keypair; bool send = false; - rcu_read_lock(); - keypair = rcu_dereference(peer->keypairs.current_keypair); + rcu_read_lock_bh(); + keypair = rcu_dereference_bh(peer->keypairs.current_keypair); if (likely(keypair && keypair->sending.is_valid) && (unlikely(atomic64_read(&keypair->sending.counter.counter) > REKEY_AFTER_MESSAGES) || (keypair->i_am_the_initiator && unlikely(time_is_before_eq_jiffies64(keypair->sending.birthdate + REKEY_AFTER_TIME))))) send = true; - rcu_read_unlock(); + rcu_read_unlock_bh(); if (send) packet_queue_handshake_initiation(peer); @@ -144,15 +144,14 @@ static void message_create_data_done(struct sk_buff_head *queue, struct wireguar void packet_send_queue(struct wireguard_peer *peer) { struct sk_buff_head queue; - unsigned long flags; peer->need_resend_queue = false; /* Steal the current queue into our local one. */ skb_queue_head_init(&queue); - spin_lock_irqsave(&peer->tx_packet_queue.lock, flags); + spin_lock_bh(&peer->tx_packet_queue.lock); skb_queue_splice_init(&peer->tx_packet_queue, &queue); - spin_unlock_irqrestore(&peer->tx_packet_queue.lock, flags); + spin_unlock_bh(&peer->tx_packet_queue.lock); if (unlikely(!skb_queue_len(&queue))) return; @@ -172,17 +171,17 @@ void packet_send_queue(struct wireguard_peer *peer) /* We stick the remaining skbs from local_queue at the top of the peer's * queue again, setting the top of local_queue to be the skb that begins * the requeueing. */ - spin_lock_irqsave(&peer->tx_packet_queue.lock, flags); + spin_lock_bh(&peer->tx_packet_queue.lock); skb_queue_splice(&queue, &peer->tx_packet_queue); - spin_unlock_irqrestore(&peer->tx_packet_queue.lock, flags); + spin_unlock_bh(&peer->tx_packet_queue.lock); break; case -ENOKEY: /* ENOKEY means that we don't have a valid session for the peer, which * means we should initiate a session, but after requeuing like above. */ - spin_lock_irqsave(&peer->tx_packet_queue.lock, flags); + spin_lock_bh(&peer->tx_packet_queue.lock); skb_queue_splice(&queue, &peer->tx_packet_queue); - spin_unlock_irqrestore(&peer->tx_packet_queue.lock, flags); + spin_unlock_bh(&peer->tx_packet_queue.lock); packet_queue_handshake_initiation(peer); break; diff --git a/src/socket.c b/src/socket.c index a2b64b3..54b1ba2 100644 --- a/src/socket.c +++ b/src/socket.c @@ -30,8 +30,8 @@ static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct skb->next = skb->prev = NULL; skb->dev = netdev_pub(wg); - rcu_read_lock(); - sock = rcu_dereference(wg->sock4); + rcu_read_lock_bh(); + sock = rcu_dereference_bh(wg->sock4); if (unlikely(!sock)) { ret = -ENONET; @@ -73,7 +73,7 @@ static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct err: kfree_skb(skb); out: - rcu_read_unlock(); + rcu_read_unlock_bh(); return ret; } @@ -97,8 +97,8 @@ static inline int send6(struct wireguard_device *wg, struct sk_buff *skb, struct skb->next = skb->prev = NULL; skb->dev = netdev_pub(wg); - rcu_read_lock(); - sock = rcu_dereference(wg->sock6); + rcu_read_lock_bh(); + sock = rcu_dereference_bh(wg->sock6); if (unlikely(!sock)) { ret = -ENONET; @@ -139,7 +139,7 @@ static inline int send6(struct wireguard_device *wg, struct sk_buff *skb, struct err: kfree_skb(skb); out: - rcu_read_unlock(); + rcu_read_unlock_bh(); return ret; #else return -EAFNOSUPPORT; @@ -377,7 +377,7 @@ void socket_uninit(struct wireguard_device *wg) rcu_assign_pointer(wg->sock4, NULL); rcu_assign_pointer(wg->sock6, NULL); mutex_unlock(&wg->socket_update_lock); - synchronize_rcu(); + synchronize_rcu_bh(); synchronize_net(); sock_free(old4); sock_free(old6); -- cgit v1.2.3-59-g8ed1b