From 5cc108349968fbaa6998220631eb749276e64f45 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sat, 21 Sep 2019 17:22:03 +0200 Subject: Added zero_key to peer --- src/handshake/device.rs | 10 ++++- src/main.rs | 1 + src/router/peer.rs | 58 ++++++++++++++++++------ src/timers.rs | 65 +++++++++++++++++++++++++++ src/types/keys.rs | 6 +++ src/wireguard.rs | 116 +++++++++++++++++++++++++++--------------------- 6 files changed, 190 insertions(+), 66 deletions(-) create mode 100644 src/timers.rs (limited to 'src') diff --git a/src/handshake/device.rs b/src/handshake/device.rs index 2a06fa7..6178831 100644 --- a/src/handshake/device.rs +++ b/src/handshake/device.rs @@ -64,10 +64,16 @@ impl Device { self.macs = macs::Validator::new(pk); // recalculate the shared secrets for every peer - for &mut peer in self.pk_map.values_mut() { - peer.reset_state().map(|id| self.release(id)); + let mut ids = vec![]; + for mut peer in self.pk_map.values_mut() { + peer.reset_state().map(|id| ids.push(id)); peer.ss = self.sk.diffie_hellman(&peer.pk) } + + // release ids from aborted handshakes + for id in ids { + self.release(id) + } } /// Add a new public key to the state machine diff --git a/src/main.rs b/src/main.rs index 103bc65..a52eecc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,6 +8,7 @@ static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc; mod constants; mod handshake; mod router; +mod timers; mod types; mod wireguard; diff --git a/src/router/peer.rs b/src/router/peer.rs index 952e439..7a3ede8 100644 --- a/src/router/peer.rs +++ b/src/router/peer.rs @@ -36,7 +36,7 @@ pub struct KeyWheel { next: Option>, // next key state (unconfirmed) current: Option>, // current key state (used for encryption) previous: Option>, // old key state (used for decryption) - retired: Option, // retired id (previous id, after confirming key-pair) + retired: Vec, // retired ids } pub struct PeerInner { @@ -188,7 +188,7 @@ pub fn new_peer( next: None, current: None, previous: None, - retired: None, + retired: vec![], }), staged_packets: spin::Mutex::new(ArrayDeque::new()), }) @@ -375,6 +375,11 @@ impl Peer { *self.state.endpoint.lock() = Some(B::Endpoint::from_address(address)); } + /// Returns the current endpoint of the peer (for configuration) + /// + /// # Note + /// + /// Does not convey potential "sticky socket" information pub fn get_endpoint(&self) -> Option { self.state .endpoint @@ -383,6 +388,30 @@ impl Peer { .map(|e| e.into_address()) } + /// Zero all key-material related to the peer + pub fn zero_keys(&self) { + let mut release: Vec = Vec::with_capacity(3); + let mut keys = self.state.keys.lock(); + + // update key-wheel + + mem::replace(&mut keys.next, None).map(|k| release.push(k.local_id())); + mem::replace(&mut keys.current, None).map(|k| release.push(k.local_id())); + mem::replace(&mut keys.previous, None).map(|k| release.push(k.local_id())); + keys.retired.extend(&release[..]); + + // update inbound "recv" map + { + let mut recv = self.state.device.recv.write(); + for id in release { + recv.remove(&id); + } + } + + // clear encryption state + *self.state.ekey.lock() = None; + } + /// Add a new keypair /// /// # Arguments @@ -393,14 +422,16 @@ impl Peer { /// /// A vector of ids which has been released. /// These should be released in the handshake module. + /// + /// # Note + /// + /// The number of ids to be released can be at most 3, + /// since the only way to add additional keys to the peer is by using this method + /// and a peer can have at most 3 keys allocated in the router at any time. pub fn add_keypair(&self, new: KeyPair) -> Vec { - let mut keys = self.state.keys.lock(); - let mut release = Vec::with_capacity(2); let new = Arc::new(new); - - // collect ids to be released - keys.retired.map(|v| release.push(v)); - keys.previous.as_ref().map(|k| release.push(k.recv.id)); + let mut keys = self.state.keys.lock(); + let mut release = mem::replace(&mut keys.retired, vec![]); // update key-wheel if new.initiator { @@ -420,10 +451,11 @@ impl Peer { { let mut recv = self.state.device.recv.write(); - // purge recv map of released ids - for id in &release { - recv.remove(&id); - } + // purge recv map of previous id + keys.previous.as_ref().map(|k| { + recv.remove(&k.local_id()); + release.push(k.local_id()); + }); // map new id to decryption state debug_assert!(!recv.contains_key(&new.recv.id)); @@ -442,7 +474,7 @@ impl Peer { } } - // return the released id (for handshake state machine) + debug_assert!(release.len() <= 3); release } diff --git a/src/timers.rs b/src/timers.rs new file mode 100644 index 0000000..0d69c3f --- /dev/null +++ b/src/timers.rs @@ -0,0 +1,65 @@ +use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +use hjul::{Runner, Timer}; + +use crate::router::Callbacks; + +const ZERO_DURATION: Duration = Duration::from_micros(0); + +pub struct TimersInner { + handshake_pending: AtomicBool, + handshake_attempts: AtomicUsize, + + retransmit_handshake: Timer, + send_keepalive: Timer, + zero_key_material: Timer, + new_handshake: Timer, + + // stats + rx_bytes: AtomicU64, + tx_bytes: AtomicU64, +} + +impl TimersInner { + pub fn new(runner: &Runner) -> Timers { + Arc::new(TimersInner { + handshake_pending: AtomicBool::new(false), + handshake_attempts: AtomicUsize::new(0), + retransmit_handshake: runner.timer(|| {}), + new_handshake: runner.timer(|| {}), + send_keepalive: runner.timer(|| {}), + zero_key_material: runner.timer(|| {}), + rx_bytes: AtomicU64::new(0), + tx_bytes: AtomicU64::new(0), + }) + } + + pub fn handshake_sent(&self) { + self.send_keepalive.stop(); + } +} + +pub type Timers = Arc; + +pub struct Events(); + +impl Callbacks for Events { + type Opaque = Timers; + + fn send(t: &Timers, size: usize, data: bool, sent: bool) { + t.tx_bytes.fetch_add(size as u64, Ordering::Relaxed); + } + + fn recv(t: &Timers, size: usize, data: bool, sent: bool) { + t.rx_bytes.fetch_add(size as u64, Ordering::Relaxed); + } + + fn need_key(t: &Timers) { + if !t.handshake_pending.swap(true, Ordering::SeqCst) { + t.handshake_attempts.store(0, Ordering::SeqCst); + t.new_handshake.reset(ZERO_DURATION); + } + } +} diff --git a/src/types/keys.rs b/src/types/keys.rs index 89cacf9..282c4ae 100644 --- a/src/types/keys.rs +++ b/src/types/keys.rs @@ -28,3 +28,9 @@ pub struct KeyPair { pub send: Key, // key for outbound messages pub recv: Key, // key for inbound messages } + +impl KeyPair { + pub fn local_id(&self) -> u32 { + self.recv.id + } +} diff --git a/src/wireguard.rs b/src/wireguard.rs index f98369f..3b4724e 100644 --- a/src/wireguard.rs +++ b/src/wireguard.rs @@ -1,5 +1,6 @@ use crate::handshake; use crate::router; +use crate::timers::{Events, Timers}; use crate::types::{Bind, Endpoint, Tun}; use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; @@ -21,28 +22,19 @@ const SIZE_HANDSHAKE_QUEUE: usize = 128; const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4; const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000); -#[derive(Clone)] -pub struct Peer(Arc>); +type Peer = Arc>; pub struct PeerInner { - router: router::Peer, - timers: Timers, - rx: AtomicU64, - tx: AtomicU64, + queue: Mutex>>, // handshake queue + router: router::Peer, // router peer + timers: Option, // } -pub struct Timers {} - -pub struct Events(); - -impl router::Callbacks for Events { - type Opaque = Timers; - - fn send(t: &Timers, size: usize, data: bool, sent: bool) {} - - fn recv(t: &Timers, size: usize, data: bool, sent: bool) {} - - fn need_key(t: &Timers) {} +impl PeerInner { + #[inline(always)] + fn timers(&self) -> &Timers { + self.timers.as_ref().unwrap() + } } struct Handshake { @@ -50,6 +42,11 @@ struct Handshake { active: bool, } +enum HandshakeJob { + Message(Vec, E), + New(PublicKey), +} + struct WireguardInner { // identify and configuration map peers: RwLock>>, @@ -61,7 +58,7 @@ struct WireguardInner { handshake: RwLock, under_load: AtomicBool, pending: AtomicUsize, // num of pending handshake packets in queue - queue: Mutex, B::Endpoint)>>, + queue: Mutex>>, // IO bind: B, @@ -90,7 +87,7 @@ impl Wireguard { fn new(tun: T, bind: B) -> Wireguard { // create device state let mut rng = OsRng::new().unwrap(); - let (tx, rx): (Sender<(Vec, B::Endpoint)>, _) = bounded(SIZE_HANDSHAKE_QUEUE); + let (tx, rx): (Sender>, _) = bounded(SIZE_HANDSHAKE_QUEUE); let wg = Arc::new(WireguardInner { peers: RwLock::new(HashMap::new()), router: router::Device::new(num_cpus::get(), tun.clone(), bind.clone()), @@ -114,50 +111,64 @@ impl Wireguard { let mut rng = OsRng::new().unwrap(); // process elements from the handshake queue - for (msg, src) in rx { + for job in rx { wg.pending.fetch_sub(1, Ordering::SeqCst); - - // feed message to handshake device - let src_validate = (&src).into_address(); // TODO avoid let state = wg.handshake.read(); if !state.active { continue; } - // process message - match state.device.process( - &mut rng, - &msg[..], - if wg.under_load.load(Ordering::Relaxed) { - Some(&src_validate) - } else { - None - }, - ) { - Ok((pk, msg, keypair)) => { - // send response - if let Some(msg) = msg { - let _ = bind.send(&msg[..], &src).map_err(|e| { - debug!( + match job { + HandshakeJob::Message(msg, src) => { + // feed message to handshake device + let src_validate = (&src).into_address(); // TODO avoid + + // process message + match state.device.process( + &mut rng, + &msg[..], + if wg.under_load.load(Ordering::Relaxed) { + Some(&src_validate) + } else { + None + }, + ) { + Ok((pk, msg, keypair)) => { + // send response + if let Some(msg) = msg { + let _ = bind.send(&msg[..], &src).map_err(|e| { + debug!( "handshake worker, failed to send response, error = {:?}", e ) - }); - } + }); + } - // update timers - if let Some(pk) = pk { - // add keypair to peer and free any unused ids - if let Some(keypair) = keypair { - if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { - for id in peer.0.router.add_keypair(keypair) { - state.device.release(id); + // update timers + if let Some(pk) = pk { + if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { + // update endpoint (DISCUSS: right semantics?) + peer.router.set_endpoint(src_validate); + + // add keypair to peer and free any unused ids + if let Some(keypair) = keypair { + for id in peer.router.add_keypair(keypair) { + state.device.release(id); + } + } } } } + Err(e) => debug!("handshake worker, error = {:?}", e), + } + } + HandshakeJob::New(pk) => { + let msg = state.device.begin(&mut rng, &pk).unwrap(); // TODO handle + if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { + peer.router.send(&msg[..]); + peer.timers().handshake_sent(); } } - Err(e) => debug!("handshake worker, error = {:?}", e), } } }); @@ -197,7 +208,10 @@ impl Wireguard { wg.under_load.store(false, Ordering::SeqCst); } - wg.queue.lock().send((msg, src)).unwrap(); + wg.queue + .lock() + .send(HandshakeJob::Message(msg, src)) + .unwrap(); } router::TYPE_TRANSPORT => { // transport message @@ -223,7 +237,7 @@ impl Wireguard { let size = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap(); msg.truncate(size); - // pad message to multiple of 16 + // pad message to multiple of 16 bytes while msg.len() < mtu && msg.len() % 16 != 0 { msg.push(0); } -- cgit v1.2.3-59-g8ed1b