diff options
Diffstat (limited to 'src/wireguard/router')
-rw-r--r-- | src/wireguard/router/mod.rs | 3 | ||||
-rw-r--r-- | src/wireguard/router/peer.rs | 6 | ||||
-rw-r--r-- | src/wireguard/router/tests.rs | 17 | ||||
-rw-r--r-- | src/wireguard/router/types.rs | 7 | ||||
-rw-r--r-- | src/wireguard/router/workers.rs | 39 |
5 files changed, 40 insertions, 32 deletions
diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs index 354700a..6aa894d 100644 --- a/src/wireguard/router/mod.rs +++ b/src/wireguard/router/mod.rs @@ -14,7 +14,8 @@ mod tests; use messages::TransportHeader; use std::mem; -use super::constants::*; +use super::constants::REJECT_AFTER_MESSAGES; +use super::types::*; pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::<TransportHeader>(); pub const CAPACITY_MESSAGE_POSTFIX: usize = workers::SIZE_TAG; diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs index 50fdfe7..5467eb7 100644 --- a/src/wireguard/router/peer.rs +++ b/src/wireguard/router/peer.rs @@ -589,6 +589,12 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T } } + pub fn clear_src(&self) { + (*self.state.endpoint.lock()) + .as_mut() + .map(|e| e.clear_src()); + } + pub fn purge_staged_packets(&self) { self.state.staged_packets.lock().clear(); } diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs index a14640c..1b122a8 100644 --- a/src/wireguard/router/tests.rs +++ b/src/wireguard/router/tests.rs @@ -3,7 +3,7 @@ use std::sync::atomic::Ordering; use std::sync::Arc; use std::sync::Mutex; use std::thread; -use std::time::Duration; +use std::time::{Duration, Instant}; use num_cpus; @@ -11,6 +11,7 @@ use super::super::bind::*; use super::super::dummy; use super::super::dummy_keypair; use super::super::tests::make_packet_dst; +use super::KeyPair; use super::SIZE_MESSAGE_PREFIX; use super::{Callbacks, Device}; @@ -85,11 +86,11 @@ mod tests { impl Callbacks for TestCallbacks { type Opaque = Opaque; - fn send(t: &Self::Opaque, size: usize, sent: bool) { + fn send(t: &Self::Opaque, size: usize, sent: bool, keypair: &Arc<KeyPair>, counter: u64) { t.0.send.lock().unwrap().push((size, sent)) } - fn recv(t: &Self::Opaque, size: usize, sent: bool) { + fn recv(t: &Self::Opaque, size: usize, sent: bool, keypair: &Arc<KeyPair>) { t.0.recv.lock().unwrap().push((size, sent)) } @@ -123,10 +124,16 @@ mod tests { struct BencherCallbacks {} impl Callbacks for BencherCallbacks { type Opaque = Arc<AtomicUsize>; - fn send(t: &Self::Opaque, size: usize, _sent: bool) { + fn send( + t: &Self::Opaque, + size: usize, + _sent: bool, + _keypair: &Arc<KeyPair>, + _counter: u64, + ) { t.fetch_add(size, Ordering::SeqCst); } - fn recv(_: &Self::Opaque, _size: usize, _sent: bool) {} + fn recv(_: &Self::Opaque, _size: usize, _sent: bool, _keypair: &Arc<KeyPair>) {} fn need_key(_: &Self::Opaque) {} fn key_confirmed(_: &Self::Opaque) {} } diff --git a/src/wireguard/router/types.rs b/src/wireguard/router/types.rs index 9f769fe..194f0d4 100644 --- a/src/wireguard/router/types.rs +++ b/src/wireguard/router/types.rs @@ -1,5 +1,8 @@ use std::error::Error; use std::fmt; +use std::sync::Arc; + +use super::KeyPair; pub trait Opaque: Send + Sync + 'static {} @@ -23,8 +26,8 @@ impl<T, F> KeyCallback<T> for F where F: Fn(&T) -> () + Sync + Send + 'static {} pub trait Callbacks: Send + Sync + 'static { type Opaque: Opaque; - fn send(opaque: &Self::Opaque, size: usize, sent: bool); - fn recv(opaque: &Self::Opaque, size: usize, sent: bool); + fn send(opaque: &Self::Opaque, size: usize, sent: bool, keypair: &Arc<KeyPair>, counter: u64); + fn recv(opaque: &Self::Opaque, size: usize, sent: bool, keypair: &Arc<KeyPair>); fn need_key(opaque: &Self::Opaque); fn key_confirmed(opaque: &Self::Opaque); } diff --git a/src/wireguard/router/workers.rs b/src/wireguard/router/workers.rs index 2a12000..08c2db9 100644 --- a/src/wireguard/router/workers.rs +++ b/src/wireguard/router/workers.rs @@ -1,6 +1,5 @@ use std::sync::mpsc::Receiver; use std::sync::Arc; -use std::time::Instant; use futures::sync::oneshot; use futures::*; @@ -18,8 +17,7 @@ use super::peer::PeerInner; use super::route::check_route; use super::types::Callbacks; -use super::{KEEPALIVE_TIMEOUT, REJECT_AFTER_TIME, REKEY_TIMEOUT}; -use super::{REJECT_AFTER_MESSAGES, REKEY_AFTER_MESSAGES, REKEY_AFTER_TIME}; +use super::REJECT_AFTER_MESSAGES; use super::super::types::KeyPair; use super::super::{bind, tun, Endpoint}; @@ -61,10 +59,6 @@ pub fn worker_inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer peer: Arc<PeerInner<E, C, T, B>>, // related peer receiver: Receiver<JobInbound<E, C, T, B>>, ) { - fn keep_key_fresh(keypair: &KeyPair) -> bool { - Instant::now() - keypair.birth > REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT - } - loop { // fetch job let (state, endpoint, rx) = match receiver.recv() { @@ -135,7 +129,7 @@ pub fn worker_inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer } // trigger callback - C::recv(&peer.opaque, buf.msg.len(), sent); + C::recv(&peer.opaque, buf.msg.len(), sent, &buf.keypair); } else { debug!("inbound worker: authentication failure") } @@ -151,11 +145,6 @@ pub fn worker_outbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Write peer: Arc<PeerInner<E, C, T, B>>, // related peer receiver: Receiver<JobOutbound>, ) { - fn keep_key_fresh(keypair: &KeyPair, counter: u64) -> bool { - counter > REKEY_AFTER_MESSAGES - || (keypair.initiator && Instant::now() - keypair.birth > REKEY_AFTER_TIME) - } - loop { // fetch job let rx = match receiver.recv() { @@ -190,12 +179,7 @@ pub fn worker_outbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Write }; // trigger callback - C::send(&peer.opaque, buf.msg.len(), xmit); - - // keep_key_fresh semantics - if keep_key_fresh(&buf.keypair, buf.counter) { - C::need_key(&peer.opaque); - } + C::send(&peer.opaque, buf.msg.len(), xmit, &buf.keypair, buf.counter); }) .wait(); } @@ -223,7 +207,10 @@ pub fn worker_parallel(receiver: Receiver<JobParallel>) { .expect("earlier code should ensure that there is ample space"); // set header fields - debug_assert!(job.counter < REJECT_AFTER_MESSAGES); + debug_assert!( + job.counter < REJECT_AFTER_MESSAGES, + "should be checked when assigning counters" + ); header.f_type.set(TYPE_TRANSPORT); header.f_receiver.set(job.keypair.send.id); header.f_counter.set(job.counter); @@ -258,10 +245,12 @@ pub fn worker_parallel(receiver: Receiver<JobParallel>) { let _ = tx.send(match layout { Some((header, body)) => { - debug_assert_eq!(header.f_type.get(), TYPE_TRANSPORT); - if header.f_counter.get() >= REJECT_AFTER_MESSAGES { - None - } else { + debug_assert_eq!( + header.f_type.get(), + TYPE_TRANSPORT, + "type and reserved bits should be checked by message de-multiplexer" + ); + if header.f_counter.get() < REJECT_AFTER_MESSAGES { // create a nonce object let mut nonce = [0u8; 12]; debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); @@ -279,6 +268,8 @@ pub fn worker_parallel(receiver: Receiver<JobParallel>) { Ok(_) => Some(job), Err(_) => None, } + } else { + None } } None => None, |