aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--src/hashtables.c5
-rw-r--r--src/hashtables.h2
-rw-r--r--src/noise.c28
3 files changed, 24 insertions, 11 deletions
diff --git a/src/hashtables.c b/src/hashtables.c
index db97f7e..a01a899 100644
--- a/src/hashtables.c
+++ b/src/hashtables.c
@@ -97,13 +97,16 @@ search_unused_slot:
return entry->index;
}
-void index_hashtable_replace(struct index_hashtable *table, struct index_hashtable_entry *old, struct index_hashtable_entry *new)
+bool index_hashtable_replace(struct index_hashtable *table, struct index_hashtable_entry *old, struct index_hashtable_entry *new)
{
+ if (unlikely(hlist_unhashed(&old->index_hash)))
+ return false;
spin_lock_bh(&table->lock);
new->index = old->index;
hlist_replace_rcu(&old->index_hash, &new->index_hash);
INIT_HLIST_NODE(&old->index_hash);
spin_unlock_bh(&table->lock);
+ return true;
}
void index_hashtable_remove(struct index_hashtable *table, struct index_hashtable_entry *entry)
diff --git a/src/hashtables.h b/src/hashtables.h
index 9fa47d5..08a2a5d 100644
--- a/src/hashtables.h
+++ b/src/hashtables.h
@@ -40,7 +40,7 @@ struct index_hashtable_entry {
};
void index_hashtable_init(struct index_hashtable *table);
__le32 index_hashtable_insert(struct index_hashtable *table, struct index_hashtable_entry *entry);
-void index_hashtable_replace(struct index_hashtable *table, struct index_hashtable_entry *old, struct index_hashtable_entry *new);
+bool index_hashtable_replace(struct index_hashtable *table, struct index_hashtable_entry *old, struct index_hashtable_entry *new);
void index_hashtable_remove(struct index_hashtable *table, struct index_hashtable_entry *entry);
struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *table, const enum index_hashtable_type type_mask, const __le32 index);
diff --git a/src/noise.c b/src/noise.c
index 7ca2a67..9583ab1 100644
--- a/src/noise.c
+++ b/src/noise.c
@@ -59,16 +59,21 @@ bool noise_handshake_init(struct noise_handshake *handshake, struct noise_static
return noise_precompute_static_static(peer);
}
-void noise_handshake_clear(struct noise_handshake *handshake)
+static void handshake_zero(struct noise_handshake *handshake)
{
- index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, &handshake->entry);
- down_write(&handshake->lock);
memset(&handshake->ephemeral_private, 0, NOISE_PUBLIC_KEY_LEN);
memset(&handshake->remote_ephemeral, 0, NOISE_PUBLIC_KEY_LEN);
memset(&handshake->hash, 0, NOISE_HASH_LEN);
memset(&handshake->chaining_key, 0, NOISE_HASH_LEN);
handshake->remote_index = 0;
handshake->state = HANDSHAKE_ZEROED;
+}
+
+void noise_handshake_clear(struct noise_handshake *handshake)
+{
+ index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, &handshake->entry);
+ down_write(&handshake->lock);
+ handshake_zero(handshake);
up_write(&handshake->lock);
index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, &handshake->entry);
}
@@ -371,8 +376,8 @@ bool noise_handshake_create_initiation(struct message_handshake_initiation *dst,
dst->sender_index = index_hashtable_insert(&handshake->entry.peer->device->index_hashtable, &handshake->entry);
- ret = true;
handshake->state = HANDSHAKE_CREATED_INITIATION;
+ ret = true;
out:
up_write(&handshake->lock);
@@ -548,6 +553,11 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake
/* Success! Copy everything to peer */
down_write(&handshake->lock);
+ /* It's important to check that the state is still the same, while we have an exclusive lock */
+ if (handshake->state != state) {
+ up_write(&handshake->lock);
+ goto fail;
+ }
memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
memcpy(handshake->hash, hash, NOISE_HASH_LEN);
memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
@@ -573,7 +583,7 @@ bool noise_handshake_begin_session(struct noise_handshake *handshake, struct noi
{
struct noise_keypair *new_keypair;
- down_read(&handshake->lock);
+ down_write(&handshake->lock);
if (handshake->state != HANDSHAKE_CREATED_RESPONSE && handshake->state != HANDSHAKE_CONSUMED_RESPONSE)
goto fail;
@@ -587,16 +597,16 @@ bool noise_handshake_begin_session(struct noise_handshake *handshake, struct noi
derive_keys(&new_keypair->sending, &new_keypair->receiving, handshake->chaining_key);
else
derive_keys(&new_keypair->receiving, &new_keypair->sending, handshake->chaining_key);
- up_read(&handshake->lock);
+ handshake_zero(handshake);
add_new_keypair(keypairs, new_keypair);
- index_hashtable_replace(&handshake->entry.peer->device->index_hashtable, &handshake->entry, &new_keypair->entry);
- noise_handshake_clear(handshake);
net_dbg_ratelimited("%s: Keypair %Lu created for peer %Lu\n", netdev_pub(new_keypair->entry.peer->device)->name, new_keypair->internal_id, new_keypair->entry.peer->internal_id);
+ WARN_ON(!index_hashtable_replace(&handshake->entry.peer->device->index_hashtable, &handshake->entry, &new_keypair->entry));
+ up_write(&handshake->lock);
return true;
fail:
- up_read(&handshake->lock);
+ up_write(&handshake->lock);
return false;
}