From dcd567c08f126b09548a98df0468ef1fe86d9f0a Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sat, 1 Feb 2020 14:39:19 +0100 Subject: Squashed commit of the following: commit 1e26a0bef44e65023a97a16ecf3b123e688d19f7 Author: Mathias Hall-Andersen Date: Sat Feb 1 14:36:50 2020 +0100 Initial version of sticky sockets for Linux commit 605cc656ad235d09ba6cd12d03dee2c5e0a9a80a Author: Mathias Hall-Andersen Date: Thu Jan 30 14:57:00 2020 +0100 Clear src when sendmsg fails with EINVAL commit dffd2b228af70f681e2a161642bbdaa348419bf3 Author: Mathias Hall-Andersen Date: Sun Jan 26 14:01:28 2020 +0100 Fix typoes commit 2015663706fbe15ed1ac443a31de86b3e6c643c7 Author: Mathias Hall-Andersen Date: Sun Jan 26 13:51:59 2020 +0100 Restructure of public key -> peer state Restructured the mapping of public keys to peer state in the project. The handshake device is now generic over an opaque type, which enables it to be the sole place where public keys are mapped to the peer states. This gets rid of the "peer" map in the WireGuard devices and avoids having to include the public key in the handshake peer state. commit bbcfaad4bcc5cf16bacdef0cefe7d29ba1519a23 Author: Mathias Hall-Andersen Date: Fri Jan 10 21:10:27 2020 +0100 Fixed bind6 also binding on IPv4 commit acbca236b70598c20c24de474690bcad883241d4 Author: Mathias Hall-Andersen Date: Thu Jan 9 11:24:13 2020 +0100 Work on sticky sockets --- src/wireguard/handshake/device.rs | 198 ++++++++++++++++++++++++-------------- src/wireguard/handshake/macs.rs | 6 +- src/wireguard/handshake/noise.rs | 46 +++++---- src/wireguard/handshake/peer.rs | 26 ++--- src/wireguard/handshake/tests.rs | 62 ++++++------ src/wireguard/handshake/types.rs | 14 ++- src/wireguard/peer.rs | 2 +- src/wireguard/router/device.rs | 2 +- src/wireguard/router/peer.rs | 2 +- src/wireguard/wireguard.rs | 85 ++++++---------- src/wireguard/workers.rs | 80 ++++++++------- 11 files changed, 277 insertions(+), 246 deletions(-) (limited to 'src/wireguard') diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs index edd1a07..4b5d8f6 100644 --- a/src/wireguard/handshake/device.rs +++ b/src/wireguard/handshake/device.rs @@ -1,4 +1,5 @@ use spin::RwLock; +use std::collections::hash_map; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Mutex; @@ -6,7 +7,10 @@ use zerocopy::AsBytes; use byteorder::{ByteOrder, LittleEndian}; -use rand::prelude::*; +use rand::Rng; +use rand_core::{CryptoRng, RngCore}; + +use clear_on_drop::clear::Clear; use x25519_dalek::PublicKey; use x25519_dalek::StaticSecret; @@ -22,42 +26,101 @@ use super::types::*; const MAX_PEER_PER_DEVICE: usize = 1 << 20; pub struct KeyState { - pub sk: StaticSecret, // static secret key - pub pk: PublicKey, // static public key - macs: macs::Validator, // validator for the mac fields + pub(super) sk: StaticSecret, // static secret key + pub(super) pk: PublicKey, // static public key + macs: macs::Validator, // validator for the mac fields } -pub struct Device { - keyst: Option, // secret/public key - pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state - id_map: RwLock>, // receiver ids -> public key +/// The device is generic over an "opaque" type +/// which can be used to associate the public key with this value. +/// (the instance is a Peer object in the parent module) +pub struct Device { + keyst: Option, + id_map: RwLock>, + pk_map: HashMap<[u8; 32], Peer>, limiter: Mutex, } +pub struct Iter<'a, O> { + iter: hash_map::Iter<'a, [u8; 32], Peer>, +} + +impl<'a, O> Iterator for Iter<'a, O> { + type Item = (PublicKey, &'a O); + + fn next(&mut self) -> Option { + self.iter + .next() + .map(|(pk, peer)| (PublicKey::from(*pk), &peer.opaque)) + } +} + +/* These methods enable the Device to act as a map + * from public keys to the set of contained opaque values. + * + * It also abstracts away the problem of PublicKey not being hashable. + */ +impl Device { + pub fn clear(&mut self) { + self.id_map.write().clear(); + self.pk_map.clear(); + } + + pub fn len(&self) -> usize { + self.pk_map.len() + } + + /// Enables enumeration of (public key, opaque) pairs + /// without exposing internal peer type. + pub fn iter(&self) -> Iter { + Iter { + iter: self.pk_map.iter(), + } + } + + /// Enables lookup by public key without exposing internal peer type. + pub fn get(&self, pk: &PublicKey) -> Option<&O> { + self.pk_map.get(pk.as_bytes()).map(|peer| &peer.opaque) + } + + pub fn contains_key(&self, pk: &PublicKey) -> bool { + self.pk_map.contains_key(pk.as_bytes()) + } +} + /* A mutable reference to the device needs to be held during configuration. * Wrapping the device in a RwLock enables peer config after "configuration time" */ -impl Device { +impl Device { /// Initialize a new handshake state machine - pub fn new() -> Device { + pub fn new() -> Device { Device { keyst: None, - pk_map: HashMap::new(), id_map: RwLock::new(HashMap::new()), + pk_map: HashMap::new(), limiter: Mutex::new(RateLimiter::new()), } } - fn update_ss(&self, peer: &mut Peer) -> Option { - if let Some(key) = self.keyst.as_ref() { - if *peer.pk.as_bytes() == *key.pk.as_bytes() { - return Some(peer.pk); + fn update_ss(&mut self) -> (Vec, Option) { + let mut same = None; + let mut ids = Vec::with_capacity(self.pk_map.len()); + for (pk, peer) in self.pk_map.iter_mut() { + if let Some(key) = self.keyst.as_ref() { + if key.pk.as_bytes() == pk { + same = Some(PublicKey::from(*pk)); + peer.ss.clear() + } else { + let pk = PublicKey::from(*pk); + peer.ss = *key.sk.diffie_hellman(&pk).as_bytes(); + } + } else { + peer.ss.clear(); } - peer.ss = *key.sk.diffie_hellman(&peer.pk).as_bytes(); - } else { - peer.ss = [0u8; 32]; - }; - None + peer.reset_state().map(|id| ids.push(id)); + } + + (ids, same) } /// Update the secret key of the device @@ -74,29 +137,15 @@ impl Device { }); // recalculate / erase the shared secrets for every peer - let mut ids = vec![]; - let mut same = None; - for mut peer in self.pk_map.values_mut() { - // clear any existing handshake state - peer.reset_state().map(|id| ids.push(id)); - - // update precomputed shared secret - if let Some(key) = self.keyst.as_ref() { - peer.ss = *key.sk.diffie_hellman(&peer.pk).as_bytes(); - if *peer.pk.as_bytes() == *key.pk.as_bytes() { - same = Some(peer.pk) - } - } else { - peer.ss = [0u8; 32]; - }; - } + let (ids, same) = self.update_ss(); // release ids from aborted handshakes for id in ids { self.release(id) } - // if we found a peer matching the device public key, remove it. + // if we found a peer matching the device public key + // remove it and return its value to the caller same.map(|pk| { self.pk_map.remove(pk.as_bytes()); pk @@ -119,29 +168,32 @@ impl Device { /// /// * `pk` - The public key to add /// * `identifier` - Associated identifier which can be used to distinguish the peers - pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> { + pub fn add(&mut self, pk: PublicKey, opaque: O) -> Result<(), ConfigError> { // ensure less than 2^20 peers if self.pk_map.len() > MAX_PEER_PER_DEVICE { return Err(ConfigError::new("Too many peers for device")); } - // create peer and precompute static secret - let mut peer = Peer::new( - pk, - self.keyst - .as_ref() - .map(|key| *key.sk.diffie_hellman(&pk).as_bytes()) - .unwrap_or([0u8; 32]), - ); - - // add peer to device - match self.update_ss(&mut peer) { - Some(_) => Err(ConfigError::new("Public key of peer matches the device")), - None => { - self.pk_map.insert(*pk.as_bytes(), peer); - Ok(()) + // error if public key matches device + if let Some(key) = self.keyst.as_ref() { + if pk.as_bytes() == key.pk.as_bytes() { + return Err(ConfigError::new("Public key of peer matches the device")); } } + + // pre-compute shared secret and add to pk_map + self.pk_map.insert( + *pk.as_bytes(), + Peer::new( + pk, + self.keyst + .as_ref() + .map(|key| *key.sk.diffie_hellman(&pk).as_bytes()) + .unwrap_or([0u8; 32]), + opaque, + ), + ); + Ok(()) } /// Remove a peer by public key @@ -163,7 +215,7 @@ impl Device { .remove(pk.as_bytes()) .ok_or(ConfigError::new("Public key not in device"))?; - // pruge the id map (linear scan) + // purge the id map (linear scan) id_map.retain(|_, v| v != pk.as_bytes()); Ok(()) } @@ -231,11 +283,11 @@ impl Device { (_, None) => Err(HandshakeError::UnknownPublicKey), (None, _) => Err(HandshakeError::UnknownPublicKey), (Some(keyst), Some(peer)) => { - let local = self.allocate(rng, peer); + let local = self.allocate(rng, pk); let mut msg = Initiation::default(); // create noise part of initation - noise::create_initiation(rng, keyst, peer, local, &mut msg.noise)?; + noise::create_initiation(rng, keyst, peer, pk, local, &mut msg.noise)?; // add macs to initation peer.macs @@ -253,11 +305,11 @@ impl Device { /// /// * `msg` - Byte slice containing the message (untrusted input) pub fn process<'a, R: RngCore + CryptoRng>( - &self, - rng: &mut R, // rng instance to sample randomness from - msg: &[u8], // message buffer + &'a self, + rng: &mut R, // rng instance to sample randomness from + msg: &[u8], // message buffer src: Option, // optional source endpoint, set when "under load" - ) -> Result { + ) -> Result, HandshakeError> { // ensure type read in-range if msg.len() < 4 { return Err(HandshakeError::InvalidMessageFormat); @@ -303,17 +355,17 @@ impl Device { } // consume the initiation - let (peer, st) = noise::consume_initiation(self, keyst, &msg.noise)?; + let (peer, pk, st) = noise::consume_initiation(self, keyst, &msg.noise)?; // allocate new index for response - let local = self.allocate(rng, peer); + let local = self.allocate(rng, &pk); // prepare memory for response, TODO: take slice for zero allocation let mut resp = Response::default(); // create response (release id on error) - let keys = - noise::create_response(rng, peer, local, st, &mut resp.noise).map_err(|e| { + let keys = noise::create_response(rng, peer, &pk, local, st, &mut resp.noise) + .map_err(|e| { self.release(local); e })?; @@ -324,7 +376,11 @@ impl Device { .generate(resp.noise.as_bytes(), &mut resp.macs); // return unconfirmed keypair and the response as vector - Ok((Some(peer.pk), Some(resp.as_bytes().to_owned()), Some(keys))) + Ok(( + Some(&peer.opaque), + Some(resp.as_bytes().to_owned()), + Some(keys), + )) } TYPE_RESPONSE => { let msg = Response::parse(msg)?; @@ -363,7 +419,7 @@ impl Device { let msg = CookieReply::parse(msg)?; // lookup peer - let peer = self.lookup_id(msg.f_receiver.get())?; + let (peer, _) = self.lookup_id(msg.f_receiver.get())?; // validate cookie reply peer.macs.lock().process(&msg)?; @@ -379,7 +435,7 @@ impl Device { // Internal function // // Return the peer associated with the public key - pub(crate) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer, HandshakeError> { + pub(super) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer, HandshakeError> { self.pk_map .get(pk.as_bytes()) .ok_or(HandshakeError::UnknownPublicKey) @@ -388,11 +444,11 @@ impl Device { // Internal function // // Return the peer currently associated with the receiver identifier - pub(crate) fn lookup_id(&self, id: u32) -> Result<&Peer, HandshakeError> { + pub(super) fn lookup_id(&self, id: u32) -> Result<(&Peer, PublicKey), HandshakeError> { let im = self.id_map.read(); let pk = im.get(&id).ok_or(HandshakeError::UnknownReceiverId)?; match self.pk_map.get(pk) { - Some(peer) => Ok(peer), + Some(peer) => Ok((peer, PublicKey::from(*pk))), _ => unreachable!(), // if the id-lookup succeeded, the peer should exist } } @@ -400,7 +456,7 @@ impl Device { // Internal function // // Allocated a new receiver identifier for the peer - fn allocate(&self, rng: &mut R, peer: &Peer) -> u32 { + fn allocate(&self, rng: &mut R, pk: &PublicKey) -> u32 { loop { let id = rng.gen(); @@ -412,7 +468,7 @@ impl Device { // take write lock and add index let mut m = self.id_map.write(); if !m.contains_key(&id) { - m.insert(id, *peer.pk.as_bytes()); + m.insert(id, *pk.as_bytes()); return id; } } diff --git a/src/wireguard/handshake/macs.rs b/src/wireguard/handshake/macs.rs index 689826b..cb5d7d4 100644 --- a/src/wireguard/handshake/macs.rs +++ b/src/wireguard/handshake/macs.rs @@ -286,8 +286,7 @@ mod tests { use x25519_dalek::StaticSecret; fn new_validator_generator() -> (Validator, Generator) { - let mut rng = OsRng::new().unwrap(); - let sk = StaticSecret::new(&mut rng); + let sk = StaticSecret::new(&mut OsRng); let pk = PublicKey::from(&sk); (Validator::new(pk), Generator::new(pk)) } @@ -296,7 +295,6 @@ mod tests { #[test] fn test_cookie_reply(inner1 : Vec, inner2 : Vec, receiver : u32) { let mut msg = CookieReply::default(); - let mut rng = OsRng::new().expect("failed to create rng"); let mut macs = MacsFooter::default(); let src = "192.0.2.16:8080".parse().unwrap(); let (validator, mut generator) = new_validator_generator(); @@ -309,7 +307,7 @@ mod tests { // check validity of mac1 validator.check_mac1(&inner1[..], &macs).expect("mac1 of inner1 did not validate"); assert_eq!(validator.check_mac2(&inner1[..], &src, &macs), false, "mac2 of inner2 did not validate"); - validator.create_cookie_reply(&mut rng, receiver, &src, &macs, &mut msg); + validator.create_cookie_reply(&mut OsRng, receiver, &src, &macs, &mut msg); // consume cookie reply generator.process(&msg).expect("failed to process CookieReply"); diff --git a/src/wireguard/handshake/noise.rs b/src/wireguard/handshake/noise.rs index 072ac13..9e431cf 100644 --- a/src/wireguard/handshake/noise.rs +++ b/src/wireguard/handshake/noise.rs @@ -10,7 +10,7 @@ use hmac::Hmac; use aead::{Aead, NewAead, Payload}; use chacha20poly1305::ChaCha20Poly1305; -use rand::{CryptoRng, RngCore}; +use rand_core::{CryptoRng, RngCore}; use log::debug; @@ -215,20 +215,21 @@ mod tests { } } -pub fn create_initiation( +pub(super) fn create_initiation( rng: &mut R, keyst: &KeyState, - peer: &Peer, + peer: &Peer, + pk: &PublicKey, local: u32, msg: &mut NoiseInitiation, ) -> Result<(), HandshakeError> { - debug!("create initation"); + debug!("create initiation"); clear_stack_on_return(CLEAR_PAGES, || { // initialize state let ck = INITIAL_CK; let hs = INITIAL_HS; - let hs = HASH!(&hs, peer.pk.as_bytes()); + let hs = HASH!(&hs, pk.as_bytes()); msg.f_type.set(TYPE_INITIATION as u32); msg.f_sender.set(local); // from us @@ -252,7 +253,7 @@ pub fn create_initiation( // (C, k) := Kdf2(C, DH(E_priv, S_pub)) - let (ck, key) = KDF2!(&ck, eph_sk.diffie_hellman(&peer.pk).as_bytes()); + let (ck, key) = KDF2!(&ck, eph_sk.diffie_hellman(&pk).as_bytes()); // msg.static := Aead(k, 0, S_pub, H) @@ -297,12 +298,12 @@ pub fn create_initiation( }) } -pub fn consume_initiation<'a>( - device: &'a Device, +pub(super) fn consume_initiation<'a, O>( + device: &'a Device, keyst: &KeyState, msg: &NoiseInitiation, -) -> Result<(&'a Peer, TemporaryState), HandshakeError> { - debug!("consume initation"); +) -> Result<(&'a Peer, PublicKey, TemporaryState), HandshakeError> { + debug!("consume initiation"); clear_stack_on_return(CLEAR_PAGES, || { // initialize new state @@ -369,13 +370,18 @@ pub fn consume_initiation<'a>( // return state (to create response) - Ok((peer, (msg.f_sender.get(), eph_r_pk, hs, ck))) + Ok(( + peer, + PublicKey::from(pk), + (msg.f_sender.get(), eph_r_pk, hs, ck), + )) }) } -pub fn create_response( +pub(super) fn create_response( rng: &mut R, - peer: &Peer, + peer: &Peer, + pk: &PublicKey, local: u32, // sending identifier state: TemporaryState, // state from "consume_initiation" msg: &mut NoiseResponse, // resulting response @@ -388,7 +394,7 @@ pub fn create_response( msg.f_type.set(TYPE_RESPONSE as u32); msg.f_sender.set(local); // from us - msg.f_receiver.set(receiver); // to the sender of the initation + msg.f_receiver.set(receiver); // to the sender of the initiation // (E_priv, E_pub) := DH-Generate() @@ -413,7 +419,7 @@ pub fn create_response( // C := Kdf1(C, DH(E_priv, S_pub)) - let ck = KDF1!(&ck, eph_sk.diffie_hellman(&peer.pk).as_bytes()); + let ck = KDF1!(&ck, eph_sk.diffie_hellman(&pk).as_bytes()); // (C, tau, k) := Kdf3(C, Q) @@ -460,15 +466,15 @@ pub fn create_response( * allow concurrent processing of potential responses to the initiation, * in order to better mitigate DoS from malformed response messages. */ -pub fn consume_response( - device: &Device, +pub(super) fn consume_response<'a, O>( + device: &'a Device, keyst: &KeyState, msg: &NoiseResponse, -) -> Result { +) -> Result, HandshakeError> { debug!("consume response"); clear_stack_on_return(CLEAR_PAGES, || { // retrieve peer and copy initiation state - let peer = device.lookup_id(msg.f_receiver.get())?; + let (peer, _) = device.lookup_id(msg.f_receiver.get())?; let (hs, ck, local, eph_sk) = match *peer.state.lock() { State::InitiationSent { @@ -537,7 +543,7 @@ pub fn consume_response( // return confirmed key-pair Ok(( - Some(peer.pk), + Some(&peer.opaque), None, Some(KeyPair { birth, diff --git a/src/wireguard/handshake/peer.rs b/src/wireguard/handshake/peer.rs index a4df560..f4d15fc 100644 --- a/src/wireguard/handshake/peer.rs +++ b/src/wireguard/handshake/peer.rs @@ -22,19 +22,21 @@ const TIME_BETWEEN_INITIATIONS: Duration = Duration::from_millis(20); * * This type is only for internal use and not exposed. */ -pub struct Peer { +pub(super) struct Peer { + // opaque type which identifies a peer + pub opaque: O, + // mutable state - pub(crate) state: Mutex, - pub(crate) timestamp: Mutex>, - pub(crate) last_initiation_consumption: Mutex>, + pub state: Mutex, + pub timestamp: Mutex>, + pub last_initiation_consumption: Mutex>, // state related to DoS mitigation fields - pub(crate) macs: Mutex, + pub macs: Mutex, // constant state - pub(crate) pk: PublicKey, // public key of peer - pub(crate) ss: [u8; 32], // precomputed DH(static, static) - pub(crate) psk: Psk, // psk of peer + pub ss: [u8; 32], // precomputed DH(static, static) + pub psk: Psk, // psk of peer } pub enum State { @@ -60,14 +62,14 @@ impl Drop for State { } } -impl Peer { - pub fn new(pk: PublicKey, ss: [u8; 32]) -> Self { +impl Peer { + pub fn new(pk: PublicKey, ss: [u8; 32], opaque: O) -> Self { Self { + opaque, macs: Mutex::new(macs::Generator::new(pk)), state: Mutex::new(State::Reset), timestamp: Mutex::new(None), last_initiation_consumption: Mutex::new(None), - pk, ss, psk: [0u8; 32], } @@ -88,7 +90,7 @@ impl Peer { /// * ts_new - The associated timestamp pub fn check_replay_flood( &self, - device: &Device, + device: &Device, timestamp_new: ×tamp::TAI64N, ) -> Result<(), HandshakeError> { let mut state = self.state.lock(); diff --git a/src/wireguard/handshake/tests.rs b/src/wireguard/handshake/tests.rs index ff27b3e..bfdc5ab 100644 --- a/src/wireguard/handshake/tests.rs +++ b/src/wireguard/handshake/tests.rs @@ -12,8 +12,10 @@ use x25519_dalek::StaticSecret; use super::messages::{Initiation, Response}; -fn setup_devices(rng: &mut R) -> (PublicKey, Device, PublicKey, Device) { - // generate new keypairs +fn setup_devices( + rng: &mut R, +) -> (PublicKey, Device, PublicKey, Device) { + // generate new key pairs let sk1 = StaticSecret::new(rng); let pk1 = PublicKey::from(&sk1); @@ -26,7 +28,7 @@ fn setup_devices(rng: &mut R) -> (PublicKey, Device, Pub let mut psk = [0u8; 32]; rng.fill_bytes(&mut psk[..]); - // intialize devices on both ends + // initialize devices on both ends let mut dev1 = Device::new(); let mut dev2 = Device::new(); @@ -34,8 +36,8 @@ fn setup_devices(rng: &mut R) -> (PublicKey, Device, Pub dev1.set_sk(Some(sk1)); dev2.set_sk(Some(sk2)); - dev1.add(pk2).unwrap(); - dev2.add(pk1).unwrap(); + dev1.add(pk2, O::default()).unwrap(); + dev2.add(pk1, O::default()).unwrap(); dev1.set_psk(pk2, psk).unwrap(); dev2.set_psk(pk1, psk).unwrap(); @@ -49,45 +51,44 @@ fn wait() { /* Test longest possible handshake interaction (7 messages): * - * 1. I -> R (initation) + * 1. I -> R (initiation) * 2. I <- R (cookie reply) - * 3. I -> R (initation) + * 3. I -> R (initiation) * 4. I <- R (response) * 5. I -> R (cookie reply) - * 6. I -> R (initation) + * 6. I -> R (initiation) * 7. I <- R (response) */ #[test] fn handshake_under_load() { - let mut rng = OsRng::new().unwrap(); - let (_pk1, dev1, pk2, dev2) = setup_devices(&mut rng); + let (_pk1, dev1, pk2, dev2): (_, Device, _, _) = setup_devices(&mut OsRng); let src1: SocketAddr = "172.16.0.1:8080".parse().unwrap(); let src2: SocketAddr = "172.16.0.2:7070".parse().unwrap(); - // 1. device-1 : create first initation - let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); + // 1. device-1 : create first initiation + let msg_init = dev1.begin(&mut OsRng, &pk2).unwrap(); // 2. device-2 : responds with CookieReply - let msg_cookie = match dev2.process(&mut rng, &msg_init, Some(src1)).unwrap() { + let msg_cookie = match dev2.process(&mut OsRng, &msg_init, Some(src1)).unwrap() { (None, Some(msg), None) => msg, _ => panic!("unexpected response"), }; // device-1 : processes CookieReply (no response) - match dev1.process(&mut rng, &msg_cookie, Some(src2)).unwrap() { + match dev1.process(&mut OsRng, &msg_cookie, Some(src2)).unwrap() { (None, None, None) => (), _ => panic!("unexpected response"), } - // avoid initation flood detection + // avoid initiation flood detection wait(); - // 3. device-1 : create second initation - let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); + // 3. device-1 : create second initiation + let msg_init = dev1.begin(&mut OsRng, &pk2).unwrap(); // 4. device-2 : responds with noise response - let msg_response = match dev2.process(&mut rng, &msg_init, Some(src1)).unwrap() { + let msg_response = match dev2.process(&mut OsRng, &msg_init, Some(src1)).unwrap() { (Some(_), Some(msg), Some(kp)) => { assert_eq!(kp.initiator, false); msg @@ -96,25 +97,25 @@ fn handshake_under_load() { }; // 5. device-1 : responds with CookieReply - let msg_cookie = match dev1.process(&mut rng, &msg_response, Some(src2)).unwrap() { + let msg_cookie = match dev1.process(&mut OsRng, &msg_response, Some(src2)).unwrap() { (None, Some(msg), None) => msg, _ => panic!("unexpected response"), }; // device-2 : processes CookieReply (no response) - match dev2.process(&mut rng, &msg_cookie, Some(src1)).unwrap() { + match dev2.process(&mut OsRng, &msg_cookie, Some(src1)).unwrap() { (None, None, None) => (), _ => panic!("unexpected response"), } - // avoid initation flood detection + // avoid initiation flood detection wait(); - // 6. device-1 : create third initation - let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); + // 6. device-1 : create third initiation + let msg_init = dev1.begin(&mut OsRng, &pk2).unwrap(); // 7. device-2 : responds with noise response - let (msg_response, kp1) = match dev2.process(&mut rng, &msg_init, Some(src1)).unwrap() { + let (msg_response, kp1) = match dev2.process(&mut OsRng, &msg_init, Some(src1)).unwrap() { (Some(_), Some(msg), Some(kp)) => { assert_eq!(kp.initiator, false); (msg, kp) @@ -123,7 +124,7 @@ fn handshake_under_load() { }; // device-1 : process noise response - let kp2 = match dev1.process(&mut rng, &msg_response, Some(src2)).unwrap() { + let kp2 = match dev1.process(&mut OsRng, &msg_response, Some(src2)).unwrap() { (Some(_), None, Some(kp)) => { assert_eq!(kp.initiator, true); kp @@ -137,8 +138,7 @@ fn handshake_under_load() { #[test] fn handshake_no_load() { - let mut rng = OsRng::new().unwrap(); - let (pk1, mut dev1, pk2, mut dev2) = setup_devices(&mut rng); + let (pk1, mut dev1, pk2, mut dev2): (_, Device, _, _) = setup_devices(&mut OsRng); // do a few handshakes (every handshake should succeed) @@ -147,7 +147,7 @@ fn handshake_no_load() { // create initiation - let msg1 = dev1.begin(&mut rng, &pk2).unwrap(); + let msg1 = dev1.begin(&mut OsRng, &pk2).unwrap(); println!("msg1 = {} : {} bytes", hex::encode(&msg1[..]), msg1.len()); println!( @@ -158,7 +158,7 @@ fn handshake_no_load() { // process initiation and create response let (_, msg2, ks_r) = dev2 - .process(&mut rng, &msg1, None) + .process(&mut OsRng, &msg1, None) .expect("failed to process initiation"); let ks_r = ks_r.unwrap(); @@ -175,7 +175,7 @@ fn handshake_no_load() { // process response and obtain confirmed key-pair let (_, msg3, ks_i) = dev1 - .process(&mut rng, &msg2, None) + .process(&mut OsRng, &msg2, None) .expect("failed to process response"); let ks_i = ks_i.unwrap(); @@ -188,7 +188,7 @@ fn handshake_no_load() { dev1.release(ks_i.local_id()); dev2.release(ks_r.local_id()); - // avoid initation flood detection + // avoid initiation flood detection wait(); } diff --git a/src/wireguard/handshake/types.rs b/src/wireguard/handshake/types.rs index 5f984cc..ed2fcbb 100644 --- a/src/wireguard/handshake/types.rs +++ b/src/wireguard/handshake/types.rs @@ -1,10 +1,8 @@ +use super::super::types::KeyPair; + use std::error::Error; use std::fmt; -use x25519_dalek::PublicKey; - -use super::super::types::KeyPair; - /* Internal types for the noise IKpsk2 implementation */ // config error @@ -79,10 +77,10 @@ impl Error for HandshakeError { } } -pub type Output = ( - Option, // external identifier associated with peer - Option>, // message to send - Option, // resulting key-pair of successful handshake +pub type Output<'a, O> = ( + Option<&'a O>, // external identifier associated with peer + Option>, // message to send + Option, // resulting key-pair of successful handshake ); // preshared key diff --git a/src/wireguard/peer.rs b/src/wireguard/peer.rs index 1af4df3..b3656fe 100644 --- a/src/wireguard/peer.rs +++ b/src/wireguard/peer.rs @@ -31,7 +31,7 @@ pub struct PeerInner { pub handshake_queued: AtomicBool, // is a handshake job currently queued for the peer? // stats and configuration - pub pk: PublicKey, // public key, DISCUSS: avoid this. TODO: remove + pub pk: PublicKey, // public key pub rx_bytes: AtomicU64, // received bytes pub tx_bytes: AtomicU64, // transmitted bytes diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs index f903a8e..6c59491 100644 --- a/src/wireguard/router/device.rs +++ b/src/wireguard/router/device.rs @@ -142,7 +142,7 @@ impl> DeviceHandle< }; // start worker threads - let mut threads = Vec::with_capacity(num_workers); + let mut threads = Vec::with_capacity(4 * num_workers); // inbound/decryption workers for _ in 0..num_workers { diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs index b8110f0..8fe2e1c 100644 --- a/src/wireguard/router/peer.rs +++ b/src/wireguard/router/peer.rs @@ -204,7 +204,7 @@ impl> PeerInner { let outbound = self.device.outbound.read(); if outbound.0 { diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index bf550ef..ecbb9c1 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -21,9 +21,6 @@ use std::sync::Mutex as StdMutex; use std::thread; use std::time::Instant; -use std::collections::hash_map::Entry; -use std::collections::HashMap; - use hjul::Runner; use rand::rngs::OsRng; use rand::Rng; @@ -50,14 +47,13 @@ pub struct WireguardInner { // outbound writer pub send: RwLock>, - // identity and configuration map - pub peers: RwLock>>, + // peer map + pub peers: RwLock>>, // cryptokey router pub router: router::Device, T::Writer, B::Writer>, // handshake related state - pub handshake: RwLock, pub last_under_load: Mutex, pub pending: AtomicUsize, // number of pending handshake packets in queue pub queue: ParallelQueue>, @@ -142,7 +138,7 @@ impl WireGuard { self.router.down(); // set all peers down (stops timers) - for peer in self.peers.write().values() { + for (_, peer) in self.peers.write().iter() { peer.down(); } @@ -163,11 +159,11 @@ impl WireGuard { return; } - // enable tranmission from router + // enable transmission from router self.router.up(); // set all peers up (restarts timers) - for peer in self.peers.write().values() { + for (_, peer) in self.peers.write().iter() { peer.up(); } @@ -179,54 +175,51 @@ impl WireGuard { } pub fn remove_peer(&self, pk: &PublicKey) { - if self.handshake.write().remove(pk).is_ok() { - self.peers.write().remove(pk.as_bytes()); - } + let _ = self.peers.write().remove(pk); } pub fn lookup_peer(&self, pk: &PublicKey) -> Option> { - self.peers.read().get(pk.as_bytes()).map(|p| p.clone()) + self.peers.read().get(pk).map(|p| p.clone()) } pub fn list_peers(&self) -> Vec> { let peers = self.peers.read(); let mut list = Vec::with_capacity(peers.len()); for (k, v) in peers.iter() { - debug_assert!(k == v.pk.as_bytes()); + debug_assert!(k.as_bytes() == v.pk.as_bytes()); list.push(v.clone()); } list } pub fn set_key(&self, sk: Option) { - let mut handshake = self.handshake.write(); - handshake.set_sk(sk); + let mut peers = self.peers.write(); + peers.set_sk(sk); self.router.clear_sending_keys(); - // handshake lock is released and new handshakes can be initated } pub fn get_sk(&self) -> Option { - self.handshake + self.peers .read() .get_sk() .map(|sk| StaticSecret::from(sk.to_bytes())) } pub fn set_psk(&self, pk: PublicKey, psk: [u8; 32]) -> bool { - self.handshake.write().set_psk(pk, psk).is_ok() + self.peers.write().set_psk(pk, psk).is_ok() } pub fn get_psk(&self, pk: &PublicKey) -> Option<[u8; 32]> { - self.handshake.read().get_psk(pk).ok() + self.peers.read().get_psk(pk).ok() } pub fn add_peer(&self, pk: PublicKey) -> bool { - if self.peers.read().contains_key(pk.as_bytes()) { + let mut peers = self.peers.write(); + if peers.contains_key(&pk) { return false; } - let mut rng = OsRng::new().unwrap(); let state = Arc::new(PeerInner { - id: rng.gen(), + id: OsRng.gen(), pk, wg: self.clone(), walltime_last_handshake: Mutex::new(None), @@ -243,33 +236,19 @@ impl WireGuard { // form WireGuard peer let peer = Peer { router, state }; + // prevent up/down while inserting + let enabled = self.enabled.read(); + + /* The need for dummy timers arises from the chicken-egg + * problem of the timer callbacks being able to set timers themselves. + * + * This is in fact the only place where the write lock is ever taken. + * TODO: Consider the ease of using atomic pointers instead. + */ + *peer.timers.write() = Timers::new(&*self.runner.lock(), *enabled, peer.clone()); + // finally, add the peer to the wireguard device - let mut peers = self.peers.write(); - match peers.entry(*pk.as_bytes()) { - Entry::Occupied(_) => false, - Entry::Vacant(vacancy) => { - // check that the public key does not cause conflict with the private key of the device - let ok_pk = self.handshake.write().add(pk).is_ok(); - if !ok_pk { - return false; - } - - // prevent up/down while inserting - let enabled = self.enabled.read(); - - /* The need for dummy timers arises from the chicken-egg - * problem of the timer callbacks being able to set timers themselves. - * - * This is in fact the only place where the write lock is ever taken. - * TODO: Consider the ease of using atomic pointers instead. - */ - *peer.timers.write() = Timers::new(&*self.runner.lock(), *enabled, peer.clone()); - - // insert into peer map (takes ownership and ensures that the peer is not dropped) - vacancy.insert(peer); - true - } - } + peers.add(pk, peer).is_ok() } /// Begin consuming messages from the reader. @@ -311,9 +290,6 @@ impl WireGuard { // workers equal to number of physical cores let cpus = num_cpus::get(); - // create device state - let mut rng = OsRng::new().unwrap(); - // create handshake queue let (tx, mut rxs) = ParallelQueue::new(cpus, 128); @@ -322,14 +298,13 @@ impl WireGuard { inner: Arc::new(WireguardInner { enabled: RwLock::new(false), tun_readers: WaitCounter::new(), - id: rng.gen(), + id: OsRng.gen(), mtu: AtomicUsize::new(0), - peers: RwLock::new(HashMap::new()), last_under_load: Mutex::new(Instant::now() - TIME_HORIZON), send: RwLock::new(None), router: router::Device::new(num_cpus::get(), writer), // router owns the writing half pending: AtomicUsize::new(0), - handshake: RwLock::new(handshake::Device::new()), + peers: RwLock::new(handshake::Device::new()), runner: Mutex::new(Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY)), queue: tx, }), diff --git a/src/wireguard/workers.rs b/src/wireguard/workers.rs index e1d3899..c1a2af7 100644 --- a/src/wireguard/workers.rs +++ b/src/wireguard/workers.rs @@ -152,9 +152,6 @@ pub fn handshake_worker( ) { debug!("{} : handshake worker, started", wg); - // prepare OsRng instance for this thread - let mut rng = OsRng::new().expect("Unable to obtain a CSPRNG"); - // process elements from the handshake queue for job in rx { // check if under load @@ -181,11 +178,11 @@ pub fn handshake_worker( // de-multiplex staged handshake jobs and handshake messages match job { - HandshakeJob::Message(msg, src) => { + HandshakeJob::Message(msg, mut src) => { // process message - let device = wg.handshake.read(); + let device = wg.peers.read(); match device.process( - &mut rng, + &mut OsRng, &msg[..], if under_load { Some(src.into_address()) @@ -193,7 +190,7 @@ pub fn handshake_worker( None }, ) { - Ok((pk, resp, keypair)) => { + Ok((peer, resp, keypair)) => { // send response (might be cookie reply or handshake response) let mut resp_len: u64 = 0; if let Some(msg) = resp { @@ -204,7 +201,7 @@ pub fn handshake_worker( "{} : handshake worker, send response ({} bytes)", wg, resp_len ); - let _ = writer.write(&msg[..], &src).map_err(|e| { + let _ = writer.write(&msg[..], &mut src).map_err(|e| { debug!( "{} : handshake worker, failed to send response, error = {}", wg, @@ -215,56 +212,55 @@ pub fn handshake_worker( } // update peer state - if let Some(pk) = pk { + if let Some(peer) = peer { // authenticated handshake packet received - if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { - // add to rx_bytes and tx_bytes - let req_len = msg.len() as u64; - peer.rx_bytes.fetch_add(req_len, Ordering::Relaxed); - peer.tx_bytes.fetch_add(resp_len, Ordering::Relaxed); - // update endpoint - peer.router.set_endpoint(src); + // add to rx_bytes and tx_bytes + let req_len = msg.len() as u64; + peer.rx_bytes.fetch_add(req_len, Ordering::Relaxed); + peer.tx_bytes.fetch_add(resp_len, Ordering::Relaxed); - if resp_len > 0 { - // update timers after sending handshake response - debug!("{} : handshake worker, handshake response sent", wg); - peer.state.sent_handshake_response(); - } else { - // update timers after receiving handshake response - debug!( - "{} : handshake worker, handshake response was received", - wg - ); - peer.state.timers_handshake_complete(); - } + // update endpoint + peer.router.set_endpoint(src); + + if resp_len > 0 { + // update timers after sending handshake response + debug!("{} : handshake worker, handshake response sent", wg); + peer.state.sent_handshake_response(); + } else { + // update timers after receiving handshake response + debug!( + "{} : handshake worker, handshake response was received", + wg + ); + peer.state.timers_handshake_complete(); + } - // add any new keypair to peer - keypair.map(|kp| { - debug!("{} : handshake worker, new keypair for {}", wg, peer); + // add any new keypair to peer + keypair.map(|kp| { + debug!("{} : handshake worker, new keypair for {}", wg, peer); - // this means that a handshake response was processed or sent - peer.timers_session_derived(); + // this means that a handshake response was processed or sent + peer.timers_session_derived(); - // free any unused ids - for id in peer.router.add_keypair(kp) { - device.release(id); - } - }); - } + // free any unused ids + for id in peer.router.add_keypair(kp) { + device.release(id); + } + }); } } Err(e) => debug!("{} : handshake worker, error = {:?}", wg, e), } } HandshakeJob::New(pk) => { - if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { + if let Some(peer) = wg.peers.read().get(&pk) { debug!( "{} : handshake worker, new handshake requested for {}", wg, peer ); - let device = wg.handshake.read(); - let _ = device.begin(&mut rng, &peer.pk).map(|msg| { + let device = wg.peers.read(); + let _ = device.begin(&mut OsRng, &peer.pk).map(|msg| { let _ = peer.router.send(&msg[..]).map_err(|e| { debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e) }); -- cgit v1.2.3-59-g8ed1b