diff options
Diffstat (limited to 'src/wireguard')
-rw-r--r-- | src/wireguard/peer.rs | 19 | ||||
-rw-r--r-- | src/wireguard/router/device.rs | 18 | ||||
-rw-r--r-- | src/wireguard/router/peer.rs | 32 | ||||
-rw-r--r-- | src/wireguard/router/tests.rs | 2 | ||||
-rw-r--r-- | src/wireguard/timers.rs | 148 | ||||
-rw-r--r-- | src/wireguard/wireguard.rs | 7 |
6 files changed, 138 insertions, 88 deletions
diff --git a/src/wireguard/peer.rs b/src/wireguard/peer.rs index 9f24dea..b77e8c6 100644 --- a/src/wireguard/peer.rs +++ b/src/wireguard/peer.rs @@ -4,12 +4,11 @@ use super::timers::{Events, Timers}; use super::HandshakeJob; use super::bind::Bind; -use super::bind::Reader as BindReader; -use super::tun::{Reader, Tun}; +use super::tun::Tun; use std::fmt; use std::ops::Deref; -use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::Arc; use std::time::{Instant, SystemTime}; @@ -34,8 +33,7 @@ pub struct PeerInner<B: Bind> { pub queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, // handshake queue // stats and configuration - pub pk: PublicKey, // public key, DISCUSS: avoid this. TODO: remove - pub keepalive_interval: AtomicU64, // keepalive interval + pub pk: PublicKey, // public key, DISCUSS: avoid this. TODO: remove pub rx_bytes: AtomicU64, // received bytes pub tx_bytes: AtomicU64, // transmitted bytes @@ -78,11 +76,20 @@ impl<T: Tun, B: Bind> Deref for Peer<T, B> { } impl<T: Tun, B: Bind> Peer<T, B> { + /// Bring the peer down. Causing: + /// + /// - Timers to be stopped and disabled. + /// - All keystate to be zeroed pub fn down(&self) { self.stop_timers(); + self.router.down(); } - pub fn up(&self) {} + /// Bring the peer up. + pub fn up(&self) { + self.router.up(); + self.start_timers(); + } } impl<B: Bind> PeerInner<B> { diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs index a5028e1..b3f1787 100644 --- a/src/wireguard/router/device.rs +++ b/src/wireguard/router/device.rs @@ -27,13 +27,11 @@ use super::route::get_route; use super::super::{bind, tun, Endpoint, KeyPair}; pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { - pub enabled: AtomicBool, - // inbound writer (TUN) pub inbound: T, // outbound writer (Bind) - pub outbound: RwLock<Option<B>>, + pub outbound: RwLock<(bool, Option<B>)>, // routing pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state @@ -93,8 +91,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C, // allocate shared device state let inner = DeviceInner { inbound: tun, - enabled: AtomicBool::new(true), - outbound: RwLock::new(None), + outbound: RwLock::new((true, None)), queues: Mutex::new(Vec::with_capacity(num_workers)), queue_next: AtomicUsize::new(0), recv: RwLock::new(HashMap::new()), @@ -120,12 +117,15 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C, /// Brings the router down. /// When the router is brought down it: /// - Prevents transmission of outbound messages. - /// - Erases all key state (key-wheels) of all peers - pub fn down(&self) {} + pub fn down(&self) { + self.state.outbound.write().0 = false; + } /// Brints the router up /// When the router is brought up it enables the transmission of outbound messages. - pub fn up(&self) {} + pub fn up(&self) { + self.state.outbound.write().0 = true; + } /// A new secret key has been set for the device. /// According to WireGuard semantics, this should cause all "sending" keys to be discarded. @@ -209,6 +209,6 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C, /// /// pub fn set_outbound_writer(&self, new: B) { - *self.state.outbound.write() = Some(new); + self.state.outbound.write().1 = Some(new); } } diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs index 7527a60..0d9b435 100644 --- a/src/wireguard/router/peer.rs +++ b/src/wireguard/router/peer.rs @@ -206,7 +206,6 @@ pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( // spawn outbound thread let thread_inbound = { let peer = peer.clone(); - let device = device.clone(); thread::spawn(move || worker_outbound(peer, out_rx)) }; @@ -237,24 +236,25 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> PeerInner<E, pub fn send(&self, msg: &[u8]) -> Result<(), RouterError> { debug!("peer.send"); - // check if device is enabled - if !self.device.enabled.load(Ordering::Acquire) { - return Ok(()); - } - // send to endpoint (if known) match self.endpoint.lock().as_ref() { - Some(endpoint) => self - .device - .outbound - .read() - .as_ref() - .ok_or(RouterError::SendError) - .and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError)), + Some(endpoint) => { + let outbound = self.device.outbound.read(); + if outbound.0 { + outbound + .1 + .as_ref() + .ok_or(RouterError::SendError) + .and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError)) + } else { + Ok(()) + } + } None => Err(RouterError::NoEndpoint), } } + // Transmit all staged packets fn send_staged(&self) -> bool { debug!("peer.send_staged"); let mut sent = false; @@ -451,6 +451,12 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T *self.state.ekey.lock() = None; } + pub fn down(&self) { + self.zero_keys(); + } + + pub fn up(&self) {} + /// Add a new keypair /// /// # Arguments diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs index d5a1133..d14b438 100644 --- a/src/wireguard/router/tests.rs +++ b/src/wireguard/router/tests.rs @@ -3,7 +3,7 @@ use std::sync::atomic::Ordering; use std::sync::Arc; use std::sync::Mutex; use std::thread; -use std::time::{Duration, Instant}; +use std::time::Duration; use num_cpus; diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs index e844f4d..038f6c6 100644 --- a/src/wireguard/timers.rs +++ b/src/wireguard/timers.rs @@ -13,7 +13,9 @@ use super::{bind, tun}; use super::types::KeyPair; pub struct Timers { + // only updated during configuration enabled: bool, + keepalive_interval: u64, handshake_attempts: AtomicUsize, sent_lastminute_handshake: AtomicBool, @@ -68,27 +70,26 @@ impl<B: bind::Bind> PeerInner<B> { timers.enabled = true; // start send_persistent_keepalive - let interval = self.keepalive_interval.load(Ordering::Acquire); - if interval > 0 { + if timers.keepalive_interval > 0 { timers.send_persistent_keepalive.start( - Duration::from_secs(interval) + Duration::from_secs(timers.keepalive_interval) ); } } /* should be called after an authenticated data packet is sent */ pub fn timers_data_sent(&self) { - self.timers() - .new_handshake - .start(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT); + let timers = self.timers(); + if timers.enabled { + timers.new_handshake.start(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT); + } } /* should be called after an authenticated data packet is received */ pub fn timers_data_received(&self) { - if !self.timers().send_keepalive.start(KEEPALIVE_TIMEOUT) { - self.timers() - .need_another_keepalive - .store(true, Ordering::SeqCst) + let timers = self.timers(); + if timers.enabled && !timers.send_keepalive.start(KEEPALIVE_TIMEOUT) { + timers.need_another_keepalive.store(true, Ordering::SeqCst) } } @@ -98,7 +99,10 @@ impl<B: bind::Bind> PeerInner<B> { * - handshake */ pub fn timers_any_authenticated_packet_sent(&self) { - self.timers().send_keepalive.stop() + let timers = self.timers(); + if timers.enabled { + timers.send_keepalive.stop() + } } /* Should be called after any type of authenticated packet is received, whether: @@ -107,48 +111,69 @@ impl<B: bind::Bind> PeerInner<B> { * - handshake */ pub fn timers_any_authenticated_packet_received(&self) { - self.timers().new_handshake.stop(); + let timers = self.timers(); + if timers.enabled { + timers.new_handshake.stop(); + } } /* Should be called after a handshake initiation message is sent. */ pub fn timers_handshake_initiated(&self) { - self.timers().send_keepalive.stop(); - self.timers().retransmit_handshake.reset(REKEY_TIMEOUT); + let timers = self.timers(); + if timers.enabled { + timers.send_keepalive.stop(); + timers.retransmit_handshake.reset(REKEY_TIMEOUT); + } } /* Should be called after a handshake response message is received and processed * or when getting key confirmation via the first data message. */ pub fn timers_handshake_complete(&self) { - self.timers().handshake_attempts.store(0, Ordering::SeqCst); - self.timers() - .sent_lastminute_handshake - .store(false, Ordering::SeqCst); - *self.walltime_last_handshake.lock() = SystemTime::now(); + let timers = self.timers(); + if timers.enabled { + timers.handshake_attempts.store(0, Ordering::SeqCst); + timers.sent_lastminute_handshake.store(false, Ordering::SeqCst); + *self.walltime_last_handshake.lock() = SystemTime::now(); + } } /* Should be called after an ephemeral key is created, which is before sending a * handshake response or after receiving a handshake response. */ pub fn timers_session_derived(&self) { - self.timers().zero_key_material.reset(REJECT_AFTER_TIME * 3); + let timers = self.timers(); + if timers.enabled { + timers.zero_key_material.reset(REJECT_AFTER_TIME * 3); + } } /* Should be called before a packet with authentication, whether * keepalive, data, or handshake is sent, or after one is received. */ pub fn timers_any_authenticated_packet_traversal(&self) { - let keepalive = self.keepalive_interval.load(Ordering::Acquire); - if keepalive > 0 { + let timers = self.timers(); + if timers.enabled && timers.keepalive_interval > 0 { // push persistent_keepalive into the future - self.timers() - .send_persistent_keepalive - .reset(Duration::from_secs(keepalive as u64)); + timers.send_persistent_keepalive.reset(Duration::from_secs( + timers.keepalive_interval + )); } } pub fn timers_session_derieved(&self) { - self.timers().zero_key_material.reset(REJECT_AFTER_TIME * 3); + let timers = self.timers(); + if timers.enabled { + timers.zero_key_material.reset(REJECT_AFTER_TIME * 3); + } + } + + fn timers_set_retransmit_handshake(&self) { + let timers = self.timers(); + if timers.enabled { + timers.retransmit_handshake.reset(REKEY_TIMEOUT); + } + } /* Called after a handshake worker sends a handshake initiation to the peer @@ -156,7 +181,7 @@ impl<B: bind::Bind> PeerInner<B> { pub fn sent_handshake_initiation(&self) { *self.last_handshake_sent.lock() = Instant::now(); self.handshake_queued.store(false, Ordering::SeqCst); - self.timers().retransmit_handshake.reset(REKEY_TIMEOUT); + self.timers_set_retransmit_handshake(); self.timers_any_authenticated_packet_traversal(); self.timers_any_authenticated_packet_sent(); } @@ -168,13 +193,18 @@ impl<B: bind::Bind> PeerInner<B> { } - pub fn set_persistent_keepalive_interval(&self, interval: u64) { - self.timers().send_persistent_keepalive.stop(); - self.keepalive_interval.store(interval, Ordering::SeqCst); - if interval > 0 { - self.timers() - .send_persistent_keepalive - .start(Duration::from_secs(interval as u64)); + pub fn set_persistent_keepalive_interval(&self, secs: u64) { + let mut timers = self.timers_mut(); + + // update the stored keepalive_interval + timers.keepalive_interval = secs; + + // stop the keepalive timer with the old interval + timers.send_persistent_keepalive.stop(); + + // restart the persistent_keepalive timer with the new interval + if secs > 0 && timers.enabled { + timers.send_persistent_keepalive.start(Duration::from_secs(secs)); } } @@ -184,8 +214,6 @@ impl<B: bind::Bind> PeerInner<B> { } self.packet_send_handshake_initiation(); } - - } @@ -198,12 +226,20 @@ impl Timers { // create a timer instance for the provided peer Timers { enabled: true, + keepalive_interval: 0, // disabled need_another_keepalive: AtomicBool::new(false), sent_lastminute_handshake: AtomicBool::new(false), handshake_attempts: AtomicUsize::new(0), retransmit_handshake: { let peer = peer.clone(); runner.timer(move || { + // ignore if timers are disabled + let timers = peer.timers(); + if !timers.enabled { + return; + } + + // check if handshake attempts remaining let attempts = peer.timers().handshake_attempts.fetch_add(1, Ordering::SeqCst); if attempts > MAX_TIMER_HANDSHAKES { debug!( @@ -211,9 +247,9 @@ impl Timers { peer, attempts + 1 ); + timers.send_keepalive.stop(); + timers.zero_key_material.start(REJECT_AFTER_TIME * 3); peer.router.purge_staged_packets(); - peer.timers().send_keepalive.stop(); - peer.timers().zero_key_material.start(REJECT_AFTER_TIME * 3); } else { debug!( "Handshake for {} did not complete after {} seconds, retrying (try {})", @@ -221,8 +257,8 @@ impl Timers { REKEY_TIMEOUT.as_secs(), attempts ); + timers.retransmit_handshake.reset(REKEY_TIMEOUT); peer.router.clear_src(); - peer.timers().retransmit_handshake.reset(REKEY_TIMEOUT); peer.packet_send_queued_handshake_initiation(true); } }) @@ -230,9 +266,15 @@ impl Timers { send_keepalive: { let peer = peer.clone(); runner.timer(move || { + // ignore if timers are disabled + let timers = peer.timers(); + if !timers.enabled { + return; + } + peer.router.send_keepalive(); - if peer.timers().need_another_keepalive() { - peer.timers().send_keepalive.start(KEEPALIVE_TIMEOUT); + if timers.need_another_keepalive() { + timers.send_keepalive.start(KEEPALIVE_TIMEOUT); } }) }, @@ -257,29 +299,23 @@ impl Timers { send_persistent_keepalive: { let peer = peer.clone(); runner.timer(move || { - let keepalive = peer.state.keepalive_interval.load(Ordering::Acquire); - if keepalive > 0 { + let timers = peer.timers(); + if timers.enabled && timers.keepalive_interval > 0 { peer.router.send_keepalive(); - peer.timers().send_keepalive.stop(); - peer.timers() - .send_persistent_keepalive - .start(Duration::from_secs(keepalive as u64)); + timers.send_keepalive.stop(); + timers.send_persistent_keepalive.start(Duration::from_secs( + timers.keepalive_interval + )); } }) }, } } - pub fn updated_persistent_keepalive(&self, keepalive: usize) { - if keepalive > 0 { - self.send_persistent_keepalive - .reset(Duration::from_secs(keepalive as u64)); - } - } - pub fn dummy(runner: &Runner) -> Timers { Timers { enabled: false, + keepalive_interval: 0, need_another_keepalive: AtomicBool::new(false), sent_lastminute_handshake: AtomicBool::new(false), handshake_attempts: AtomicUsize::new(0), @@ -290,10 +326,6 @@ impl Timers { zero_key_material: runner.timer(|| {}), } } - - pub fn handshake_sent(&self) { - self.send_keepalive.stop(); - } } /* Instance of the router callbacks */ diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index 722f64a..6da428c 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -130,6 +130,9 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { // ensure exclusive access (to avoid race with "up" call) let peers = self.peers.write(); + // avoid tranmission from router + self.router.down(); + // set all peers down (stops timers) for peer in peers.values() { peer.down(); @@ -142,6 +145,9 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { // ensure exclusive access (to avoid race with "down" call) let peers = self.peers.write(); + // enable tranmission from router + self.router.up(); + // set all peers up (restarts timers) for peer in peers.values() { peer.up(); @@ -215,7 +221,6 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { last_handshake_sent: Mutex::new(self.state.start - TIME_HORIZON), handshake_queued: AtomicBool::new(false), queue: Mutex::new(self.state.queue.lock().clone()), - keepalive_interval: AtomicU64::new(0), // disabled rx_bytes: AtomicU64::new(0), tx_bytes: AtomicU64::new(0), timers: RwLock::new(Timers::dummy(&self.runner)), |