aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src/noise.c
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2018-08-01 15:59:37 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2018-08-03 00:14:18 +0200
commit81eb0e30f9b39e99d1bb7b56828fd32e50ea055a (patch)
tree7b9e212d2a73644bae8b09da164147a4491d98fa /src/noise.c
parentnoise: free peer references on failure (diff)
downloadwireguard-monolithic-historical-81eb0e30f9b39e99d1bb7b56828fd32e50ea055a.tar.xz
wireguard-monolithic-historical-81eb0e30f9b39e99d1bb7b56828fd32e50ea055a.zip
peer: ensure destruction doesn't race
Completely rework peer removal to ensure peers don't jump between contexts and create races.
Diffstat (limited to 'src/noise.c')
-rw-r--r--src/noise.c58
1 files changed, 30 insertions, 28 deletions
diff --git a/src/noise.c b/src/noise.c
index a1e094b..0f6e51b 100644
--- a/src/noise.c
+++ b/src/noise.c
@@ -103,24 +103,23 @@ static struct noise_keypair *keypair_create(struct wireguard_peer *peer)
static void keypair_free_rcu(struct rcu_head *rcu)
{
- struct noise_keypair *keypair = container_of(rcu, struct noise_keypair, rcu);
-
- net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n", keypair->entry.peer->device->dev->name, keypair->internal_id, keypair->entry.peer->internal_id);
- kzfree(keypair);
+ kzfree(container_of(rcu, struct noise_keypair, rcu));
}
static void keypair_free_kref(struct kref *kref)
{
struct noise_keypair *keypair = container_of(kref, struct noise_keypair, refcount);
-
+ net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n", keypair->entry.peer->device->dev->name, keypair->internal_id, keypair->entry.peer->internal_id);
index_hashtable_remove(&keypair->entry.peer->device->index_hashtable, &keypair->entry);
call_rcu_bh(&keypair->rcu, keypair_free_rcu);
}
-void noise_keypair_put(struct noise_keypair *keypair)
+void noise_keypair_put(struct noise_keypair *keypair, bool unreference_now)
{
if (unlikely(!keypair))
return;
+ if (unlikely(unreference_now))
+ index_hashtable_remove(&keypair->entry.peer->device->index_hashtable, &keypair->entry);
kref_put(&keypair->refcount, keypair_free_kref);
}
@@ -139,13 +138,13 @@ void noise_keypairs_clear(struct noise_keypairs *keypairs)
spin_lock_bh(&keypairs->keypair_update_lock);
old = rcu_dereference_protected(keypairs->previous_keypair, lockdep_is_held(&keypairs->keypair_update_lock));
RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
- noise_keypair_put(old);
+ noise_keypair_put(old, true);
old = rcu_dereference_protected(keypairs->next_keypair, lockdep_is_held(&keypairs->keypair_update_lock));
RCU_INIT_POINTER(keypairs->next_keypair, NULL);
- noise_keypair_put(old);
+ noise_keypair_put(old, true);
old = rcu_dereference_protected(keypairs->current_keypair, lockdep_is_held(&keypairs->keypair_update_lock));
RCU_INIT_POINTER(keypairs->current_keypair, NULL);
- noise_keypair_put(old);
+ noise_keypair_put(old, true);
spin_unlock_bh(&keypairs->keypair_update_lock);
}
@@ -171,7 +170,7 @@ static void add_new_keypair(struct noise_keypairs *keypairs, struct noise_keypai
*/
RCU_INIT_POINTER(keypairs->next_keypair, NULL);
rcu_assign_pointer(keypairs->previous_keypair, next_keypair);
- noise_keypair_put(current_keypair);
+ noise_keypair_put(current_keypair, true);
} else /* If there wasn't an existing next keypair, we replace the
* previous with the current one.
*/
@@ -179,7 +178,7 @@ static void add_new_keypair(struct noise_keypairs *keypairs, struct noise_keypai
/* At this point we can get rid of the old previous keypair, and set up
* the new keypair.
*/
- noise_keypair_put(previous_keypair);
+ noise_keypair_put(previous_keypair, true);
rcu_assign_pointer(keypairs->current_keypair, new_keypair);
} else {
/* If we're the responder, it means we can't use the new keypair until
@@ -188,9 +187,9 @@ static void add_new_keypair(struct noise_keypairs *keypairs, struct noise_keypai
* in the new next one.
*/
rcu_assign_pointer(keypairs->next_keypair, new_keypair);
- noise_keypair_put(next_keypair);
+ noise_keypair_put(next_keypair, true);
RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
- noise_keypair_put(previous_keypair);
+ noise_keypair_put(previous_keypair, true);
}
spin_unlock_bh(&keypairs->keypair_update_lock);
}
@@ -218,7 +217,7 @@ bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_k
*/
old_keypair = rcu_dereference_protected(keypairs->previous_keypair, lockdep_is_held(&keypairs->keypair_update_lock));
rcu_assign_pointer(keypairs->previous_keypair, rcu_dereference_protected(keypairs->current_keypair, lockdep_is_held(&keypairs->keypair_update_lock)));
- noise_keypair_put(old_keypair);
+ noise_keypair_put(old_keypair, true);
rcu_assign_pointer(keypairs->current_keypair, received_keypair);
RCU_INIT_POINTER(keypairs->next_keypair, NULL);
@@ -542,7 +541,7 @@ out:
struct wireguard_peer *noise_handshake_consume_response(struct message_handshake_response *src, struct wireguard_device *wg)
{
struct noise_handshake *handshake;
- struct wireguard_peer *ret_peer = NULL;
+ struct wireguard_peer *peer = NULL, *ret_peer = NULL;
u8 key[NOISE_SYMMETRIC_KEY_LEN];
u8 hash[NOISE_HASH_LEN];
u8 chaining_key[NOISE_HASH_LEN];
@@ -556,7 +555,7 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake
if (unlikely(!wg->static_identity.has_identity))
goto out;
- handshake = (struct noise_handshake *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE, src->receiver_index);
+ handshake = (struct noise_handshake *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE, src->receiver_index, &peer);
if (unlikely(!handshake))
goto out;
@@ -601,11 +600,11 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake
handshake->remote_index = src->sender_index;
handshake->state = HANDSHAKE_CONSUMED_RESPONSE;
up_write(&handshake->lock);
- ret_peer = handshake->entry.peer;
+ ret_peer = peer;
goto out;
fail:
- peer_put(handshake->entry.peer);
+ peer_put(peer);
out:
memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
memzero_explicit(hash, NOISE_HASH_LEN);
@@ -619,14 +618,15 @@ out:
bool noise_handshake_begin_session(struct noise_handshake *handshake, struct noise_keypairs *keypairs)
{
struct noise_keypair *new_keypair;
+ bool ret = false;
down_write(&handshake->lock);
if (handshake->state != HANDSHAKE_CREATED_RESPONSE && handshake->state != HANDSHAKE_CONSUMED_RESPONSE)
- goto fail;
+ goto out;
new_keypair = keypair_create(handshake->entry.peer);
if (!new_keypair)
- goto fail;
+ goto out;
new_keypair->i_am_the_initiator = handshake->state == HANDSHAKE_CONSUMED_RESPONSE;
new_keypair->remote_index = handshake->remote_index;
@@ -636,14 +636,16 @@ bool noise_handshake_begin_session(struct noise_handshake *handshake, struct noi
derive_keys(&new_keypair->receiving, &new_keypair->sending, handshake->chaining_key);
handshake_zero(handshake);
- add_new_keypair(keypairs, new_keypair);
- net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n", handshake->entry.peer->device->dev->name, new_keypair->internal_id, handshake->entry.peer->internal_id);
- WARN_ON(!index_hashtable_replace(&handshake->entry.peer->device->index_hashtable, &handshake->entry, &new_keypair->entry));
- up_write(&handshake->lock);
-
- return true;
+ rcu_read_lock_bh();
+ if (likely(!container_of(handshake, struct wireguard_peer, handshake)->is_dead)) {
+ add_new_keypair(keypairs, new_keypair);
+ net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n", handshake->entry.peer->device->dev->name, new_keypair->internal_id, handshake->entry.peer->internal_id);
+ ret = index_hashtable_replace(&handshake->entry.peer->device->index_hashtable, &handshake->entry, &new_keypair->entry);
+ } else
+ kzfree(new_keypair);
+ rcu_read_unlock_bh();
-fail:
+out:
up_write(&handshake->lock);
- return false;
+ return ret;
}