diff options
Diffstat (limited to 'src/noise.c')
-rw-r--r-- | src/noise.c | 360 |
1 files changed, 247 insertions, 113 deletions
diff --git a/src/noise.c b/src/noise.c index 0f6e51b..70b53a6 100644 --- a/src/noise.c +++ b/src/noise.c @@ -35,7 +35,8 @@ void __init noise_init(void) { struct blake2s_state blake; - blake2s(handshake_init_chaining_key, handshake_name, NULL, NOISE_HASH_LEN, sizeof(handshake_name), 0); + blake2s(handshake_init_chaining_key, handshake_name, NULL, + NOISE_HASH_LEN, sizeof(handshake_name), 0); blake2s_init(&blake, NOISE_HASH_LEN); blake2s_update(&blake, handshake_init_chaining_key, NOISE_HASH_LEN); blake2s_update(&blake, identifier_name, sizeof(identifier_name)); @@ -46,16 +47,25 @@ void __init noise_init(void) bool noise_precompute_static_static(struct wireguard_peer *peer) { bool ret = true; + down_write(&peer->handshake.lock); if (peer->handshake.static_identity->has_identity) - ret = curve25519(peer->handshake.precomputed_static_static, peer->handshake.static_identity->static_private, peer->handshake.remote_static); + ret = curve25519( + peer->handshake.precomputed_static_static, + peer->handshake.static_identity->static_private, + peer->handshake.remote_static); else - memset(peer->handshake.precomputed_static_static, 0, NOISE_PUBLIC_KEY_LEN); + memset(peer->handshake.precomputed_static_static, 0, + NOISE_PUBLIC_KEY_LEN); up_write(&peer->handshake.lock); return ret; } -bool noise_handshake_init(struct noise_handshake *handshake, struct noise_static_identity *static_identity, const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN], const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN], struct wireguard_peer *peer) +bool noise_handshake_init(struct noise_handshake *handshake, + struct noise_static_identity *static_identity, + const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN], + const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN], + struct wireguard_peer *peer) { memset(handshake, 0, sizeof(struct noise_handshake)); init_rwsem(&handshake->lock); @@ -63,7 +73,8 @@ bool noise_handshake_init(struct noise_handshake *handshake, struct noise_static handshake->entry.peer = peer; memcpy(handshake->remote_static, peer_public_key, NOISE_PUBLIC_KEY_LEN); if (peer_preshared_key) - memcpy(handshake->preshared_key, peer_preshared_key, NOISE_SYMMETRIC_KEY_LEN); + memcpy(handshake->preshared_key, peer_preshared_key, + NOISE_SYMMETRIC_KEY_LEN); handshake->static_identity = static_identity; handshake->state = HANDSHAKE_ZEROED; return noise_precompute_static_static(peer); @@ -81,16 +92,19 @@ static void handshake_zero(struct noise_handshake *handshake) void noise_handshake_clear(struct noise_handshake *handshake) { - index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, &handshake->entry); + index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, + &handshake->entry); down_write(&handshake->lock); handshake_zero(handshake); up_write(&handshake->lock); - index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, &handshake->entry); + index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, + &handshake->entry); } static struct noise_keypair *keypair_create(struct wireguard_peer *peer) { - struct noise_keypair *keypair = kzalloc(sizeof(struct noise_keypair), GFP_KERNEL); + struct noise_keypair *keypair = + kzalloc(sizeof(struct noise_keypair), GFP_KERNEL); if (unlikely(!keypair)) return NULL; @@ -108,9 +122,14 @@ static void keypair_free_rcu(struct rcu_head *rcu) static void keypair_free_kref(struct kref *kref) { - struct noise_keypair *keypair = container_of(kref, struct noise_keypair, refcount); - net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n", keypair->entry.peer->device->dev->name, keypair->internal_id, keypair->entry.peer->internal_id); - index_hashtable_remove(&keypair->entry.peer->device->index_hashtable, &keypair->entry); + struct noise_keypair *keypair = + container_of(kref, struct noise_keypair, refcount); + net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n", + keypair->entry.peer->device->dev->name, + keypair->internal_id, + keypair->entry.peer->internal_id); + index_hashtable_remove(&keypair->entry.peer->device->index_hashtable, + &keypair->entry); call_rcu_bh(&keypair->rcu, keypair_free_rcu); } @@ -119,13 +138,16 @@ void noise_keypair_put(struct noise_keypair *keypair, bool unreference_now) if (unlikely(!keypair)) return; if (unlikely(unreference_now)) - index_hashtable_remove(&keypair->entry.peer->device->index_hashtable, &keypair->entry); + index_hashtable_remove( + &keypair->entry.peer->device->index_hashtable, + &keypair->entry); kref_put(&keypair->refcount, keypair_free_kref); } struct noise_keypair *noise_keypair_get(struct noise_keypair *keypair) { - RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(), "Taking noise keypair reference without holding the RCU BH read lock"); + RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(), + "Taking noise keypair reference without holding the RCU BH read lock"); if (unlikely(!keypair || !kref_get_unless_zero(&keypair->refcount))) return NULL; return keypair; @@ -136,55 +158,67 @@ void noise_keypairs_clear(struct noise_keypairs *keypairs) struct noise_keypair *old; spin_lock_bh(&keypairs->keypair_update_lock); - old = rcu_dereference_protected(keypairs->previous_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); + old = rcu_dereference_protected(keypairs->previous_keypair, + lockdep_is_held(&keypairs->keypair_update_lock)); RCU_INIT_POINTER(keypairs->previous_keypair, NULL); noise_keypair_put(old, true); - old = rcu_dereference_protected(keypairs->next_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); + old = rcu_dereference_protected(keypairs->next_keypair, + lockdep_is_held(&keypairs->keypair_update_lock)); RCU_INIT_POINTER(keypairs->next_keypair, NULL); noise_keypair_put(old, true); - old = rcu_dereference_protected(keypairs->current_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); + old = rcu_dereference_protected(keypairs->current_keypair, + lockdep_is_held(&keypairs->keypair_update_lock)); RCU_INIT_POINTER(keypairs->current_keypair, NULL); noise_keypair_put(old, true); spin_unlock_bh(&keypairs->keypair_update_lock); } -static void add_new_keypair(struct noise_keypairs *keypairs, struct noise_keypair *new_keypair) +static void add_new_keypair(struct noise_keypairs *keypairs, + struct noise_keypair *new_keypair) { struct noise_keypair *previous_keypair, *next_keypair, *current_keypair; spin_lock_bh(&keypairs->keypair_update_lock); - previous_keypair = rcu_dereference_protected(keypairs->previous_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); - next_keypair = rcu_dereference_protected(keypairs->next_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); - current_keypair = rcu_dereference_protected(keypairs->current_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); + previous_keypair = rcu_dereference_protected(keypairs->previous_keypair, + lockdep_is_held(&keypairs->keypair_update_lock)); + next_keypair = rcu_dereference_protected(keypairs->next_keypair, + lockdep_is_held(&keypairs->keypair_update_lock)); + current_keypair = rcu_dereference_protected(keypairs->current_keypair, + lockdep_is_held(&keypairs->keypair_update_lock)); if (new_keypair->i_am_the_initiator) { - /* If we're the initiator, it means we've sent a handshake, and received - * a confirmation response, which means this new keypair can now be used. + /* If we're the initiator, it means we've sent a handshake, and + * received a confirmation response, which means this new + * keypair can now be used. */ if (next_keypair) { - /* If there already was a next keypair pending, we demote it to be - * the previous keypair, and free the existing current. - * TODO: note that this means KCI can result in this transition. It - * would perhaps be more sound to always just get rid of the unused - * next keypair instead of putting it in the previous slot, but this - * might be a bit less robust. Something to think about and decide on. + /* If there already was a next keypair pending, we + * demote it to be the previous keypair, and free the + * existing current. Note that this means KCI can result + * in this transition. It would perhaps be more sound to + * always just get rid of the unused next keypair + * instead of putting it in the previous slot, but this + * might be a bit less robust. Something to think about + * for the future. */ RCU_INIT_POINTER(keypairs->next_keypair, NULL); - rcu_assign_pointer(keypairs->previous_keypair, next_keypair); + rcu_assign_pointer(keypairs->previous_keypair, + next_keypair); noise_keypair_put(current_keypair, true); - } else /* If there wasn't an existing next keypair, we replace the - * previous with the current one. + } else /* If there wasn't an existing next keypair, we replace + * the previous with the current one. */ - rcu_assign_pointer(keypairs->previous_keypair, current_keypair); - /* At this point we can get rid of the old previous keypair, and set up - * the new keypair. + rcu_assign_pointer(keypairs->previous_keypair, + current_keypair); + /* At this point we can get rid of the old previous keypair, and + * set up the new keypair. */ noise_keypair_put(previous_keypair, true); rcu_assign_pointer(keypairs->current_keypair, new_keypair); } else { - /* If we're the responder, it means we can't use the new keypair until - * we receive confirmation via the first data packet, so we get rid of - * the existing previous one, the possibly existing next one, and slide - * in the new next one. + /* If we're the responder, it means we can't use the new keypair + * until we receive confirmation via the first data packet, so + * we get rid of the existing previous one, the possibly + * existing next one, and slide in the new next one. */ rcu_assign_pointer(keypairs->next_keypair, new_keypair); noise_keypair_put(next_keypair, true); @@ -194,19 +228,25 @@ static void add_new_keypair(struct noise_keypairs *keypairs, struct noise_keypai spin_unlock_bh(&keypairs->keypair_update_lock); } -bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_keypair *received_keypair) +bool noise_received_with_keypair(struct noise_keypairs *keypairs, + struct noise_keypair *received_keypair) { - bool key_is_new; struct noise_keypair *old_keypair; + bool key_is_new; /* We first check without taking the spinlock. */ - key_is_new = received_keypair == rcu_access_pointer(keypairs->next_keypair); + key_is_new = received_keypair == + rcu_access_pointer(keypairs->next_keypair); if (likely(!key_is_new)) return false; spin_lock_bh(&keypairs->keypair_update_lock); - /* After locking, we double check that things didn't change from beneath us. */ - if (unlikely(received_keypair != rcu_dereference_protected(keypairs->next_keypair, lockdep_is_held(&keypairs->keypair_update_lock)))) { + /* After locking, we double check that things didn't change from + * beneath us. + */ + if (unlikely(received_keypair != + rcu_dereference_protected(keypairs->next_keypair, + lockdep_is_held(&keypairs->keypair_update_lock)))) { spin_unlock_bh(&keypairs->keypair_update_lock); return false; } @@ -215,8 +255,11 @@ bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_k * into the current, the current into the previous, and get rid of * the old previous. */ - old_keypair = rcu_dereference_protected(keypairs->previous_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); - rcu_assign_pointer(keypairs->previous_keypair, rcu_dereference_protected(keypairs->current_keypair, lockdep_is_held(&keypairs->keypair_update_lock))); + old_keypair = rcu_dereference_protected(keypairs->previous_keypair, + lockdep_is_held(&keypairs->keypair_update_lock)); + rcu_assign_pointer(keypairs->previous_keypair, + rcu_dereference_protected(keypairs->current_keypair, + lockdep_is_held(&keypairs->keypair_update_lock))); noise_keypair_put(old_keypair, true); rcu_assign_pointer(keypairs->current_keypair, received_keypair); RCU_INIT_POINTER(keypairs->next_keypair, NULL); @@ -226,34 +269,46 @@ bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_k } /* Must hold static_identity->lock */ -void noise_set_static_identity_private_key(struct noise_static_identity *static_identity, const u8 private_key[NOISE_PUBLIC_KEY_LEN]) +void noise_set_static_identity_private_key( + struct noise_static_identity *static_identity, + const u8 private_key[NOISE_PUBLIC_KEY_LEN]) { - memcpy(static_identity->static_private, private_key, NOISE_PUBLIC_KEY_LEN); - static_identity->has_identity = curve25519_generate_public(static_identity->static_public, private_key); + memcpy(static_identity->static_private, private_key, + NOISE_PUBLIC_KEY_LEN); + static_identity->has_identity = curve25519_generate_public( + static_identity->static_public, private_key); } /* This is Hugo Krawczyk's HKDF: * - https://eprint.iacr.org/2010/264.pdf * - https://tools.ietf.org/html/rfc5869 */ -static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data, size_t first_len, size_t second_len, size_t third_len, size_t data_len, const u8 chaining_key[NOISE_HASH_LEN]) +static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data, + size_t first_len, size_t second_len, size_t third_len, + size_t data_len, const u8 chaining_key[NOISE_HASH_LEN]) { - u8 secret[BLAKE2S_OUTBYTES]; u8 output[BLAKE2S_OUTBYTES + 1]; + u8 secret[BLAKE2S_OUTBYTES]; #ifdef DEBUG - BUG_ON(first_len > BLAKE2S_OUTBYTES || second_len > BLAKE2S_OUTBYTES || third_len > BLAKE2S_OUTBYTES || ((second_len || second_dst || third_len || third_dst) && (!first_len || !first_dst)) || ((third_len || third_dst) && (!second_len || !second_dst))); + BUG_ON(first_len > BLAKE2S_OUTBYTES || second_len > BLAKE2S_OUTBYTES || + third_len > BLAKE2S_OUTBYTES || + ((second_len || second_dst || third_len || third_dst) && + (!first_len || !first_dst)) || + ((third_len || third_dst) && (!second_len || !second_dst))); #endif /* Extract entropy from data into secret */ - blake2s_hmac(secret, data, chaining_key, BLAKE2S_OUTBYTES, data_len, NOISE_HASH_LEN); + blake2s_hmac(secret, data, chaining_key, BLAKE2S_OUTBYTES, data_len, + NOISE_HASH_LEN); if (!first_dst || !first_len) goto out; /* Expand first key: key = secret, data = 0x1 */ output[0] = 1; - blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES, 1, BLAKE2S_OUTBYTES); + blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES, 1, + BLAKE2S_OUTBYTES); memcpy(first_dst, output, first_len); if (!second_dst || !second_len) @@ -261,7 +316,8 @@ static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data, si /* Expand second key: key = secret, data = first-key || 0x2 */ output[BLAKE2S_OUTBYTES] = 2; - blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES, BLAKE2S_OUTBYTES + 1, BLAKE2S_OUTBYTES); + blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES, + BLAKE2S_OUTBYTES + 1, BLAKE2S_OUTBYTES); memcpy(second_dst, output, second_len); if (!third_dst || !third_len) @@ -269,7 +325,8 @@ static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data, si /* Expand third key: key = secret, data = second-key || 0x3 */ output[BLAKE2S_OUTBYTES] = 3; - blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES, BLAKE2S_OUTBYTES + 1, BLAKE2S_OUTBYTES); + blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES, + BLAKE2S_OUTBYTES + 1, BLAKE2S_OUTBYTES); memcpy(third_dst, output, third_len); out: @@ -282,25 +339,34 @@ static void symmetric_key_init(struct noise_symmetric_key *key) { spin_lock_init(&key->counter.receive.lock); atomic64_set(&key->counter.counter, 0); - memset(key->counter.receive.backtrack, 0, sizeof(key->counter.receive.backtrack)); + memset(key->counter.receive.backtrack, 0, + sizeof(key->counter.receive.backtrack)); key->birthdate = ktime_get_boot_fast_ns(); key->is_valid = true; } -static void derive_keys(struct noise_symmetric_key *first_dst, struct noise_symmetric_key *second_dst, const u8 chaining_key[NOISE_HASH_LEN]) +static void derive_keys(struct noise_symmetric_key *first_dst, + struct noise_symmetric_key *second_dst, + const u8 chaining_key[NOISE_HASH_LEN]) { - kdf(first_dst->key, second_dst->key, NULL, NULL, NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0, chaining_key); + kdf(first_dst->key, second_dst->key, NULL, NULL, + NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0, + chaining_key); symmetric_key_init(first_dst); symmetric_key_init(second_dst); } -static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN], u8 key[NOISE_SYMMETRIC_KEY_LEN], const u8 private[NOISE_PUBLIC_KEY_LEN], const u8 public[NOISE_PUBLIC_KEY_LEN]) +static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN], + u8 key[NOISE_SYMMETRIC_KEY_LEN], + const u8 private[NOISE_PUBLIC_KEY_LEN], + const u8 public[NOISE_PUBLIC_KEY_LEN]) { u8 dh_calculation[NOISE_PUBLIC_KEY_LEN]; if (unlikely(!curve25519(dh_calculation, private, public))) return false; - kdf(chaining_key, key, NULL, dh_calculation, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key); + kdf(chaining_key, key, NULL, dh_calculation, NOISE_HASH_LEN, + NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key); memzero_explicit(dh_calculation, NOISE_PUBLIC_KEY_LEN); return true; } @@ -315,42 +381,59 @@ static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len) blake2s_final(&blake, hash, NOISE_HASH_LEN); } -static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN], u8 key[NOISE_SYMMETRIC_KEY_LEN], const u8 psk[NOISE_SYMMETRIC_KEY_LEN]) +static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN], + u8 key[NOISE_SYMMETRIC_KEY_LEN], + const u8 psk[NOISE_SYMMETRIC_KEY_LEN]) { u8 temp_hash[NOISE_HASH_LEN]; - kdf(chaining_key, temp_hash, key, psk, NOISE_HASH_LEN, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, chaining_key); + kdf(chaining_key, temp_hash, key, psk, NOISE_HASH_LEN, NOISE_HASH_LEN, + NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, chaining_key); mix_hash(hash, temp_hash, NOISE_HASH_LEN); memzero_explicit(temp_hash, NOISE_HASH_LEN); } -static void handshake_init(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN], const u8 remote_static[NOISE_PUBLIC_KEY_LEN]) +static void handshake_init(u8 chaining_key[NOISE_HASH_LEN], + u8 hash[NOISE_HASH_LEN], + const u8 remote_static[NOISE_PUBLIC_KEY_LEN]) { memcpy(hash, handshake_init_hash, NOISE_HASH_LEN); memcpy(chaining_key, handshake_init_chaining_key, NOISE_HASH_LEN); mix_hash(hash, remote_static, NOISE_PUBLIC_KEY_LEN); } -static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext, size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN], u8 hash[NOISE_HASH_LEN]) +static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext, + size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN], + u8 hash[NOISE_HASH_LEN]) { - chacha20poly1305_encrypt(dst_ciphertext, src_plaintext, src_len, hash, NOISE_HASH_LEN, 0 /* Always zero for Noise_IK */, key); + chacha20poly1305_encrypt(dst_ciphertext, src_plaintext, src_len, hash, + NOISE_HASH_LEN, + 0 /* Always zero for Noise_IK */, key); mix_hash(hash, dst_ciphertext, noise_encrypted_len(src_len)); } -static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext, size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN], u8 hash[NOISE_HASH_LEN]) +static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext, + size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN], + u8 hash[NOISE_HASH_LEN]) { - if (!chacha20poly1305_decrypt(dst_plaintext, src_ciphertext, src_len, hash, NOISE_HASH_LEN, 0 /* Always zero for Noise_IK */, key)) + if (!chacha20poly1305_decrypt(dst_plaintext, src_ciphertext, src_len, + hash, NOISE_HASH_LEN, + 0 /* Always zero for Noise_IK */, key)) return false; mix_hash(hash, src_ciphertext, src_len); return true; } -static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN], const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN], u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN]) +static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN], + const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN], + u8 chaining_key[NOISE_HASH_LEN], + u8 hash[NOISE_HASH_LEN]) { if (ephemeral_dst != ephemeral_src) memcpy(ephemeral_dst, ephemeral_src, NOISE_PUBLIC_KEY_LEN); mix_hash(hash, ephemeral_src, NOISE_PUBLIC_KEY_LEN); - kdf(chaining_key, NULL, NULL, ephemeral_src, NOISE_HASH_LEN, 0, 0, NOISE_PUBLIC_KEY_LEN, chaining_key); + kdf(chaining_key, NULL, NULL, ephemeral_src, NOISE_HASH_LEN, 0, 0, + NOISE_PUBLIC_KEY_LEN, chaining_key); } static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN]) @@ -363,14 +446,15 @@ static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN]) *(__be32 *)(output + sizeof(__be64)) = cpu_to_be32(now.tv_nsec); } -bool noise_handshake_create_initiation(struct message_handshake_initiation *dst, struct noise_handshake *handshake) +bool noise_handshake_create_initiation(struct message_handshake_initiation *dst, + struct noise_handshake *handshake) { u8 timestamp[NOISE_TIMESTAMP_LEN]; u8 key[NOISE_SYMMETRIC_KEY_LEN]; bool ret = false; - /* We need to wait for crng _before_ taking any locks, since curve25519_generate_secret - * uses get_random_bytes_wait. + /* We need to wait for crng _before_ taking any locks, since + * curve25519_generate_secret uses get_random_bytes_wait. */ wait_for_random_bytes(); @@ -382,29 +466,42 @@ bool noise_handshake_create_initiation(struct message_handshake_initiation *dst, dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION); - handshake_init(handshake->chaining_key, handshake->hash, handshake->remote_static); + handshake_init(handshake->chaining_key, handshake->hash, + handshake->remote_static); /* e */ curve25519_generate_secret(handshake->ephemeral_private); - if (!curve25519_generate_public(dst->unencrypted_ephemeral, handshake->ephemeral_private)) + if (!curve25519_generate_public(dst->unencrypted_ephemeral, + handshake->ephemeral_private)) goto out; - message_ephemeral(dst->unencrypted_ephemeral, dst->unencrypted_ephemeral, handshake->chaining_key, handshake->hash); + message_ephemeral(dst->unencrypted_ephemeral, + dst->unencrypted_ephemeral, handshake->chaining_key, + handshake->hash); /* es */ - if (!mix_dh(handshake->chaining_key, key, handshake->ephemeral_private, handshake->remote_static)) + if (!mix_dh(handshake->chaining_key, key, handshake->ephemeral_private, + handshake->remote_static)) goto out; /* s */ - message_encrypt(dst->encrypted_static, handshake->static_identity->static_public, NOISE_PUBLIC_KEY_LEN, key, handshake->hash); + message_encrypt(dst->encrypted_static, + handshake->static_identity->static_public, + NOISE_PUBLIC_KEY_LEN, key, handshake->hash); /* ss */ - kdf(handshake->chaining_key, key, NULL, handshake->precomputed_static_static, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, handshake->chaining_key); + kdf(handshake->chaining_key, key, NULL, + handshake->precomputed_static_static, NOISE_HASH_LEN, + NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, + handshake->chaining_key); /* {t} */ tai64n_now(timestamp); - message_encrypt(dst->encrypted_timestamp, timestamp, NOISE_TIMESTAMP_LEN, key, handshake->hash); + message_encrypt(dst->encrypted_timestamp, timestamp, + NOISE_TIMESTAMP_LEN, key, handshake->hash); - dst->sender_index = index_hashtable_insert(&handshake->entry.peer->device->index_hashtable, &handshake->entry); + dst->sender_index = index_hashtable_insert( + &handshake->entry.peer->device->index_hashtable, + &handshake->entry); handshake->state = HANDSHAKE_CREATED_INITIATION; ret = true; @@ -416,17 +513,19 @@ out: return ret; } -struct wireguard_peer *noise_handshake_consume_initiation(struct message_handshake_initiation *src, struct wireguard_device *wg) +struct wireguard_peer * +noise_handshake_consume_initiation(struct message_handshake_initiation *src, + struct wireguard_device *wg) { + struct wireguard_peer *peer = NULL, *ret_peer = NULL; + struct noise_handshake *handshake; bool replay_attack, flood_attack; + u8 key[NOISE_SYMMETRIC_KEY_LEN]; + u8 chaining_key[NOISE_HASH_LEN]; + u8 hash[NOISE_HASH_LEN]; u8 s[NOISE_PUBLIC_KEY_LEN]; u8 e[NOISE_PUBLIC_KEY_LEN]; u8 t[NOISE_TIMESTAMP_LEN]; - struct noise_handshake *handshake; - struct wireguard_peer *peer = NULL, *ret_peer = NULL; - u8 key[NOISE_SYMMETRIC_KEY_LEN]; - u8 hash[NOISE_HASH_LEN]; - u8 chaining_key[NOISE_HASH_LEN]; down_read(&wg->static_identity.lock); if (unlikely(!wg->static_identity.has_identity)) @@ -442,7 +541,8 @@ struct wireguard_peer *noise_handshake_consume_initiation(struct message_handsha goto out; /* s */ - if (!message_decrypt(s, src->encrypted_static, sizeof(src->encrypted_static), key, hash)) + if (!message_decrypt(s, src->encrypted_static, + sizeof(src->encrypted_static), key, hash)) goto out; /* Lookup which peer we're actually talking to */ @@ -452,15 +552,21 @@ struct wireguard_peer *noise_handshake_consume_initiation(struct message_handsha handshake = &peer->handshake; /* ss */ - kdf(chaining_key, key, NULL, handshake->precomputed_static_static, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key); + kdf(chaining_key, key, NULL, handshake->precomputed_static_static, + NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, + chaining_key); /* {t} */ - if (!message_decrypt(t, src->encrypted_timestamp, sizeof(src->encrypted_timestamp), key, hash)) + if (!message_decrypt(t, src->encrypted_timestamp, + sizeof(src->encrypted_timestamp), key, hash)) goto out; down_read(&handshake->lock); - replay_attack = memcmp(t, handshake->latest_timestamp, NOISE_TIMESTAMP_LEN) <= 0; - flood_attack = handshake->last_initiation_consumption + NSEC_PER_SEC / INITIATIONS_PER_SECOND > ktime_get_boot_fast_ns(); + replay_attack = memcmp(t, handshake->latest_timestamp, + NOISE_TIMESTAMP_LEN) <= 0; + flood_attack = handshake->last_initiation_consumption + + NSEC_PER_SEC / INITIATIONS_PER_SECOND > + ktime_get_boot_fast_ns(); up_read(&handshake->lock); if (replay_attack || flood_attack) goto out; @@ -487,13 +593,14 @@ out: return ret_peer; } -bool noise_handshake_create_response(struct message_handshake_response *dst, struct noise_handshake *handshake) +bool noise_handshake_create_response(struct message_handshake_response *dst, + struct noise_handshake *handshake) { bool ret = false; u8 key[NOISE_SYMMETRIC_KEY_LEN]; - /* We need to wait for crng _before_ taking any locks, since curve25519_generate_secret - * uses get_random_bytes_wait. + /* We need to wait for crng _before_ taking any locks, since + * curve25519_generate_secret uses get_random_bytes_wait. */ wait_for_random_bytes(); @@ -508,25 +615,33 @@ bool noise_handshake_create_response(struct message_handshake_response *dst, str /* e */ curve25519_generate_secret(handshake->ephemeral_private); - if (!curve25519_generate_public(dst->unencrypted_ephemeral, handshake->ephemeral_private)) + if (!curve25519_generate_public(dst->unencrypted_ephemeral, + handshake->ephemeral_private)) goto out; - message_ephemeral(dst->unencrypted_ephemeral, dst->unencrypted_ephemeral, handshake->chaining_key, handshake->hash); + message_ephemeral(dst->unencrypted_ephemeral, + dst->unencrypted_ephemeral, handshake->chaining_key, + handshake->hash); /* ee */ - if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private, handshake->remote_ephemeral)) + if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private, + handshake->remote_ephemeral)) goto out; /* se */ - if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private, handshake->remote_static)) + if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private, + handshake->remote_static)) goto out; /* psk */ - mix_psk(handshake->chaining_key, handshake->hash, key, handshake->preshared_key); + mix_psk(handshake->chaining_key, handshake->hash, key, + handshake->preshared_key); /* {} */ message_encrypt(dst->encrypted_nothing, NULL, 0, key, handshake->hash); - dst->sender_index = index_hashtable_insert(&handshake->entry.peer->device->index_hashtable, &handshake->entry); + dst->sender_index = index_hashtable_insert( + &handshake->entry.peer->device->index_hashtable, + &handshake->entry); handshake->state = HANDSHAKE_CREATED_RESPONSE; ret = true; @@ -538,7 +653,9 @@ out: return ret; } -struct wireguard_peer *noise_handshake_consume_response(struct message_handshake_response *src, struct wireguard_device *wg) +struct wireguard_peer * +noise_handshake_consume_response(struct message_handshake_response *src, + struct wireguard_device *wg) { struct noise_handshake *handshake; struct wireguard_peer *peer = NULL, *ret_peer = NULL; @@ -555,7 +672,9 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake if (unlikely(!wg->static_identity.has_identity)) goto out; - handshake = (struct noise_handshake *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE, src->receiver_index, &peer); + handshake = (struct noise_handshake *)index_hashtable_lookup( + &wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE, + src->receiver_index, &peer); if (unlikely(!handshake)) goto out; @@ -563,7 +682,8 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake state = handshake->state; memcpy(hash, handshake->hash, NOISE_HASH_LEN); memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN); - memcpy(ephemeral_private, handshake->ephemeral_private, NOISE_PUBLIC_KEY_LEN); + memcpy(ephemeral_private, handshake->ephemeral_private, + NOISE_PUBLIC_KEY_LEN); up_read(&handshake->lock); if (state != HANDSHAKE_CREATED_INITIATION) @@ -584,12 +704,15 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake mix_psk(chaining_key, hash, key, handshake->preshared_key); /* {} */ - if (!message_decrypt(NULL, src->encrypted_nothing, sizeof(src->encrypted_nothing), key, hash)) + if (!message_decrypt(NULL, src->encrypted_nothing, + sizeof(src->encrypted_nothing), key, hash)) goto fail; /* Success! Copy everything to peer */ down_write(&handshake->lock); - /* It's important to check that the state is still the same, while we have an exclusive lock */ + /* It's important to check that the state is still the same, while we + * have an exclusive lock. + */ if (handshake->state != state) { up_write(&handshake->lock); goto fail; @@ -615,32 +738,43 @@ out: return ret_peer; } -bool noise_handshake_begin_session(struct noise_handshake *handshake, struct noise_keypairs *keypairs) +bool noise_handshake_begin_session(struct noise_handshake *handshake, + struct noise_keypairs *keypairs) { struct noise_keypair *new_keypair; bool ret = false; down_write(&handshake->lock); - if (handshake->state != HANDSHAKE_CREATED_RESPONSE && handshake->state != HANDSHAKE_CONSUMED_RESPONSE) + if (handshake->state != HANDSHAKE_CREATED_RESPONSE && + handshake->state != HANDSHAKE_CONSUMED_RESPONSE) goto out; new_keypair = keypair_create(handshake->entry.peer); if (!new_keypair) goto out; - new_keypair->i_am_the_initiator = handshake->state == HANDSHAKE_CONSUMED_RESPONSE; + new_keypair->i_am_the_initiator = handshake->state == + HANDSHAKE_CONSUMED_RESPONSE; new_keypair->remote_index = handshake->remote_index; if (new_keypair->i_am_the_initiator) - derive_keys(&new_keypair->sending, &new_keypair->receiving, handshake->chaining_key); + derive_keys(&new_keypair->sending, &new_keypair->receiving, + handshake->chaining_key); else - derive_keys(&new_keypair->receiving, &new_keypair->sending, handshake->chaining_key); + derive_keys(&new_keypair->receiving, &new_keypair->sending, + handshake->chaining_key); handshake_zero(handshake); rcu_read_lock_bh(); - if (likely(!container_of(handshake, struct wireguard_peer, handshake)->is_dead)) { + if (likely(!container_of(handshake, struct wireguard_peer, + handshake)->is_dead)) { add_new_keypair(keypairs, new_keypair); - net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n", handshake->entry.peer->device->dev->name, new_keypair->internal_id, handshake->entry.peer->internal_id); - ret = index_hashtable_replace(&handshake->entry.peer->device->index_hashtable, &handshake->entry, &new_keypair->entry); + net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n", + handshake->entry.peer->device->dev->name, + new_keypair->internal_id, + handshake->entry.peer->internal_id); + ret = index_hashtable_replace( + &handshake->entry.peer->device->index_hashtable, + &handshake->entry, &new_keypair->entry); } else kzfree(new_keypair); rcu_read_unlock_bh(); |