From 81eb0e30f9b39e99d1bb7b56828fd32e50ea055a Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Wed, 1 Aug 2018 15:59:37 +0200 Subject: peer: ensure destruction doesn't race Completely rework peer removal to ensure peers don't jump between contexts and create races. --- src/noise.c | 58 ++++++++++++++++++++++++++++++---------------------------- 1 file changed, 30 insertions(+), 28 deletions(-) (limited to 'src/noise.c') diff --git a/src/noise.c b/src/noise.c index a1e094b..0f6e51b 100644 --- a/src/noise.c +++ b/src/noise.c @@ -103,24 +103,23 @@ static struct noise_keypair *keypair_create(struct wireguard_peer *peer) static void keypair_free_rcu(struct rcu_head *rcu) { - struct noise_keypair *keypair = container_of(rcu, struct noise_keypair, rcu); - - net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n", keypair->entry.peer->device->dev->name, keypair->internal_id, keypair->entry.peer->internal_id); - kzfree(keypair); + kzfree(container_of(rcu, struct noise_keypair, rcu)); } static void keypair_free_kref(struct kref *kref) { struct noise_keypair *keypair = container_of(kref, struct noise_keypair, refcount); - + net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n", keypair->entry.peer->device->dev->name, keypair->internal_id, keypair->entry.peer->internal_id); index_hashtable_remove(&keypair->entry.peer->device->index_hashtable, &keypair->entry); call_rcu_bh(&keypair->rcu, keypair_free_rcu); } -void noise_keypair_put(struct noise_keypair *keypair) +void noise_keypair_put(struct noise_keypair *keypair, bool unreference_now) { if (unlikely(!keypair)) return; + if (unlikely(unreference_now)) + index_hashtable_remove(&keypair->entry.peer->device->index_hashtable, &keypair->entry); kref_put(&keypair->refcount, keypair_free_kref); } @@ -139,13 +138,13 @@ void noise_keypairs_clear(struct noise_keypairs *keypairs) spin_lock_bh(&keypairs->keypair_update_lock); old = rcu_dereference_protected(keypairs->previous_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); RCU_INIT_POINTER(keypairs->previous_keypair, NULL); - noise_keypair_put(old); + noise_keypair_put(old, true); old = rcu_dereference_protected(keypairs->next_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); RCU_INIT_POINTER(keypairs->next_keypair, NULL); - noise_keypair_put(old); + noise_keypair_put(old, true); old = rcu_dereference_protected(keypairs->current_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); RCU_INIT_POINTER(keypairs->current_keypair, NULL); - noise_keypair_put(old); + noise_keypair_put(old, true); spin_unlock_bh(&keypairs->keypair_update_lock); } @@ -171,7 +170,7 @@ static void add_new_keypair(struct noise_keypairs *keypairs, struct noise_keypai */ RCU_INIT_POINTER(keypairs->next_keypair, NULL); rcu_assign_pointer(keypairs->previous_keypair, next_keypair); - noise_keypair_put(current_keypair); + noise_keypair_put(current_keypair, true); } else /* If there wasn't an existing next keypair, we replace the * previous with the current one. */ @@ -179,7 +178,7 @@ static void add_new_keypair(struct noise_keypairs *keypairs, struct noise_keypai /* At this point we can get rid of the old previous keypair, and set up * the new keypair. */ - noise_keypair_put(previous_keypair); + noise_keypair_put(previous_keypair, true); rcu_assign_pointer(keypairs->current_keypair, new_keypair); } else { /* If we're the responder, it means we can't use the new keypair until @@ -188,9 +187,9 @@ static void add_new_keypair(struct noise_keypairs *keypairs, struct noise_keypai * in the new next one. */ rcu_assign_pointer(keypairs->next_keypair, new_keypair); - noise_keypair_put(next_keypair); + noise_keypair_put(next_keypair, true); RCU_INIT_POINTER(keypairs->previous_keypair, NULL); - noise_keypair_put(previous_keypair); + noise_keypair_put(previous_keypair, true); } spin_unlock_bh(&keypairs->keypair_update_lock); } @@ -218,7 +217,7 @@ bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_k */ old_keypair = rcu_dereference_protected(keypairs->previous_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); rcu_assign_pointer(keypairs->previous_keypair, rcu_dereference_protected(keypairs->current_keypair, lockdep_is_held(&keypairs->keypair_update_lock))); - noise_keypair_put(old_keypair); + noise_keypair_put(old_keypair, true); rcu_assign_pointer(keypairs->current_keypair, received_keypair); RCU_INIT_POINTER(keypairs->next_keypair, NULL); @@ -542,7 +541,7 @@ out: struct wireguard_peer *noise_handshake_consume_response(struct message_handshake_response *src, struct wireguard_device *wg) { struct noise_handshake *handshake; - struct wireguard_peer *ret_peer = NULL; + struct wireguard_peer *peer = NULL, *ret_peer = NULL; u8 key[NOISE_SYMMETRIC_KEY_LEN]; u8 hash[NOISE_HASH_LEN]; u8 chaining_key[NOISE_HASH_LEN]; @@ -556,7 +555,7 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake if (unlikely(!wg->static_identity.has_identity)) goto out; - handshake = (struct noise_handshake *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE, src->receiver_index); + handshake = (struct noise_handshake *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE, src->receiver_index, &peer); if (unlikely(!handshake)) goto out; @@ -601,11 +600,11 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake handshake->remote_index = src->sender_index; handshake->state = HANDSHAKE_CONSUMED_RESPONSE; up_write(&handshake->lock); - ret_peer = handshake->entry.peer; + ret_peer = peer; goto out; fail: - peer_put(handshake->entry.peer); + peer_put(peer); out: memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN); memzero_explicit(hash, NOISE_HASH_LEN); @@ -619,14 +618,15 @@ out: bool noise_handshake_begin_session(struct noise_handshake *handshake, struct noise_keypairs *keypairs) { struct noise_keypair *new_keypair; + bool ret = false; down_write(&handshake->lock); if (handshake->state != HANDSHAKE_CREATED_RESPONSE && handshake->state != HANDSHAKE_CONSUMED_RESPONSE) - goto fail; + goto out; new_keypair = keypair_create(handshake->entry.peer); if (!new_keypair) - goto fail; + goto out; new_keypair->i_am_the_initiator = handshake->state == HANDSHAKE_CONSUMED_RESPONSE; new_keypair->remote_index = handshake->remote_index; @@ -636,14 +636,16 @@ bool noise_handshake_begin_session(struct noise_handshake *handshake, struct noi derive_keys(&new_keypair->receiving, &new_keypair->sending, handshake->chaining_key); handshake_zero(handshake); - add_new_keypair(keypairs, new_keypair); - net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n", handshake->entry.peer->device->dev->name, new_keypair->internal_id, handshake->entry.peer->internal_id); - WARN_ON(!index_hashtable_replace(&handshake->entry.peer->device->index_hashtable, &handshake->entry, &new_keypair->entry)); - up_write(&handshake->lock); - - return true; + rcu_read_lock_bh(); + if (likely(!container_of(handshake, struct wireguard_peer, handshake)->is_dead)) { + add_new_keypair(keypairs, new_keypair); + net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n", handshake->entry.peer->device->dev->name, new_keypair->internal_id, handshake->entry.peer->internal_id); + ret = index_hashtable_replace(&handshake->entry.peer->device->index_hashtable, &handshake->entry, &new_keypair->entry); + } else + kzfree(new_keypair); + rcu_read_unlock_bh(); -fail: +out: up_write(&handshake->lock); - return false; + return ret; } -- cgit v1.2.3-59-g8ed1b