diff options
Diffstat (limited to 'src/wireguard')
-rw-r--r-- | src/wireguard/handshake/device.rs | 198 | ||||
-rw-r--r-- | src/wireguard/handshake/macs.rs | 6 | ||||
-rw-r--r-- | src/wireguard/handshake/noise.rs | 46 | ||||
-rw-r--r-- | src/wireguard/handshake/peer.rs | 26 | ||||
-rw-r--r-- | src/wireguard/handshake/tests.rs | 62 | ||||
-rw-r--r-- | src/wireguard/handshake/types.rs | 14 | ||||
-rw-r--r-- | src/wireguard/peer.rs | 2 | ||||
-rw-r--r-- | src/wireguard/router/device.rs | 2 | ||||
-rw-r--r-- | src/wireguard/router/peer.rs | 2 | ||||
-rw-r--r-- | src/wireguard/wireguard.rs | 85 | ||||
-rw-r--r-- | src/wireguard/workers.rs | 80 |
11 files changed, 277 insertions, 246 deletions
diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs index edd1a07..4b5d8f6 100644 --- a/src/wireguard/handshake/device.rs +++ b/src/wireguard/handshake/device.rs @@ -1,4 +1,5 @@ use spin::RwLock; +use std::collections::hash_map; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Mutex; @@ -6,7 +7,10 @@ use zerocopy::AsBytes; use byteorder::{ByteOrder, LittleEndian}; -use rand::prelude::*; +use rand::Rng; +use rand_core::{CryptoRng, RngCore}; + +use clear_on_drop::clear::Clear; use x25519_dalek::PublicKey; use x25519_dalek::StaticSecret; @@ -22,42 +26,101 @@ use super::types::*; const MAX_PEER_PER_DEVICE: usize = 1 << 20; pub struct KeyState { - pub sk: StaticSecret, // static secret key - pub pk: PublicKey, // static public key - macs: macs::Validator, // validator for the mac fields + pub(super) sk: StaticSecret, // static secret key + pub(super) pk: PublicKey, // static public key + macs: macs::Validator, // validator for the mac fields } -pub struct Device { - keyst: Option<KeyState>, // secret/public key - pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state - id_map: RwLock<HashMap<u32, [u8; 32]>>, // receiver ids -> public key +/// The device is generic over an "opaque" type +/// which can be used to associate the public key with this value. +/// (the instance is a Peer object in the parent module) +pub struct Device<O> { + keyst: Option<KeyState>, + id_map: RwLock<HashMap<u32, [u8; 32]>>, + pk_map: HashMap<[u8; 32], Peer<O>>, limiter: Mutex<RateLimiter>, } +pub struct Iter<'a, O> { + iter: hash_map::Iter<'a, [u8; 32], Peer<O>>, +} + +impl<'a, O> Iterator for Iter<'a, O> { + type Item = (PublicKey, &'a O); + + fn next(&mut self) -> Option<Self::Item> { + self.iter + .next() + .map(|(pk, peer)| (PublicKey::from(*pk), &peer.opaque)) + } +} + +/* These methods enable the Device to act as a map + * from public keys to the set of contained opaque values. + * + * It also abstracts away the problem of PublicKey not being hashable. + */ +impl<O> Device<O> { + pub fn clear(&mut self) { + self.id_map.write().clear(); + self.pk_map.clear(); + } + + pub fn len(&self) -> usize { + self.pk_map.len() + } + + /// Enables enumeration of (public key, opaque) pairs + /// without exposing internal peer type. + pub fn iter(&self) -> Iter<O> { + Iter { + iter: self.pk_map.iter(), + } + } + + /// Enables lookup by public key without exposing internal peer type. + pub fn get(&self, pk: &PublicKey) -> Option<&O> { + self.pk_map.get(pk.as_bytes()).map(|peer| &peer.opaque) + } + + pub fn contains_key(&self, pk: &PublicKey) -> bool { + self.pk_map.contains_key(pk.as_bytes()) + } +} + /* A mutable reference to the device needs to be held during configuration. * Wrapping the device in a RwLock enables peer config after "configuration time" */ -impl Device { +impl<O> Device<O> { /// Initialize a new handshake state machine - pub fn new() -> Device { + pub fn new() -> Device<O> { Device { keyst: None, - pk_map: HashMap::new(), id_map: RwLock::new(HashMap::new()), + pk_map: HashMap::new(), limiter: Mutex::new(RateLimiter::new()), } } - fn update_ss(&self, peer: &mut Peer) -> Option<PublicKey> { - if let Some(key) = self.keyst.as_ref() { - if *peer.pk.as_bytes() == *key.pk.as_bytes() { - return Some(peer.pk); + fn update_ss(&mut self) -> (Vec<u32>, Option<PublicKey>) { + let mut same = None; + let mut ids = Vec::with_capacity(self.pk_map.len()); + for (pk, peer) in self.pk_map.iter_mut() { + if let Some(key) = self.keyst.as_ref() { + if key.pk.as_bytes() == pk { + same = Some(PublicKey::from(*pk)); + peer.ss.clear() + } else { + let pk = PublicKey::from(*pk); + peer.ss = *key.sk.diffie_hellman(&pk).as_bytes(); + } + } else { + peer.ss.clear(); } - peer.ss = *key.sk.diffie_hellman(&peer.pk).as_bytes(); - } else { - peer.ss = [0u8; 32]; - }; - None + peer.reset_state().map(|id| ids.push(id)); + } + + (ids, same) } /// Update the secret key of the device @@ -74,29 +137,15 @@ impl Device { }); // recalculate / erase the shared secrets for every peer - let mut ids = vec![]; - let mut same = None; - for mut peer in self.pk_map.values_mut() { - // clear any existing handshake state - peer.reset_state().map(|id| ids.push(id)); - - // update precomputed shared secret - if let Some(key) = self.keyst.as_ref() { - peer.ss = *key.sk.diffie_hellman(&peer.pk).as_bytes(); - if *peer.pk.as_bytes() == *key.pk.as_bytes() { - same = Some(peer.pk) - } - } else { - peer.ss = [0u8; 32]; - }; - } + let (ids, same) = self.update_ss(); // release ids from aborted handshakes for id in ids { self.release(id) } - // if we found a peer matching the device public key, remove it. + // if we found a peer matching the device public key + // remove it and return its value to the caller same.map(|pk| { self.pk_map.remove(pk.as_bytes()); pk @@ -119,29 +168,32 @@ impl Device { /// /// * `pk` - The public key to add /// * `identifier` - Associated identifier which can be used to distinguish the peers - pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> { + pub fn add(&mut self, pk: PublicKey, opaque: O) -> Result<(), ConfigError> { // ensure less than 2^20 peers if self.pk_map.len() > MAX_PEER_PER_DEVICE { return Err(ConfigError::new("Too many peers for device")); } - // create peer and precompute static secret - let mut peer = Peer::new( - pk, - self.keyst - .as_ref() - .map(|key| *key.sk.diffie_hellman(&pk).as_bytes()) - .unwrap_or([0u8; 32]), - ); - - // add peer to device - match self.update_ss(&mut peer) { - Some(_) => Err(ConfigError::new("Public key of peer matches the device")), - None => { - self.pk_map.insert(*pk.as_bytes(), peer); - Ok(()) + // error if public key matches device + if let Some(key) = self.keyst.as_ref() { + if pk.as_bytes() == key.pk.as_bytes() { + return Err(ConfigError::new("Public key of peer matches the device")); } } + + // pre-compute shared secret and add to pk_map + self.pk_map.insert( + *pk.as_bytes(), + Peer::new( + pk, + self.keyst + .as_ref() + .map(|key| *key.sk.diffie_hellman(&pk).as_bytes()) + .unwrap_or([0u8; 32]), + opaque, + ), + ); + Ok(()) } /// Remove a peer by public key @@ -163,7 +215,7 @@ impl Device { .remove(pk.as_bytes()) .ok_or(ConfigError::new("Public key not in device"))?; - // pruge the id map (linear scan) + // purge the id map (linear scan) id_map.retain(|_, v| v != pk.as_bytes()); Ok(()) } @@ -231,11 +283,11 @@ impl Device { (_, None) => Err(HandshakeError::UnknownPublicKey), (None, _) => Err(HandshakeError::UnknownPublicKey), (Some(keyst), Some(peer)) => { - let local = self.allocate(rng, peer); + let local = self.allocate(rng, pk); let mut msg = Initiation::default(); // create noise part of initation - noise::create_initiation(rng, keyst, peer, local, &mut msg.noise)?; + noise::create_initiation(rng, keyst, peer, pk, local, &mut msg.noise)?; // add macs to initation peer.macs @@ -253,11 +305,11 @@ impl Device { /// /// * `msg` - Byte slice containing the message (untrusted input) pub fn process<'a, R: RngCore + CryptoRng>( - &self, - rng: &mut R, // rng instance to sample randomness from - msg: &[u8], // message buffer + &'a self, + rng: &mut R, // rng instance to sample randomness from + msg: &[u8], // message buffer src: Option<SocketAddr>, // optional source endpoint, set when "under load" - ) -> Result<Output, HandshakeError> { + ) -> Result<Output<'a, O>, HandshakeError> { // ensure type read in-range if msg.len() < 4 { return Err(HandshakeError::InvalidMessageFormat); @@ -303,17 +355,17 @@ impl Device { } // consume the initiation - let (peer, st) = noise::consume_initiation(self, keyst, &msg.noise)?; + let (peer, pk, st) = noise::consume_initiation(self, keyst, &msg.noise)?; // allocate new index for response - let local = self.allocate(rng, peer); + let local = self.allocate(rng, &pk); // prepare memory for response, TODO: take slice for zero allocation let mut resp = Response::default(); // create response (release id on error) - let keys = - noise::create_response(rng, peer, local, st, &mut resp.noise).map_err(|e| { + let keys = noise::create_response(rng, peer, &pk, local, st, &mut resp.noise) + .map_err(|e| { self.release(local); e })?; @@ -324,7 +376,11 @@ impl Device { .generate(resp.noise.as_bytes(), &mut resp.macs); // return unconfirmed keypair and the response as vector - Ok((Some(peer.pk), Some(resp.as_bytes().to_owned()), Some(keys))) + Ok(( + Some(&peer.opaque), + Some(resp.as_bytes().to_owned()), + Some(keys), + )) } TYPE_RESPONSE => { let msg = Response::parse(msg)?; @@ -363,7 +419,7 @@ impl Device { let msg = CookieReply::parse(msg)?; // lookup peer - let peer = self.lookup_id(msg.f_receiver.get())?; + let (peer, _) = self.lookup_id(msg.f_receiver.get())?; // validate cookie reply peer.macs.lock().process(&msg)?; @@ -379,7 +435,7 @@ impl Device { // Internal function // // Return the peer associated with the public key - pub(crate) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer, HandshakeError> { + pub(super) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer<O>, HandshakeError> { self.pk_map .get(pk.as_bytes()) .ok_or(HandshakeError::UnknownPublicKey) @@ -388,11 +444,11 @@ impl Device { // Internal function // // Return the peer currently associated with the receiver identifier - pub(crate) fn lookup_id(&self, id: u32) -> Result<&Peer, HandshakeError> { + pub(super) fn lookup_id(&self, id: u32) -> Result<(&Peer<O>, PublicKey), HandshakeError> { let im = self.id_map.read(); let pk = im.get(&id).ok_or(HandshakeError::UnknownReceiverId)?; match self.pk_map.get(pk) { - Some(peer) => Ok(peer), + Some(peer) => Ok((peer, PublicKey::from(*pk))), _ => unreachable!(), // if the id-lookup succeeded, the peer should exist } } @@ -400,7 +456,7 @@ impl Device { // Internal function // // Allocated a new receiver identifier for the peer - fn allocate<R: RngCore + CryptoRng>(&self, rng: &mut R, peer: &Peer) -> u32 { + fn allocate<R: RngCore + CryptoRng>(&self, rng: &mut R, pk: &PublicKey) -> u32 { loop { let id = rng.gen(); @@ -412,7 +468,7 @@ impl Device { // take write lock and add index let mut m = self.id_map.write(); if !m.contains_key(&id) { - m.insert(id, *peer.pk.as_bytes()); + m.insert(id, *pk.as_bytes()); return id; } } diff --git a/src/wireguard/handshake/macs.rs b/src/wireguard/handshake/macs.rs index 689826b..cb5d7d4 100644 --- a/src/wireguard/handshake/macs.rs +++ b/src/wireguard/handshake/macs.rs @@ -286,8 +286,7 @@ mod tests { use x25519_dalek::StaticSecret; fn new_validator_generator() -> (Validator, Generator) { - let mut rng = OsRng::new().unwrap(); - let sk = StaticSecret::new(&mut rng); + let sk = StaticSecret::new(&mut OsRng); let pk = PublicKey::from(&sk); (Validator::new(pk), Generator::new(pk)) } @@ -296,7 +295,6 @@ mod tests { #[test] fn test_cookie_reply(inner1 : Vec<u8>, inner2 : Vec<u8>, receiver : u32) { let mut msg = CookieReply::default(); - let mut rng = OsRng::new().expect("failed to create rng"); let mut macs = MacsFooter::default(); let src = "192.0.2.16:8080".parse().unwrap(); let (validator, mut generator) = new_validator_generator(); @@ -309,7 +307,7 @@ mod tests { // check validity of mac1 validator.check_mac1(&inner1[..], &macs).expect("mac1 of inner1 did not validate"); assert_eq!(validator.check_mac2(&inner1[..], &src, &macs), false, "mac2 of inner2 did not validate"); - validator.create_cookie_reply(&mut rng, receiver, &src, &macs, &mut msg); + validator.create_cookie_reply(&mut OsRng, receiver, &src, &macs, &mut msg); // consume cookie reply generator.process(&msg).expect("failed to process CookieReply"); diff --git a/src/wireguard/handshake/noise.rs b/src/wireguard/handshake/noise.rs index 072ac13..9e431cf 100644 --- a/src/wireguard/handshake/noise.rs +++ b/src/wireguard/handshake/noise.rs @@ -10,7 +10,7 @@ use hmac::Hmac; use aead::{Aead, NewAead, Payload}; use chacha20poly1305::ChaCha20Poly1305; -use rand::{CryptoRng, RngCore}; +use rand_core::{CryptoRng, RngCore}; use log::debug; @@ -215,20 +215,21 @@ mod tests { } } -pub fn create_initiation<R: RngCore + CryptoRng>( +pub(super) fn create_initiation<R: RngCore + CryptoRng, O>( rng: &mut R, keyst: &KeyState, - peer: &Peer, + peer: &Peer<O>, + pk: &PublicKey, local: u32, msg: &mut NoiseInitiation, ) -> Result<(), HandshakeError> { - debug!("create initation"); + debug!("create initiation"); clear_stack_on_return(CLEAR_PAGES, || { // initialize state let ck = INITIAL_CK; let hs = INITIAL_HS; - let hs = HASH!(&hs, peer.pk.as_bytes()); + let hs = HASH!(&hs, pk.as_bytes()); msg.f_type.set(TYPE_INITIATION as u32); msg.f_sender.set(local); // from us @@ -252,7 +253,7 @@ pub fn create_initiation<R: RngCore + CryptoRng>( // (C, k) := Kdf2(C, DH(E_priv, S_pub)) - let (ck, key) = KDF2!(&ck, eph_sk.diffie_hellman(&peer.pk).as_bytes()); + let (ck, key) = KDF2!(&ck, eph_sk.diffie_hellman(&pk).as_bytes()); // msg.static := Aead(k, 0, S_pub, H) @@ -297,12 +298,12 @@ pub fn create_initiation<R: RngCore + CryptoRng>( }) } -pub fn consume_initiation<'a>( - device: &'a Device, +pub(super) fn consume_initiation<'a, O>( + device: &'a Device<O>, keyst: &KeyState, msg: &NoiseInitiation, -) -> Result<(&'a Peer, TemporaryState), HandshakeError> { - debug!("consume initation"); +) -> Result<(&'a Peer<O>, PublicKey, TemporaryState), HandshakeError> { + debug!("consume initiation"); clear_stack_on_return(CLEAR_PAGES, || { // initialize new state @@ -369,13 +370,18 @@ pub fn consume_initiation<'a>( // return state (to create response) - Ok((peer, (msg.f_sender.get(), eph_r_pk, hs, ck))) + Ok(( + peer, + PublicKey::from(pk), + (msg.f_sender.get(), eph_r_pk, hs, ck), + )) }) } -pub fn create_response<R: RngCore + CryptoRng>( +pub(super) fn create_response<R: RngCore + CryptoRng, O>( rng: &mut R, - peer: &Peer, + peer: &Peer<O>, + pk: &PublicKey, local: u32, // sending identifier state: TemporaryState, // state from "consume_initiation" msg: &mut NoiseResponse, // resulting response @@ -388,7 +394,7 @@ pub fn create_response<R: RngCore + CryptoRng>( msg.f_type.set(TYPE_RESPONSE as u32); msg.f_sender.set(local); // from us - msg.f_receiver.set(receiver); // to the sender of the initation + msg.f_receiver.set(receiver); // to the sender of the initiation // (E_priv, E_pub) := DH-Generate() @@ -413,7 +419,7 @@ pub fn create_response<R: RngCore + CryptoRng>( // C := Kdf1(C, DH(E_priv, S_pub)) - let ck = KDF1!(&ck, eph_sk.diffie_hellman(&peer.pk).as_bytes()); + let ck = KDF1!(&ck, eph_sk.diffie_hellman(&pk).as_bytes()); // (C, tau, k) := Kdf3(C, Q) @@ -460,15 +466,15 @@ pub fn create_response<R: RngCore + CryptoRng>( * allow concurrent processing of potential responses to the initiation, * in order to better mitigate DoS from malformed response messages. */ -pub fn consume_response( - device: &Device, +pub(super) fn consume_response<'a, O>( + device: &'a Device<O>, keyst: &KeyState, msg: &NoiseResponse, -) -> Result<Output, HandshakeError> { +) -> Result<Output<'a, O>, HandshakeError> { debug!("consume response"); clear_stack_on_return(CLEAR_PAGES, || { // retrieve peer and copy initiation state - let peer = device.lookup_id(msg.f_receiver.get())?; + let (peer, _) = device.lookup_id(msg.f_receiver.get())?; let (hs, ck, local, eph_sk) = match *peer.state.lock() { State::InitiationSent { @@ -537,7 +543,7 @@ pub fn consume_response( // return confirmed key-pair Ok(( - Some(peer.pk), + Some(&peer.opaque), None, Some(KeyPair { birth, diff --git a/src/wireguard/handshake/peer.rs b/src/wireguard/handshake/peer.rs index a4df560..f4d15fc 100644 --- a/src/wireguard/handshake/peer.rs +++ b/src/wireguard/handshake/peer.rs @@ -22,19 +22,21 @@ const TIME_BETWEEN_INITIATIONS: Duration = Duration::from_millis(20); * * This type is only for internal use and not exposed. */ -pub struct Peer { +pub(super) struct Peer<O> { + // opaque type which identifies a peer + pub opaque: O, + // mutable state - pub(crate) state: Mutex<State>, - pub(crate) timestamp: Mutex<Option<timestamp::TAI64N>>, - pub(crate) last_initiation_consumption: Mutex<Option<Instant>>, + pub state: Mutex<State>, + pub timestamp: Mutex<Option<timestamp::TAI64N>>, + pub last_initiation_consumption: Mutex<Option<Instant>>, // state related to DoS mitigation fields - pub(crate) macs: Mutex<macs::Generator>, + pub macs: Mutex<macs::Generator>, // constant state - pub(crate) pk: PublicKey, // public key of peer - pub(crate) ss: [u8; 32], // precomputed DH(static, static) - pub(crate) psk: Psk, // psk of peer + pub ss: [u8; 32], // precomputed DH(static, static) + pub psk: Psk, // psk of peer } pub enum State { @@ -60,14 +62,14 @@ impl Drop for State { } } -impl Peer { - pub fn new(pk: PublicKey, ss: [u8; 32]) -> Self { +impl<O> Peer<O> { + pub fn new(pk: PublicKey, ss: [u8; 32], opaque: O) -> Self { Self { + opaque, macs: Mutex::new(macs::Generator::new(pk)), state: Mutex::new(State::Reset), timestamp: Mutex::new(None), last_initiation_consumption: Mutex::new(None), - pk, ss, psk: [0u8; 32], } @@ -88,7 +90,7 @@ impl Peer { /// * ts_new - The associated timestamp pub fn check_replay_flood( &self, - device: &Device, + device: &Device<O>, timestamp_new: ×tamp::TAI64N, ) -> Result<(), HandshakeError> { let mut state = self.state.lock(); diff --git a/src/wireguard/handshake/tests.rs b/src/wireguard/handshake/tests.rs index ff27b3e..bfdc5ab 100644 --- a/src/wireguard/handshake/tests.rs +++ b/src/wireguard/handshake/tests.rs @@ -12,8 +12,10 @@ use x25519_dalek::StaticSecret; use super::messages::{Initiation, Response}; -fn setup_devices<R: RngCore + CryptoRng>(rng: &mut R) -> (PublicKey, Device, PublicKey, Device) { - // generate new keypairs +fn setup_devices<R: RngCore + CryptoRng, O: Default>( + rng: &mut R, +) -> (PublicKey, Device<O>, PublicKey, Device<O>) { + // generate new key pairs let sk1 = StaticSecret::new(rng); let pk1 = PublicKey::from(&sk1); @@ -26,7 +28,7 @@ fn setup_devices<R: RngCore + CryptoRng>(rng: &mut R) -> (PublicKey, Device, Pub let mut psk = [0u8; 32]; rng.fill_bytes(&mut psk[..]); - // intialize devices on both ends + // initialize devices on both ends let mut dev1 = Device::new(); let mut dev2 = Device::new(); @@ -34,8 +36,8 @@ fn setup_devices<R: RngCore + CryptoRng>(rng: &mut R) -> (PublicKey, Device, Pub dev1.set_sk(Some(sk1)); dev2.set_sk(Some(sk2)); - dev1.add(pk2).unwrap(); - dev2.add(pk1).unwrap(); + dev1.add(pk2, O::default()).unwrap(); + dev2.add(pk1, O::default()).unwrap(); dev1.set_psk(pk2, psk).unwrap(); dev2.set_psk(pk1, psk).unwrap(); @@ -49,45 +51,44 @@ fn wait() { /* Test longest possible handshake interaction (7 messages): * - * 1. I -> R (initation) + * 1. I -> R (initiation) * 2. I <- R (cookie reply) - * 3. I -> R (initation) + * 3. I -> R (initiation) * 4. I <- R (response) * 5. I -> R (cookie reply) - * 6. I -> R (initation) + * 6. I -> R (initiation) * 7. I <- R (response) */ #[test] fn handshake_under_load() { - let mut rng = OsRng::new().unwrap(); - let (_pk1, dev1, pk2, dev2) = setup_devices(&mut rng); + let (_pk1, dev1, pk2, dev2): (_, Device<usize>, _, _) = setup_devices(&mut OsRng); let src1: SocketAddr = "172.16.0.1:8080".parse().unwrap(); let src2: SocketAddr = "172.16.0.2:7070".parse().unwrap(); - // 1. device-1 : create first initation - let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); + // 1. device-1 : create first initiation + let msg_init = dev1.begin(&mut OsRng, &pk2).unwrap(); // 2. device-2 : responds with CookieReply - let msg_cookie = match dev2.process(&mut rng, &msg_init, Some(src1)).unwrap() { + let msg_cookie = match dev2.process(&mut OsRng, &msg_init, Some(src1)).unwrap() { (None, Some(msg), None) => msg, _ => panic!("unexpected response"), }; // device-1 : processes CookieReply (no response) - match dev1.process(&mut rng, &msg_cookie, Some(src2)).unwrap() { + match dev1.process(&mut OsRng, &msg_cookie, Some(src2)).unwrap() { (None, None, None) => (), _ => panic!("unexpected response"), } - // avoid initation flood detection + // avoid initiation flood detection wait(); - // 3. device-1 : create second initation - let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); + // 3. device-1 : create second initiation + let msg_init = dev1.begin(&mut OsRng, &pk2).unwrap(); // 4. device-2 : responds with noise response - let msg_response = match dev2.process(&mut rng, &msg_init, Some(src1)).unwrap() { + let msg_response = match dev2.process(&mut OsRng, &msg_init, Some(src1)).unwrap() { (Some(_), Some(msg), Some(kp)) => { assert_eq!(kp.initiator, false); msg @@ -96,25 +97,25 @@ fn handshake_under_load() { }; // 5. device-1 : responds with CookieReply - let msg_cookie = match dev1.process(&mut rng, &msg_response, Some(src2)).unwrap() { + let msg_cookie = match dev1.process(&mut OsRng, &msg_response, Some(src2)).unwrap() { (None, Some(msg), None) => msg, _ => panic!("unexpected response"), }; // device-2 : processes CookieReply (no response) - match dev2.process(&mut rng, &msg_cookie, Some(src1)).unwrap() { + match dev2.process(&mut OsRng, &msg_cookie, Some(src1)).unwrap() { (None, None, None) => (), _ => panic!("unexpected response"), } - // avoid initation flood detection + // avoid initiation flood detection wait(); - // 6. device-1 : create third initation - let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); + // 6. device-1 : create third initiation + let msg_init = dev1.begin(&mut OsRng, &pk2).unwrap(); // 7. device-2 : responds with noise response - let (msg_response, kp1) = match dev2.process(&mut rng, &msg_init, Some(src1)).unwrap() { + let (msg_response, kp1) = match dev2.process(&mut OsRng, &msg_init, Some(src1)).unwrap() { (Some(_), Some(msg), Some(kp)) => { assert_eq!(kp.initiator, false); (msg, kp) @@ -123,7 +124,7 @@ fn handshake_under_load() { }; // device-1 : process noise response - let kp2 = match dev1.process(&mut rng, &msg_response, Some(src2)).unwrap() { + let kp2 = match dev1.process(&mut OsRng, &msg_response, Some(src2)).unwrap() { (Some(_), None, Some(kp)) => { assert_eq!(kp.initiator, true); kp @@ -137,8 +138,7 @@ fn handshake_under_load() { #[test] fn handshake_no_load() { - let mut rng = OsRng::new().unwrap(); - let (pk1, mut dev1, pk2, mut dev2) = setup_devices(&mut rng); + let (pk1, mut dev1, pk2, mut dev2): (_, Device<usize>, _, _) = setup_devices(&mut OsRng); // do a few handshakes (every handshake should succeed) @@ -147,7 +147,7 @@ fn handshake_no_load() { // create initiation - let msg1 = dev1.begin(&mut rng, &pk2).unwrap(); + let msg1 = dev1.begin(&mut OsRng, &pk2).unwrap(); println!("msg1 = {} : {} bytes", hex::encode(&msg1[..]), msg1.len()); println!( @@ -158,7 +158,7 @@ fn handshake_no_load() { // process initiation and create response let (_, msg2, ks_r) = dev2 - .process(&mut rng, &msg1, None) + .process(&mut OsRng, &msg1, None) .expect("failed to process initiation"); let ks_r = ks_r.unwrap(); @@ -175,7 +175,7 @@ fn handshake_no_load() { // process response and obtain confirmed key-pair let (_, msg3, ks_i) = dev1 - .process(&mut rng, &msg2, None) + .process(&mut OsRng, &msg2, None) .expect("failed to process response"); let ks_i = ks_i.unwrap(); @@ -188,7 +188,7 @@ fn handshake_no_load() { dev1.release(ks_i.local_id()); dev2.release(ks_r.local_id()); - // avoid initation flood detection + // avoid initiation flood detection wait(); } diff --git a/src/wireguard/handshake/types.rs b/src/wireguard/handshake/types.rs index 5f984cc..ed2fcbb 100644 --- a/src/wireguard/handshake/types.rs +++ b/src/wireguard/handshake/types.rs @@ -1,10 +1,8 @@ +use super::super::types::KeyPair; + use std::error::Error; use std::fmt; -use x25519_dalek::PublicKey; - -use super::super::types::KeyPair; - /* Internal types for the noise IKpsk2 implementation */ // config error @@ -79,10 +77,10 @@ impl Error for HandshakeError { } } -pub type Output = ( - Option<PublicKey>, // external identifier associated with peer - Option<Vec<u8>>, // message to send - Option<KeyPair>, // resulting key-pair of successful handshake +pub type Output<'a, O> = ( + Option<&'a O>, // external identifier associated with peer + Option<Vec<u8>>, // message to send + Option<KeyPair>, // resulting key-pair of successful handshake ); // preshared key diff --git a/src/wireguard/peer.rs b/src/wireguard/peer.rs index 1af4df3..b3656fe 100644 --- a/src/wireguard/peer.rs +++ b/src/wireguard/peer.rs @@ -31,7 +31,7 @@ pub struct PeerInner<T: Tun, B: UDP> { pub handshake_queued: AtomicBool, // is a handshake job currently queued for the peer? // stats and configuration - pub pk: PublicKey, // public key, DISCUSS: avoid this. TODO: remove + pub pk: PublicKey, // public key pub rx_bytes: AtomicU64, // received bytes pub tx_bytes: AtomicU64, // transmitted bytes diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs index f903a8e..6c59491 100644 --- a/src/wireguard/router/device.rs +++ b/src/wireguard/router/device.rs @@ -142,7 +142,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle< }; // start worker threads - let mut threads = Vec::with_capacity(num_workers); + let mut threads = Vec::with_capacity(4 * num_workers); // inbound/decryption workers for _ in 0..num_workers { diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs index b8110f0..8fe2e1c 100644 --- a/src/wireguard/router/peer.rs +++ b/src/wireguard/router/peer.rs @@ -204,7 +204,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E, debug!("peer.send"); // send to endpoint (if known) - match self.endpoint.lock().as_ref() { + match self.endpoint.lock().as_mut() { Some(endpoint) => { let outbound = self.device.outbound.read(); if outbound.0 { diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index bf550ef..ecbb9c1 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -21,9 +21,6 @@ use std::sync::Mutex as StdMutex; use std::thread; use std::time::Instant; -use std::collections::hash_map::Entry; -use std::collections::HashMap; - use hjul::Runner; use rand::rngs::OsRng; use rand::Rng; @@ -50,14 +47,13 @@ pub struct WireguardInner<T: Tun, B: UDP> { // outbound writer pub send: RwLock<Option<B::Writer>>, - // identity and configuration map - pub peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>, + // peer map + pub peers: RwLock<handshake::Device<Peer<T, B>>>, // cryptokey router pub router: router::Device<B::Endpoint, Events<T, B>, T::Writer, B::Writer>, // handshake related state - pub handshake: RwLock<handshake::Device>, pub last_under_load: Mutex<Instant>, pub pending: AtomicUsize, // number of pending handshake packets in queue pub queue: ParallelQueue<HandshakeJob<B::Endpoint>>, @@ -142,7 +138,7 @@ impl<T: Tun, B: UDP> WireGuard<T, B> { self.router.down(); // set all peers down (stops timers) - for peer in self.peers.write().values() { + for (_, peer) in self.peers.write().iter() { peer.down(); } @@ -163,11 +159,11 @@ impl<T: Tun, B: UDP> WireGuard<T, B> { return; } - // enable tranmission from router + // enable transmission from router self.router.up(); // set all peers up (restarts timers) - for peer in self.peers.write().values() { + for (_, peer) in self.peers.write().iter() { peer.up(); } @@ -179,54 +175,51 @@ impl<T: Tun, B: UDP> WireGuard<T, B> { } pub fn remove_peer(&self, pk: &PublicKey) { - if self.handshake.write().remove(pk).is_ok() { - self.peers.write().remove(pk.as_bytes()); - } + let _ = self.peers.write().remove(pk); } pub fn lookup_peer(&self, pk: &PublicKey) -> Option<Peer<T, B>> { - self.peers.read().get(pk.as_bytes()).map(|p| p.clone()) + self.peers.read().get(pk).map(|p| p.clone()) } pub fn list_peers(&self) -> Vec<Peer<T, B>> { let peers = self.peers.read(); let mut list = Vec::with_capacity(peers.len()); for (k, v) in peers.iter() { - debug_assert!(k == v.pk.as_bytes()); + debug_assert!(k.as_bytes() == v.pk.as_bytes()); list.push(v.clone()); } list } pub fn set_key(&self, sk: Option<StaticSecret>) { - let mut handshake = self.handshake.write(); - handshake.set_sk(sk); + let mut peers = self.peers.write(); + peers.set_sk(sk); self.router.clear_sending_keys(); - // handshake lock is released and new handshakes can be initated } pub fn get_sk(&self) -> Option<StaticSecret> { - self.handshake + self.peers .read() .get_sk() .map(|sk| StaticSecret::from(sk.to_bytes())) } pub fn set_psk(&self, pk: PublicKey, psk: [u8; 32]) -> bool { - self.handshake.write().set_psk(pk, psk).is_ok() + self.peers.write().set_psk(pk, psk).is_ok() } pub fn get_psk(&self, pk: &PublicKey) -> Option<[u8; 32]> { - self.handshake.read().get_psk(pk).ok() + self.peers.read().get_psk(pk).ok() } pub fn add_peer(&self, pk: PublicKey) -> bool { - if self.peers.read().contains_key(pk.as_bytes()) { + let mut peers = self.peers.write(); + if peers.contains_key(&pk) { return false; } - let mut rng = OsRng::new().unwrap(); let state = Arc::new(PeerInner { - id: rng.gen(), + id: OsRng.gen(), pk, wg: self.clone(), walltime_last_handshake: Mutex::new(None), @@ -243,33 +236,19 @@ impl<T: Tun, B: UDP> WireGuard<T, B> { // form WireGuard peer let peer = Peer { router, state }; + // prevent up/down while inserting + let enabled = self.enabled.read(); + + /* The need for dummy timers arises from the chicken-egg + * problem of the timer callbacks being able to set timers themselves. + * + * This is in fact the only place where the write lock is ever taken. + * TODO: Consider the ease of using atomic pointers instead. + */ + *peer.timers.write() = Timers::new(&*self.runner.lock(), *enabled, peer.clone()); + // finally, add the peer to the wireguard device - let mut peers = self.peers.write(); - match peers.entry(*pk.as_bytes()) { - Entry::Occupied(_) => false, - Entry::Vacant(vacancy) => { - // check that the public key does not cause conflict with the private key of the device - let ok_pk = self.handshake.write().add(pk).is_ok(); - if !ok_pk { - return false; - } - - // prevent up/down while inserting - let enabled = self.enabled.read(); - - /* The need for dummy timers arises from the chicken-egg - * problem of the timer callbacks being able to set timers themselves. - * - * This is in fact the only place where the write lock is ever taken. - * TODO: Consider the ease of using atomic pointers instead. - */ - *peer.timers.write() = Timers::new(&*self.runner.lock(), *enabled, peer.clone()); - - // insert into peer map (takes ownership and ensures that the peer is not dropped) - vacancy.insert(peer); - true - } - } + peers.add(pk, peer).is_ok() } /// Begin consuming messages from the reader. @@ -311,9 +290,6 @@ impl<T: Tun, B: UDP> WireGuard<T, B> { // workers equal to number of physical cores let cpus = num_cpus::get(); - // create device state - let mut rng = OsRng::new().unwrap(); - // create handshake queue let (tx, mut rxs) = ParallelQueue::new(cpus, 128); @@ -322,14 +298,13 @@ impl<T: Tun, B: UDP> WireGuard<T, B> { inner: Arc::new(WireguardInner { enabled: RwLock::new(false), tun_readers: WaitCounter::new(), - id: rng.gen(), + id: OsRng.gen(), mtu: AtomicUsize::new(0), - peers: RwLock::new(HashMap::new()), last_under_load: Mutex::new(Instant::now() - TIME_HORIZON), send: RwLock::new(None), router: router::Device::new(num_cpus::get(), writer), // router owns the writing half pending: AtomicUsize::new(0), - handshake: RwLock::new(handshake::Device::new()), + peers: RwLock::new(handshake::Device::new()), runner: Mutex::new(Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY)), queue: tx, }), diff --git a/src/wireguard/workers.rs b/src/wireguard/workers.rs index e1d3899..c1a2af7 100644 --- a/src/wireguard/workers.rs +++ b/src/wireguard/workers.rs @@ -152,9 +152,6 @@ pub fn handshake_worker<T: Tun, B: UDP>( ) { debug!("{} : handshake worker, started", wg); - // prepare OsRng instance for this thread - let mut rng = OsRng::new().expect("Unable to obtain a CSPRNG"); - // process elements from the handshake queue for job in rx { // check if under load @@ -181,11 +178,11 @@ pub fn handshake_worker<T: Tun, B: UDP>( // de-multiplex staged handshake jobs and handshake messages match job { - HandshakeJob::Message(msg, src) => { + HandshakeJob::Message(msg, mut src) => { // process message - let device = wg.handshake.read(); + let device = wg.peers.read(); match device.process( - &mut rng, + &mut OsRng, &msg[..], if under_load { Some(src.into_address()) @@ -193,7 +190,7 @@ pub fn handshake_worker<T: Tun, B: UDP>( None }, ) { - Ok((pk, resp, keypair)) => { + Ok((peer, resp, keypair)) => { // send response (might be cookie reply or handshake response) let mut resp_len: u64 = 0; if let Some(msg) = resp { @@ -204,7 +201,7 @@ pub fn handshake_worker<T: Tun, B: UDP>( "{} : handshake worker, send response ({} bytes)", wg, resp_len ); - let _ = writer.write(&msg[..], &src).map_err(|e| { + let _ = writer.write(&msg[..], &mut src).map_err(|e| { debug!( "{} : handshake worker, failed to send response, error = {}", wg, @@ -215,56 +212,55 @@ pub fn handshake_worker<T: Tun, B: UDP>( } // update peer state - if let Some(pk) = pk { + if let Some(peer) = peer { // authenticated handshake packet received - if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { - // add to rx_bytes and tx_bytes - let req_len = msg.len() as u64; - peer.rx_bytes.fetch_add(req_len, Ordering::Relaxed); - peer.tx_bytes.fetch_add(resp_len, Ordering::Relaxed); - // update endpoint - peer.router.set_endpoint(src); + // add to rx_bytes and tx_bytes + let req_len = msg.len() as u64; + peer.rx_bytes.fetch_add(req_len, Ordering::Relaxed); + peer.tx_bytes.fetch_add(resp_len, Ordering::Relaxed); - if resp_len > 0 { - // update timers after sending handshake response - debug!("{} : handshake worker, handshake response sent", wg); - peer.state.sent_handshake_response(); - } else { - // update timers after receiving handshake response - debug!( - "{} : handshake worker, handshake response was received", - wg - ); - peer.state.timers_handshake_complete(); - } + // update endpoint + peer.router.set_endpoint(src); + + if resp_len > 0 { + // update timers after sending handshake response + debug!("{} : handshake worker, handshake response sent", wg); + peer.state.sent_handshake_response(); + } else { + // update timers after receiving handshake response + debug!( + "{} : handshake worker, handshake response was received", + wg + ); + peer.state.timers_handshake_complete(); + } - // add any new keypair to peer - keypair.map(|kp| { - debug!("{} : handshake worker, new keypair for {}", wg, peer); + // add any new keypair to peer + keypair.map(|kp| { + debug!("{} : handshake worker, new keypair for {}", wg, peer); - // this means that a handshake response was processed or sent - peer.timers_session_derived(); + // this means that a handshake response was processed or sent + peer.timers_session_derived(); - // free any unused ids - for id in peer.router.add_keypair(kp) { - device.release(id); - } - }); - } + // free any unused ids + for id in peer.router.add_keypair(kp) { + device.release(id); + } + }); } } Err(e) => debug!("{} : handshake worker, error = {:?}", wg, e), } } HandshakeJob::New(pk) => { - if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { + if let Some(peer) = wg.peers.read().get(&pk) { debug!( "{} : handshake worker, new handshake requested for {}", wg, peer ); - let device = wg.handshake.read(); - let _ = device.begin(&mut rng, &peer.pk).map(|msg| { + let device = wg.peers.read(); + let _ = device.begin(&mut OsRng, &peer.pk).map(|msg| { let _ = peer.router.send(&msg[..]).map_err(|e| { debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e) }); |