From b25c21885bf97e74802549e3ac22f57bc0c44d76 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Thu, 31 Oct 2019 17:11:09 +0100 Subject: Work on timer semantics --- src/wireguard/constants.rs | 6 +++ src/wireguard/endpoint.rs | 29 +++++++++++ src/wireguard/mod.rs | 1 + src/wireguard/router/mod.rs | 3 +- src/wireguard/router/peer.rs | 6 +++ src/wireguard/router/tests.rs | 17 ++++-- src/wireguard/router/types.rs | 7 ++- src/wireguard/router/workers.rs | 39 ++++++-------- src/wireguard/timers.rs | 113 ++++++++++++++++++++++++++-------------- src/wireguard/wireguard.rs | 40 ++++++++++---- 10 files changed, 181 insertions(+), 80 deletions(-) create mode 100644 src/wireguard/endpoint.rs diff --git a/src/wireguard/constants.rs b/src/wireguard/constants.rs index ec60801..c53c559 100644 --- a/src/wireguard/constants.rs +++ b/src/wireguard/constants.rs @@ -18,3 +18,9 @@ pub const TIMERS_SLOTS: usize = (TIMER_MAX_DURATION.as_micros() / TIMERS_TICK.as pub const TIMERS_CAPACITY: usize = 1024; pub const MESSAGE_PADDING_MULTIPLE: usize = 16; + +/* A long duration (compared to the WireGuard time constants), + * used in places to avoid Option by instead using a long "expired" Instant: + * (Instant::now() - TIME_HORIZON) + */ +pub const TIME_HORIZON: Duration = Duration::from_secs(3600 * 24); diff --git a/src/wireguard/endpoint.rs b/src/wireguard/endpoint.rs new file mode 100644 index 0000000..f6a560b --- /dev/null +++ b/src/wireguard/endpoint.rs @@ -0,0 +1,29 @@ +use spin::{Mutex, MutexGuard}; +use std::sync::Arc; + +use super::super::platform::Endpoint; + +#[derive(Clone)] +struct EndpointStore { + endpoint: Arc>>, +} + +impl EndpointStore { + pub fn new() -> EndpointStore { + EndpointStore { + endpoint: Arc::new(Mutex::new(None)), + } + } + + pub fn set(&self, endpoint: E) { + *self.endpoint.lock() = Some(endpoint); + } + + pub fn get(&self) -> MutexGuard> { + self.endpoint.lock() + } + + pub fn clear_src(&self) { + (*self.endpoint.lock()).as_mut().map(|e| e.clear_src()); + } +} diff --git a/src/wireguard/mod.rs b/src/wireguard/mod.rs index c3e9c58..83f9e8a 100644 --- a/src/wireguard/mod.rs +++ b/src/wireguard/mod.rs @@ -2,6 +2,7 @@ mod constants; mod timers; mod wireguard; +mod endpoint; mod handshake; mod router; mod types; diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs index 354700a..6aa894d 100644 --- a/src/wireguard/router/mod.rs +++ b/src/wireguard/router/mod.rs @@ -14,7 +14,8 @@ mod tests; use messages::TransportHeader; use std::mem; -use super::constants::*; +use super::constants::REJECT_AFTER_MESSAGES; +use super::types::*; pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::(); pub const CAPACITY_MESSAGE_POSTFIX: usize = workers::SIZE_TAG; diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs index 50fdfe7..5467eb7 100644 --- a/src/wireguard/router/peer.rs +++ b/src/wireguard/router/peer.rs @@ -589,6 +589,12 @@ impl> Peer, counter: u64) { t.0.send.lock().unwrap().push((size, sent)) } - fn recv(t: &Self::Opaque, size: usize, sent: bool) { + fn recv(t: &Self::Opaque, size: usize, sent: bool, keypair: &Arc) { t.0.recv.lock().unwrap().push((size, sent)) } @@ -123,10 +124,16 @@ mod tests { struct BencherCallbacks {} impl Callbacks for BencherCallbacks { type Opaque = Arc; - fn send(t: &Self::Opaque, size: usize, _sent: bool) { + fn send( + t: &Self::Opaque, + size: usize, + _sent: bool, + _keypair: &Arc, + _counter: u64, + ) { t.fetch_add(size, Ordering::SeqCst); } - fn recv(_: &Self::Opaque, _size: usize, _sent: bool) {} + fn recv(_: &Self::Opaque, _size: usize, _sent: bool, _keypair: &Arc) {} fn need_key(_: &Self::Opaque) {} fn key_confirmed(_: &Self::Opaque) {} } diff --git a/src/wireguard/router/types.rs b/src/wireguard/router/types.rs index 9f769fe..194f0d4 100644 --- a/src/wireguard/router/types.rs +++ b/src/wireguard/router/types.rs @@ -1,5 +1,8 @@ use std::error::Error; use std::fmt; +use std::sync::Arc; + +use super::KeyPair; pub trait Opaque: Send + Sync + 'static {} @@ -23,8 +26,8 @@ impl KeyCallback for F where F: Fn(&T) -> () + Sync + Send + 'static {} pub trait Callbacks: Send + Sync + 'static { type Opaque: Opaque; - fn send(opaque: &Self::Opaque, size: usize, sent: bool); - fn recv(opaque: &Self::Opaque, size: usize, sent: bool); + fn send(opaque: &Self::Opaque, size: usize, sent: bool, keypair: &Arc, counter: u64); + fn recv(opaque: &Self::Opaque, size: usize, sent: bool, keypair: &Arc); fn need_key(opaque: &Self::Opaque); fn key_confirmed(opaque: &Self::Opaque); } diff --git a/src/wireguard/router/workers.rs b/src/wireguard/router/workers.rs index 2a12000..08c2db9 100644 --- a/src/wireguard/router/workers.rs +++ b/src/wireguard/router/workers.rs @@ -1,6 +1,5 @@ use std::sync::mpsc::Receiver; use std::sync::Arc; -use std::time::Instant; use futures::sync::oneshot; use futures::*; @@ -18,8 +17,7 @@ use super::peer::PeerInner; use super::route::check_route; use super::types::Callbacks; -use super::{KEEPALIVE_TIMEOUT, REJECT_AFTER_TIME, REKEY_TIMEOUT}; -use super::{REJECT_AFTER_MESSAGES, REKEY_AFTER_MESSAGES, REKEY_AFTER_TIME}; +use super::REJECT_AFTER_MESSAGES; use super::super::types::KeyPair; use super::super::{bind, tun, Endpoint}; @@ -61,10 +59,6 @@ pub fn worker_inbound>, // related peer receiver: Receiver>, ) { - fn keep_key_fresh(keypair: &KeyPair) -> bool { - Instant::now() - keypair.birth > REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT - } - loop { // fetch job let (state, endpoint, rx) = match receiver.recv() { @@ -135,7 +129,7 @@ pub fn worker_inbound>, // related peer receiver: Receiver, ) { - fn keep_key_fresh(keypair: &KeyPair, counter: u64) -> bool { - counter > REKEY_AFTER_MESSAGES - || (keypair.initiator && Instant::now() - keypair.birth > REKEY_AFTER_TIME) - } - loop { // fetch job let rx = match receiver.recv() { @@ -190,12 +179,7 @@ pub fn worker_outbound) { .expect("earlier code should ensure that there is ample space"); // set header fields - debug_assert!(job.counter < REJECT_AFTER_MESSAGES); + debug_assert!( + job.counter < REJECT_AFTER_MESSAGES, + "should be checked when assigning counters" + ); header.f_type.set(TYPE_TRANSPORT); header.f_receiver.set(job.keypair.send.id); header.f_counter.set(job.counter); @@ -258,10 +245,12 @@ pub fn worker_parallel(receiver: Receiver) { let _ = tx.send(match layout { Some((header, body)) => { - debug_assert_eq!(header.f_type.get(), TYPE_TRANSPORT); - if header.f_counter.get() >= REJECT_AFTER_MESSAGES { - None - } else { + debug_assert_eq!( + header.f_type.get(), + TYPE_TRANSPORT, + "type and reserved bits should be checked by message de-multiplexer" + ); + if header.f_counter.get() < REJECT_AFTER_MESSAGES { // create a nonce object let mut nonce = [0u8; 12]; debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); @@ -279,6 +268,8 @@ pub fn worker_parallel(receiver: Receiver) { Ok(_) => Some(job), Err(_) => None, } + } else { + None } } None => None, diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs index 3b16bf6..2e9263d 100644 --- a/src/wireguard/timers.rs +++ b/src/wireguard/timers.rs @@ -1,10 +1,10 @@ use std::marker::PhantomData; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; -use std::time::{Duration, SystemTime}; - -use log::info; +use std::time::{Duration, Instant, SystemTime}; +use log::{debug, info}; +use spin::Mutex; use hjul::{Runner, Timer}; use super::constants::*; @@ -12,8 +12,9 @@ use super::router::{message_data_len, Callbacks}; use super::wireguard::{Peer, PeerInner}; use super::{bind, tun}; +use super::types::KeyPair; + pub struct Timers { - handshake_pending: AtomicBool, handshake_attempts: AtomicUsize, retransmit_handshake: Timer, @@ -98,6 +99,7 @@ impl PeerInner { pub fn timers_any_authenticated_packet_traversal(&self) { let keepalive = self.keepalive.load(Ordering::Acquire); if keepalive > 0 { + // push persistent_keepalive into the future self.timers() .send_persistent_keepalive .reset(Duration::from_secs(keepalive as u64)); @@ -107,15 +109,24 @@ impl PeerInner { /* Called after a handshake worker sends a handshake initiation to the peer */ pub fn sent_handshake_initiation(&self) { - *self.last_handshake.lock() = SystemTime::now(); + *self.last_handshake_sent.lock() = Instant::now(); self.handshake_queued.store(false, Ordering::SeqCst); + self.timers().retransmit_handshake.reset(REKEY_TIMEOUT); self.timers_any_authenticated_packet_traversal(); self.timers_any_authenticated_packet_sent(); } pub fn sent_handshake_response(&self) { + *self.last_handshake_sent.lock() = Instant::now(); self.timers_any_authenticated_packet_traversal(); self.timers_any_authenticated_packet_sent(); + } + + fn packet_send_queued_handshake_initiation(&self, is_retry: bool) { + if !is_retry { + self.timers().handshake_attempts.store(0, Ordering::SeqCst); + } + self.packet_send_handshake_initiation(); } } @@ -127,21 +138,32 @@ impl Timers { { // create a timer instance for the provided peer Timers { - handshake_pending: AtomicBool::new(false), need_another_keepalive: AtomicBool::new(false), sent_lastminute_handshake: AtomicBool::new(false), handshake_attempts: AtomicUsize::new(0), retransmit_handshake: { let peer = peer.clone(); runner.timer(move || { - if peer.timers().handshake_retry() { - info!("Retransmit handshake for {}", peer); - peer.new_handshake(); - } else { - info!("Failed to complete handshake for {}", peer); + let attempts = peer.timers().handshake_attempts.fetch_add(1, Ordering::SeqCst); + if attempts > MAX_TIMER_HANDSHAKES { + debug!( + "Handshake for peer {} did not complete after {} attempts, giving up", + peer, + attempts + 1 + ); peer.router.purge_staged_packets(); peer.timers().send_keepalive.stop(); peer.timers().zero_key_material.start(REJECT_AFTER_TIME * 3); + } else { + debug!( + "Handshake for {} did not complete after {} seconds, retrying (try {})", + peer, + REKEY_TIMEOUT.as_secs(), + attempts + ); + peer.router.clear_src(); + peer.timers().retransmit_handshake.reset(REKEY_TIMEOUT); + peer.packet_send_queued_handshake_initiation(true); } }) }, @@ -157,9 +179,13 @@ impl Timers { new_handshake: { let peer = peer.clone(); runner.timer(move || { - info!("Initiate new handshake with {}", peer); - peer.new_handshake(); - peer.timers.read().handshake_begun(); + debug!( + "Retrying handshake with {} because we stopped hearing back after {} seconds", + peer, + (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT).as_secs() + ); + peer.router.clear_src(); + peer.packet_send_queued_handshake_initiation(false); }) }, zero_key_material: { @@ -184,22 +210,6 @@ impl Timers { } } - fn handshake_begun(&self) { - self.handshake_pending.store(true, Ordering::SeqCst); - self.handshake_attempts.store(0, Ordering::SeqCst); - self.retransmit_handshake.reset(REKEY_TIMEOUT); - } - - fn handshake_retry(&self) -> bool { - if self.handshake_attempts.fetch_add(1, Ordering::SeqCst) <= MAX_TIMER_HANDSHAKES { - self.retransmit_handshake.reset(REKEY_TIMEOUT); - true - } else { - self.handshake_pending.store(false, Ordering::SeqCst); - false - } - } - pub fn updated_persistent_keepalive(&self, keepalive: usize) { if keepalive > 0 { self.send_persistent_keepalive @@ -209,7 +219,6 @@ impl Timers { pub fn dummy(runner: &Runner) -> Timers { Timers { - handshake_pending: AtomicBool::new(false), need_another_keepalive: AtomicBool::new(false), sent_lastminute_handshake: AtomicBool::new(false), handshake_attempts: AtomicUsize::new(0), @@ -236,13 +245,28 @@ impl Callbacks for Events { /* Called after the router encrypts a transport message destined for the peer. * This method is called, even if the encrypted payload is empty (keepalive) */ - fn send(peer: &Self::Opaque, size: usize, sent: bool) { + #[inline(always)] + fn send(peer: &Self::Opaque, size: usize, sent: bool, keypair: &Arc, counter: u64) { + + // update timers and stats + peer.timers_any_authenticated_packet_traversal(); peer.timers_any_authenticated_packet_sent(); peer.tx_bytes.fetch_add(size as u64, Ordering::Relaxed); if size > message_data_len(0) && sent { peer.timers_data_sent(); } + + // keep_key_fresh + + fn keep_key_fresh(keypair: &Arc, counter: u64) -> bool { + counter > REKEY_AFTER_MESSAGES + || (keypair.initiator && Instant::now() - keypair.birth > REKEY_AFTER_TIME) + } + + if keep_key_fresh(keypair, counter) { + peer.packet_send_queued_handshake_initiation(false); + } } /* Called after the router successfully decrypts a transport message from a peer. @@ -252,13 +276,28 @@ impl Callbacks for Events { * - A malformed IP packet * - Fails to cryptkey route */ - fn recv(peer: &Self::Opaque, size: usize, sent: bool) { + #[inline(always)] + fn recv(peer: &Self::Opaque, size: usize, sent: bool, keypair: &Arc) { + + // update timers and stats + peer.timers_any_authenticated_packet_traversal(); peer.timers_any_authenticated_packet_received(); peer.rx_bytes.fetch_add(size as u64, Ordering::Relaxed); if size > 0 && sent { peer.timers_data_received(); } + + // keep_key_fresh + + #[inline(always)] + fn keep_key_fresh(keypair: &Arc) -> bool { + Instant::now() - keypair.birth > REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT + } + + if keep_key_fresh(keypair) && !peer.timers().sent_lastminute_handshake.swap(true, Ordering::Acquire) { + peer.packet_send_queued_handshake_initiation(false); + } } /* Called every time the router detects that a key is required, @@ -267,14 +306,12 @@ impl Callbacks for Events { * The message is called continuously * (e.g. for every packet that must be encrypted, until a key becomes available) */ + #[inline(always)] fn need_key(peer: &Self::Opaque) { - let timers = peer.timers(); - if !timers.handshake_pending.swap(true, Ordering::SeqCst) { - timers.handshake_attempts.store(0, Ordering::SeqCst); - timers.new_handshake.fire(); - } + peer.packet_send_queued_handshake_initiation(false); } + #[inline(always)] fn key_confirmed(peer: &Self::Opaque) { peer.timers().retransmit_handshake.stop(); } diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index 233559e..e308c50 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -38,23 +38,28 @@ pub struct Peer { } pub struct PeerInner { + // internal id (for logging) pub id: u64, - pub keepalive: AtomicUsize, // keepalive interval - pub rx_bytes: AtomicU64, - pub tx_bytes: AtomicU64, + // handshake state + pub last_handshake_sent: Mutex, // instant for last handshake + pub handshake_queued: AtomicBool, // is a handshake job currently queued for the peer? + pub queue: Mutex>>, // handshake queue - pub last_handshake: Mutex, - pub handshake_queued: AtomicBool, + // stats and configuration + pub pk: PublicKey, // public key, DISCUSS: avoid this. TODO: remove + pub keepalive: AtomicUsize, // keepalive interval + pub rx_bytes: AtomicU64, // received bytes + pub tx_bytes: AtomicU64, // transmitted bytes - pub queue: Mutex>>, // handshake queue - pub pk: PublicKey, // DISCUSS: Change layout in handshake module (adopt pattern of router), to avoid this. TODO: remove - pub timers: RwLock, // + // timer model + pub timers: RwLock, } pub struct WireguardInner { // identifier (for logging) id: u32, + start: Instant, // provides access to the MTU value of the tun device // (otherwise owned solely by the router and a dedicated read IO thread) @@ -122,8 +127,22 @@ impl Deref for Peer { impl PeerInner { /* Queue a handshake request for the parallel workers * (if one does not already exist) + * + * The function is ratelimited. */ - pub fn new_handshake(&self) { + pub fn packet_send_handshake_initiation(&self) { + // the function is rate limited + + { + let mut lhs = self.last_handshake_sent.lock(); + if lhs.elapsed() < REKEY_TIMEOUT { + return; + } + *lhs = Instant::now(); + } + + // create a new handshake job for the peer + if !self.handshake_queued.swap(true, Ordering::SeqCst) { self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap(); } @@ -225,7 +244,7 @@ impl Wireguard { let state = Arc::new(PeerInner { id: rng.gen(), pk, - last_handshake: Mutex::new(SystemTime::UNIX_EPOCH), + last_handshake_sent: Mutex::new(self.state.start - TIME_HORIZON), handshake_queued: AtomicBool::new(false), queue: Mutex::new(self.state.queue.lock().clone()), keepalive: AtomicUsize::new(0), @@ -335,6 +354,7 @@ impl Wireguard { let mut rng = OsRng::new().unwrap(); let (tx, rx): (Sender>, _) = bounded(SIZE_HANDSHAKE_QUEUE); let wg = Arc::new(WireguardInner { + start: Instant::now(), id: rng.gen(), mtu: mtu.clone(), peers: RwLock::new(HashMap::new()), -- cgit v1.2.3-59-g8ed1b