From 6311aa34022a24224b1dc8d0427cd72dd42e9396 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Wed, 18 Sep 2019 15:31:10 +0200 Subject: WIP: TUN IO worker Also removed the type parameters from the handshake device. --- src/handshake/device.rs | 58 +++++++------ src/handshake/noise.rs | 23 +++-- src/handshake/peer.rs | 23 ++--- src/handshake/types.rs | 10 ++- src/router/device.rs | 5 +- src/wireguard.rs | 222 +++++++++++++++++++++++++++++++----------------- 6 files changed, 212 insertions(+), 129 deletions(-) (limited to 'src') diff --git a/src/handshake/device.rs b/src/handshake/device.rs index 638d63f..2a06fa7 100644 --- a/src/handshake/device.rs +++ b/src/handshake/device.rs @@ -21,11 +21,11 @@ use super::types::*; const MAX_PEER_PER_DEVICE: usize = 1 << 20; -pub struct Device { +pub struct Device { pub sk: StaticSecret, // static secret key pub pk: PublicKey, // static public key macs: macs::Validator, // validator for the mac fields - pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state + pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state id_map: RwLock>, // receiver ids -> public key limiter: Mutex, } @@ -33,16 +33,13 @@ pub struct Device { /* A mutable reference to the device needs to be held during configuration. * Wrapping the device in a RwLock enables peer config after "configuration time" */ -impl Device -where - T: Clone, -{ +impl Device { /// Initialize a new handshake state machine /// /// # Arguments /// /// * `sk` - x25519 scalar representing the local private key - pub fn new(sk: StaticSecret) -> Device { + pub fn new(sk: StaticSecret) -> Device { let pk = PublicKey::from(&sk); Device { pk, @@ -54,6 +51,25 @@ where } } + /// Update the secret key of the device + /// + /// # Arguments + /// + /// * `sk` - x25519 scalar representing the local private key + pub fn set_sk(&mut self, sk: StaticSecret) { + // update secret and public key + let pk = PublicKey::from(&sk); + self.sk = sk; + self.pk = pk; + self.macs = macs::Validator::new(pk); + + // recalculate the shared secrets for every peer + for &mut peer in self.pk_map.values_mut() { + peer.reset_state().map(|id| self.release(id)); + peer.ss = self.sk.diffie_hellman(&peer.pk) + } + } + /// Add a new public key to the state machine /// To remove public keys, you must create a new machine instance /// @@ -61,7 +77,7 @@ where /// /// * `pk` - The public key to add /// * `identifier` - Associated identifier which can be used to distinguish the peers - pub fn add(&mut self, pk: PublicKey, identifier: T) -> Result<(), ConfigError> { + pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> { // check that the pk is not added twice if let Some(_) = self.pk_map.get(pk.as_bytes()) { return Err(ConfigError::new("Duplicate public key")); @@ -80,10 +96,8 @@ where } // map the public key to the peer state - self.pk_map.insert( - *pk.as_bytes(), - Peer::new(identifier, pk, self.sk.diffie_hellman(&pk)), - ); + self.pk_map + .insert(*pk.as_bytes(), Peer::new(pk, self.sk.diffie_hellman(&pk))); Ok(()) } @@ -204,7 +218,7 @@ where rng: &mut R, // rng instance to sample randomness from msg: &[u8], // message buffer src: Option<&'a S>, // optional source endpoint, set when "under load" - ) -> Result, HandshakeError> + ) -> Result where &'a S: Into<&'a SocketAddr>, { @@ -269,11 +283,7 @@ where .generate(resp.noise.as_bytes(), &mut resp.macs); // return unconfirmed keypair and the response as vector - Ok(( - Some(peer.identifier.clone()), - Some(resp.as_bytes().to_owned()), - Some(keys), - )) + Ok((Some(peer.pk), Some(resp.as_bytes().to_owned()), Some(keys))) } TYPE_RESPONSE => { let msg = Response::parse(msg)?; @@ -328,7 +338,7 @@ where // Internal function // // Return the peer associated with the public key - pub(crate) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer, HandshakeError> { + pub(crate) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer, HandshakeError> { self.pk_map .get(pk.as_bytes()) .ok_or(HandshakeError::UnknownPublicKey) @@ -337,7 +347,7 @@ where // Internal function // // Return the peer currently associated with the receiver identifier - pub(crate) fn lookup_id(&self, id: u32) -> Result<&Peer, HandshakeError> { + pub(crate) fn lookup_id(&self, id: u32) -> Result<&Peer, HandshakeError> { let im = self.id_map.read(); let pk = im.get(&id).ok_or(HandshakeError::UnknownReceiverId)?; match self.pk_map.get(pk) { @@ -349,7 +359,7 @@ where // Internal function // // Allocated a new receiver identifier for the peer - fn allocate(&self, rng: &mut R, peer: &Peer) -> u32 { + fn allocate(&self, rng: &mut R, peer: &Peer) -> u32 { loop { let id = rng.gen(); @@ -380,7 +390,7 @@ mod tests { fn setup_devices( rng: &mut R, - ) -> (PublicKey, Device, PublicKey, Device) { + ) -> (PublicKey, Device, PublicKey, Device) { // generate new keypairs let sk1 = StaticSecret::new(rng); @@ -399,8 +409,8 @@ mod tests { let mut dev1 = Device::new(sk1); let mut dev2 = Device::new(sk2); - dev1.add(pk2, 1337).unwrap(); - dev2.add(pk1, 2600).unwrap(); + dev1.add(pk2).unwrap(); + dev2.add(pk1).unwrap(); dev1.set_psk(pk2, Some(psk)).unwrap(); dev2.set_psk(pk1, Some(psk)).unwrap(); diff --git a/src/handshake/noise.rs b/src/handshake/noise.rs index eafb9e9..1dc8402 100644 --- a/src/handshake/noise.rs +++ b/src/handshake/noise.rs @@ -215,10 +215,10 @@ mod tests { } } -pub fn create_initiation( +pub fn create_initiation( rng: &mut R, - device: &Device, - peer: &Peer, + device: &Device, + peer: &Peer, sender: u32, msg: &mut NoiseInitiation, ) -> Result<(), HandshakeError> { @@ -296,10 +296,10 @@ pub fn create_initiation( }) } -pub fn consume_initiation<'a, T: Clone>( - device: &'a Device, +pub fn consume_initiation<'a>( + device: &'a Device, msg: &NoiseInitiation, -) -> Result<(&'a Peer, TemporaryState), HandshakeError> { +) -> Result<(&'a Peer, TemporaryState), HandshakeError> { clear_stack_on_return(CLEAR_PAGES, || { // initialize new state @@ -370,9 +370,9 @@ pub fn consume_initiation<'a, T: Clone>( }) } -pub fn create_response( +pub fn create_response( rng: &mut R, - peer: &Peer, + peer: &Peer, sender: u32, // sending identifier state: TemporaryState, // state from "consume_initiation" msg: &mut NoiseResponse, // resulting response @@ -456,10 +456,7 @@ pub fn create_response( * allow concurrent processing of potential responses to the initiation, * in order to better mitigate DoS from malformed response messages. */ -pub fn consume_response( - device: &Device, - msg: &NoiseResponse, -) -> Result, HandshakeError> { +pub fn consume_response(device: &Device, msg: &NoiseResponse) -> Result { clear_stack_on_return(CLEAR_PAGES, || { // retrieve peer and copy initiation state let peer = device.lookup_id(msg.f_receiver.get())?; @@ -530,7 +527,7 @@ pub fn consume_response( // return confirmed key-pair Ok(( - Some(peer.identifier.clone()), + Some(peer.pk), None, Some(KeyPair { birth, diff --git a/src/handshake/peer.rs b/src/handshake/peer.rs index 4c6f2fd..6a85cee 100644 --- a/src/handshake/peer.rs +++ b/src/handshake/peer.rs @@ -1,5 +1,7 @@ use lazy_static::lazy_static; use spin::Mutex; + +use std::mem; use std::time::{Duration, Instant}; use generic_array::typenum::U32; @@ -24,10 +26,7 @@ lazy_static! { * * This type is only for internal use and not exposed. */ -pub struct Peer { - // external identifier - pub(crate) identifier: T, - +pub struct Peer { // mutable state pub(crate) state: Mutex, pub(crate) timestamp: Mutex>, @@ -65,18 +64,13 @@ impl Drop for State { } } -impl Peer -where - T: Clone, -{ +impl Peer { pub fn new( - identifier: T, // external identifier pk: PublicKey, // public key of peer ss: SharedSecret, // precomputed DH(static, static) ) -> Self { Self { macs: Mutex::new(macs::Generator::new(pk)), - identifier: identifier, state: Mutex::new(State::Reset), timestamp: Mutex::new(None), last_initiation_consumption: Mutex::new(None), @@ -94,6 +88,13 @@ where *self.state.lock() = state_new; } + pub fn reset_state(&self) -> Option { + match mem::replace(&mut *self.state.lock(), State::Reset) { + State::InitiationSent { sender, .. } => Some(sender), + _ => None, + } + } + /// Set the mutable state of the peer conditioned on the timestamp being newer /// /// # Arguments @@ -102,7 +103,7 @@ where /// * ts_new - The associated timestamp pub fn check_replay_flood( &self, - device: &Device, + device: &Device, timestamp_new: ×tamp::TAI64N, ) -> Result<(), HandshakeError> { let mut state = self.state.lock(); diff --git a/src/handshake/types.rs b/src/handshake/types.rs index 7b190ec..ba71ec4 100644 --- a/src/handshake/types.rs +++ b/src/handshake/types.rs @@ -1,6 +1,8 @@ use std::error::Error; use std::fmt; +use x25519_dalek::PublicKey; + use crate::types::KeyPair; /* Internal types for the noise IKpsk2 implementation */ @@ -77,10 +79,10 @@ impl Error for HandshakeError { } } -pub type Output = ( - Option, // external identifier associated with peer - Option>, // message to send - Option, // resulting key-pair of successful handshake +pub type Output = ( + Option, // external identifier associated with peer + Option>, // message to send + Option, // resulting key-pair of successful handshake ); // preshared key diff --git a/src/router/device.rs b/src/router/device.rs index e9e0fb3..e8250cb 100644 --- a/src/router/device.rs +++ b/src/router/device.rs @@ -121,7 +121,6 @@ fn get_route( } impl Device { - pub fn new(num_workers: usize, tun: T, bind: B) -> Device { // allocate shared device state let mut inner = DeviceInner { @@ -149,6 +148,10 @@ impl Device { } } + /// A new secret key has been set for the device. + /// According to WireGuard semantics, this should cause all "sending" keys to be discarded. + pub fn new_sk(&self) {} + /// Adds a new peer to the device /// /// # Returns diff --git a/src/wireguard.rs b/src/wireguard.rs index 2c166b4..f98369f 100644 --- a/src/wireguard.rs +++ b/src/wireguard.rs @@ -2,17 +2,20 @@ use crate::handshake; use crate::router; use crate::types::{Bind, Endpoint, Tun}; -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; use std::thread; use std::time::{Duration, Instant}; +use std::collections::HashMap; + use log::debug; use rand::rngs::OsRng; +use spin::{Mutex, RwLock}; use byteorder::{ByteOrder, LittleEndian}; -use crossbeam_channel::bounded; -use x25519_dalek::StaticSecret; +use crossbeam_channel::{bounded, Sender}; +use x25519_dalek::{PublicKey, StaticSecret}; const SIZE_HANDSHAKE_QUEUE: usize = 128; const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4; @@ -22,8 +25,10 @@ const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000); pub struct Peer(Arc>); pub struct PeerInner { - peer: router::Peer, + router: router::Peer, timers: Timers, + rx: AtomicU64, + tx: AtomicU64, } pub struct Timers {} @@ -40,96 +45,96 @@ impl router::Callbacks for Events { fn need_key(t: &Timers) {} } -pub struct Wireguard { - router: Arc>, - handshake: Option>>, +struct Handshake { + device: handshake::Device, + active: bool, } -impl Wireguard { - fn start(&self) {} - - fn new(tun: T, bind: B, sk: StaticSecret) -> Wireguard { - let router = Arc::new(router::Device::new( - num_cpus::get(), - tun.clone(), - bind.clone(), - )); - - let handshake_staged = Arc::new(AtomicUsize::new(0)); - let handshake_device: Arc>> = - Arc::new(handshake::Device::new(sk)); +struct WireguardInner { + // identify and configuration map + peers: RwLock>>, - // start UDP read IO thread - let (handshake_tx, handshake_rx) = bounded(128); - { - let tun = tun.clone(); - let bind = bind.clone(); - thread::spawn(move || { - let mut under_load = - Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000); + // cryptkey routing + router: router::Device, - loop { - // read UDP packet into vector - let size = tun.mtu() + 148; // maximum message size - let mut msg: Vec = - Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX); - msg.resize(size, 0); - let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error - msg.truncate(size); + // handshake related state + handshake: RwLock, + under_load: AtomicBool, + pending: AtomicUsize, // num of pending handshake packets in queue + queue: Mutex, B::Endpoint)>>, - // message type de-multiplexer - if msg.len() < std::mem::size_of::() { - continue; - } + // IO + bind: B, +} - match LittleEndian::read_u32(&msg[..]) { - handshake::TYPE_COOKIE_REPLY - | handshake::TYPE_INITIATION - | handshake::TYPE_RESPONSE => { - // detect if under load - if handshake_staged.fetch_add(1, Ordering::SeqCst) - > THRESHOLD_UNDER_LOAD - { - under_load = Instant::now() - } +pub struct Wireguard { + state: Arc>, +} - // pass source address along if under load - handshake_tx - .send((msg, src, under_load.elapsed() < DURATION_UNDER_LOAD)) - .unwrap(); - } - router::TYPE_TRANSPORT => { - // transport message - } - _ => (), - } - } - }); +impl Wireguard { + fn set_key(&self, sk: Option) { + let mut handshake = self.state.handshake.write(); + match sk { + None => { + let mut rng = OsRng::new().unwrap(); + handshake.device.set_sk(StaticSecret::new(&mut rng)); + handshake.active = false; + } + Some(sk) => { + handshake.device.set_sk(sk); + handshake.active = true; + } } + } + + fn new(tun: T, bind: B) -> Wireguard { + // create device state + let mut rng = OsRng::new().unwrap(); + let (tx, rx): (Sender<(Vec, B::Endpoint)>, _) = bounded(SIZE_HANDSHAKE_QUEUE); + let wg = Arc::new(WireguardInner { + peers: RwLock::new(HashMap::new()), + router: router::Device::new(num_cpus::get(), tun.clone(), bind.clone()), + pending: AtomicUsize::new(0), + handshake: RwLock::new(Handshake { + device: handshake::Device::new(StaticSecret::new(&mut rng)), + active: false, + }), + under_load: AtomicBool::new(false), + bind: bind.clone(), + queue: Mutex::new(tx), + }); // start handshake workers for _ in 0..num_cpus::get() { + let wg = wg.clone(); + let rx = rx.clone(); let bind = bind.clone(); - let handshake_rx = handshake_rx.clone(); - let handshake_device = handshake_device.clone(); thread::spawn(move || { // prepare OsRng instance for this thread let mut rng = OsRng::new().unwrap(); // process elements from the handshake queue - for (msg, src, under_load) in handshake_rx { + for (msg, src) in rx { + wg.pending.fetch_sub(1, Ordering::SeqCst); + // feed message to handshake device let src_validate = (&src).into_address(); // TODO avoid - match handshake_device.process( + let state = wg.handshake.read(); + if !state.active { + continue; + } + + // process message + match state.device.process( &mut rng, &msg[..], - if under_load { + if wg.under_load.load(Ordering::Relaxed) { Some(&src_validate) } else { None }, ) { - Ok((identity, msg, keypair)) => { + Ok((pk, msg, keypair)) => { // send response if let Some(msg) = msg { let _ = bind.send(&msg[..], &src).map_err(|e| { @@ -141,11 +146,13 @@ impl Wireguard { } // update timers - if let Some(identity) = identity { + if let Some(pk) = pk { // add keypair to peer and free any unused ids if let Some(keypair) = keypair { - for id in identity.0.peer.add_keypair(keypair) { - handshake_device.release(id); + if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { + for id in peer.0.router.add_keypair(keypair) { + state.device.release(id); + } } } } @@ -156,13 +163,76 @@ impl Wireguard { }); } - // start TUN read IO thread + // start UDP read IO thread + { + let wg = wg.clone(); + let tun = tun.clone(); + let bind = bind.clone(); + thread::spawn(move || { + let mut last_under_load = + Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000); + + loop { + // read UDP packet into vector + let size = tun.mtu() + 148; // maximum message size + let mut msg: Vec = Vec::with_capacity(size); + msg.resize(size, 0); + let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error + msg.truncate(size); - thread::spawn(move || {}); + // message type de-multiplexer + if msg.len() < std::mem::size_of::() { + continue; + } - Wireguard { - router, - handshake: None, + match LittleEndian::read_u32(&msg[..]) { + handshake::TYPE_COOKIE_REPLY + | handshake::TYPE_INITIATION + | handshake::TYPE_RESPONSE => { + // update under_load flag + if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD { + last_under_load = Instant::now(); + wg.under_load.store(true, Ordering::SeqCst); + } else if last_under_load.elapsed() > DURATION_UNDER_LOAD { + wg.under_load.store(false, Ordering::SeqCst); + } + + wg.queue.lock().send((msg, src)).unwrap(); + } + router::TYPE_TRANSPORT => { + // transport message + + // pad the message + + let _ = wg.router.recv(src, msg); + } + _ => (), + } + } + }); + } + + // start TUN read IO thread + { + let wg = wg.clone(); + thread::spawn(move || loop { + // read a new IP packet + let mtu = tun.mtu(); + let size = mtu + 148; + let mut msg: Vec = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX); + let size = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap(); + msg.truncate(size); + + // pad message to multiple of 16 + while msg.len() < mtu && msg.len() % 16 != 0 { + msg.push(0); + } + + // crypt-key route + let _ = wg.router.send(msg); + }); } + + Wireguard { state: wg } } } -- cgit v1.2.3-59-g8ed1b