aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-12-26 22:55:33 +0100
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-12-26 22:55:33 +0100
commitbb0a8acea3161a08ac69cc2e35489f8d33741d1a (patch)
treee81a6384c4dc26743b67937fd0e0e721a73dc5b3
parentRemove unused test code. (diff)
downloadwireguard-rs-bb0a8acea3161a08ac69cc2e35489f8d33741d1a.tar.xz
wireguard-rs-bb0a8acea3161a08ac69cc2e35489f8d33741d1a.zip
Make under_load global for WireGuard device
-rw-r--r--src/configuration/config.rs18
-rw-r--r--src/configuration/mod.rs4
-rw-r--r--src/main.rs6
-rw-r--r--src/platform/dummy/mod.rs2
-rw-r--r--src/wireguard/handshake/device.rs24
-rw-r--r--src/wireguard/handshake/tests.rs14
-rw-r--r--src/wireguard/mod.rs2
-rw-r--r--src/wireguard/peer.rs4
-rw-r--r--src/wireguard/router/tests.rs3
-rw-r--r--src/wireguard/router/workers.rs258
-rw-r--r--src/wireguard/tests.rs17
-rw-r--r--src/wireguard/wireguard.rs22
-rw-r--r--src/wireguard/workers.rs40
13 files changed, 77 insertions, 337 deletions
diff --git a/src/configuration/config.rs b/src/configuration/config.rs
index ac6e9a1..aec943f 100644
--- a/src/configuration/config.rs
+++ b/src/configuration/config.rs
@@ -27,24 +27,24 @@ pub struct PeerState {
pub preshared_key: [u8; 32], // 0^32 is the "default value" (though treated like any other psk)
}
-pub struct WireguardConfig<T: tun::Tun, B: udp::PlatformUDP>(Arc<Mutex<Inner<T, B>>>);
+pub struct WireGuardConfig<T: tun::Tun, B: udp::PlatformUDP>(Arc<Mutex<Inner<T, B>>>);
struct Inner<T: tun::Tun, B: udp::PlatformUDP> {
- wireguard: Wireguard<T, B>,
+ wireguard: WireGuard<T, B>,
port: u16,
bind: Option<B::Owner>,
fwmark: Option<u32>,
}
-impl<T: tun::Tun, B: udp::PlatformUDP> WireguardConfig<T, B> {
+impl<T: tun::Tun, B: udp::PlatformUDP> WireGuardConfig<T, B> {
fn lock(&self) -> MutexGuard<Inner<T, B>> {
self.0.lock().unwrap()
}
}
-impl<T: tun::Tun, B: udp::PlatformUDP> WireguardConfig<T, B> {
- pub fn new(wg: Wireguard<T, B>) -> WireguardConfig<T, B> {
- WireguardConfig(Arc::new(Mutex::new(Inner {
+impl<T: tun::Tun, B: udp::PlatformUDP> WireGuardConfig<T, B> {
+ pub fn new(wg: WireGuard<T, B>) -> WireGuardConfig<T, B> {
+ WireGuardConfig(Arc::new(Mutex::new(Inner {
wireguard: wg,
port: 0,
bind: None,
@@ -53,9 +53,9 @@ impl<T: tun::Tun, B: udp::PlatformUDP> WireguardConfig<T, B> {
}
}
-impl<T: tun::Tun, B: udp::PlatformUDP> Clone for WireguardConfig<T, B> {
+impl<T: tun::Tun, B: udp::PlatformUDP> Clone for WireGuardConfig<T, B> {
fn clone(&self) -> Self {
- WireguardConfig(self.0.clone())
+ WireGuardConfig(self.0.clone())
}
}
@@ -195,7 +195,7 @@ pub trait Configuration {
fn get_fwmark(&self) -> Option<u32>;
}
-impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireguardConfig<T, B> {
+impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireGuardConfig<T, B> {
fn up(&self, mtu: usize) {
self.lock().wireguard.up(mtu);
}
diff --git a/src/configuration/mod.rs b/src/configuration/mod.rs
index d7524d9..a3c11d9 100644
--- a/src/configuration/mod.rs
+++ b/src/configuration/mod.rs
@@ -4,9 +4,9 @@ pub mod uapi;
use super::platform::Endpoint;
use super::platform::{tun, udp};
-use super::wireguard::Wireguard;
+use super::wireguard::WireGuard;
pub use error::ConfigError;
pub use config::Configuration;
-pub use config::WireguardConfig;
+pub use config::WireGuardConfig;
diff --git a/src/main.rs b/src/main.rs
index e9dbe82..a0f4a23 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -25,6 +25,8 @@ use platform::tun::{PlatformTun, Status};
use platform::uapi::{BindUAPI, PlatformUAPI};
use platform::*;
+use wireguard::WireGuard;
+
#[cfg(feature = "profiler")]
fn profiler_stop() {
println!("Stopping profiler");
@@ -118,7 +120,7 @@ fn main() {
profiler_start(name.as_str());
// create WireGuard device
- let wg: wireguard::Wireguard<plt::Tun, plt::UDP> = wireguard::Wireguard::new(writer);
+ let wg: WireGuard<plt::Tun, plt::UDP> = WireGuard::new(writer);
// add all Tun readers
while let Some(reader) = readers.pop() {
@@ -126,7 +128,7 @@ fn main() {
}
// wrap in configuration interface
- let cfg = configuration::WireguardConfig::new(wg.clone());
+ let cfg = configuration::WireGuardConfig::new(wg.clone());
// start Tun event thread
{
diff --git a/src/platform/dummy/mod.rs b/src/platform/dummy/mod.rs
index 2d2e7c6..ed34da4 100644
--- a/src/platform/dummy/mod.rs
+++ b/src/platform/dummy/mod.rs
@@ -1,6 +1,6 @@
-mod udp;
mod endpoint;
mod tun;
+mod udp;
/* A pure dummy platform available during "test-time"
*
diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs
index c684965..edd1a07 100644
--- a/src/wireguard/handshake/device.rs
+++ b/src/wireguard/handshake/device.rs
@@ -252,15 +252,12 @@ impl Device {
/// # Arguments
///
/// * `msg` - Byte slice containing the message (untrusted input)
- pub fn process<'a, R: RngCore + CryptoRng, S>(
+ pub fn process<'a, R: RngCore + CryptoRng>(
&self,
- rng: &mut R, // rng instance to sample randomness from
- msg: &[u8], // message buffer
- src: Option<&'a S>, // optional source endpoint, set when "under load"
- ) -> Result<Output, HandshakeError>
- where
- &'a S: Into<&'a SocketAddr>,
- {
+ rng: &mut R, // rng instance to sample randomness from
+ msg: &[u8], // message buffer
+ src: Option<SocketAddr>, // optional source endpoint, set when "under load"
+ ) -> Result<Output, HandshakeError> {
// ensure type read in-range
if msg.len() < 4 {
return Err(HandshakeError::InvalidMessageFormat);
@@ -286,16 +283,13 @@ impl Device {
// address validation & DoS mitigation
if let Some(src) = src {
- // obtain ref to socket addr
- let src = src.into();
-
// check mac2 field
- if !keyst.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
+ if !keyst.macs.check_mac2(msg.noise.as_bytes(), &src, &msg.macs) {
let mut reply = Default::default();
keyst.macs.create_cookie_reply(
rng,
msg.noise.f_sender.get(),
- src,
+ &src,
&msg.macs,
&mut reply,
);
@@ -344,12 +338,12 @@ impl Device {
let src = src.into();
// check mac2 field
- if !keyst.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
+ if !keyst.macs.check_mac2(msg.noise.as_bytes(), &src, &msg.macs) {
let mut reply = Default::default();
keyst.macs.create_cookie_reply(
rng,
msg.noise.f_sender.get(),
- src,
+ &src,
&msg.macs,
&mut reply,
);
diff --git a/src/wireguard/handshake/tests.rs b/src/wireguard/handshake/tests.rs
index 1df046d..ff27b3e 100644
--- a/src/wireguard/handshake/tests.rs
+++ b/src/wireguard/handshake/tests.rs
@@ -69,13 +69,13 @@ fn handshake_under_load() {
let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
// 2. device-2 : responds with CookieReply
- let msg_cookie = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
+ let msg_cookie = match dev2.process(&mut rng, &msg_init, Some(src1)).unwrap() {
(None, Some(msg), None) => msg,
_ => panic!("unexpected response"),
};
// device-1 : processes CookieReply (no response)
- match dev1.process(&mut rng, &msg_cookie, Some(&src2)).unwrap() {
+ match dev1.process(&mut rng, &msg_cookie, Some(src2)).unwrap() {
(None, None, None) => (),
_ => panic!("unexpected response"),
}
@@ -87,7 +87,7 @@ fn handshake_under_load() {
let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
// 4. device-2 : responds with noise response
- let msg_response = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
+ let msg_response = match dev2.process(&mut rng, &msg_init, Some(src1)).unwrap() {
(Some(_), Some(msg), Some(kp)) => {
assert_eq!(kp.initiator, false);
msg
@@ -96,13 +96,13 @@ fn handshake_under_load() {
};
// 5. device-1 : responds with CookieReply
- let msg_cookie = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() {
+ let msg_cookie = match dev1.process(&mut rng, &msg_response, Some(src2)).unwrap() {
(None, Some(msg), None) => msg,
_ => panic!("unexpected response"),
};
// device-2 : processes CookieReply (no response)
- match dev2.process(&mut rng, &msg_cookie, Some(&src1)).unwrap() {
+ match dev2.process(&mut rng, &msg_cookie, Some(src1)).unwrap() {
(None, None, None) => (),
_ => panic!("unexpected response"),
}
@@ -114,7 +114,7 @@ fn handshake_under_load() {
let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
// 7. device-2 : responds with noise response
- let (msg_response, kp1) = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
+ let (msg_response, kp1) = match dev2.process(&mut rng, &msg_init, Some(src1)).unwrap() {
(Some(_), Some(msg), Some(kp)) => {
assert_eq!(kp.initiator, false);
(msg, kp)
@@ -123,7 +123,7 @@ fn handshake_under_load() {
};
// device-1 : process noise response
- let kp2 = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() {
+ let kp2 = match dev1.process(&mut rng, &msg_response, Some(src2)).unwrap() {
(Some(_), None, Some(kp)) => {
assert_eq!(kp.initiator, true);
kp
diff --git a/src/wireguard/mod.rs b/src/wireguard/mod.rs
index 5310e96..c08fe1e 100644
--- a/src/wireguard/mod.rs
+++ b/src/wireguard/mod.rs
@@ -24,7 +24,7 @@ mod tests;
pub use peer::Peer;
// represents a WireGuard interface
-pub use wireguard::Wireguard;
+pub use wireguard::WireGuard;
#[cfg(test)]
pub use types::dummy_keypair;
diff --git a/src/wireguard/peer.rs b/src/wireguard/peer.rs
index 5d15cf3..e02d2e0 100644
--- a/src/wireguard/peer.rs
+++ b/src/wireguard/peer.rs
@@ -3,8 +3,8 @@ use super::timers::{Events, Timers};
use super::tun::Tun;
use super::udp::UDP;
-use super::Wireguard;
+use super::wireguard::WireGuard;
use super::constants::REKEY_TIMEOUT;
use super::workers::HandshakeJob;
@@ -23,7 +23,7 @@ pub struct PeerInner<T: Tun, B: UDP> {
pub id: u64,
// wireguard device state
- pub wg: Wireguard<T, B>,
+ pub wg: WireGuard<T, B>,
// handshake state
pub walltime_last_handshake: Mutex<Option<SystemTime>>, // walltime for last handshake (for UAPI status)
diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs
index 8d1e812..bad657c 100644
--- a/src/wireguard/router/tests.rs
+++ b/src/wireguard/router/tests.rs
@@ -50,6 +50,7 @@ mod tests {
}))
}
+ #[allow(dead_code)]
fn reset(&self) {
self.0.send.lock().unwrap().clear();
self.0.recv.lock().unwrap().clear();
@@ -103,7 +104,7 @@ mod tests {
}
}
- // wait for scheduling
+ // wait for scheduling (VERY conservative)
fn wait() {
thread::sleep(Duration::from_millis(30));
}
diff --git a/src/wireguard/router/workers.rs b/src/wireguard/router/workers.rs
deleted file mode 100644
index 43464a0..0000000
--- a/src/wireguard/router/workers.rs
+++ /dev/null
@@ -1,258 +0,0 @@
-use std::sync::Arc;
-
-use log::{debug, trace};
-
-use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
-
-use crossbeam_channel::Receiver;
-use std::sync::atomic::Ordering;
-use zerocopy::{AsBytes, LayoutVerified};
-
-use super::device::{DecryptionState, DeviceInner};
-use super::messages::{TransportHeader, TYPE_TRANSPORT};
-use super::peer::PeerInner;
-use super::types::Callbacks;
-
-use super::REJECT_AFTER_MESSAGES;
-
-use super::super::types::KeyPair;
-use super::super::{tun, udp, Endpoint};
-
-pub const SIZE_TAG: usize = 16;
-
-pub struct JobEncryption {
- pub msg: Vec<u8>,
- pub keypair: Arc<KeyPair>,
- pub counter: u64,
-}
-
-pub struct JobDecryption {
- pub msg: Vec<u8>,
- pub keypair: Arc<KeyPair>,
-}
-
-pub enum JobParallel {
- Encryption(oneshot::Sender<JobEncryption>, JobEncryption),
- Decryption(oneshot::Sender<Option<JobDecryption>>, JobDecryption),
-}
-
-#[allow(type_alias_bounds)]
-pub type JobInbound<E, C, T, B: udp::Writer<E>> = (
- Arc<DecryptionState<E, C, T, B>>,
- E,
- oneshot::Receiver<Option<JobDecryption>>,
-);
-
-pub type JobOutbound = oneshot::Receiver<JobEncryption>;
-
-/* TODO: Replace with run-queue
- */
-pub fn worker_inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- device: Arc<DeviceInner<E, C, T, B>>, // related device
- peer: Arc<PeerInner<E, C, T, B>>, // related peer
- receiver: Receiver<JobInbound<E, C, T, B>>,
-) {
- loop {
- // fetch job
- let (state, endpoint, rx) = match receiver.recv() {
- Ok(v) => v,
- _ => {
- return;
- }
- };
- debug!("inbound worker: obtained job");
-
- // wait for job to complete
- let _ = rx
- .map(|buf| {
- debug!("inbound worker: job complete");
- if let Some(buf) = buf {
- // cast transport header
- let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) =
- match LayoutVerified::new_from_prefix(&buf.msg[..]) {
- Some(v) => v,
- None => {
- debug!("inbound worker: failed to parse message");
- return;
- }
- };
-
- debug_assert!(
- packet.len() >= CHACHA20_POLY1305.tag_len(),
- "this should be checked earlier in the pipeline (decryption should fail)"
- );
-
- // check for replay
- if !state.protector.lock().update(header.f_counter.get()) {
- debug!("inbound worker: replay detected");
- return;
- }
-
- // check for confirms key
- if !state.confirmed.swap(true, Ordering::SeqCst) {
- debug!("inbound worker: message confirms key");
- peer.confirm_key(&state.keypair);
- }
-
- // update endpoint
- *peer.endpoint.lock() = Some(endpoint);
-
- // calculate length of IP packet + padding
- let length = packet.len() - SIZE_TAG;
- debug!("inbound worker: plaintext length = {}", length);
-
- // check if should be written to TUN
- let mut sent = false;
- if length > 0 {
- if let Some(inner_len) = device.table.check_route(&peer, &packet[..length])
- {
- // TODO: Consider moving the cryptkey route check to parallel decryption worker
- debug_assert!(inner_len <= length, "should be validated earlier");
- if inner_len <= length {
- sent = match device.inbound.write(&packet[..inner_len]) {
- Err(e) => {
- debug!("failed to write inbound packet to TUN: {:?}", e);
- false
- }
- Ok(_) => true,
- }
- }
- }
- } else {
- debug!("inbound worker: received keepalive")
- }
-
- // trigger callback
- C::recv(&peer.opaque, buf.msg.len(), sent, &buf.keypair);
- } else {
- debug!("inbound worker: authentication failure")
- }
- })
- .wait();
- }
-}
-
-/* TODO: Replace with run-queue
- */
-pub fn worker_outbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- peer: Arc<PeerInner<E, C, T, B>>,
- receiver: Receiver<JobOutbound>,
-) {
- loop {
- // fetch job
- let rx = match receiver.recv() {
- Ok(v) => v,
- _ => {
- return;
- }
- };
- debug!("outbound worker: obtained job");
-
- // wait for job to complete
- let _ = rx
- .map(|buf| {
- debug!("outbound worker: job complete");
-
- // send to peer
- let xmit = peer.send(&buf.msg[..]).is_ok();
-
- // trigger callback
- C::send(&peer.opaque, buf.msg.len(), xmit, &buf.keypair, buf.counter);
- })
- .wait();
- }
-}
-
-pub fn worker_parallel(receiver: Receiver<JobParallel>) {
- loop {
- // fetch next job
- let job = match receiver.recv() {
- Err(_) => {
- return;
- }
- Ok(val) => val,
- };
- trace!("parallel worker: obtained job");
-
- // handle job
- match job {
- JobParallel::Encryption(tx, mut job) => {
- job.msg.extend([0u8; SIZE_TAG].iter());
-
- // cast to header (should never fail)
- let (mut header, body): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
- LayoutVerified::new_from_prefix(&mut job.msg[..])
- .expect("earlier code should ensure that there is ample space");
-
- // set header fields
- debug_assert!(
- job.counter < REJECT_AFTER_MESSAGES,
- "should be checked when assigning counters"
- );
- header.f_type.set(TYPE_TRANSPORT);
- header.f_receiver.set(job.keypair.send.id);
- header.f_counter.set(job.counter);
-
- // create a nonce object
- let mut nonce = [0u8; 12];
- debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
- nonce[4..].copy_from_slice(header.f_counter.as_bytes());
- let nonce = Nonce::assume_unique_for_key(nonce);
-
- // do the weird ring AEAD dance
- let key = LessSafeKey::new(
- UnboundKey::new(&CHACHA20_POLY1305, &job.keypair.send.key[..]).unwrap(),
- );
-
- // encrypt content of transport message in-place
- let end = body.len() - SIZE_TAG;
- let tag = key
- .seal_in_place_separate_tag(nonce, Aad::empty(), &mut body[..end])
- .unwrap();
-
- // append tag
- body[end..].copy_from_slice(tag.as_ref());
-
- // pass ownership
- let _ = tx.send(job);
- }
- JobParallel::Decryption(tx, mut job) => {
- // cast to header (could fail)
- let layout: Option<(LayoutVerified<&mut [u8], TransportHeader>, &mut [u8])> =
- LayoutVerified::new_from_prefix(&mut job.msg[..]);
-
- let _ = tx.send(match layout {
- Some((header, body)) => {
- debug_assert_eq!(
- header.f_type.get(),
- TYPE_TRANSPORT,
- "type and reserved bits should be checked by message de-multiplexer"
- );
- if header.f_counter.get() < REJECT_AFTER_MESSAGES {
- // create a nonce object
- let mut nonce = [0u8; 12];
- debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
- nonce[4..].copy_from_slice(header.f_counter.as_bytes());
- let nonce = Nonce::assume_unique_for_key(nonce);
-
- // do the weird ring AEAD dance
- let key = LessSafeKey::new(
- UnboundKey::new(&CHACHA20_POLY1305, &job.keypair.recv.key[..])
- .unwrap(),
- );
-
- // attempt to open (and authenticate) the body
- match key.open_in_place(nonce, Aad::empty(), body) {
- Ok(_) => Some(job),
- Err(_) => None,
- }
- } else {
- None
- }
- }
- None => None,
- });
- }
- }
- }
-}
diff --git a/src/wireguard/tests.rs b/src/wireguard/tests.rs
index f71576a..2ed2202 100644
--- a/src/wireguard/tests.rs
+++ b/src/wireguard/tests.rs
@@ -1,12 +1,6 @@
-use super::dummy;
-use super::wireguard::Wireguard;
-
use std::net::IpAddr;
-use std::thread;
-use std::time::Duration;
use hex;
-
use rand_chacha::ChaCha8Rng;
use rand_core::{RngCore, SeedableRng};
use x25519_dalek::{PublicKey, StaticSecret};
@@ -14,6 +8,9 @@ use x25519_dalek::{PublicKey, StaticSecret};
use pnet::packet::ipv4::MutableIpv4Packet;
use pnet::packet::ipv6::MutableIpv6Packet;
+use super::dummy;
+use super::wireguard::WireGuard;
+
pub fn make_packet(size: usize, src: IpAddr, dst: IpAddr, id: u64) -> Vec<u8> {
// expand pseudo random payload
let mut rng: _ = ChaCha8Rng::seed_from_u64(id);
@@ -58,10 +55,6 @@ fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
-fn wait() {
- thread::sleep(Duration::from_millis(500));
-}
-
/* Create and configure two matching pure instances of WireGuard
*/
#[test]
@@ -71,12 +64,12 @@ fn test_pure_wireguard() {
// create WG instances for dummy TUN devices
let (fake1, tun_reader1, tun_writer1, _) = dummy::TunTest::create(true);
- let wg1: Wireguard<dummy::TunTest, dummy::PairBind> = Wireguard::new(tun_writer1);
+ let wg1: WireGuard<dummy::TunTest, dummy::PairBind> = WireGuard::new(tun_writer1);
wg1.add_tun_reader(tun_reader1);
wg1.up(1500);
let (fake2, tun_reader2, tun_writer2, _) = dummy::TunTest::create(true);
- let wg2: Wireguard<dummy::TunTest, dummy::PairBind> = Wireguard::new(tun_writer2);
+ let wg2: WireGuard<dummy::TunTest, dummy::PairBind> = WireGuard::new(tun_writer2);
wg2.add_tun_reader(tun_reader2);
wg2.up(1500);
diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs
index 2fa14fc..bf550ef 100644
--- a/src/wireguard/wireguard.rs
+++ b/src/wireguard/wireguard.rs
@@ -58,33 +58,33 @@ pub struct WireguardInner<T: Tun, B: UDP> {
// handshake related state
pub handshake: RwLock<handshake::Device>,
- pub last_under_load: AtomicUsize,
- pub pending: AtomicUsize, // num of pending handshake packets in queue
+ pub last_under_load: Mutex<Instant>,
+ pub pending: AtomicUsize, // number of pending handshake packets in queue
pub queue: ParallelQueue<HandshakeJob<B::Endpoint>>,
}
-pub struct Wireguard<T: Tun, B: UDP> {
+pub struct WireGuard<T: Tun, B: UDP> {
inner: Arc<WireguardInner<T, B>>,
}
pub struct WaitCounter(StdMutex<usize>, Condvar);
-impl<T: Tun, B: UDP> fmt::Display for Wireguard<T, B> {
+impl<T: Tun, B: UDP> fmt::Display for WireGuard<T, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "wireguard({:x})", self.id)
}
}
-impl<T: Tun, B: UDP> Deref for Wireguard<T, B> {
+impl<T: Tun, B: UDP> Deref for WireGuard<T, B> {
type Target = WireguardInner<T, B>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
-impl<T: Tun, B: UDP> Clone for Wireguard<T, B> {
+impl<T: Tun, B: UDP> Clone for WireGuard<T, B> {
fn clone(&self) -> Self {
- Wireguard {
+ WireGuard {
inner: self.inner.clone(),
}
}
@@ -116,7 +116,7 @@ impl WaitCounter {
}
}
-impl<T: Tun, B: UDP> Wireguard<T, B> {
+impl<T: Tun, B: UDP> WireGuard<T, B> {
/// Brings the WireGuard device down.
/// Usually called when the associated interface is brought down.
///
@@ -307,7 +307,7 @@ impl<T: Tun, B: UDP> Wireguard<T, B> {
self.tun_readers.wait();
}
- pub fn new(writer: T::Writer) -> Wireguard<T, B> {
+ pub fn new(writer: T::Writer) -> WireGuard<T, B> {
// workers equal to number of physical cores
let cpus = num_cpus::get();
@@ -318,14 +318,14 @@ impl<T: Tun, B: UDP> Wireguard<T, B> {
let (tx, mut rxs) = ParallelQueue::new(cpus, 128);
// create arc to state
- let wg = Wireguard {
+ let wg = WireGuard {
inner: Arc::new(WireguardInner {
enabled: RwLock::new(false),
tun_readers: WaitCounter::new(),
id: rng.gen(),
mtu: AtomicUsize::new(0),
peers: RwLock::new(HashMap::new()),
- last_under_load: AtomicUsize::new(0), // TODO
+ last_under_load: Mutex::new(Instant::now() - TIME_HORIZON),
send: RwLock::new(None),
router: router::Device::new(num_cpus::get(), writer), // router owns the writing half
pending: AtomicUsize::new(0),
diff --git a/src/wireguard/workers.rs b/src/wireguard/workers.rs
index b65f49a..aeb6063 100644
--- a/src/wireguard/workers.rs
+++ b/src/wireguard/workers.rs
@@ -25,7 +25,7 @@ use super::handshake::MAX_HANDSHAKE_MSG_SIZE;
use super::handshake::{TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE};
use super::router::{CAPACITY_MESSAGE_POSTFIX, SIZE_MESSAGE_PREFIX, TYPE_TRANSPORT};
-use super::Wireguard;
+use super::wireguard::WireGuard;
pub enum HandshakeJob<E> {
Message(Vec<u8>, E),
@@ -54,7 +54,7 @@ const fn padding(size: usize, mtu: usize) -> usize {
min(mtu, size + (pad - size % pad) % pad)
}
-pub fn tun_worker<T: Tun, B: UDP>(wg: &Wireguard<T, B>, reader: T::Reader) {
+pub fn tun_worker<T: Tun, B: UDP>(wg: &WireGuard<T, B>, reader: T::Reader) {
loop {
// create vector big enough for any transport message (based on MTU)
let mtu = wg.mtu.load(Ordering::Relaxed);
@@ -100,7 +100,7 @@ pub fn tun_worker<T: Tun, B: UDP>(wg: &Wireguard<T, B>, reader: T::Reader) {
}
}
-pub fn udp_worker<T: Tun, B: UDP>(wg: &Wireguard<T, B>, reader: B::Reader) {
+pub fn udp_worker<T: Tun, B: UDP>(wg: &WireGuard<T, B>, reader: B::Reader) {
let mut last_under_load = Instant::now() - TIME_HORIZON;
loop {
@@ -160,7 +160,7 @@ pub fn udp_worker<T: Tun, B: UDP>(wg: &Wireguard<T, B>, reader: B::Reader) {
}
pub fn handshake_worker<T: Tun, B: UDP>(
- wg: &Wireguard<T, B>,
+ wg: &WireGuard<T, B>,
rx: Receiver<HandshakeJob<B::Endpoint>>,
) {
debug!("{} : handshake worker, started", wg);
@@ -170,30 +170,38 @@ pub fn handshake_worker<T: Tun, B: UDP>(
// process elements from the handshake queue
for job in rx {
- // decrement pending pakcets (under_load)
+ // check if under load
let job: HandshakeJob<B::Endpoint> = job;
- wg.pending.fetch_sub(1, Ordering::SeqCst);
+ let pending = wg.pending.fetch_sub(1, Ordering::SeqCst);
+ let mut under_load = false;
+
+ // immediate go under load if too many handshakes pending
+ if pending > THRESHOLD_UNDER_LOAD {
+ *wg.last_under_load.lock() = Instant::now();
+ under_load = true;
+ }
+
+ // remain under load for a while
+ if !under_load {
+ let elapsed = wg.last_under_load.lock().elapsed();
+ if elapsed > DURATION_UNDER_LOAD {
+ under_load = true;
+ }
+ }
- // demultiplex staged handshake jobs and handshake messages
+ // de-multiplex staged handshake jobs and handshake messages
match job {
HandshakeJob::Message(msg, src) => {
- // feed message to handshake device
- let src_validate = (&src).into_address(); // TODO avoid
-
// process message
let device = wg.handshake.read();
match device.process(
&mut rng,
&msg[..],
- None,
- /*
- if wg.under_load.load(Ordering::Relaxed) {
- debug!("{} : handshake worker, under load", wg);
- Some(&src_validate)
+ if under_load {
+ Some(src.into_address())
} else {
None
}
- */
) {
Ok((pk, resp, keypair)) => {
// send response (might be cookie reply or handshake response)