From 293914e47b046f862608a1af91864b6b38336aa5 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Wed, 6 Nov 2019 13:50:38 +0100 Subject: Implement disable/enable timers --- src/wireguard/peer.rs | 19 ++++-- src/wireguard/router/device.rs | 18 ++--- src/wireguard/router/peer.rs | 32 +++++---- src/wireguard/router/tests.rs | 2 +- src/wireguard/timers.rs | 148 +++++++++++++++++++++++++---------------- src/wireguard/wireguard.rs | 7 +- 6 files changed, 138 insertions(+), 88 deletions(-) diff --git a/src/wireguard/peer.rs b/src/wireguard/peer.rs index 9f24dea..b77e8c6 100644 --- a/src/wireguard/peer.rs +++ b/src/wireguard/peer.rs @@ -4,12 +4,11 @@ use super::timers::{Events, Timers}; use super::HandshakeJob; use super::bind::Bind; -use super::bind::Reader as BindReader; -use super::tun::{Reader, Tun}; +use super::tun::Tun; use std::fmt; use std::ops::Deref; -use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::Arc; use std::time::{Instant, SystemTime}; @@ -34,8 +33,7 @@ pub struct PeerInner { pub queue: Mutex>>, // handshake queue // stats and configuration - pub pk: PublicKey, // public key, DISCUSS: avoid this. TODO: remove - pub keepalive_interval: AtomicU64, // keepalive interval + pub pk: PublicKey, // public key, DISCUSS: avoid this. TODO: remove pub rx_bytes: AtomicU64, // received bytes pub tx_bytes: AtomicU64, // transmitted bytes @@ -78,11 +76,20 @@ impl Deref for Peer { } impl Peer { + /// Bring the peer down. Causing: + /// + /// - Timers to be stopped and disabled. + /// - All keystate to be zeroed pub fn down(&self) { self.stop_timers(); + self.router.down(); } - pub fn up(&self) {} + /// Bring the peer up. + pub fn up(&self) { + self.router.up(); + self.start_timers(); + } } impl PeerInner { diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs index a5028e1..b3f1787 100644 --- a/src/wireguard/router/device.rs +++ b/src/wireguard/router/device.rs @@ -27,13 +27,11 @@ use super::route::get_route; use super::super::{bind, tun, Endpoint, KeyPair}; pub struct DeviceInner> { - pub enabled: AtomicBool, - // inbound writer (TUN) pub inbound: T, // outbound writer (Bind) - pub outbound: RwLock>, + pub outbound: RwLock<(bool, Option)>, // routing pub recv: RwLock>>>, // receiver id -> decryption state @@ -93,8 +91,7 @@ impl> Device> Device> Device>( // spawn outbound thread let thread_inbound = { let peer = peer.clone(); - let device = device.clone(); thread::spawn(move || worker_outbound(peer, out_rx)) }; @@ -237,24 +236,25 @@ impl> PeerInner Result<(), RouterError> { debug!("peer.send"); - // check if device is enabled - if !self.device.enabled.load(Ordering::Acquire) { - return Ok(()); - } - // send to endpoint (if known) match self.endpoint.lock().as_ref() { - Some(endpoint) => self - .device - .outbound - .read() - .as_ref() - .ok_or(RouterError::SendError) - .and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError)), + Some(endpoint) => { + let outbound = self.device.outbound.read(); + if outbound.0 { + outbound + .1 + .as_ref() + .ok_or(RouterError::SendError) + .and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError)) + } else { + Ok(()) + } + } None => Err(RouterError::NoEndpoint), } } + // Transmit all staged packets fn send_staged(&self) -> bool { debug!("peer.send_staged"); let mut sent = false; @@ -451,6 +451,12 @@ impl> Peer PeerInner { timers.enabled = true; // start send_persistent_keepalive - let interval = self.keepalive_interval.load(Ordering::Acquire); - if interval > 0 { + if timers.keepalive_interval > 0 { timers.send_persistent_keepalive.start( - Duration::from_secs(interval) + Duration::from_secs(timers.keepalive_interval) ); } } /* should be called after an authenticated data packet is sent */ pub fn timers_data_sent(&self) { - self.timers() - .new_handshake - .start(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT); + let timers = self.timers(); + if timers.enabled { + timers.new_handshake.start(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT); + } } /* should be called after an authenticated data packet is received */ pub fn timers_data_received(&self) { - if !self.timers().send_keepalive.start(KEEPALIVE_TIMEOUT) { - self.timers() - .need_another_keepalive - .store(true, Ordering::SeqCst) + let timers = self.timers(); + if timers.enabled && !timers.send_keepalive.start(KEEPALIVE_TIMEOUT) { + timers.need_another_keepalive.store(true, Ordering::SeqCst) } } @@ -98,7 +99,10 @@ impl PeerInner { * - handshake */ pub fn timers_any_authenticated_packet_sent(&self) { - self.timers().send_keepalive.stop() + let timers = self.timers(); + if timers.enabled { + timers.send_keepalive.stop() + } } /* Should be called after any type of authenticated packet is received, whether: @@ -107,48 +111,69 @@ impl PeerInner { * - handshake */ pub fn timers_any_authenticated_packet_received(&self) { - self.timers().new_handshake.stop(); + let timers = self.timers(); + if timers.enabled { + timers.new_handshake.stop(); + } } /* Should be called after a handshake initiation message is sent. */ pub fn timers_handshake_initiated(&self) { - self.timers().send_keepalive.stop(); - self.timers().retransmit_handshake.reset(REKEY_TIMEOUT); + let timers = self.timers(); + if timers.enabled { + timers.send_keepalive.stop(); + timers.retransmit_handshake.reset(REKEY_TIMEOUT); + } } /* Should be called after a handshake response message is received and processed * or when getting key confirmation via the first data message. */ pub fn timers_handshake_complete(&self) { - self.timers().handshake_attempts.store(0, Ordering::SeqCst); - self.timers() - .sent_lastminute_handshake - .store(false, Ordering::SeqCst); - *self.walltime_last_handshake.lock() = SystemTime::now(); + let timers = self.timers(); + if timers.enabled { + timers.handshake_attempts.store(0, Ordering::SeqCst); + timers.sent_lastminute_handshake.store(false, Ordering::SeqCst); + *self.walltime_last_handshake.lock() = SystemTime::now(); + } } /* Should be called after an ephemeral key is created, which is before sending a * handshake response or after receiving a handshake response. */ pub fn timers_session_derived(&self) { - self.timers().zero_key_material.reset(REJECT_AFTER_TIME * 3); + let timers = self.timers(); + if timers.enabled { + timers.zero_key_material.reset(REJECT_AFTER_TIME * 3); + } } /* Should be called before a packet with authentication, whether * keepalive, data, or handshake is sent, or after one is received. */ pub fn timers_any_authenticated_packet_traversal(&self) { - let keepalive = self.keepalive_interval.load(Ordering::Acquire); - if keepalive > 0 { + let timers = self.timers(); + if timers.enabled && timers.keepalive_interval > 0 { // push persistent_keepalive into the future - self.timers() - .send_persistent_keepalive - .reset(Duration::from_secs(keepalive as u64)); + timers.send_persistent_keepalive.reset(Duration::from_secs( + timers.keepalive_interval + )); } } pub fn timers_session_derieved(&self) { - self.timers().zero_key_material.reset(REJECT_AFTER_TIME * 3); + let timers = self.timers(); + if timers.enabled { + timers.zero_key_material.reset(REJECT_AFTER_TIME * 3); + } + } + + fn timers_set_retransmit_handshake(&self) { + let timers = self.timers(); + if timers.enabled { + timers.retransmit_handshake.reset(REKEY_TIMEOUT); + } + } /* Called after a handshake worker sends a handshake initiation to the peer @@ -156,7 +181,7 @@ impl PeerInner { pub fn sent_handshake_initiation(&self) { *self.last_handshake_sent.lock() = Instant::now(); self.handshake_queued.store(false, Ordering::SeqCst); - self.timers().retransmit_handshake.reset(REKEY_TIMEOUT); + self.timers_set_retransmit_handshake(); self.timers_any_authenticated_packet_traversal(); self.timers_any_authenticated_packet_sent(); } @@ -168,13 +193,18 @@ impl PeerInner { } - pub fn set_persistent_keepalive_interval(&self, interval: u64) { - self.timers().send_persistent_keepalive.stop(); - self.keepalive_interval.store(interval, Ordering::SeqCst); - if interval > 0 { - self.timers() - .send_persistent_keepalive - .start(Duration::from_secs(interval as u64)); + pub fn set_persistent_keepalive_interval(&self, secs: u64) { + let mut timers = self.timers_mut(); + + // update the stored keepalive_interval + timers.keepalive_interval = secs; + + // stop the keepalive timer with the old interval + timers.send_persistent_keepalive.stop(); + + // restart the persistent_keepalive timer with the new interval + if secs > 0 && timers.enabled { + timers.send_persistent_keepalive.start(Duration::from_secs(secs)); } } @@ -184,8 +214,6 @@ impl PeerInner { } self.packet_send_handshake_initiation(); } - - } @@ -198,12 +226,20 @@ impl Timers { // create a timer instance for the provided peer Timers { enabled: true, + keepalive_interval: 0, // disabled need_another_keepalive: AtomicBool::new(false), sent_lastminute_handshake: AtomicBool::new(false), handshake_attempts: AtomicUsize::new(0), retransmit_handshake: { let peer = peer.clone(); runner.timer(move || { + // ignore if timers are disabled + let timers = peer.timers(); + if !timers.enabled { + return; + } + + // check if handshake attempts remaining let attempts = peer.timers().handshake_attempts.fetch_add(1, Ordering::SeqCst); if attempts > MAX_TIMER_HANDSHAKES { debug!( @@ -211,9 +247,9 @@ impl Timers { peer, attempts + 1 ); + timers.send_keepalive.stop(); + timers.zero_key_material.start(REJECT_AFTER_TIME * 3); peer.router.purge_staged_packets(); - peer.timers().send_keepalive.stop(); - peer.timers().zero_key_material.start(REJECT_AFTER_TIME * 3); } else { debug!( "Handshake for {} did not complete after {} seconds, retrying (try {})", @@ -221,8 +257,8 @@ impl Timers { REKEY_TIMEOUT.as_secs(), attempts ); + timers.retransmit_handshake.reset(REKEY_TIMEOUT); peer.router.clear_src(); - peer.timers().retransmit_handshake.reset(REKEY_TIMEOUT); peer.packet_send_queued_handshake_initiation(true); } }) @@ -230,9 +266,15 @@ impl Timers { send_keepalive: { let peer = peer.clone(); runner.timer(move || { + // ignore if timers are disabled + let timers = peer.timers(); + if !timers.enabled { + return; + } + peer.router.send_keepalive(); - if peer.timers().need_another_keepalive() { - peer.timers().send_keepalive.start(KEEPALIVE_TIMEOUT); + if timers.need_another_keepalive() { + timers.send_keepalive.start(KEEPALIVE_TIMEOUT); } }) }, @@ -257,29 +299,23 @@ impl Timers { send_persistent_keepalive: { let peer = peer.clone(); runner.timer(move || { - let keepalive = peer.state.keepalive_interval.load(Ordering::Acquire); - if keepalive > 0 { + let timers = peer.timers(); + if timers.enabled && timers.keepalive_interval > 0 { peer.router.send_keepalive(); - peer.timers().send_keepalive.stop(); - peer.timers() - .send_persistent_keepalive - .start(Duration::from_secs(keepalive as u64)); + timers.send_keepalive.stop(); + timers.send_persistent_keepalive.start(Duration::from_secs( + timers.keepalive_interval + )); } }) }, } } - pub fn updated_persistent_keepalive(&self, keepalive: usize) { - if keepalive > 0 { - self.send_persistent_keepalive - .reset(Duration::from_secs(keepalive as u64)); - } - } - pub fn dummy(runner: &Runner) -> Timers { Timers { enabled: false, + keepalive_interval: 0, need_another_keepalive: AtomicBool::new(false), sent_lastminute_handshake: AtomicBool::new(false), handshake_attempts: AtomicUsize::new(0), @@ -290,10 +326,6 @@ impl Timers { zero_key_material: runner.timer(|| {}), } } - - pub fn handshake_sent(&self) { - self.send_keepalive.stop(); - } } /* Instance of the router callbacks */ diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index 722f64a..6da428c 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -130,6 +130,9 @@ impl Wireguard { // ensure exclusive access (to avoid race with "up" call) let peers = self.peers.write(); + // avoid tranmission from router + self.router.down(); + // set all peers down (stops timers) for peer in peers.values() { peer.down(); @@ -142,6 +145,9 @@ impl Wireguard { // ensure exclusive access (to avoid race with "down" call) let peers = self.peers.write(); + // enable tranmission from router + self.router.up(); + // set all peers up (restarts timers) for peer in peers.values() { peer.up(); @@ -215,7 +221,6 @@ impl Wireguard { last_handshake_sent: Mutex::new(self.state.start - TIME_HORIZON), handshake_queued: AtomicBool::new(false), queue: Mutex::new(self.state.queue.lock().clone()), - keepalive_interval: AtomicU64::new(0), // disabled rx_bytes: AtomicU64::new(0), tx_bytes: AtomicU64::new(0), timers: RwLock::new(Timers::dummy(&self.runner)), -- cgit v1.2.3-59-g8ed1b