aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-11-06 13:50:38 +0100
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-11-06 13:50:38 +0100
commit293914e47b046f862608a1af91864b6b38336aa5 (patch)
treec6851f4c0e8cd38efdcbc2aa6999395f67f1e555
parentWork on Up/Down operation on WireGuard device (diff)
downloadwireguard-rs-293914e47b046f862608a1af91864b6b38336aa5.tar.xz
wireguard-rs-293914e47b046f862608a1af91864b6b38336aa5.zip
Implement disable/enable timers
-rw-r--r--src/wireguard/peer.rs19
-rw-r--r--src/wireguard/router/device.rs18
-rw-r--r--src/wireguard/router/peer.rs32
-rw-r--r--src/wireguard/router/tests.rs2
-rw-r--r--src/wireguard/timers.rs148
-rw-r--r--src/wireguard/wireguard.rs7
6 files changed, 138 insertions, 88 deletions
diff --git a/src/wireguard/peer.rs b/src/wireguard/peer.rs
index 9f24dea..b77e8c6 100644
--- a/src/wireguard/peer.rs
+++ b/src/wireguard/peer.rs
@@ -4,12 +4,11 @@ use super::timers::{Events, Timers};
use super::HandshakeJob;
use super::bind::Bind;
-use super::bind::Reader as BindReader;
-use super::tun::{Reader, Tun};
+use super::tun::Tun;
use std::fmt;
use std::ops::Deref;
-use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
+use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Instant, SystemTime};
@@ -34,8 +33,7 @@ pub struct PeerInner<B: Bind> {
pub queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, // handshake queue
// stats and configuration
- pub pk: PublicKey, // public key, DISCUSS: avoid this. TODO: remove
- pub keepalive_interval: AtomicU64, // keepalive interval
+ pub pk: PublicKey, // public key, DISCUSS: avoid this. TODO: remove
pub rx_bytes: AtomicU64, // received bytes
pub tx_bytes: AtomicU64, // transmitted bytes
@@ -78,11 +76,20 @@ impl<T: Tun, B: Bind> Deref for Peer<T, B> {
}
impl<T: Tun, B: Bind> Peer<T, B> {
+ /// Bring the peer down. Causing:
+ ///
+ /// - Timers to be stopped and disabled.
+ /// - All keystate to be zeroed
pub fn down(&self) {
self.stop_timers();
+ self.router.down();
}
- pub fn up(&self) {}
+ /// Bring the peer up.
+ pub fn up(&self) {
+ self.router.up();
+ self.start_timers();
+ }
}
impl<B: Bind> PeerInner<B> {
diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs
index a5028e1..b3f1787 100644
--- a/src/wireguard/router/device.rs
+++ b/src/wireguard/router/device.rs
@@ -27,13 +27,11 @@ use super::route::get_route;
use super::super::{bind, tun, Endpoint, KeyPair};
pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
- pub enabled: AtomicBool,
-
// inbound writer (TUN)
pub inbound: T,
// outbound writer (Bind)
- pub outbound: RwLock<Option<B>>,
+ pub outbound: RwLock<(bool, Option<B>)>,
// routing
pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state
@@ -93,8 +91,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C,
// allocate shared device state
let inner = DeviceInner {
inbound: tun,
- enabled: AtomicBool::new(true),
- outbound: RwLock::new(None),
+ outbound: RwLock::new((true, None)),
queues: Mutex::new(Vec::with_capacity(num_workers)),
queue_next: AtomicUsize::new(0),
recv: RwLock::new(HashMap::new()),
@@ -120,12 +117,15 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C,
/// Brings the router down.
/// When the router is brought down it:
/// - Prevents transmission of outbound messages.
- /// - Erases all key state (key-wheels) of all peers
- pub fn down(&self) {}
+ pub fn down(&self) {
+ self.state.outbound.write().0 = false;
+ }
/// Brints the router up
/// When the router is brought up it enables the transmission of outbound messages.
- pub fn up(&self) {}
+ pub fn up(&self) {
+ self.state.outbound.write().0 = true;
+ }
/// A new secret key has been set for the device.
/// According to WireGuard semantics, this should cause all "sending" keys to be discarded.
@@ -209,6 +209,6 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C,
///
///
pub fn set_outbound_writer(&self, new: B) {
- *self.state.outbound.write() = Some(new);
+ self.state.outbound.write().1 = Some(new);
}
}
diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs
index 7527a60..0d9b435 100644
--- a/src/wireguard/router/peer.rs
+++ b/src/wireguard/router/peer.rs
@@ -206,7 +206,6 @@ pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
// spawn outbound thread
let thread_inbound = {
let peer = peer.clone();
- let device = device.clone();
thread::spawn(move || worker_outbound(peer, out_rx))
};
@@ -237,24 +236,25 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> PeerInner<E,
pub fn send(&self, msg: &[u8]) -> Result<(), RouterError> {
debug!("peer.send");
- // check if device is enabled
- if !self.device.enabled.load(Ordering::Acquire) {
- return Ok(());
- }
-
// send to endpoint (if known)
match self.endpoint.lock().as_ref() {
- Some(endpoint) => self
- .device
- .outbound
- .read()
- .as_ref()
- .ok_or(RouterError::SendError)
- .and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError)),
+ Some(endpoint) => {
+ let outbound = self.device.outbound.read();
+ if outbound.0 {
+ outbound
+ .1
+ .as_ref()
+ .ok_or(RouterError::SendError)
+ .and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError))
+ } else {
+ Ok(())
+ }
+ }
None => Err(RouterError::NoEndpoint),
}
}
+ // Transmit all staged packets
fn send_staged(&self) -> bool {
debug!("peer.send_staged");
let mut sent = false;
@@ -451,6 +451,12 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T
*self.state.ekey.lock() = None;
}
+ pub fn down(&self) {
+ self.zero_keys();
+ }
+
+ pub fn up(&self) {}
+
/// Add a new keypair
///
/// # Arguments
diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs
index d5a1133..d14b438 100644
--- a/src/wireguard/router/tests.rs
+++ b/src/wireguard/router/tests.rs
@@ -3,7 +3,7 @@ use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::sync::Mutex;
use std::thread;
-use std::time::{Duration, Instant};
+use std::time::Duration;
use num_cpus;
diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs
index e844f4d..038f6c6 100644
--- a/src/wireguard/timers.rs
+++ b/src/wireguard/timers.rs
@@ -13,7 +13,9 @@ use super::{bind, tun};
use super::types::KeyPair;
pub struct Timers {
+ // only updated during configuration
enabled: bool,
+ keepalive_interval: u64,
handshake_attempts: AtomicUsize,
sent_lastminute_handshake: AtomicBool,
@@ -68,27 +70,26 @@ impl<B: bind::Bind> PeerInner<B> {
timers.enabled = true;
// start send_persistent_keepalive
- let interval = self.keepalive_interval.load(Ordering::Acquire);
- if interval > 0 {
+ if timers.keepalive_interval > 0 {
timers.send_persistent_keepalive.start(
- Duration::from_secs(interval)
+ Duration::from_secs(timers.keepalive_interval)
);
}
}
/* should be called after an authenticated data packet is sent */
pub fn timers_data_sent(&self) {
- self.timers()
- .new_handshake
- .start(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT);
+ let timers = self.timers();
+ if timers.enabled {
+ timers.new_handshake.start(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT);
+ }
}
/* should be called after an authenticated data packet is received */
pub fn timers_data_received(&self) {
- if !self.timers().send_keepalive.start(KEEPALIVE_TIMEOUT) {
- self.timers()
- .need_another_keepalive
- .store(true, Ordering::SeqCst)
+ let timers = self.timers();
+ if timers.enabled && !timers.send_keepalive.start(KEEPALIVE_TIMEOUT) {
+ timers.need_another_keepalive.store(true, Ordering::SeqCst)
}
}
@@ -98,7 +99,10 @@ impl<B: bind::Bind> PeerInner<B> {
* - handshake
*/
pub fn timers_any_authenticated_packet_sent(&self) {
- self.timers().send_keepalive.stop()
+ let timers = self.timers();
+ if timers.enabled {
+ timers.send_keepalive.stop()
+ }
}
/* Should be called after any type of authenticated packet is received, whether:
@@ -107,48 +111,69 @@ impl<B: bind::Bind> PeerInner<B> {
* - handshake
*/
pub fn timers_any_authenticated_packet_received(&self) {
- self.timers().new_handshake.stop();
+ let timers = self.timers();
+ if timers.enabled {
+ timers.new_handshake.stop();
+ }
}
/* Should be called after a handshake initiation message is sent. */
pub fn timers_handshake_initiated(&self) {
- self.timers().send_keepalive.stop();
- self.timers().retransmit_handshake.reset(REKEY_TIMEOUT);
+ let timers = self.timers();
+ if timers.enabled {
+ timers.send_keepalive.stop();
+ timers.retransmit_handshake.reset(REKEY_TIMEOUT);
+ }
}
/* Should be called after a handshake response message is received and processed
* or when getting key confirmation via the first data message.
*/
pub fn timers_handshake_complete(&self) {
- self.timers().handshake_attempts.store(0, Ordering::SeqCst);
- self.timers()
- .sent_lastminute_handshake
- .store(false, Ordering::SeqCst);
- *self.walltime_last_handshake.lock() = SystemTime::now();
+ let timers = self.timers();
+ if timers.enabled {
+ timers.handshake_attempts.store(0, Ordering::SeqCst);
+ timers.sent_lastminute_handshake.store(false, Ordering::SeqCst);
+ *self.walltime_last_handshake.lock() = SystemTime::now();
+ }
}
/* Should be called after an ephemeral key is created, which is before sending a
* handshake response or after receiving a handshake response.
*/
pub fn timers_session_derived(&self) {
- self.timers().zero_key_material.reset(REJECT_AFTER_TIME * 3);
+ let timers = self.timers();
+ if timers.enabled {
+ timers.zero_key_material.reset(REJECT_AFTER_TIME * 3);
+ }
}
/* Should be called before a packet with authentication, whether
* keepalive, data, or handshake is sent, or after one is received.
*/
pub fn timers_any_authenticated_packet_traversal(&self) {
- let keepalive = self.keepalive_interval.load(Ordering::Acquire);
- if keepalive > 0 {
+ let timers = self.timers();
+ if timers.enabled && timers.keepalive_interval > 0 {
// push persistent_keepalive into the future
- self.timers()
- .send_persistent_keepalive
- .reset(Duration::from_secs(keepalive as u64));
+ timers.send_persistent_keepalive.reset(Duration::from_secs(
+ timers.keepalive_interval
+ ));
}
}
pub fn timers_session_derieved(&self) {
- self.timers().zero_key_material.reset(REJECT_AFTER_TIME * 3);
+ let timers = self.timers();
+ if timers.enabled {
+ timers.zero_key_material.reset(REJECT_AFTER_TIME * 3);
+ }
+ }
+
+ fn timers_set_retransmit_handshake(&self) {
+ let timers = self.timers();
+ if timers.enabled {
+ timers.retransmit_handshake.reset(REKEY_TIMEOUT);
+ }
+
}
/* Called after a handshake worker sends a handshake initiation to the peer
@@ -156,7 +181,7 @@ impl<B: bind::Bind> PeerInner<B> {
pub fn sent_handshake_initiation(&self) {
*self.last_handshake_sent.lock() = Instant::now();
self.handshake_queued.store(false, Ordering::SeqCst);
- self.timers().retransmit_handshake.reset(REKEY_TIMEOUT);
+ self.timers_set_retransmit_handshake();
self.timers_any_authenticated_packet_traversal();
self.timers_any_authenticated_packet_sent();
}
@@ -168,13 +193,18 @@ impl<B: bind::Bind> PeerInner<B> {
}
- pub fn set_persistent_keepalive_interval(&self, interval: u64) {
- self.timers().send_persistent_keepalive.stop();
- self.keepalive_interval.store(interval, Ordering::SeqCst);
- if interval > 0 {
- self.timers()
- .send_persistent_keepalive
- .start(Duration::from_secs(interval as u64));
+ pub fn set_persistent_keepalive_interval(&self, secs: u64) {
+ let mut timers = self.timers_mut();
+
+ // update the stored keepalive_interval
+ timers.keepalive_interval = secs;
+
+ // stop the keepalive timer with the old interval
+ timers.send_persistent_keepalive.stop();
+
+ // restart the persistent_keepalive timer with the new interval
+ if secs > 0 && timers.enabled {
+ timers.send_persistent_keepalive.start(Duration::from_secs(secs));
}
}
@@ -184,8 +214,6 @@ impl<B: bind::Bind> PeerInner<B> {
}
self.packet_send_handshake_initiation();
}
-
-
}
@@ -198,12 +226,20 @@ impl Timers {
// create a timer instance for the provided peer
Timers {
enabled: true,
+ keepalive_interval: 0, // disabled
need_another_keepalive: AtomicBool::new(false),
sent_lastminute_handshake: AtomicBool::new(false),
handshake_attempts: AtomicUsize::new(0),
retransmit_handshake: {
let peer = peer.clone();
runner.timer(move || {
+ // ignore if timers are disabled
+ let timers = peer.timers();
+ if !timers.enabled {
+ return;
+ }
+
+ // check if handshake attempts remaining
let attempts = peer.timers().handshake_attempts.fetch_add(1, Ordering::SeqCst);
if attempts > MAX_TIMER_HANDSHAKES {
debug!(
@@ -211,9 +247,9 @@ impl Timers {
peer,
attempts + 1
);
+ timers.send_keepalive.stop();
+ timers.zero_key_material.start(REJECT_AFTER_TIME * 3);
peer.router.purge_staged_packets();
- peer.timers().send_keepalive.stop();
- peer.timers().zero_key_material.start(REJECT_AFTER_TIME * 3);
} else {
debug!(
"Handshake for {} did not complete after {} seconds, retrying (try {})",
@@ -221,8 +257,8 @@ impl Timers {
REKEY_TIMEOUT.as_secs(),
attempts
);
+ timers.retransmit_handshake.reset(REKEY_TIMEOUT);
peer.router.clear_src();
- peer.timers().retransmit_handshake.reset(REKEY_TIMEOUT);
peer.packet_send_queued_handshake_initiation(true);
}
})
@@ -230,9 +266,15 @@ impl Timers {
send_keepalive: {
let peer = peer.clone();
runner.timer(move || {
+ // ignore if timers are disabled
+ let timers = peer.timers();
+ if !timers.enabled {
+ return;
+ }
+
peer.router.send_keepalive();
- if peer.timers().need_another_keepalive() {
- peer.timers().send_keepalive.start(KEEPALIVE_TIMEOUT);
+ if timers.need_another_keepalive() {
+ timers.send_keepalive.start(KEEPALIVE_TIMEOUT);
}
})
},
@@ -257,29 +299,23 @@ impl Timers {
send_persistent_keepalive: {
let peer = peer.clone();
runner.timer(move || {
- let keepalive = peer.state.keepalive_interval.load(Ordering::Acquire);
- if keepalive > 0 {
+ let timers = peer.timers();
+ if timers.enabled && timers.keepalive_interval > 0 {
peer.router.send_keepalive();
- peer.timers().send_keepalive.stop();
- peer.timers()
- .send_persistent_keepalive
- .start(Duration::from_secs(keepalive as u64));
+ timers.send_keepalive.stop();
+ timers.send_persistent_keepalive.start(Duration::from_secs(
+ timers.keepalive_interval
+ ));
}
})
},
}
}
- pub fn updated_persistent_keepalive(&self, keepalive: usize) {
- if keepalive > 0 {
- self.send_persistent_keepalive
- .reset(Duration::from_secs(keepalive as u64));
- }
- }
-
pub fn dummy(runner: &Runner) -> Timers {
Timers {
enabled: false,
+ keepalive_interval: 0,
need_another_keepalive: AtomicBool::new(false),
sent_lastminute_handshake: AtomicBool::new(false),
handshake_attempts: AtomicUsize::new(0),
@@ -290,10 +326,6 @@ impl Timers {
zero_key_material: runner.timer(|| {}),
}
}
-
- pub fn handshake_sent(&self) {
- self.send_keepalive.stop();
- }
}
/* Instance of the router callbacks */
diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs
index 722f64a..6da428c 100644
--- a/src/wireguard/wireguard.rs
+++ b/src/wireguard/wireguard.rs
@@ -130,6 +130,9 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
// ensure exclusive access (to avoid race with "up" call)
let peers = self.peers.write();
+ // avoid tranmission from router
+ self.router.down();
+
// set all peers down (stops timers)
for peer in peers.values() {
peer.down();
@@ -142,6 +145,9 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
// ensure exclusive access (to avoid race with "down" call)
let peers = self.peers.write();
+ // enable tranmission from router
+ self.router.up();
+
// set all peers up (restarts timers)
for peer in peers.values() {
peer.up();
@@ -215,7 +221,6 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
last_handshake_sent: Mutex::new(self.state.start - TIME_HORIZON),
handshake_queued: AtomicBool::new(false),
queue: Mutex::new(self.state.queue.lock().clone()),
- keepalive_interval: AtomicU64::new(0), // disabled
rx_bytes: AtomicU64::new(0),
tx_bytes: AtomicU64::new(0),
timers: RwLock::new(Timers::dummy(&self.runner)),