diff options
Diffstat (limited to 'src/wireguard')
-rw-r--r-- | src/wireguard/handshake/device.rs | 64 | ||||
-rw-r--r-- | src/wireguard/handshake/macs.rs | 19 | ||||
-rw-r--r-- | src/wireguard/handshake/noise.rs | 35 | ||||
-rw-r--r-- | src/wireguard/handshake/peer.rs | 34 | ||||
-rw-r--r-- | src/wireguard/handshake/ratelimiter.rs | 5 | ||||
-rw-r--r-- | src/wireguard/handshake/tests.rs | 18 | ||||
-rw-r--r-- | src/wireguard/handshake/timestamp.rs | 2 | ||||
-rw-r--r-- | src/wireguard/mod.rs | 21 | ||||
-rw-r--r-- | src/wireguard/peer.rs | 8 | ||||
-rw-r--r-- | src/wireguard/queue.rs | 5 | ||||
-rw-r--r-- | src/wireguard/router/device.rs | 34 | ||||
-rw-r--r-- | src/wireguard/router/peer.rs | 96 | ||||
-rw-r--r-- | src/wireguard/router/queue.rs | 4 | ||||
-rw-r--r-- | src/wireguard/router/route.rs | 4 | ||||
-rw-r--r-- | src/wireguard/router/tests/bench.rs | 15 | ||||
-rw-r--r-- | src/wireguard/router/tests/tests.rs | 1 | ||||
-rw-r--r-- | src/wireguard/router/types.rs | 16 | ||||
-rw-r--r-- | src/wireguard/router/worker.rs | 1 | ||||
-rw-r--r-- | src/wireguard/timers.rs | 5 | ||||
-rw-r--r-- | src/wireguard/wireguard.rs | 15 | ||||
-rw-r--r-- | src/wireguard/workers.rs | 4 |
21 files changed, 204 insertions, 202 deletions
diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs index 3a3d023..47ca401 100644 --- a/src/wireguard/handshake/device.rs +++ b/src/wireguard/handshake/device.rs @@ -1,14 +1,15 @@ -use spin::RwLock; use std::collections::hash_map; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Mutex; -use zerocopy::AsBytes; use byteorder::{ByteOrder, LittleEndian}; +use dashmap::mapref::entry::Entry; +use dashmap::DashMap; +use zerocopy::AsBytes; -use rand::prelude::{CryptoRng, RngCore}; use rand::Rng; +use rand_core::{CryptoRng, RngCore}; use clear_on_drop::clear::Clear; @@ -36,7 +37,7 @@ pub struct KeyState { /// (the instance is a Peer object in the parent module) pub struct Device<O> { keyst: Option<KeyState>, - id_map: RwLock<HashMap<u32, [u8; 32]>>, + id_map: DashMap<u32, [u8; 32]>, // concurrent map pk_map: HashMap<[u8; 32], Peer<O>>, limiter: Mutex<RateLimiter>, } @@ -62,7 +63,7 @@ impl<'a, O> Iterator for Iter<'a, O> { */ impl<O> Device<O> { pub fn clear(&mut self) { - self.id_map.write().clear(); + self.id_map.clear(); self.pk_map.clear(); } @@ -96,7 +97,7 @@ impl<O> Device<O> { pub fn new() -> Device<O> { Device { keyst: None, - id_map: RwLock::new(HashMap::new()), + id_map: DashMap::new(), pk_map: HashMap::new(), limiter: Mutex::new(RateLimiter::new()), } @@ -117,7 +118,9 @@ impl<O> Device<O> { } else { peer.ss.clear(); } - peer.reset_state().map(|id| ids.push(id)); + if let Some(id) = peer.reset_state() { + ids.push(id) + } } (ids, same) @@ -208,16 +211,14 @@ impl<O> Device<O> { /// /// The call might fail if the public key is not found pub fn remove(&mut self, pk: &PublicKey) -> Result<(), ConfigError> { - // take write-lock on receive id table - let mut id_map = self.id_map.write(); - // remove the peer self.pk_map .remove(pk.as_bytes()) - .ok_or(ConfigError::new("Public key not in device"))?; + .ok_or_else(|| ConfigError::new("Public key not in device"))?; - // purge the id map (linear scan) - id_map.retain(|_, v| v != pk.as_bytes()); + // remove every id entry for the peer in the public key map + // O(n) operations, however it is rare: only when removing peers. + self.id_map.retain(|_, v| v != pk.as_bytes()); Ok(()) } @@ -265,9 +266,8 @@ impl<O> Device<O> { /// /// * `id` - The (sender) id to release pub fn release(&self, id: u32) { - let mut m = self.id_map.write(); - debug_assert!(m.contains_key(&id), "Releasing id not allocated"); - m.remove(&id); + let old = self.id_map.remove(&id); + assert!(old.is_some(), "released id not allocated"); } /// Begin a new handshake @@ -391,9 +391,6 @@ impl<O> Device<O> { // address validation & DoS mitigation if let Some(src) = src { - // obtain ref to socket addr - let src = src.into(); - // check mac2 field if !keyst.macs.check_mac2(msg.noise.as_bytes(), &src, &msg.macs) { let mut reply = Default::default(); @@ -446,32 +443,37 @@ impl<O> Device<O> { // // Return the peer currently associated with the receiver identifier pub(super) fn lookup_id(&self, id: u32) -> Result<(&Peer<O>, PublicKey), HandshakeError> { - let im = self.id_map.read(); - let pk = im.get(&id).ok_or(HandshakeError::UnknownReceiverId)?; - match self.pk_map.get(pk) { + // obtain a read reference to entry in the id_map + let pk = self + .id_map + .get(&id) + .ok_or(HandshakeError::UnknownReceiverId)?; + + // lookup the public key from the pk map + match self.pk_map.get(&*pk) { Some(peer) => Ok((peer, PublicKey::from(*pk))), - _ => unreachable!(), // if the id-lookup succeeded, the peer should exist + _ => unreachable!(), } } // Internal function // - // Allocated a new receiver identifier for the peer + // Allocated a new receiver identifier for the peer. + // Implemented via rejection sampling. fn allocate<R: RngCore + CryptoRng>(&self, rng: &mut R, pk: &PublicKey) -> u32 { loop { let id = rng.gen(); - // check membership with read lock - if self.id_map.read().contains_key(&id) { + // read lock the shard and do quick check + if self.id_map.contains_key(&id) { continue; } - // take write lock and add index - let mut m = self.id_map.write(); - if !m.contains_key(&id) { - m.insert(id, *pk.as_bytes()); + // write lock the shard and insert + if let Entry::Vacant(entry) = self.id_map.entry(id) { + entry.insert(*pk.as_bytes()); return id; - } + }; } } } diff --git a/src/wireguard/handshake/macs.rs b/src/wireguard/handshake/macs.rs index cb5d7d4..f4f5586 100644 --- a/src/wireguard/handshake/macs.rs +++ b/src/wireguard/handshake/macs.rs @@ -1,5 +1,5 @@ use generic_array::GenericArray; -use rand::{CryptoRng, RngCore}; +use rand_core::{CryptoRng, RngCore}; use spin::RwLock; use std::time::{Duration, Instant}; @@ -8,6 +8,7 @@ use std::net::SocketAddr; use x25519_dalek::PublicKey; // AEAD + use aead::{Aead, NewAead, Payload}; use chacha20poly1305::XChaCha20Poly1305; @@ -33,30 +34,29 @@ macro_rules! HASH { use blake2::Digest; let mut hsh = Blake2s::new(); $( - hsh.input($input); + hsh.update($input); )* - hsh.result() + hsh.finalize() }}; } macro_rules! MAC { ( $key:expr, $($input:expr),* ) => {{ use blake2::VarBlake2s; - use digest::Input; - use digest::VariableOutput; + use blake2::digest::{Update, VariableOutput}; let mut tag = [0u8; SIZE_MAC]; let mut mac = VarBlake2s::new_keyed($key, SIZE_MAC); $( - mac.input($input); + mac.update($input); )* - mac.variable_result(|buf| tag.copy_from_slice(buf)); + mac.finalize_variable(|buf| tag.copy_from_slice(buf)); tag }}; } macro_rules! XSEAL { ($key:expr, $nonce:expr, $ad:expr, $pt:expr, $ct:expr) => {{ - let ct = XChaCha20Poly1305::new(*GenericArray::from_slice($key)) + let ct = XChaCha20Poly1305::new(GenericArray::from_slice($key)) .encrypt( GenericArray::from_slice($nonce), Payload { msg: $pt, aad: $ad }, @@ -70,7 +70,7 @@ macro_rules! XSEAL { macro_rules! XOPEN { ($key:expr, $nonce:expr, $ad:expr, $pt:expr, $ct:expr) => {{ debug_assert_eq!($ct.len(), $pt.len() + SIZE_TAG); - XChaCha20Poly1305::new(*GenericArray::from_slice($key)) + XChaCha20Poly1305::new(GenericArray::from_slice($key)) .decrypt( GenericArray::from_slice($nonce), Payload { msg: $ct, aad: $ad }, @@ -141,6 +141,7 @@ impl Generator { pub fn process(&mut self, reply: &CookieReply) -> Result<(), HandshakeError> { let mac1 = self.last_mac1.ok_or(HandshakeError::InvalidState)?; let mut tau = [0u8; SIZE_COOKIE]; + #[allow(clippy::unnecessary_mut_passed)] XOPEN!( &self.cookie_key, // key &reply.f_nonce, // nonce diff --git a/src/wireguard/handshake/noise.rs b/src/wireguard/handshake/noise.rs index beb99c2..92c8c5f 100644 --- a/src/wireguard/handshake/noise.rs +++ b/src/wireguard/handshake/noise.rs @@ -1,7 +1,7 @@ use std::time::Instant; // DH -use x25519_dalek::{PublicKey, StaticSecret, SharedSecret}; +use x25519_dalek::{PublicKey, SharedSecret, StaticSecret}; // HASH & MAC use blake2::Blake2s; @@ -11,15 +11,13 @@ use hmac::Hmac; use aead::{Aead, NewAead, Payload}; use chacha20poly1305::ChaCha20Poly1305; -use log; - -use rand::prelude::{CryptoRng, RngCore}; +use rand_core::{CryptoRng, RngCore}; use generic_array::typenum::*; use generic_array::*; use clear_on_drop::clear::Clear; -use clear_on_drop::clear_stack_on_return; +use clear_on_drop::clear_stack_on_return_fnonce; use subtle::ConstantTimeEq; @@ -65,20 +63,20 @@ macro_rules! HASH { use blake2::Digest; let mut hsh = Blake2s::new(); $( - hsh.input($input); + hsh.update($input); )* - hsh.result() + hsh.finalize() }}; } macro_rules! HMAC { ($key:expr, $($input:expr),*) => {{ - use hmac::Mac; + use hmac::{Mac, NewMac}; let mut mac = HMACBlake2s::new_varkey($key).unwrap(); $( - mac.input($input); + mac.update($input); )* - mac.result().code() + mac.finalize().into_bytes() }}; } @@ -114,7 +112,7 @@ macro_rules! KDF3 { macro_rules! SEAL { ($key:expr, $ad:expr, $pt:expr, $ct:expr) => { - ChaCha20Poly1305::new(*GenericArray::from_slice($key)) + ChaCha20Poly1305::new(GenericArray::from_slice($key)) .encrypt(&ZERO_NONCE.into(), Payload { msg: $pt, aad: $ad }) .map(|ct| $ct.copy_from_slice(&ct)) .unwrap() @@ -123,7 +121,7 @@ macro_rules! SEAL { macro_rules! OPEN { ($key:expr, $ad:expr, $pt:expr, $ct:expr) => { - ChaCha20Poly1305::new(*GenericArray::from_slice($key)) + ChaCha20Poly1305::new(GenericArray::from_slice($key)) .decrypt(&ZERO_NONCE.into(), Payload { msg: $ct, aad: $ad }) .map_err(|_| HandshakeError::DecryptionFailure) .map(|pt| $pt.copy_from_slice(&pt)) @@ -215,7 +213,7 @@ mod tests { } // Computes an X25519 shared secret. -// +// // This function wraps dalek to add a zero-check. // This is not recommended by the Noise specification, // but implemented in the kernel with which we strive for absolute equivalent behavior. @@ -244,7 +242,7 @@ pub(super) fn create_initiation<R: RngCore + CryptoRng, O>( return Err(HandshakeError::InvalidSharedSecret); } - clear_stack_on_return(CLEAR_PAGES, || { + clear_stack_on_return_fnonce(CLEAR_PAGES, || { // initialize state let ck = INITIAL_CK; @@ -290,7 +288,6 @@ pub(super) fn create_initiation<R: RngCore + CryptoRng, O>( // (C, k) := Kdf2(C, DH(S_priv, S_pub)) - let (ck, key) = KDF2!(&ck, &peer.ss); // msg.timestamp := Aead(k, 0, Timestamp(), H) @@ -326,7 +323,7 @@ pub(super) fn consume_initiation<'a, O>( ) -> Result<(&'a Peer<O>, PublicKey, TemporaryState), HandshakeError> { log::debug!("consume initiation"); - clear_stack_on_return(CLEAR_PAGES, || { + clear_stack_on_return_fnonce(CLEAR_PAGES, || { // initialize new state let ck = INITIAL_CK; @@ -360,7 +357,7 @@ pub(super) fn consume_initiation<'a, O>( let peer = device.lookup_pk(&PublicKey::from(pk))?; // check for zero shared-secret (see "shared_secret" note). - + if peer.ss.ct_eq(&[0u8; 32]).into() { return Err(HandshakeError::InvalidSharedSecret); } @@ -415,7 +412,7 @@ pub(super) fn create_response<R: RngCore + CryptoRng, O>( msg: &mut NoiseResponse, // resulting response ) -> Result<KeyPair, HandshakeError> { log::debug!("create response"); - clear_stack_on_return(CLEAR_PAGES, || { + clear_stack_on_return_fnonce(CLEAR_PAGES, || { // unpack state let (receiver, eph_r_pk, hs, ck) = state; @@ -500,7 +497,7 @@ pub(super) fn consume_response<'a, O>( msg: &NoiseResponse, ) -> Result<Output<'a, O>, HandshakeError> { log::debug!("consume response"); - clear_stack_on_return(CLEAR_PAGES, || { + clear_stack_on_return_fnonce(CLEAR_PAGES, || { // retrieve peer and copy initiation state let (peer, _) = device.lookup_id(msg.f_receiver.get())?; diff --git a/src/wireguard/handshake/peer.rs b/src/wireguard/handshake/peer.rs index 1636e62..f847725 100644 --- a/src/wireguard/handshake/peer.rs +++ b/src/wireguard/handshake/peer.rs @@ -50,13 +50,10 @@ pub enum State { impl Drop for State { fn drop(&mut self) { - match self { - State::InitiationSent { hs, ck, .. } => { - // eph_sk already cleared by dalek-x25519 - hs.clear(); - ck.clear(); - } - _ => (), + if let State::InitiationSent { hs, ck, .. } = self { + // eph_sk already cleared by dalek-x25519 + hs.clear(); + ck.clear(); } } } @@ -97,29 +94,22 @@ impl<O> Peer<O> { let mut last_initiation_consumption = self.last_initiation_consumption.lock(); // check replay attack - match *timestamp { - Some(timestamp_old) => { - if !timestamp::compare(×tamp_old, ×tamp_new) { - return Err(HandshakeError::OldTimestamp); - } + if let Some(timestamp_old) = *timestamp { + if !timestamp::compare(×tamp_old, ×tamp_new) { + return Err(HandshakeError::OldTimestamp); } - _ => (), }; // check flood attack - match *last_initiation_consumption { - Some(last) => { - if last.elapsed() < TIME_BETWEEN_INITIATIONS { - return Err(HandshakeError::InitiationFlood); - } + if let Some(last) = *last_initiation_consumption { + if last.elapsed() < TIME_BETWEEN_INITIATIONS { + return Err(HandshakeError::InitiationFlood); } - _ => (), } // reset state - match *state { - State::InitiationSent { local, .. } => device.release(local), - _ => (), + if let State::InitiationSent { local, .. } = *state { + device.release(local) } // update replay & flood protection diff --git a/src/wireguard/handshake/ratelimiter.rs b/src/wireguard/handshake/ratelimiter.rs index 89109e9..9e796a0 100644 --- a/src/wireguard/handshake/ratelimiter.rs +++ b/src/wireguard/handshake/ratelimiter.rs @@ -5,8 +5,6 @@ use std::sync::{Arc, Condvar, Mutex}; use std::thread; use std::time::{Duration, Instant}; -use spin; - const PACKETS_PER_SECOND: u64 = 20; const PACKETS_BURSTABLE: u64 = 5; const PACKET_COST: u64 = 1_000_000_000 / PACKETS_PER_SECOND; @@ -39,6 +37,7 @@ impl Drop for RateLimiter { impl RateLimiter { pub fn new() -> Self { + #[allow(clippy::mutex_atomic)] RateLimiter(Arc::new(RateLimiterInner { gc_dropped: (Mutex::new(false), Condvar::new()), gc_running: AtomicBool::from(false), @@ -145,7 +144,7 @@ mod tests { expected.push(Result { allowed: true, wait: Duration::new(0, 0), - text: "inital burst", + text: "initial burst", }); } diff --git a/src/wireguard/handshake/tests.rs b/src/wireguard/handshake/tests.rs index 5174d2e..35ff152 100644 --- a/src/wireguard/handshake/tests.rs +++ b/src/wireguard/handshake/tests.rs @@ -6,8 +6,8 @@ use std::time::Duration; use hex; -use rand::prelude::{CryptoRng, RngCore}; use rand::rngs::OsRng; +use rand_core::{CryptoRng, RngCore}; use x25519_dalek::PublicKey; use x25519_dalek::StaticSecret; @@ -15,20 +15,22 @@ use x25519_dalek::StaticSecret; use super::messages::{Initiation, Response}; fn setup_devices<R: RngCore + CryptoRng, O: Default>( - rng: &mut R, + rng1: &mut R, + rng2: &mut R, + rng3: &mut R, ) -> (PublicKey, Device<O>, PublicKey, Device<O>) { // generate new key pairs - let sk1 = StaticSecret::new(rng); + let sk1 = StaticSecret::new(rng1); let pk1 = PublicKey::from(&sk1); - let sk2 = StaticSecret::new(rng); + let sk2 = StaticSecret::new(rng2); let pk2 = PublicKey::from(&sk2); // pick random psk let mut psk = [0u8; 32]; - rng.fill_bytes(&mut psk[..]); + rng3.fill_bytes(&mut psk[..]); // initialize devices on both ends @@ -63,7 +65,8 @@ fn wait() { */ #[test] fn handshake_under_load() { - let (_pk1, dev1, pk2, dev2): (_, Device<usize>, _, _) = setup_devices(&mut OsRng); + let (_pk1, dev1, pk2, dev2): (_, Device<usize>, _, _) = + setup_devices(&mut OsRng, &mut OsRng, &mut OsRng); let src1: SocketAddr = "172.16.0.1:8080".parse().unwrap(); let src2: SocketAddr = "172.16.0.2:7070".parse().unwrap(); @@ -140,7 +143,8 @@ fn handshake_under_load() { #[test] fn handshake_no_load() { - let (pk1, mut dev1, pk2, mut dev2): (_, Device<usize>, _, _) = setup_devices(&mut OsRng); + let (pk1, mut dev1, pk2, mut dev2): (_, Device<usize>, _, _) = + setup_devices(&mut OsRng, &mut OsRng, &mut OsRng); // do a few handshakes (every handshake should succeed) diff --git a/src/wireguard/handshake/timestamp.rs b/src/wireguard/handshake/timestamp.rs index b5bd9f0..485bb8d 100644 --- a/src/wireguard/handshake/timestamp.rs +++ b/src/wireguard/handshake/timestamp.rs @@ -28,5 +28,5 @@ pub fn compare(old: &TAI64N, new: &TAI64N) -> bool { return true; } } - return false; + false } diff --git a/src/wireguard/mod.rs b/src/wireguard/mod.rs index ca17737..e79a250 100644 --- a/src/wireguard/mod.rs +++ b/src/wireguard/mod.rs @@ -1,12 +1,11 @@ -/* The wireguard sub-module represents a full, pure, WireGuard implementation: - * - * The WireGuard device described here does not depend on particular IO implementations - * or UAPI, and can be instantiated in unit-tests with the dummy IO implementation. - * - * The code at this level serves to "glue" the handshake state-machine - * and the crypto-key router code together, - * e.g. every WireGuard peer consists of a handshake and router peer. - */ +/// The wireguard sub-module represents a full, pure, WireGuard implementation: +/// +/// The WireGuard device described here does not depend on particular IO implementations +/// or UAPI, and can be instantiated in unit-tests with the dummy IO implementation. +/// +/// The code at this level serves to "glue" the handshake state-machine +/// and the crypto-key router code together, +/// e.g. every WireGuard peer consists of one handshake peer and one router peer. mod constants; mod handshake; mod peer; @@ -14,12 +13,14 @@ mod queue; mod router; mod timers; mod types; -mod wireguard; mod workers; #[cfg(test)] mod tests; +#[allow(clippy::module_inception)] +mod wireguard; + // represents a WireGuard interface pub use wireguard::WireGuard; diff --git a/src/wireguard/peer.rs b/src/wireguard/peer.rs index 27d39bd..170d2b1 100644 --- a/src/wireguard/peer.rs +++ b/src/wireguard/peer.rs @@ -22,13 +22,15 @@ pub struct PeerInner<T: Tun, B: UDP> { // wireguard device state pub wg: WireGuard<T, B>, + // TODO: eliminate + pub pk: PublicKey, + // handshake state - pub walltime_last_handshake: Mutex<Option<SystemTime>>, // walltime for last handshake (for UAPI status) + pub walltime_last_handshake: Mutex<Option<SystemTime>>, /* walltime for last handshake (for UAPI status) */ pub last_handshake_sent: Mutex<Instant>, // instant for last handshake - pub handshake_queued: AtomicBool, // is a handshake job currently queued for the peer? + pub handshake_queued: AtomicBool, // is a handshake job currently queued? // stats and configuration - pub pk: PublicKey, // public key (TODO: there has to be a way to remove this) pub rx_bytes: AtomicU64, // received bytes pub tx_bytes: AtomicU64, // transmitted bytes diff --git a/src/wireguard/queue.rs b/src/wireguard/queue.rs index 75b9104..f9e4150 100644 --- a/src/wireguard/queue.rs +++ b/src/wireguard/queue.rs @@ -12,7 +12,6 @@ impl<T> ParallelQueue<T> { /// /// - `queues`: number of readers /// - `capacity`: capacity of each internal queue - /// pub fn new(queues: usize, capacity: usize) -> (Self, Vec<Receiver<T>>) { let mut receivers = Vec::with_capacity(queues); let (tx, rx) = bounded(capacity); @@ -28,9 +27,9 @@ impl<T> ParallelQueue<T> { } pub fn send(&self, v: T) { - self.queue.lock().unwrap().as_ref().map(|s| { + if let Some(s) = self.queue.lock().unwrap().as_ref() { let _ = s.send(v); - }); + } } pub fn close(&self) { diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs index 7c90f22..eeae621 100644 --- a/src/wireguard/router/device.rs +++ b/src/wireguard/router/device.rs @@ -3,9 +3,7 @@ use std::ops::Deref; use std::sync::atomic::AtomicBool; use std::sync::Arc; use std::thread; -use std::time::Instant; -use log; use spin::{Mutex, RwLock}; use zerocopy::LayoutVerified; @@ -26,31 +24,30 @@ use super::ParallelQueue; pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { // inbound writer (TUN) - pub inbound: T, + pub(super) inbound: T, // outbound writer (Bind) - pub outbound: RwLock<(bool, Option<B>)>, + pub(super) outbound: RwLock<(bool, Option<B>)>, // routing - pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state - pub table: RoutingTable<Peer<E, C, T, B>>, + #[allow(clippy::type_complexity)] + pub(super) recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, /* receiver id -> decryption state */ + pub(super) table: RoutingTable<Peer<E, C, T, B>>, // work queue - pub work: ParallelQueue<JobUnion<E, C, T, B>>, + pub(super) work: ParallelQueue<JobUnion<E, C, T, B>>, } pub struct EncryptionState { - pub keypair: Arc<KeyPair>, // keypair - pub nonce: u64, // next available nonce - pub death: Instant, // (birth + reject-after-time - keepalive-timeout - rekey-timeout) + pub(super) keypair: Arc<KeyPair>, // keypair + pub(super) nonce: u64, // next available nonce } pub struct DecryptionState<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { - pub keypair: Arc<KeyPair>, - pub confirmed: AtomicBool, - pub protector: Mutex<AntiReplay>, - pub peer: Peer<E, C, T, B>, - pub death: Instant, // time when the key can no longer be used for decryption + pub(super) keypair: Arc<KeyPair>, + pub(super) confirmed: AtomicBool, + pub(super) protector: Mutex<AntiReplay>, + pub(super) peer: Peer<E, C, T, B>, } pub struct Device<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { @@ -144,7 +141,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle< return bind.write(msg, dst); } } - return Ok(()); + Ok(()) } /// Brings the router down. @@ -181,7 +178,6 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle< /// # Arguments /// /// - msg: IP packet to crypt-key route - /// pub fn send(&self, msg: Vec<u8>) -> Result<(), RouterError> { debug_assert!(msg.len() > SIZE_MESSAGE_PREFIX); log::trace!( @@ -212,8 +208,6 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle< /// - msg: Encrypted transport message /// /// # Returns - /// - /// pub fn recv(&self, src: E, msg: Vec<u8>) -> Result<(), RouterError> { log::trace!("receive, src: {}", src.into_address()); @@ -256,8 +250,6 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle< } /// Set outbound writer - /// - /// pub fn set_outbound_writer(&self, new: B) { self.state.outbound.write().1 = Some(new); } diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs index 8248a55..0803b13 100644 --- a/src/wireguard/router/peer.rs +++ b/src/wireguard/router/peer.rs @@ -26,7 +26,6 @@ use std::fmt; use std::net::{IpAddr, SocketAddr}; use arraydeque::{ArrayDeque, Wrapping}; -use log; use spin::Mutex; pub struct KeyWheel { @@ -37,16 +36,22 @@ pub struct KeyWheel { } pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { - pub device: Device<E, C, T, B>, - pub opaque: C::Opaque, - pub outbound: Queue<SendJob<E, C, T, B>>, - pub inbound: Queue<ReceiveJob<E, C, T, B>>, - pub staged_packets: Mutex<ArrayDeque<[Vec<u8>; MAX_QUEUED_PACKETS], Wrapping>>, - pub keys: Mutex<KeyWheel>, - pub enc_key: Mutex<Option<EncryptionState>>, - pub endpoint: Mutex<Option<E>>, + pub(super) device: Device<E, C, T, B>, + pub(super) opaque: C::Opaque, + pub(super) outbound: Queue<SendJob<E, C, T, B>>, + pub(super) inbound: Queue<ReceiveJob<E, C, T, B>>, + pub(super) staged_packets: Mutex<ArrayDeque<[Vec<u8>; MAX_QUEUED_PACKETS], Wrapping>>, + pub(super) keys: Mutex<KeyWheel>, + pub(super) enc_key: Mutex<Option<EncryptionState>>, + pub(super) endpoint: Mutex<Option<E>>, } +/// A Peer dereferences to its opaque type: +/// This allows the router code to take ownership of the opaque type +/// used for callback events, while still enabling the rest of the code to access the opaque type +/// (which might expose other functionality in their scope) from a Peer pointer. +/// +/// e.g. it can take ownership of the timer state of a peer. impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref for PeerInner<E, C, T, B> { type Target = C::Opaque; @@ -55,10 +60,20 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref for Pee } } +/// A Peer represents a reference to the router state associated with a peer pub struct Peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { inner: Arc<PeerInner<E, C, T, B>>, } +/// A PeerHandle is a specially designated reference to the peer +/// which removes the peer from the device when dropped. +/// +/// A PeerHandle cannot be cloned (unlike the wrapped type). +/// A PeerHandle dereferences to a Peer (meaning you can use it like a Peer struct) +pub struct PeerHandle<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { + peer: Peer<E, C, T, B>, +} + impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone for Peer<E, C, T, B> { fn clone(&self) -> Self { Peer { @@ -67,7 +82,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone for Pee } } -/* Equality of peers is defined as pointer equality +/* Equality of peers is defined as pointer equality of * the atomic reference counted pointer. */ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PartialEq for Peer<E, C, T, B> { @@ -89,25 +104,6 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref for Pee } } -/* A peer handle is a specially designated peer pointer - * which removes the peer from the device when dropped. - */ -pub struct PeerHandle<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { - peer: Peer<E, C, T, B>, -} - -/* -impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone - for PeerHandle<E, C, T, B> -{ - fn clone(&self) -> Self { - PeerHandle { - peer: self.peer.clone(), - } - } -} -*/ - impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref for PeerHandle<E, C, T, B> { @@ -130,7 +126,6 @@ impl EncryptionState { EncryptionState { nonce: 0, keypair: keypair.clone(), - death: keypair.birth + REJECT_AFTER_TIME, } } } @@ -141,7 +136,6 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DecryptionSta confirmed: AtomicBool::new(keypair.initiator), keypair: keypair.clone(), protector: spin::Mutex::new(AntiReplay::new()), - death: keypair.birth + REJECT_AFTER_TIME, peer, } } @@ -160,11 +154,17 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Peer let mut keys = peer.keys.lock(); let mut release = Vec::with_capacity(3); - keys.next.as_ref().map(|k| release.push(k.recv.id)); - keys.current.as_ref().map(|k| release.push(k.recv.id)); - keys.previous.as_ref().map(|k| release.push(k.recv.id)); + if let Some(k) = keys.next.as_ref() { + release.push(k.recv.id) + } + if let Some(k) = keys.current.as_ref() { + release.push(k.recv.id) + } + if let Some(k) = keys.previous.as_ref() { + release.push(k.recv.id) + } - if release.len() > 0 { + if !release.is_empty() { let mut recv = peer.device.recv.write(); for id in &release { recv.remove(id); @@ -190,7 +190,6 @@ pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( ) -> PeerHandle<E, C, T, B> { // allocate peer object let peer = { - let device = device.clone(); Peer { inner: Arc::new(PeerInner { opaque, @@ -250,7 +249,6 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, /// /// - `msg` : A padded vector holding the message (allows in-place construction of the transport header) /// - `stage`: Should the message be staged if no key is available - /// pub(super) fn send(&self, msg: Vec<u8>, stage: bool) { // check if key available let (job, need_key) = { @@ -390,9 +388,15 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, // update key-wheel - mem::replace(&mut keys.next, None).map(|k| release.push(k.local_id())); - mem::replace(&mut keys.current, None).map(|k| release.push(k.local_id())); - mem::replace(&mut keys.previous, None).map(|k| release.push(k.local_id())); + if let Some(k) = mem::replace(&mut keys.next, None) { + release.push(k.local_id()) + } + if let Some(k) = mem::replace(&mut keys.current, None) { + release.push(k.local_id()) + } + if let Some(k) = mem::replace(&mut keys.previous, None) { + release.push(k.local_id()) + } keys.retired.extend(&release[..]); // update inbound "recv" map @@ -444,11 +448,11 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, *self.peer.enc_key.lock() = Some(EncryptionState::new(&new)); // move current into previous - keys.previous = keys.current.as_ref().map(|v| v.clone()); + keys.previous = keys.current.as_ref().cloned(); keys.current = Some(new.clone()); } else { // store the key and await confirmation - keys.previous = keys.next.as_ref().map(|v| v.clone()); + keys.previous = keys.next.as_ref().cloned(); keys.next = Some(new.clone()); }; @@ -458,10 +462,10 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, let mut recv = self.peer.device.recv.write(); // purge recv map of previous id - keys.previous.as_ref().map(|k| { + if let Some(k) = &keys.previous { recv.remove(&k.local_id()); release.push(k.local_id()); - }); + } // map new id to decryption state debug_assert!(!recv.contains_key(&new.recv.id)); @@ -536,7 +540,9 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, } pub fn clear_src(&self) { - (*self.peer.endpoint.lock()).as_mut().map(|e| e.clear_src()); + if let Some(e) = (*self.peer.endpoint.lock()).as_mut() { + e.clear_src() + } } pub fn purge_staged_packets(&self) { diff --git a/src/wireguard/router/queue.rs b/src/wireguard/router/queue.rs index d5d657a..b266a57 100644 --- a/src/wireguard/router/queue.rs +++ b/src/wireguard/router/queue.rs @@ -67,9 +67,7 @@ impl<J: SequentialJob> Queue<J> { match queue.front() { None => break, Some(job) => { - if job.is_ready() { - () - } else { + if !job.is_ready() { break; } } diff --git a/src/wireguard/router/route.rs b/src/wireguard/router/route.rs index a556010..7e50153 100644 --- a/src/wireguard/router/route.rs +++ b/src/wireguard/router/route.rs @@ -88,7 +88,7 @@ impl<T: Eq + Clone> RoutingTable<T> { self.ipv4 .read() .longest_match(Ipv4Addr::from(header.f_destination)) - .and_then(|(_, _, p)| Some(p.clone())) + .map(|(_, _, p)| p.clone()) } VERSION_IP6 => { // check length and cast to IPv6 header @@ -104,7 +104,7 @@ impl<T: Eq + Clone> RoutingTable<T> { self.ipv6 .read() .longest_match(Ipv6Addr::from(header.f_destination)) - .and_then(|(_, _, p)| Some(p.clone())) + .map(|(_, _, p)| p.clone()) } v => { log::trace!("router, invalid IP version {}", v); diff --git a/src/wireguard/router/tests/bench.rs b/src/wireguard/router/tests/bench.rs index f025dc9..c2334b3 100644 --- a/src/wireguard/router/tests/bench.rs +++ b/src/wireguard/router/tests/bench.rs @@ -1,13 +1,21 @@ +#[cfg(feature = "unstable")] extern crate test; use super::*; -use std::net::IpAddr; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering; use std::sync::Arc; +// only used in benchmark +#[cfg(feature = "unstable")] +use std::net::IpAddr; + +// only used in benchmark +#[cfg(feature = "unstable")] use num_cpus; + +#[cfg(feature = "unstable")] use test::Bencher; // @@ -17,6 +25,7 @@ struct TransmissionCounter { } impl TransmissionCounter { + #[allow(dead_code)] fn new() -> TransmissionCounter { TransmissionCounter { sent: AtomicUsize::new(0), @@ -24,15 +33,18 @@ impl TransmissionCounter { } } + #[allow(dead_code)] fn reset(&self) { self.sent.store(0, Ordering::SeqCst); self.recv.store(0, Ordering::SeqCst); } + #[allow(dead_code)] fn sent(&self) -> usize { self.sent.load(Ordering::Acquire) } + #[allow(dead_code)] fn recv(&self) -> usize { self.recv.load(Ordering::Acquire) } @@ -78,6 +90,7 @@ fn profiler_start(name: &str) { } } +#[cfg(feature = "unstable")] #[bench] fn bench_router_outbound(b: &mut Bencher) { // 10 GB transmission per iteration diff --git a/src/wireguard/router/tests/tests.rs b/src/wireguard/router/tests/tests.rs index 6819644..f6205d5 100644 --- a/src/wireguard/router/tests/tests.rs +++ b/src/wireguard/router/tests/tests.rs @@ -11,6 +11,7 @@ use rand::Rng; use super::*; +#[cfg(feature = "unstable")] extern crate test; const SIZE_MSG: usize = 1024; diff --git a/src/wireguard/router/types.rs b/src/wireguard/router/types.rs index e0cd459..e44963f 100644 --- a/src/wireguard/router/types.rs +++ b/src/wireguard/router/types.rs @@ -15,16 +15,16 @@ impl<T> Opaque for T where T: Send + Sync + 'static {} /// * `0`, a reference to the opaque value assigned to the peer /// * `1`, a bool indicating whether the message contained data (not just keepalive) /// * `2`, a bool indicating whether the message was transmitted (i.e. did the peer have an associated endpoint?) -pub trait Callback<T>: Fn(&T, usize, bool) -> () + Sync + Send + 'static {} +pub trait Callback<T>: Fn(&T, usize, bool) + Sync + Send + 'static {} -impl<T, F> Callback<T> for F where F: Fn(&T, usize, bool) -> () + Sync + Send + 'static {} +impl<T, F> Callback<T> for F where F: Fn(&T, usize, bool) + Sync + Send + 'static {} /// A key callback takes 1 argument /// /// * `0`, a reference to the opaque value assigned to the peer -pub trait KeyCallback<T>: Fn(&T) -> () + Sync + Send + 'static {} +pub trait KeyCallback<T>: Fn(&T) + Sync + Send + 'static {} -impl<T, F> KeyCallback<T> for F where F: Fn(&T) -> () + Sync + Send + 'static {} +impl<T, F> KeyCallback<T> for F where F: Fn(&T) + Sync + Send + 'static {} pub trait Callbacks: Send + Sync + 'static { type Opaque: Opaque; @@ -58,11 +58,11 @@ impl fmt::Display for RouterError { } impl Error for RouterError { - fn description(&self) -> &str { - "Generic Handshake Error" - } - fn source(&self) -> Option<&(dyn Error + 'static)> { None } + + fn description(&self) -> &str { + "Generic Handshake Error" + } } diff --git a/src/wireguard/router/worker.rs b/src/wireguard/router/worker.rs index 4913a21..99c2a1d 100644 --- a/src/wireguard/router/worker.rs +++ b/src/wireguard/router/worker.rs @@ -6,7 +6,6 @@ use super::super::{tun, udp, Endpoint}; use super::types::Callbacks; use crossbeam_channel::Receiver; -use log; pub enum JobUnion<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { Outbound(SendJob<E, C, T, B>), diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs index 94a95ab..be0f5f9 100644 --- a/src/wireguard/timers.rs +++ b/src/wireguard/timers.rs @@ -268,7 +268,6 @@ impl Timers { handshake_attempts: AtomicUsize::new(0), retransmit_handshake: { let wg = wg.clone(); - let pk = pk.clone(); runner.timer(move || { // fetch peer by public key fetch_peer!(wg, pk, peer); @@ -300,7 +299,6 @@ impl Timers { }, send_keepalive: { let wg = wg.clone(); - let pk = pk.clone(); runner.timer(move || { // fetch peer by public key fetch_peer!(wg, pk, peer); @@ -315,7 +313,6 @@ impl Timers { }, new_handshake: { let wg = wg.clone(); - let pk = pk.clone(); runner.timer(move || { // fetch peer by public key fetch_peer!(wg, pk, peer); @@ -333,7 +330,6 @@ impl Timers { }, zero_key_material: { let wg = wg.clone(); - let pk = pk.clone(); runner.timer(move || { // fetch peer by public key fetch_peer!(wg, pk, peer); @@ -345,7 +341,6 @@ impl Timers { }, send_persistent_keepalive: { let wg = wg.clone(); - let pk = pk.clone(); runner.timer(move || { // fetch peer by public key fetch_peer!(wg, pk, peer); diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index 91526aa..44d698f 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -13,19 +13,20 @@ use super::udp::UDP; use super::workers::{handshake_worker, tun_worker, udp_worker}; use std::fmt; +use std::thread; + use std::ops::Deref; use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; use std::sync::Condvar; use std::sync::Mutex as StdMutex; -use std::thread; use std::time::Instant; -use hjul::Runner; use rand::rngs::OsRng; use rand::Rng; -use spin::{Mutex, RwLock}; +use hjul::Runner; +use spin::{Mutex, RwLock}; use x25519_dalek::{PublicKey, StaticSecret}; pub struct WireguardInner<T: Tun, B: UDP> { @@ -45,6 +46,7 @@ pub struct WireguardInner<T: Tun, B: UDP> { pub mtu: AtomicUsize, // peer map + #[allow(clippy::type_complexity)] pub peers: RwLock< handshake::Device<router::PeerHandle<B::Endpoint, PeerInner<T, B>, T::Writer, B::Writer>>, >, @@ -85,6 +87,7 @@ impl<T: Tun, B: UDP> Clone for WireGuard<T, B> { } } +#[allow(clippy::mutex_atomic)] impl WaitCounter { pub fn wait(&self) { let mut nread = self.0.lock().unwrap(); @@ -126,7 +129,7 @@ impl<T: Tun, B: UDP> WireGuard<T, B> { let mut enabled = self.enabled.write(); // check if already down - if *enabled == false { + if !(*enabled) { return; } @@ -206,10 +209,10 @@ impl<T: Tun, B: UDP> WireGuard<T, B> { } // prevent up/down while inserting - let enabled = *self.enabled.read(); + let enabled = self.enabled.read(); // create timers (lookup by public key) - let timers = Timers::new::<T, B>(self.clone(), pk.clone(), enabled); + let timers = Timers::new::<T, B>(self.clone(), pk, *enabled); // create new router peer let peer: router::PeerHandle<B::Endpoint, PeerInner<T, B>, T::Writer, B::Writer> = diff --git a/src/wireguard/workers.rs b/src/wireguard/workers.rs index b4673cd..27acf2f 100644 --- a/src/wireguard/workers.rs +++ b/src/wireguard/workers.rs @@ -231,7 +231,7 @@ pub fn handshake_worker<T: Tun, B: UDP>( } // add any new keypair to peer - keypair.map(|kp| { + if let Some(kp) = keypair { debug!("{} : handshake worker, new keypair for {}", wg, peer); // this means that a handshake response was processed or sent @@ -241,7 +241,7 @@ pub fn handshake_worker<T: Tun, B: UDP>( for id in peer.add_keypair(kp) { device.release(id); } - }); + }; } } Err(e) => debug!("{} : handshake worker, error = {:?}", wg, e), |