summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-10-31 17:11:09 +0100
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-10-31 17:11:09 +0100
commitb25c21885bf97e74802549e3ac22f57bc0c44d76 (patch)
treed35eb556666846045e434e27f91648fa94bebd46
parentRemove unused dependencies (diff)
downloadwireguard-rs-b25c21885bf97e74802549e3ac22f57bc0c44d76.tar.xz
wireguard-rs-b25c21885bf97e74802549e3ac22f57bc0c44d76.zip
Work on timer semantics
-rw-r--r--src/wireguard/constants.rs6
-rw-r--r--src/wireguard/endpoint.rs29
-rw-r--r--src/wireguard/mod.rs1
-rw-r--r--src/wireguard/router/mod.rs3
-rw-r--r--src/wireguard/router/peer.rs6
-rw-r--r--src/wireguard/router/tests.rs17
-rw-r--r--src/wireguard/router/types.rs7
-rw-r--r--src/wireguard/router/workers.rs39
-rw-r--r--src/wireguard/timers.rs113
-rw-r--r--src/wireguard/wireguard.rs40
10 files changed, 181 insertions, 80 deletions
diff --git a/src/wireguard/constants.rs b/src/wireguard/constants.rs
index ec60801..c53c559 100644
--- a/src/wireguard/constants.rs
+++ b/src/wireguard/constants.rs
@@ -18,3 +18,9 @@ pub const TIMERS_SLOTS: usize = (TIMER_MAX_DURATION.as_micros() / TIMERS_TICK.as
pub const TIMERS_CAPACITY: usize = 1024;
pub const MESSAGE_PADDING_MULTIPLE: usize = 16;
+
+/* A long duration (compared to the WireGuard time constants),
+ * used in places to avoid Option<Instant> by instead using a long "expired" Instant:
+ * (Instant::now() - TIME_HORIZON)
+ */
+pub const TIME_HORIZON: Duration = Duration::from_secs(3600 * 24);
diff --git a/src/wireguard/endpoint.rs b/src/wireguard/endpoint.rs
new file mode 100644
index 0000000..f6a560b
--- /dev/null
+++ b/src/wireguard/endpoint.rs
@@ -0,0 +1,29 @@
+use spin::{Mutex, MutexGuard};
+use std::sync::Arc;
+
+use super::super::platform::Endpoint;
+
+#[derive(Clone)]
+struct EndpointStore<E: Endpoint> {
+ endpoint: Arc<Mutex<Option<E>>>,
+}
+
+impl<E: Endpoint> EndpointStore<E> {
+ pub fn new() -> EndpointStore<E> {
+ EndpointStore {
+ endpoint: Arc::new(Mutex::new(None)),
+ }
+ }
+
+ pub fn set(&self, endpoint: E) {
+ *self.endpoint.lock() = Some(endpoint);
+ }
+
+ pub fn get(&self) -> MutexGuard<Option<E>> {
+ self.endpoint.lock()
+ }
+
+ pub fn clear_src(&self) {
+ (*self.endpoint.lock()).as_mut().map(|e| e.clear_src());
+ }
+}
diff --git a/src/wireguard/mod.rs b/src/wireguard/mod.rs
index c3e9c58..83f9e8a 100644
--- a/src/wireguard/mod.rs
+++ b/src/wireguard/mod.rs
@@ -2,6 +2,7 @@ mod constants;
mod timers;
mod wireguard;
+mod endpoint;
mod handshake;
mod router;
mod types;
diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs
index 354700a..6aa894d 100644
--- a/src/wireguard/router/mod.rs
+++ b/src/wireguard/router/mod.rs
@@ -14,7 +14,8 @@ mod tests;
use messages::TransportHeader;
use std::mem;
-use super::constants::*;
+use super::constants::REJECT_AFTER_MESSAGES;
+use super::types::*;
pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::<TransportHeader>();
pub const CAPACITY_MESSAGE_POSTFIX: usize = workers::SIZE_TAG;
diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs
index 50fdfe7..5467eb7 100644
--- a/src/wireguard/router/peer.rs
+++ b/src/wireguard/router/peer.rs
@@ -589,6 +589,12 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T
}
}
+ pub fn clear_src(&self) {
+ (*self.state.endpoint.lock())
+ .as_mut()
+ .map(|e| e.clear_src());
+ }
+
pub fn purge_staged_packets(&self) {
self.state.staged_packets.lock().clear();
}
diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs
index a14640c..1b122a8 100644
--- a/src/wireguard/router/tests.rs
+++ b/src/wireguard/router/tests.rs
@@ -3,7 +3,7 @@ use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::sync::Mutex;
use std::thread;
-use std::time::Duration;
+use std::time::{Duration, Instant};
use num_cpus;
@@ -11,6 +11,7 @@ use super::super::bind::*;
use super::super::dummy;
use super::super::dummy_keypair;
use super::super::tests::make_packet_dst;
+use super::KeyPair;
use super::SIZE_MESSAGE_PREFIX;
use super::{Callbacks, Device};
@@ -85,11 +86,11 @@ mod tests {
impl Callbacks for TestCallbacks {
type Opaque = Opaque;
- fn send(t: &Self::Opaque, size: usize, sent: bool) {
+ fn send(t: &Self::Opaque, size: usize, sent: bool, keypair: &Arc<KeyPair>, counter: u64) {
t.0.send.lock().unwrap().push((size, sent))
}
- fn recv(t: &Self::Opaque, size: usize, sent: bool) {
+ fn recv(t: &Self::Opaque, size: usize, sent: bool, keypair: &Arc<KeyPair>) {
t.0.recv.lock().unwrap().push((size, sent))
}
@@ -123,10 +124,16 @@ mod tests {
struct BencherCallbacks {}
impl Callbacks for BencherCallbacks {
type Opaque = Arc<AtomicUsize>;
- fn send(t: &Self::Opaque, size: usize, _sent: bool) {
+ fn send(
+ t: &Self::Opaque,
+ size: usize,
+ _sent: bool,
+ _keypair: &Arc<KeyPair>,
+ _counter: u64,
+ ) {
t.fetch_add(size, Ordering::SeqCst);
}
- fn recv(_: &Self::Opaque, _size: usize, _sent: bool) {}
+ fn recv(_: &Self::Opaque, _size: usize, _sent: bool, _keypair: &Arc<KeyPair>) {}
fn need_key(_: &Self::Opaque) {}
fn key_confirmed(_: &Self::Opaque) {}
}
diff --git a/src/wireguard/router/types.rs b/src/wireguard/router/types.rs
index 9f769fe..194f0d4 100644
--- a/src/wireguard/router/types.rs
+++ b/src/wireguard/router/types.rs
@@ -1,5 +1,8 @@
use std::error::Error;
use std::fmt;
+use std::sync::Arc;
+
+use super::KeyPair;
pub trait Opaque: Send + Sync + 'static {}
@@ -23,8 +26,8 @@ impl<T, F> KeyCallback<T> for F where F: Fn(&T) -> () + Sync + Send + 'static {}
pub trait Callbacks: Send + Sync + 'static {
type Opaque: Opaque;
- fn send(opaque: &Self::Opaque, size: usize, sent: bool);
- fn recv(opaque: &Self::Opaque, size: usize, sent: bool);
+ fn send(opaque: &Self::Opaque, size: usize, sent: bool, keypair: &Arc<KeyPair>, counter: u64);
+ fn recv(opaque: &Self::Opaque, size: usize, sent: bool, keypair: &Arc<KeyPair>);
fn need_key(opaque: &Self::Opaque);
fn key_confirmed(opaque: &Self::Opaque);
}
diff --git a/src/wireguard/router/workers.rs b/src/wireguard/router/workers.rs
index 2a12000..08c2db9 100644
--- a/src/wireguard/router/workers.rs
+++ b/src/wireguard/router/workers.rs
@@ -1,6 +1,5 @@
use std::sync::mpsc::Receiver;
use std::sync::Arc;
-use std::time::Instant;
use futures::sync::oneshot;
use futures::*;
@@ -18,8 +17,7 @@ use super::peer::PeerInner;
use super::route::check_route;
use super::types::Callbacks;
-use super::{KEEPALIVE_TIMEOUT, REJECT_AFTER_TIME, REKEY_TIMEOUT};
-use super::{REJECT_AFTER_MESSAGES, REKEY_AFTER_MESSAGES, REKEY_AFTER_TIME};
+use super::REJECT_AFTER_MESSAGES;
use super::super::types::KeyPair;
use super::super::{bind, tun, Endpoint};
@@ -61,10 +59,6 @@ pub fn worker_inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer
peer: Arc<PeerInner<E, C, T, B>>, // related peer
receiver: Receiver<JobInbound<E, C, T, B>>,
) {
- fn keep_key_fresh(keypair: &KeyPair) -> bool {
- Instant::now() - keypair.birth > REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT
- }
-
loop {
// fetch job
let (state, endpoint, rx) = match receiver.recv() {
@@ -135,7 +129,7 @@ pub fn worker_inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer
}
// trigger callback
- C::recv(&peer.opaque, buf.msg.len(), sent);
+ C::recv(&peer.opaque, buf.msg.len(), sent, &buf.keypair);
} else {
debug!("inbound worker: authentication failure")
}
@@ -151,11 +145,6 @@ pub fn worker_outbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Write
peer: Arc<PeerInner<E, C, T, B>>, // related peer
receiver: Receiver<JobOutbound>,
) {
- fn keep_key_fresh(keypair: &KeyPair, counter: u64) -> bool {
- counter > REKEY_AFTER_MESSAGES
- || (keypair.initiator && Instant::now() - keypair.birth > REKEY_AFTER_TIME)
- }
-
loop {
// fetch job
let rx = match receiver.recv() {
@@ -190,12 +179,7 @@ pub fn worker_outbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Write
};
// trigger callback
- C::send(&peer.opaque, buf.msg.len(), xmit);
-
- // keep_key_fresh semantics
- if keep_key_fresh(&buf.keypair, buf.counter) {
- C::need_key(&peer.opaque);
- }
+ C::send(&peer.opaque, buf.msg.len(), xmit, &buf.keypair, buf.counter);
})
.wait();
}
@@ -223,7 +207,10 @@ pub fn worker_parallel(receiver: Receiver<JobParallel>) {
.expect("earlier code should ensure that there is ample space");
// set header fields
- debug_assert!(job.counter < REJECT_AFTER_MESSAGES);
+ debug_assert!(
+ job.counter < REJECT_AFTER_MESSAGES,
+ "should be checked when assigning counters"
+ );
header.f_type.set(TYPE_TRANSPORT);
header.f_receiver.set(job.keypair.send.id);
header.f_counter.set(job.counter);
@@ -258,10 +245,12 @@ pub fn worker_parallel(receiver: Receiver<JobParallel>) {
let _ = tx.send(match layout {
Some((header, body)) => {
- debug_assert_eq!(header.f_type.get(), TYPE_TRANSPORT);
- if header.f_counter.get() >= REJECT_AFTER_MESSAGES {
- None
- } else {
+ debug_assert_eq!(
+ header.f_type.get(),
+ TYPE_TRANSPORT,
+ "type and reserved bits should be checked by message de-multiplexer"
+ );
+ if header.f_counter.get() < REJECT_AFTER_MESSAGES {
// create a nonce object
let mut nonce = [0u8; 12];
debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
@@ -279,6 +268,8 @@ pub fn worker_parallel(receiver: Receiver<JobParallel>) {
Ok(_) => Some(job),
Err(_) => None,
}
+ } else {
+ None
}
}
None => None,
diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs
index 3b16bf6..2e9263d 100644
--- a/src/wireguard/timers.rs
+++ b/src/wireguard/timers.rs
@@ -1,10 +1,10 @@
use std::marker::PhantomData;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
-use std::time::{Duration, SystemTime};
-
-use log::info;
+use std::time::{Duration, Instant, SystemTime};
+use log::{debug, info};
+use spin::Mutex;
use hjul::{Runner, Timer};
use super::constants::*;
@@ -12,8 +12,9 @@ use super::router::{message_data_len, Callbacks};
use super::wireguard::{Peer, PeerInner};
use super::{bind, tun};
+use super::types::KeyPair;
+
pub struct Timers {
- handshake_pending: AtomicBool,
handshake_attempts: AtomicUsize,
retransmit_handshake: Timer,
@@ -98,6 +99,7 @@ impl<B: bind::Bind> PeerInner<B> {
pub fn timers_any_authenticated_packet_traversal(&self) {
let keepalive = self.keepalive.load(Ordering::Acquire);
if keepalive > 0 {
+ // push persistent_keepalive into the future
self.timers()
.send_persistent_keepalive
.reset(Duration::from_secs(keepalive as u64));
@@ -107,15 +109,24 @@ impl<B: bind::Bind> PeerInner<B> {
/* Called after a handshake worker sends a handshake initiation to the peer
*/
pub fn sent_handshake_initiation(&self) {
- *self.last_handshake.lock() = SystemTime::now();
+ *self.last_handshake_sent.lock() = Instant::now();
self.handshake_queued.store(false, Ordering::SeqCst);
+ self.timers().retransmit_handshake.reset(REKEY_TIMEOUT);
self.timers_any_authenticated_packet_traversal();
self.timers_any_authenticated_packet_sent();
}
pub fn sent_handshake_response(&self) {
+ *self.last_handshake_sent.lock() = Instant::now();
self.timers_any_authenticated_packet_traversal();
self.timers_any_authenticated_packet_sent();
+ }
+
+ fn packet_send_queued_handshake_initiation(&self, is_retry: bool) {
+ if !is_retry {
+ self.timers().handshake_attempts.store(0, Ordering::SeqCst);
+ }
+ self.packet_send_handshake_initiation();
}
}
@@ -127,21 +138,32 @@ impl Timers {
{
// create a timer instance for the provided peer
Timers {
- handshake_pending: AtomicBool::new(false),
need_another_keepalive: AtomicBool::new(false),
sent_lastminute_handshake: AtomicBool::new(false),
handshake_attempts: AtomicUsize::new(0),
retransmit_handshake: {
let peer = peer.clone();
runner.timer(move || {
- if peer.timers().handshake_retry() {
- info!("Retransmit handshake for {}", peer);
- peer.new_handshake();
- } else {
- info!("Failed to complete handshake for {}", peer);
+ let attempts = peer.timers().handshake_attempts.fetch_add(1, Ordering::SeqCst);
+ if attempts > MAX_TIMER_HANDSHAKES {
+ debug!(
+ "Handshake for peer {} did not complete after {} attempts, giving up",
+ peer,
+ attempts + 1
+ );
peer.router.purge_staged_packets();
peer.timers().send_keepalive.stop();
peer.timers().zero_key_material.start(REJECT_AFTER_TIME * 3);
+ } else {
+ debug!(
+ "Handshake for {} did not complete after {} seconds, retrying (try {})",
+ peer,
+ REKEY_TIMEOUT.as_secs(),
+ attempts
+ );
+ peer.router.clear_src();
+ peer.timers().retransmit_handshake.reset(REKEY_TIMEOUT);
+ peer.packet_send_queued_handshake_initiation(true);
}
})
},
@@ -157,9 +179,13 @@ impl Timers {
new_handshake: {
let peer = peer.clone();
runner.timer(move || {
- info!("Initiate new handshake with {}", peer);
- peer.new_handshake();
- peer.timers.read().handshake_begun();
+ debug!(
+ "Retrying handshake with {} because we stopped hearing back after {} seconds",
+ peer,
+ (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT).as_secs()
+ );
+ peer.router.clear_src();
+ peer.packet_send_queued_handshake_initiation(false);
})
},
zero_key_material: {
@@ -184,22 +210,6 @@ impl Timers {
}
}
- fn handshake_begun(&self) {
- self.handshake_pending.store(true, Ordering::SeqCst);
- self.handshake_attempts.store(0, Ordering::SeqCst);
- self.retransmit_handshake.reset(REKEY_TIMEOUT);
- }
-
- fn handshake_retry(&self) -> bool {
- if self.handshake_attempts.fetch_add(1, Ordering::SeqCst) <= MAX_TIMER_HANDSHAKES {
- self.retransmit_handshake.reset(REKEY_TIMEOUT);
- true
- } else {
- self.handshake_pending.store(false, Ordering::SeqCst);
- false
- }
- }
-
pub fn updated_persistent_keepalive(&self, keepalive: usize) {
if keepalive > 0 {
self.send_persistent_keepalive
@@ -209,7 +219,6 @@ impl Timers {
pub fn dummy(runner: &Runner) -> Timers {
Timers {
- handshake_pending: AtomicBool::new(false),
need_another_keepalive: AtomicBool::new(false),
sent_lastminute_handshake: AtomicBool::new(false),
handshake_attempts: AtomicUsize::new(0),
@@ -236,13 +245,28 @@ impl<T: tun::Tun, B: bind::Bind> Callbacks for Events<T, B> {
/* Called after the router encrypts a transport message destined for the peer.
* This method is called, even if the encrypted payload is empty (keepalive)
*/
- fn send(peer: &Self::Opaque, size: usize, sent: bool) {
+ #[inline(always)]
+ fn send(peer: &Self::Opaque, size: usize, sent: bool, keypair: &Arc<KeyPair>, counter: u64) {
+
+ // update timers and stats
+
peer.timers_any_authenticated_packet_traversal();
peer.timers_any_authenticated_packet_sent();
peer.tx_bytes.fetch_add(size as u64, Ordering::Relaxed);
if size > message_data_len(0) && sent {
peer.timers_data_sent();
}
+
+ // keep_key_fresh
+
+ fn keep_key_fresh(keypair: &Arc<KeyPair>, counter: u64) -> bool {
+ counter > REKEY_AFTER_MESSAGES
+ || (keypair.initiator && Instant::now() - keypair.birth > REKEY_AFTER_TIME)
+ }
+
+ if keep_key_fresh(keypair, counter) {
+ peer.packet_send_queued_handshake_initiation(false);
+ }
}
/* Called after the router successfully decrypts a transport message from a peer.
@@ -252,13 +276,28 @@ impl<T: tun::Tun, B: bind::Bind> Callbacks for Events<T, B> {
* - A malformed IP packet
* - Fails to cryptkey route
*/
- fn recv(peer: &Self::Opaque, size: usize, sent: bool) {
+ #[inline(always)]
+ fn recv(peer: &Self::Opaque, size: usize, sent: bool, keypair: &Arc<KeyPair>) {
+
+ // update timers and stats
+
peer.timers_any_authenticated_packet_traversal();
peer.timers_any_authenticated_packet_received();
peer.rx_bytes.fetch_add(size as u64, Ordering::Relaxed);
if size > 0 && sent {
peer.timers_data_received();
}
+
+ // keep_key_fresh
+
+ #[inline(always)]
+ fn keep_key_fresh(keypair: &Arc<KeyPair>) -> bool {
+ Instant::now() - keypair.birth > REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT
+ }
+
+ if keep_key_fresh(keypair) && !peer.timers().sent_lastminute_handshake.swap(true, Ordering::Acquire) {
+ peer.packet_send_queued_handshake_initiation(false);
+ }
}
/* Called every time the router detects that a key is required,
@@ -267,14 +306,12 @@ impl<T: tun::Tun, B: bind::Bind> Callbacks for Events<T, B> {
* The message is called continuously
* (e.g. for every packet that must be encrypted, until a key becomes available)
*/
+ #[inline(always)]
fn need_key(peer: &Self::Opaque) {
- let timers = peer.timers();
- if !timers.handshake_pending.swap(true, Ordering::SeqCst) {
- timers.handshake_attempts.store(0, Ordering::SeqCst);
- timers.new_handshake.fire();
- }
+ peer.packet_send_queued_handshake_initiation(false);
}
+ #[inline(always)]
fn key_confirmed(peer: &Self::Opaque) {
peer.timers().retransmit_handshake.stop();
}
diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs
index 233559e..e308c50 100644
--- a/src/wireguard/wireguard.rs
+++ b/src/wireguard/wireguard.rs
@@ -38,23 +38,28 @@ pub struct Peer<T: Tun, B: Bind> {
}
pub struct PeerInner<B: Bind> {
+ // internal id (for logging)
pub id: u64,
- pub keepalive: AtomicUsize, // keepalive interval
- pub rx_bytes: AtomicU64,
- pub tx_bytes: AtomicU64,
+ // handshake state
+ pub last_handshake_sent: Mutex<Instant>, // instant for last handshake
+ pub handshake_queued: AtomicBool, // is a handshake job currently queued for the peer?
+ pub queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, // handshake queue
- pub last_handshake: Mutex<SystemTime>,
- pub handshake_queued: AtomicBool,
+ // stats and configuration
+ pub pk: PublicKey, // public key, DISCUSS: avoid this. TODO: remove
+ pub keepalive: AtomicUsize, // keepalive interval
+ pub rx_bytes: AtomicU64, // received bytes
+ pub tx_bytes: AtomicU64, // transmitted bytes
- pub queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, // handshake queue
- pub pk: PublicKey, // DISCUSS: Change layout in handshake module (adopt pattern of router), to avoid this. TODO: remove
- pub timers: RwLock<Timers>, //
+ // timer model
+ pub timers: RwLock<Timers>,
}
pub struct WireguardInner<T: Tun, B: Bind> {
// identifier (for logging)
id: u32,
+ start: Instant,
// provides access to the MTU value of the tun device
// (otherwise owned solely by the router and a dedicated read IO thread)
@@ -122,8 +127,22 @@ impl<T: Tun, B: Bind> Deref for Peer<T, B> {
impl<B: Bind> PeerInner<B> {
/* Queue a handshake request for the parallel workers
* (if one does not already exist)
+ *
+ * The function is ratelimited.
*/
- pub fn new_handshake(&self) {
+ pub fn packet_send_handshake_initiation(&self) {
+ // the function is rate limited
+
+ {
+ let mut lhs = self.last_handshake_sent.lock();
+ if lhs.elapsed() < REKEY_TIMEOUT {
+ return;
+ }
+ *lhs = Instant::now();
+ }
+
+ // create a new handshake job for the peer
+
if !self.handshake_queued.swap(true, Ordering::SeqCst) {
self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap();
}
@@ -225,7 +244,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
let state = Arc::new(PeerInner {
id: rng.gen(),
pk,
- last_handshake: Mutex::new(SystemTime::UNIX_EPOCH),
+ last_handshake_sent: Mutex::new(self.state.start - TIME_HORIZON),
handshake_queued: AtomicBool::new(false),
queue: Mutex::new(self.state.queue.lock().clone()),
keepalive: AtomicUsize::new(0),
@@ -335,6 +354,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
let mut rng = OsRng::new().unwrap();
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
let wg = Arc::new(WireguardInner {
+ start: Instant::now(),
id: rng.gen(),
mtu: mtu.clone(),
peers: RwLock::new(HashMap::new()),