diff options
Diffstat (limited to 'src/wireguard/handshake')
-rw-r--r-- | src/wireguard/handshake/device.rs | 64 | ||||
-rw-r--r-- | src/wireguard/handshake/macs.rs | 19 | ||||
-rw-r--r-- | src/wireguard/handshake/noise.rs | 35 | ||||
-rw-r--r-- | src/wireguard/handshake/peer.rs | 34 | ||||
-rw-r--r-- | src/wireguard/handshake/ratelimiter.rs | 5 | ||||
-rw-r--r-- | src/wireguard/handshake/tests.rs | 18 | ||||
-rw-r--r-- | src/wireguard/handshake/timestamp.rs | 2 |
7 files changed, 85 insertions, 92 deletions
diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs index 3a3d023..47ca401 100644 --- a/src/wireguard/handshake/device.rs +++ b/src/wireguard/handshake/device.rs @@ -1,14 +1,15 @@ -use spin::RwLock; use std::collections::hash_map; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Mutex; -use zerocopy::AsBytes; use byteorder::{ByteOrder, LittleEndian}; +use dashmap::mapref::entry::Entry; +use dashmap::DashMap; +use zerocopy::AsBytes; -use rand::prelude::{CryptoRng, RngCore}; use rand::Rng; +use rand_core::{CryptoRng, RngCore}; use clear_on_drop::clear::Clear; @@ -36,7 +37,7 @@ pub struct KeyState { /// (the instance is a Peer object in the parent module) pub struct Device<O> { keyst: Option<KeyState>, - id_map: RwLock<HashMap<u32, [u8; 32]>>, + id_map: DashMap<u32, [u8; 32]>, // concurrent map pk_map: HashMap<[u8; 32], Peer<O>>, limiter: Mutex<RateLimiter>, } @@ -62,7 +63,7 @@ impl<'a, O> Iterator for Iter<'a, O> { */ impl<O> Device<O> { pub fn clear(&mut self) { - self.id_map.write().clear(); + self.id_map.clear(); self.pk_map.clear(); } @@ -96,7 +97,7 @@ impl<O> Device<O> { pub fn new() -> Device<O> { Device { keyst: None, - id_map: RwLock::new(HashMap::new()), + id_map: DashMap::new(), pk_map: HashMap::new(), limiter: Mutex::new(RateLimiter::new()), } @@ -117,7 +118,9 @@ impl<O> Device<O> { } else { peer.ss.clear(); } - peer.reset_state().map(|id| ids.push(id)); + if let Some(id) = peer.reset_state() { + ids.push(id) + } } (ids, same) @@ -208,16 +211,14 @@ impl<O> Device<O> { /// /// The call might fail if the public key is not found pub fn remove(&mut self, pk: &PublicKey) -> Result<(), ConfigError> { - // take write-lock on receive id table - let mut id_map = self.id_map.write(); - // remove the peer self.pk_map .remove(pk.as_bytes()) - .ok_or(ConfigError::new("Public key not in device"))?; + .ok_or_else(|| ConfigError::new("Public key not in device"))?; - // purge the id map (linear scan) - id_map.retain(|_, v| v != pk.as_bytes()); + // remove every id entry for the peer in the public key map + // O(n) operations, however it is rare: only when removing peers. + self.id_map.retain(|_, v| v != pk.as_bytes()); Ok(()) } @@ -265,9 +266,8 @@ impl<O> Device<O> { /// /// * `id` - The (sender) id to release pub fn release(&self, id: u32) { - let mut m = self.id_map.write(); - debug_assert!(m.contains_key(&id), "Releasing id not allocated"); - m.remove(&id); + let old = self.id_map.remove(&id); + assert!(old.is_some(), "released id not allocated"); } /// Begin a new handshake @@ -391,9 +391,6 @@ impl<O> Device<O> { // address validation & DoS mitigation if let Some(src) = src { - // obtain ref to socket addr - let src = src.into(); - // check mac2 field if !keyst.macs.check_mac2(msg.noise.as_bytes(), &src, &msg.macs) { let mut reply = Default::default(); @@ -446,32 +443,37 @@ impl<O> Device<O> { // // Return the peer currently associated with the receiver identifier pub(super) fn lookup_id(&self, id: u32) -> Result<(&Peer<O>, PublicKey), HandshakeError> { - let im = self.id_map.read(); - let pk = im.get(&id).ok_or(HandshakeError::UnknownReceiverId)?; - match self.pk_map.get(pk) { + // obtain a read reference to entry in the id_map + let pk = self + .id_map + .get(&id) + .ok_or(HandshakeError::UnknownReceiverId)?; + + // lookup the public key from the pk map + match self.pk_map.get(&*pk) { Some(peer) => Ok((peer, PublicKey::from(*pk))), - _ => unreachable!(), // if the id-lookup succeeded, the peer should exist + _ => unreachable!(), } } // Internal function // - // Allocated a new receiver identifier for the peer + // Allocated a new receiver identifier for the peer. + // Implemented via rejection sampling. fn allocate<R: RngCore + CryptoRng>(&self, rng: &mut R, pk: &PublicKey) -> u32 { loop { let id = rng.gen(); - // check membership with read lock - if self.id_map.read().contains_key(&id) { + // read lock the shard and do quick check + if self.id_map.contains_key(&id) { continue; } - // take write lock and add index - let mut m = self.id_map.write(); - if !m.contains_key(&id) { - m.insert(id, *pk.as_bytes()); + // write lock the shard and insert + if let Entry::Vacant(entry) = self.id_map.entry(id) { + entry.insert(*pk.as_bytes()); return id; - } + }; } } } diff --git a/src/wireguard/handshake/macs.rs b/src/wireguard/handshake/macs.rs index cb5d7d4..f4f5586 100644 --- a/src/wireguard/handshake/macs.rs +++ b/src/wireguard/handshake/macs.rs @@ -1,5 +1,5 @@ use generic_array::GenericArray; -use rand::{CryptoRng, RngCore}; +use rand_core::{CryptoRng, RngCore}; use spin::RwLock; use std::time::{Duration, Instant}; @@ -8,6 +8,7 @@ use std::net::SocketAddr; use x25519_dalek::PublicKey; // AEAD + use aead::{Aead, NewAead, Payload}; use chacha20poly1305::XChaCha20Poly1305; @@ -33,30 +34,29 @@ macro_rules! HASH { use blake2::Digest; let mut hsh = Blake2s::new(); $( - hsh.input($input); + hsh.update($input); )* - hsh.result() + hsh.finalize() }}; } macro_rules! MAC { ( $key:expr, $($input:expr),* ) => {{ use blake2::VarBlake2s; - use digest::Input; - use digest::VariableOutput; + use blake2::digest::{Update, VariableOutput}; let mut tag = [0u8; SIZE_MAC]; let mut mac = VarBlake2s::new_keyed($key, SIZE_MAC); $( - mac.input($input); + mac.update($input); )* - mac.variable_result(|buf| tag.copy_from_slice(buf)); + mac.finalize_variable(|buf| tag.copy_from_slice(buf)); tag }}; } macro_rules! XSEAL { ($key:expr, $nonce:expr, $ad:expr, $pt:expr, $ct:expr) => {{ - let ct = XChaCha20Poly1305::new(*GenericArray::from_slice($key)) + let ct = XChaCha20Poly1305::new(GenericArray::from_slice($key)) .encrypt( GenericArray::from_slice($nonce), Payload { msg: $pt, aad: $ad }, @@ -70,7 +70,7 @@ macro_rules! XSEAL { macro_rules! XOPEN { ($key:expr, $nonce:expr, $ad:expr, $pt:expr, $ct:expr) => {{ debug_assert_eq!($ct.len(), $pt.len() + SIZE_TAG); - XChaCha20Poly1305::new(*GenericArray::from_slice($key)) + XChaCha20Poly1305::new(GenericArray::from_slice($key)) .decrypt( GenericArray::from_slice($nonce), Payload { msg: $ct, aad: $ad }, @@ -141,6 +141,7 @@ impl Generator { pub fn process(&mut self, reply: &CookieReply) -> Result<(), HandshakeError> { let mac1 = self.last_mac1.ok_or(HandshakeError::InvalidState)?; let mut tau = [0u8; SIZE_COOKIE]; + #[allow(clippy::unnecessary_mut_passed)] XOPEN!( &self.cookie_key, // key &reply.f_nonce, // nonce diff --git a/src/wireguard/handshake/noise.rs b/src/wireguard/handshake/noise.rs index beb99c2..92c8c5f 100644 --- a/src/wireguard/handshake/noise.rs +++ b/src/wireguard/handshake/noise.rs @@ -1,7 +1,7 @@ use std::time::Instant; // DH -use x25519_dalek::{PublicKey, StaticSecret, SharedSecret}; +use x25519_dalek::{PublicKey, SharedSecret, StaticSecret}; // HASH & MAC use blake2::Blake2s; @@ -11,15 +11,13 @@ use hmac::Hmac; use aead::{Aead, NewAead, Payload}; use chacha20poly1305::ChaCha20Poly1305; -use log; - -use rand::prelude::{CryptoRng, RngCore}; +use rand_core::{CryptoRng, RngCore}; use generic_array::typenum::*; use generic_array::*; use clear_on_drop::clear::Clear; -use clear_on_drop::clear_stack_on_return; +use clear_on_drop::clear_stack_on_return_fnonce; use subtle::ConstantTimeEq; @@ -65,20 +63,20 @@ macro_rules! HASH { use blake2::Digest; let mut hsh = Blake2s::new(); $( - hsh.input($input); + hsh.update($input); )* - hsh.result() + hsh.finalize() }}; } macro_rules! HMAC { ($key:expr, $($input:expr),*) => {{ - use hmac::Mac; + use hmac::{Mac, NewMac}; let mut mac = HMACBlake2s::new_varkey($key).unwrap(); $( - mac.input($input); + mac.update($input); )* - mac.result().code() + mac.finalize().into_bytes() }}; } @@ -114,7 +112,7 @@ macro_rules! KDF3 { macro_rules! SEAL { ($key:expr, $ad:expr, $pt:expr, $ct:expr) => { - ChaCha20Poly1305::new(*GenericArray::from_slice($key)) + ChaCha20Poly1305::new(GenericArray::from_slice($key)) .encrypt(&ZERO_NONCE.into(), Payload { msg: $pt, aad: $ad }) .map(|ct| $ct.copy_from_slice(&ct)) .unwrap() @@ -123,7 +121,7 @@ macro_rules! SEAL { macro_rules! OPEN { ($key:expr, $ad:expr, $pt:expr, $ct:expr) => { - ChaCha20Poly1305::new(*GenericArray::from_slice($key)) + ChaCha20Poly1305::new(GenericArray::from_slice($key)) .decrypt(&ZERO_NONCE.into(), Payload { msg: $ct, aad: $ad }) .map_err(|_| HandshakeError::DecryptionFailure) .map(|pt| $pt.copy_from_slice(&pt)) @@ -215,7 +213,7 @@ mod tests { } // Computes an X25519 shared secret. -// +// // This function wraps dalek to add a zero-check. // This is not recommended by the Noise specification, // but implemented in the kernel with which we strive for absolute equivalent behavior. @@ -244,7 +242,7 @@ pub(super) fn create_initiation<R: RngCore + CryptoRng, O>( return Err(HandshakeError::InvalidSharedSecret); } - clear_stack_on_return(CLEAR_PAGES, || { + clear_stack_on_return_fnonce(CLEAR_PAGES, || { // initialize state let ck = INITIAL_CK; @@ -290,7 +288,6 @@ pub(super) fn create_initiation<R: RngCore + CryptoRng, O>( // (C, k) := Kdf2(C, DH(S_priv, S_pub)) - let (ck, key) = KDF2!(&ck, &peer.ss); // msg.timestamp := Aead(k, 0, Timestamp(), H) @@ -326,7 +323,7 @@ pub(super) fn consume_initiation<'a, O>( ) -> Result<(&'a Peer<O>, PublicKey, TemporaryState), HandshakeError> { log::debug!("consume initiation"); - clear_stack_on_return(CLEAR_PAGES, || { + clear_stack_on_return_fnonce(CLEAR_PAGES, || { // initialize new state let ck = INITIAL_CK; @@ -360,7 +357,7 @@ pub(super) fn consume_initiation<'a, O>( let peer = device.lookup_pk(&PublicKey::from(pk))?; // check for zero shared-secret (see "shared_secret" note). - + if peer.ss.ct_eq(&[0u8; 32]).into() { return Err(HandshakeError::InvalidSharedSecret); } @@ -415,7 +412,7 @@ pub(super) fn create_response<R: RngCore + CryptoRng, O>( msg: &mut NoiseResponse, // resulting response ) -> Result<KeyPair, HandshakeError> { log::debug!("create response"); - clear_stack_on_return(CLEAR_PAGES, || { + clear_stack_on_return_fnonce(CLEAR_PAGES, || { // unpack state let (receiver, eph_r_pk, hs, ck) = state; @@ -500,7 +497,7 @@ pub(super) fn consume_response<'a, O>( msg: &NoiseResponse, ) -> Result<Output<'a, O>, HandshakeError> { log::debug!("consume response"); - clear_stack_on_return(CLEAR_PAGES, || { + clear_stack_on_return_fnonce(CLEAR_PAGES, || { // retrieve peer and copy initiation state let (peer, _) = device.lookup_id(msg.f_receiver.get())?; diff --git a/src/wireguard/handshake/peer.rs b/src/wireguard/handshake/peer.rs index 1636e62..f847725 100644 --- a/src/wireguard/handshake/peer.rs +++ b/src/wireguard/handshake/peer.rs @@ -50,13 +50,10 @@ pub enum State { impl Drop for State { fn drop(&mut self) { - match self { - State::InitiationSent { hs, ck, .. } => { - // eph_sk already cleared by dalek-x25519 - hs.clear(); - ck.clear(); - } - _ => (), + if let State::InitiationSent { hs, ck, .. } = self { + // eph_sk already cleared by dalek-x25519 + hs.clear(); + ck.clear(); } } } @@ -97,29 +94,22 @@ impl<O> Peer<O> { let mut last_initiation_consumption = self.last_initiation_consumption.lock(); // check replay attack - match *timestamp { - Some(timestamp_old) => { - if !timestamp::compare(×tamp_old, ×tamp_new) { - return Err(HandshakeError::OldTimestamp); - } + if let Some(timestamp_old) = *timestamp { + if !timestamp::compare(×tamp_old, ×tamp_new) { + return Err(HandshakeError::OldTimestamp); } - _ => (), }; // check flood attack - match *last_initiation_consumption { - Some(last) => { - if last.elapsed() < TIME_BETWEEN_INITIATIONS { - return Err(HandshakeError::InitiationFlood); - } + if let Some(last) = *last_initiation_consumption { + if last.elapsed() < TIME_BETWEEN_INITIATIONS { + return Err(HandshakeError::InitiationFlood); } - _ => (), } // reset state - match *state { - State::InitiationSent { local, .. } => device.release(local), - _ => (), + if let State::InitiationSent { local, .. } = *state { + device.release(local) } // update replay & flood protection diff --git a/src/wireguard/handshake/ratelimiter.rs b/src/wireguard/handshake/ratelimiter.rs index 89109e9..9e796a0 100644 --- a/src/wireguard/handshake/ratelimiter.rs +++ b/src/wireguard/handshake/ratelimiter.rs @@ -5,8 +5,6 @@ use std::sync::{Arc, Condvar, Mutex}; use std::thread; use std::time::{Duration, Instant}; -use spin; - const PACKETS_PER_SECOND: u64 = 20; const PACKETS_BURSTABLE: u64 = 5; const PACKET_COST: u64 = 1_000_000_000 / PACKETS_PER_SECOND; @@ -39,6 +37,7 @@ impl Drop for RateLimiter { impl RateLimiter { pub fn new() -> Self { + #[allow(clippy::mutex_atomic)] RateLimiter(Arc::new(RateLimiterInner { gc_dropped: (Mutex::new(false), Condvar::new()), gc_running: AtomicBool::from(false), @@ -145,7 +144,7 @@ mod tests { expected.push(Result { allowed: true, wait: Duration::new(0, 0), - text: "inital burst", + text: "initial burst", }); } diff --git a/src/wireguard/handshake/tests.rs b/src/wireguard/handshake/tests.rs index 5174d2e..35ff152 100644 --- a/src/wireguard/handshake/tests.rs +++ b/src/wireguard/handshake/tests.rs @@ -6,8 +6,8 @@ use std::time::Duration; use hex; -use rand::prelude::{CryptoRng, RngCore}; use rand::rngs::OsRng; +use rand_core::{CryptoRng, RngCore}; use x25519_dalek::PublicKey; use x25519_dalek::StaticSecret; @@ -15,20 +15,22 @@ use x25519_dalek::StaticSecret; use super::messages::{Initiation, Response}; fn setup_devices<R: RngCore + CryptoRng, O: Default>( - rng: &mut R, + rng1: &mut R, + rng2: &mut R, + rng3: &mut R, ) -> (PublicKey, Device<O>, PublicKey, Device<O>) { // generate new key pairs - let sk1 = StaticSecret::new(rng); + let sk1 = StaticSecret::new(rng1); let pk1 = PublicKey::from(&sk1); - let sk2 = StaticSecret::new(rng); + let sk2 = StaticSecret::new(rng2); let pk2 = PublicKey::from(&sk2); // pick random psk let mut psk = [0u8; 32]; - rng.fill_bytes(&mut psk[..]); + rng3.fill_bytes(&mut psk[..]); // initialize devices on both ends @@ -63,7 +65,8 @@ fn wait() { */ #[test] fn handshake_under_load() { - let (_pk1, dev1, pk2, dev2): (_, Device<usize>, _, _) = setup_devices(&mut OsRng); + let (_pk1, dev1, pk2, dev2): (_, Device<usize>, _, _) = + setup_devices(&mut OsRng, &mut OsRng, &mut OsRng); let src1: SocketAddr = "172.16.0.1:8080".parse().unwrap(); let src2: SocketAddr = "172.16.0.2:7070".parse().unwrap(); @@ -140,7 +143,8 @@ fn handshake_under_load() { #[test] fn handshake_no_load() { - let (pk1, mut dev1, pk2, mut dev2): (_, Device<usize>, _, _) = setup_devices(&mut OsRng); + let (pk1, mut dev1, pk2, mut dev2): (_, Device<usize>, _, _) = + setup_devices(&mut OsRng, &mut OsRng, &mut OsRng); // do a few handshakes (every handshake should succeed) diff --git a/src/wireguard/handshake/timestamp.rs b/src/wireguard/handshake/timestamp.rs index b5bd9f0..485bb8d 100644 --- a/src/wireguard/handshake/timestamp.rs +++ b/src/wireguard/handshake/timestamp.rs @@ -28,5 +28,5 @@ pub fn compare(old: &TAI64N, new: &TAI64N) -> bool { return true; } } - return false; + false } |