From aabefa50436af8d614520bb219d675953eeba6eb Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sat, 21 Dec 2019 00:17:31 +0100 Subject: Remove unused test code. - make naming consistent with the kernel module. - better distribution of functionality from src/wireguard.rs - more consistent "import pattern" throughout the project. - remove unused test code. --- src/configuration/config.rs | 10 +- src/configuration/error.rs | 39 ++-- src/main.rs | 7 +- src/platform/dummy/bind.rs | 224 ------------------- src/platform/dummy/mod.rs | 4 +- src/platform/dummy/tun.rs | 75 +++---- src/platform/dummy/udp.rs | 201 +++++++++++++++++ src/platform/linux/tun.rs | 2 +- src/wireguard/constants.rs | 39 +++- src/wireguard/mod.rs | 19 +- src/wireguard/peer.rs | 43 +++- src/wireguard/router/runq.rs | 35 --- src/wireguard/router/tests.rs | 2 +- src/wireguard/tests.rs | 16 +- src/wireguard/timers.rs | 15 +- src/wireguard/wireguard.rs | 488 +++++++++--------------------------------- src/wireguard/workers.rs | 280 ++++++++++++++++++++++++ 17 files changed, 746 insertions(+), 753 deletions(-) delete mode 100644 src/platform/dummy/bind.rs create mode 100644 src/platform/dummy/udp.rs create mode 100644 src/wireguard/workers.rs diff --git a/src/configuration/config.rs b/src/configuration/config.rs index 94b79f7..ac6e9a1 100644 --- a/src/configuration/config.rs +++ b/src/configuration/config.rs @@ -288,13 +288,15 @@ impl Configuration for WireguardConfig { fn set_fwmark(&self, mark: Option) -> Result<(), ConfigError> { log::trace!("Config, Set fwmark: {:?}", mark); - match self.lock().bind.as_mut() { Some(bind) => { - bind.set_fwmark(mark).unwrap(); // TODO: handle - Ok(()) + if bind.set_fwmark(mark).is_err() { + Err(ConfigError::IOError) + } else { + Ok(()) + } } - None => Err(ConfigError::NotListening), + None => Ok(()), } } diff --git a/src/configuration/error.rs b/src/configuration/error.rs index fca194f..de790e2 100644 --- a/src/configuration/error.rs +++ b/src/configuration/error.rs @@ -1,9 +1,11 @@ use std::error::Error; use std::fmt; +#[cfg(unix)] +use libc::*; + #[derive(Debug)] pub enum ConfigError { - NotListening, FailedToBind, InvalidHexValue, InvalidPortNumber, @@ -35,24 +37,31 @@ impl Error for ConfigError { } } +#[cfg(unix)] impl ConfigError { pub fn errno(&self) -> i32 { // TODO: obtain the correct errorno values match self { - ConfigError::NotListening => 2, - ConfigError::FailedToBind => 3, - ConfigError::InvalidHexValue => 4, - ConfigError::InvalidPortNumber => 5, - ConfigError::InvalidFwmark => 6, - ConfigError::InvalidSocketAddr => 10, - ConfigError::InvalidKeepaliveInterval => 11, - ConfigError::InvalidAllowedIp => 12, - ConfigError::InvalidOperation => 15, - ConfigError::UnsupportedValue => 7, - ConfigError::LineTooLong => 13, - ConfigError::InvalidKey => 8, - ConfigError::UnsupportedProtocolVersion => 9, - ConfigError::IOError => 14, + // insufficient perms + ConfigError::FailedToBind => EPERM, + + // parsing of value failed + ConfigError::InvalidHexValue => EINVAL, + ConfigError::InvalidPortNumber => EINVAL, + ConfigError::InvalidFwmark => EINVAL, + ConfigError::InvalidSocketAddr => EINVAL, + ConfigError::InvalidKeepaliveInterval => EINVAL, + ConfigError::InvalidAllowedIp => EINVAL, + ConfigError::InvalidOperation => EINVAL, + ConfigError::UnsupportedValue => EINVAL, + + // other protocol errors + ConfigError::LineTooLong => EPROTO, + ConfigError::InvalidKey => EPROTO, + ConfigError::UnsupportedProtocolVersion => EPROTO, + + // IO + ConfigError::IOError => EIO, } } } diff --git a/src/main.rs b/src/main.rs index 0cf7ae6..e9dbe82 100644 --- a/src/main.rs +++ b/src/main.rs @@ -125,11 +125,8 @@ fn main() { wg.add_tun_reader(reader); } - // obtain handle for waiting - let wait = wg.wait(); - // wrap in configuration interface - let cfg = configuration::WireguardConfig::new(wg); + let cfg = configuration::WireguardConfig::new(wg.clone()); // start Tun event thread { @@ -187,6 +184,6 @@ fn main() { }); // block until all tun readers closed - wait.wait(); + wg.wait(); profiler_stop(); } diff --git a/src/platform/dummy/bind.rs b/src/platform/dummy/bind.rs deleted file mode 100644 index 3146af8..0000000 --- a/src/platform/dummy/bind.rs +++ /dev/null @@ -1,224 +0,0 @@ -use hex; -use std::error::Error; -use std::fmt; -use std::marker; - -use log::debug; -use rand::rngs::OsRng; -use rand::Rng; - -use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; -use std::sync::Arc; -use std::sync::Mutex; - -use super::super::udp::*; - -use super::UnitEndpoint; - -pub struct VoidOwner {} - -#[derive(Debug)] -pub enum BindError { - Disconnected, -} - -impl Error for BindError { - fn description(&self) -> &str { - "Generic Bind Error" - } - - fn source(&self) -> Option<&(dyn Error + 'static)> { - None - } -} - -impl fmt::Display for BindError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - BindError::Disconnected => write!(f, "PairBind disconnected"), - } - } -} - -/* TUN implementation */ - -#[derive(Debug)] -pub enum TunError { - Disconnected, -} - -impl Error for TunError { - fn description(&self) -> &str { - "Generic Tun Error" - } - - fn source(&self) -> Option<&(dyn Error + 'static)> { - None - } -} - -impl fmt::Display for TunError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Not Possible") - } -} - -#[derive(Clone, Copy)] -pub struct VoidBind {} - -impl Reader for VoidBind { - type Error = BindError; - - fn read(&self, _buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> { - Ok((0, UnitEndpoint {})) - } -} - -impl Writer for VoidBind { - type Error = BindError; - - fn write(&self, _buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> { - Ok(()) - } -} - -impl UDP for VoidBind { - type Error = BindError; - type Endpoint = UnitEndpoint; - - type Reader = VoidBind; - type Writer = VoidBind; -} - -impl VoidBind { - pub fn new() -> VoidBind { - VoidBind {} - } -} - -/* Pair Bind */ - -#[derive(Clone)] -pub struct PairReader { - id: u32, - recv: Arc>>>, - _marker: marker::PhantomData, -} - -impl Reader for PairReader { - type Error = BindError; - fn read(&self, buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> { - let vec = self - .recv - .lock() - .unwrap() - .recv() - .map_err(|_| BindError::Disconnected)?; - let len = vec.len(); - buf[..len].copy_from_slice(&vec[..]); - debug!( - "dummy({}): read ({}, {})", - self.id, - len, - hex::encode(&buf[..len]) - ); - Ok((len, UnitEndpoint {})) - } -} - -impl Writer for PairWriter { - type Error = BindError; - fn write(&self, buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> { - debug!( - "dummy({}): write ({}, {})", - self.id, - buf.len(), - hex::encode(buf) - ); - let owned = buf.to_owned(); - match self.send.lock().unwrap().send(owned) { - Err(_) => Err(BindError::Disconnected), - Ok(_) => Ok(()), - } - } -} - -#[derive(Clone)] -pub struct PairWriter { - id: u32, - send: Arc>>>, - _marker: marker::PhantomData, -} - -#[derive(Clone)] -pub struct PairBind {} - -impl PairBind { - pub fn pair() -> ( - (PairReader, PairWriter), - (PairReader, PairWriter), - ) { - let mut rng = OsRng::new().unwrap(); - let id1: u32 = rng.gen(); - let id2: u32 = rng.gen(); - - let (tx1, rx1) = sync_channel(128); - let (tx2, rx2) = sync_channel(128); - ( - ( - PairReader { - id: id1, - recv: Arc::new(Mutex::new(rx1)), - _marker: marker::PhantomData, - }, - PairWriter { - id: id1, - send: Arc::new(Mutex::new(tx2)), - _marker: marker::PhantomData, - }, - ), - ( - PairReader { - id: id2, - recv: Arc::new(Mutex::new(rx2)), - _marker: marker::PhantomData, - }, - PairWriter { - id: id2, - send: Arc::new(Mutex::new(tx1)), - _marker: marker::PhantomData, - }, - ), - ) - } -} - -impl UDP for PairBind { - type Error = BindError; - type Endpoint = UnitEndpoint; - type Reader = PairReader; - type Writer = PairWriter; -} - -impl Owner for VoidOwner { - type Error = BindError; - - fn set_fwmark(&mut self, _value: Option) -> Result<(), Self::Error> { - Ok(()) - } - - fn get_port(&self) -> u16 { - 0 - } - - fn get_fwmark(&self) -> Option { - None - } -} - -impl PlatformUDP for PairBind { - type Owner = VoidOwner; - fn bind(_port: u16) -> Result<(Vec, Self::Writer, Self::Owner), Self::Error> { - Err(BindError::Disconnected) - } -} diff --git a/src/platform/dummy/mod.rs b/src/platform/dummy/mod.rs index 884bd7e..2d2e7c6 100644 --- a/src/platform/dummy/mod.rs +++ b/src/platform/dummy/mod.rs @@ -1,4 +1,4 @@ -mod bind; +mod udp; mod endpoint; mod tun; @@ -8,6 +8,6 @@ mod tun; * the configuration interface and the UAPI parser. */ -pub use bind::*; pub use endpoint::*; pub use tun::*; +pub use udp::*; diff --git a/src/platform/dummy/tun.rs b/src/platform/dummy/tun.rs index 50c6654..9836b48 100644 --- a/src/platform/dummy/tun.rs +++ b/src/platform/dummy/tun.rs @@ -13,38 +13,51 @@ use std::time::Duration; use super::super::tun::*; -/* This submodule provides pure/dummy implementations of the IO interfaces - * for use in unit tests thoughout the project. - */ - -/* Error implementation */ - #[derive(Debug)] -pub enum BindError { +pub enum TunError { Disconnected, } -impl Error for BindError { - fn description(&self) -> &str { - "Generic Bind Error" +pub struct TunTest {} + +pub struct TunFakeIO { + id: u32, + store: bool, + tx: SyncSender>, + rx: Receiver>, +} + +pub struct TunReader { + id: u32, + rx: Receiver>, +} + +pub struct TunWriter { + id: u32, + store: bool, + tx: Mutex>>, +} + +impl fmt::Display for TunFakeIO { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "FakeIO({})", self.id) } +} - fn source(&self) -> Option<&(dyn Error + 'static)> { - None +impl fmt::Display for TunReader { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "TunReader({})", self.id) } } -impl fmt::Display for BindError { +impl fmt::Display for TunWriter { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - BindError::Disconnected => write!(f, "PairBind disconnected"), - } + write!(f, "TunWriter({})", self.id) } } -#[derive(Debug)] -pub enum TunError { - Disconnected, +pub struct TunStatus { + first: bool, } impl Error for TunError { @@ -63,30 +76,6 @@ impl fmt::Display for TunError { } } -pub struct TunTest {} - -pub struct TunFakeIO { - id: u32, - store: bool, - tx: SyncSender>, - rx: Receiver>, -} - -pub struct TunReader { - id: u32, - rx: Receiver>, -} - -pub struct TunWriter { - id: u32, - store: bool, - tx: Mutex>>, -} - -pub struct TunStatus { - first: bool, -} - impl Reader for TunReader { type Error = TunError; diff --git a/src/platform/dummy/udp.rs b/src/platform/dummy/udp.rs new file mode 100644 index 0000000..35c905d --- /dev/null +++ b/src/platform/dummy/udp.rs @@ -0,0 +1,201 @@ +use hex; +use std::error::Error; +use std::fmt; +use std::marker; + +use log::debug; +use rand::rngs::OsRng; +use rand::Rng; + +use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; +use std::sync::Arc; +use std::sync::Mutex; + +use super::super::udp::*; + +use super::UnitEndpoint; + +pub struct VoidOwner {} + +#[derive(Debug)] +pub enum BindError { + Disconnected, +} + +impl Error for BindError { + fn description(&self) -> &str { + "Generic Bind Error" + } + + fn source(&self) -> Option<&(dyn Error + 'static)> { + None + } +} + +impl fmt::Display for BindError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + BindError::Disconnected => write!(f, "PairBind disconnected"), + } + } +} + +#[derive(Clone, Copy)] +pub struct VoidBind {} + +impl Reader for VoidBind { + type Error = BindError; + + fn read(&self, _buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> { + Ok((0, UnitEndpoint {})) + } +} + +impl Writer for VoidBind { + type Error = BindError; + + fn write(&self, _buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> { + Ok(()) + } +} + +impl UDP for VoidBind { + type Error = BindError; + type Endpoint = UnitEndpoint; + + type Reader = VoidBind; + type Writer = VoidBind; +} + +impl VoidBind { + pub fn new() -> VoidBind { + VoidBind {} + } +} + +/* Pair Bind */ + +#[derive(Clone)] +pub struct PairReader { + id: u32, + recv: Arc>>>, + _marker: marker::PhantomData, +} + +impl Reader for PairReader { + type Error = BindError; + fn read(&self, buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> { + let vec = self + .recv + .lock() + .unwrap() + .recv() + .map_err(|_| BindError::Disconnected)?; + let len = vec.len(); + buf[..len].copy_from_slice(&vec[..]); + debug!( + "dummy({}): read ({}, {})", + self.id, + len, + hex::encode(&buf[..len]) + ); + Ok((len, UnitEndpoint {})) + } +} + +impl Writer for PairWriter { + type Error = BindError; + fn write(&self, buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> { + debug!( + "dummy({}): write ({}, {})", + self.id, + buf.len(), + hex::encode(buf) + ); + let owned = buf.to_owned(); + match self.send.lock().unwrap().send(owned) { + Err(_) => Err(BindError::Disconnected), + Ok(_) => Ok(()), + } + } +} + +#[derive(Clone)] +pub struct PairWriter { + id: u32, + send: Arc>>>, + _marker: marker::PhantomData, +} + +#[derive(Clone)] +pub struct PairBind {} + +impl PairBind { + pub fn pair() -> ( + (PairReader, PairWriter), + (PairReader, PairWriter), + ) { + let mut rng = OsRng::new().unwrap(); + let id1: u32 = rng.gen(); + let id2: u32 = rng.gen(); + + let (tx1, rx1) = sync_channel(128); + let (tx2, rx2) = sync_channel(128); + ( + ( + PairReader { + id: id1, + recv: Arc::new(Mutex::new(rx1)), + _marker: marker::PhantomData, + }, + PairWriter { + id: id1, + send: Arc::new(Mutex::new(tx2)), + _marker: marker::PhantomData, + }, + ), + ( + PairReader { + id: id2, + recv: Arc::new(Mutex::new(rx2)), + _marker: marker::PhantomData, + }, + PairWriter { + id: id2, + send: Arc::new(Mutex::new(tx1)), + _marker: marker::PhantomData, + }, + ), + ) + } +} + +impl UDP for PairBind { + type Error = BindError; + type Endpoint = UnitEndpoint; + type Reader = PairReader; + type Writer = PairWriter; +} + +impl Owner for VoidOwner { + type Error = BindError; + + fn set_fwmark(&mut self, _value: Option) -> Result<(), Self::Error> { + Ok(()) + } + + fn get_port(&self) -> u16 { + 0 + } + + fn get_fwmark(&self) -> Option { + None + } +} + +impl PlatformUDP for PairBind { + type Owner = VoidOwner; + fn bind(_port: u16) -> Result<(Vec, Self::Writer, Self::Owner), Self::Error> { + Err(BindError::Disconnected) + } +} diff --git a/src/platform/linux/tun.rs b/src/platform/linux/tun.rs index 3c98c34..9ccda86 100644 --- a/src/platform/linux/tun.rs +++ b/src/platform/linux/tun.rs @@ -299,7 +299,7 @@ impl LinuxTunStatus { Err(LinuxTunError::Closed) } else { Ok(LinuxTunStatus { - events: vec![], + events: vec![TunEvent::Up(1500)], index: get_ifindex(&name), fd, name, diff --git a/src/wireguard/constants.rs b/src/wireguard/constants.rs index 97ce6b1..4d0ae54 100644 --- a/src/wireguard/constants.rs +++ b/src/wireguard/constants.rs @@ -10,17 +10,48 @@ pub const REKEY_ATTEMPT_TIME: Duration = Duration::from_secs(90); pub const REKEY_TIMEOUT: Duration = Duration::from_secs(5); pub const KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10); -pub const MAX_TIMER_HANDSHAKES: usize = 18; +pub const MAX_TIMER_HANDSHAKES: usize = + (REKEY_ATTEMPT_TIME.as_secs() / REKEY_TIMEOUT.as_secs()) as usize; +// Semantics: +// Maximum number of buffered handshake requests +// (either from outside message or handshake requests triggered locally) +pub const MAX_QUEUED_INCOMING_HANDSHAKES: usize = 4096; + +// Semantics: +// When the number of queued handshake requests exceeds this number +// the device is considered under load and DoS mitigation is triggered. +pub const THRESHOLD_UNDER_LOAD: usize = MAX_QUEUED_INCOMING_HANDSHAKES / 8; + +// Semantics: +// When a device is detected to go under load, +// it will remain under load for at least the following duration. +pub const DURATION_UNDER_LOAD: Duration = Duration::from_secs(1); + +// Semantics: +// The payload of transport messages are padded to this multiple +pub const MESSAGE_PADDING_MULTIPLE: usize = 16; + +// Semantics: +// Longest possible duration of any WireGuard timer pub const TIMER_MAX_DURATION: Duration = Duration::from_secs(200); + +// Semantics: +// Resolution of the timer-wheel pub const TIMERS_TICK: Duration = Duration::from_millis(100); + +// Semantics: +// Resulting number of slots in the wheel pub const TIMERS_SLOTS: usize = (TIMER_MAX_DURATION.as_micros() / TIMERS_TICK.as_micros()) as usize; -pub const TIMERS_CAPACITY: usize = 1024; -pub const MESSAGE_PADDING_MULTIPLE: usize = 16; +// Performance: +// Initial capacity of timer-wheel (grows to accommodate more timers) +pub const TIMERS_CAPACITY: usize = 16; /* A long duration (compared to the WireGuard time constants), * used in places to avoid Option by instead using a long "expired" Instant: * (Instant::now() - TIME_HORIZON) + * + * Note, this duration need not fit inside the timer wheel. */ -pub const TIME_HORIZON: Duration = Duration::from_secs(60 * 60 * 24); +pub const TIME_HORIZON: Duration = Duration::from_secs(TIMER_MAX_DURATION.as_secs() * 2); diff --git a/src/wireguard/mod.rs b/src/wireguard/mod.rs index ac7d9be..5310e96 100644 --- a/src/wireguard/mod.rs +++ b/src/wireguard/mod.rs @@ -1,17 +1,29 @@ +/* The wireguard sub-module represents a full, pure, WireGuard implementation: + * + * The WireGuard device described here does not depend on particular IO implementations + * or UAPI, and can be instantiated in unit-tests with the dummy IO implementation. + * + * The code at this level serves to "glue" the handshake state-machine + * and the crypto-key router code together, + * e.g. every WireGuard peer consists of a handshake and router peer. + */ mod constants; -mod timers; -mod wireguard; - mod handshake; mod peer; mod queue; mod router; +mod timers; mod types; +mod wireguard; +mod workers; #[cfg(test)] mod tests; +// represents a peer pub use peer::Peer; + +// represents a WireGuard interface pub use wireguard::Wireguard; #[cfg(test)] @@ -21,5 +33,4 @@ pub use types::dummy_keypair; use super::platform::dummy; use super::platform::{tun, udp, Endpoint}; -use peer::PeerInner; use types::KeyPair; diff --git a/src/wireguard/peer.rs b/src/wireguard/peer.rs index 85e340f..5d15cf3 100644 --- a/src/wireguard/peer.rs +++ b/src/wireguard/peer.rs @@ -3,11 +3,14 @@ use super::timers::{Events, Timers}; use super::tun::Tun; use super::udp::UDP; -use super::wireguard::WireguardInner; +use super::Wireguard; + +use super::constants::REKEY_TIMEOUT; +use super::workers::HandshakeJob; use std::fmt; use std::ops::Deref; -use std::sync::atomic::{AtomicBool, AtomicU64}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::Arc; use std::time::{Instant, SystemTime}; @@ -15,17 +18,12 @@ use spin::{Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard}; use x25519_dalek::PublicKey; -pub struct Peer { - pub router: Arc, T::Writer, B::Writer>>, - pub state: Arc>, -} - pub struct PeerInner { // internal id (for logging) pub id: u64, // wireguard device state - pub wg: Arc>, + pub wg: Wireguard, // handshake state pub walltime_last_handshake: Mutex>, // walltime for last handshake (for UAPI status) @@ -41,6 +39,11 @@ pub struct PeerInner { pub timers: RwLock, } +pub struct Peer { + pub router: Arc, T::Writer, B::Writer>>, + pub state: Arc>, +} + impl Clone for Peer { fn clone(&self) -> Peer { Peer { @@ -51,6 +54,30 @@ impl Clone for Peer { } impl PeerInner { + /* Queue a handshake request for the parallel workers + * (if one does not already exist) + * + * The function is ratelimited. + */ + pub fn packet_send_handshake_initiation(&self) { + // the function is rate limited + + { + let mut lhs = self.last_handshake_sent.lock(); + if lhs.elapsed() < REKEY_TIMEOUT { + return; + } + *lhs = Instant::now(); + } + + // create a new handshake job for the peer + + if !self.handshake_queued.swap(true, Ordering::SeqCst) { + self.wg.pending.fetch_add(1, Ordering::SeqCst); + self.wg.queue.send(HandshakeJob::New(self.pk)); + } + } + #[inline(always)] pub fn timers(&self) -> RwLockReadGuard { self.timers.read() diff --git a/src/wireguard/router/runq.rs b/src/wireguard/router/runq.rs index 44e11a1..4c848cd 100644 --- a/src/wireguard/router/runq.rs +++ b/src/wireguard/router/runq.rs @@ -127,38 +127,3 @@ impl RunQueue { } } } - -#[cfg(test)] -mod tests { - use super::*; - use std::thread; - use std::time::Duration; - - /* - #[test] - fn test_wait() { - let queue: Arc> = Arc::new(RunQueue::new()); - - { - let queue = queue.clone(); - thread::spawn(move || { - queue.run(|e| { - println!("t0 {}", e); - thread::sleep(Duration::from_millis(100)); - }) - }); - } - - { - let queue = queue.clone(); - thread::spawn(move || { - queue.run(|e| { - println!("t1 {}", e); - thread::sleep(Duration::from_millis(100)); - }) - }); - } - - } - */ -} diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs index 3d5c79b..8d1e812 100644 --- a/src/wireguard/router/tests.rs +++ b/src/wireguard/router/tests.rs @@ -105,7 +105,7 @@ mod tests { // wait for scheduling fn wait() { - thread::sleep(Duration::from_millis(15)); + thread::sleep(Duration::from_millis(30)); } fn init() { diff --git a/src/wireguard/tests.rs b/src/wireguard/tests.rs index 3cccb42..f71576a 100644 --- a/src/wireguard/tests.rs +++ b/src/wireguard/tests.rs @@ -14,20 +14,6 @@ use x25519_dalek::{PublicKey, StaticSecret}; use pnet::packet::ipv4::MutableIpv4Packet; use pnet::packet::ipv6::MutableIpv6Packet; -pub fn make_packet_src(size: usize, src: IpAddr, id: u64) -> Vec { - match src { - IpAddr::V4(_) => make_packet(size, src, "127.0.0.1".parse().unwrap(), id), - IpAddr::V6(_) => make_packet(size, src, "::1".parse().unwrap(), id), - } -} - -pub fn make_packet_dst(size: usize, dst: IpAddr, id: u64) -> Vec { - match dst { - IpAddr::V4(_) => make_packet(size, "127.0.0.1".parse().unwrap(), dst, id), - IpAddr::V6(_) => make_packet(size, "::1".parse().unwrap(), dst, id), - } -} - pub fn make_packet(size: usize, src: IpAddr, dst: IpAddr, id: u64) -> Vec { // expand pseudo random payload let mut rng: _ = ChaCha8Rng::seed_from_u64(id); @@ -104,7 +90,7 @@ fn test_pure_wireguard() { wg1.add_udp_reader(bind_reader1); wg2.add_udp_reader(bind_reader2); - // generate (public, pivate) key pairs + // generate (public, private) key pairs let sk1 = StaticSecret::from([ 0x3f, 0x69, 0x86, 0xd1, 0xc0, 0xec, 0x25, 0xa0, 0x9c, 0x8e, 0x56, 0xb5, 0x1d, 0xb7, 0x3c, diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs index 8f8a244..b8c6d99 100644 --- a/src/wireguard/timers.rs +++ b/src/wireguard/timers.rs @@ -7,10 +7,11 @@ use hjul::{Runner, Timer}; use log::debug; use super::constants::*; +use super::peer::{Peer, PeerInner}; use super::router::{message_data_len, Callbacks}; +use super::tun::Tun; use super::types::KeyPair; -use super::{tun, udp}; -use super::{Peer, PeerInner}; +use super::udp::UDP; pub struct Timers { // only updated during configuration @@ -35,7 +36,7 @@ impl Timers { } } -impl PeerInner { +impl PeerInner { pub fn get_keepalive_interval(&self) -> u64 { self.timers().keepalive_interval } @@ -221,11 +222,7 @@ impl PeerInner { } impl Timers { - pub fn new(runner: &Runner, running: bool, peer: Peer) -> Timers - where - T: tun::Tun, - B: udp::UDP, - { + pub fn new(runner: &Runner, running: bool, peer: Peer) -> Timers { // create a timer instance for the provided peer Timers { enabled: running, @@ -338,7 +335,7 @@ impl Timers { pub struct Events(PhantomData<(T, B)>); -impl Callbacks for Events { +impl Callbacks for Events { type Opaque = Arc>; /* Called after the router encrypts a transport message destined for the peer. diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index 45b1fcb..2fa14fc 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -1,179 +1,122 @@ use super::constants::*; use super::handshake; +use super::peer::{Peer, PeerInner}; use super::router; use super::timers::{Events, Timers}; -use super::{Peer, PeerInner}; use super::queue::ParallelQueue; -use super::tun; -use super::tun::Reader as TunReader; +use super::workers::HandshakeJob; -use super::udp; -use super::udp::Reader as UDPReader; -use super::udp::Writer as UDPWriter; +use super::tun::Tun; +use super::udp::UDP; -use super::Endpoint; - -use hjul::Runner; +use super::workers::{handshake_worker, tun_worker, udp_worker}; use std::fmt; use std::ops::Deref; use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; -use std::thread; -use std::time::{Duration, Instant}; - -// TODO: avoid use std::sync::Condvar; use std::sync::Mutex as StdMutex; +use std::thread; +use std::time::Instant; use std::collections::hash_map::Entry; use std::collections::HashMap; -use log::debug; +use hjul::Runner; use rand::rngs::OsRng; use rand::Rng; use spin::{Mutex, RwLock}; -use byteorder::{ByteOrder, LittleEndian}; use x25519_dalek::{PublicKey, StaticSecret}; -const SIZE_HANDSHAKE_QUEUE: usize = 128; -const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4; -const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000); - -#[derive(Clone)] -pub struct WaitHandle(Arc<(StdMutex, Condvar)>); - -impl WaitHandle { - pub fn wait(&self) { - let (lock, cvar) = &*self.0; - let mut nread = lock.lock().unwrap(); - while *nread > 0 { - nread = cvar.wait(nread).unwrap(); - } - } - - fn new() -> Self { - Self(Arc::new((StdMutex::new(0), Condvar::new()))) - } - - fn decrease(&self) { - let (lock, cvar) = &*self.0; - let mut nread = lock.lock().unwrap(); - assert!(*nread > 0); - *nread -= 1; - cvar.notify_all(); - } - - fn increase(&self) { - let (lock, _) = &*self.0; - let mut nread = lock.lock().unwrap(); - *nread += 1; - } -} - -pub struct WireguardInner { +pub struct WireguardInner { // identifier (for logging) - id: u32, + pub id: u32, + + // timer wheel + pub runner: Mutex, // device enabled - enabled: RwLock, + pub enabled: RwLock, - // enables waiting for all readers to finish - tun_readers: WaitHandle, + // number of tun readers + pub tun_readers: WaitCounter, // current MTU - mtu: AtomicUsize, + pub mtu: AtomicUsize, // outbound writer - send: RwLock>, + pub send: RwLock>, // identity and configuration map - peers: RwLock>>, + pub peers: RwLock>>, // cryptokey router - router: router::Device, T::Writer, B::Writer>, + pub router: router::Device, T::Writer, B::Writer>, // handshake related state - handshake: RwLock, - under_load: AtomicBool, - pending: AtomicUsize, // num of pending handshake packets in queue - queue: ParallelQueue>, + pub handshake: RwLock, + pub last_under_load: AtomicUsize, + pub pending: AtomicUsize, // num of pending handshake packets in queue + pub queue: ParallelQueue>, } -impl PeerInner { - /* Queue a handshake request for the parallel workers - * (if one does not already exist) - * - * The function is ratelimited. - */ - pub fn packet_send_handshake_initiation(&self) { - // the function is rate limited - - { - let mut lhs = self.last_handshake_sent.lock(); - if lhs.elapsed() < REKEY_TIMEOUT { - return; - } - *lhs = Instant::now(); - } - - // create a new handshake job for the peer - - if !self.handshake_queued.swap(true, Ordering::SeqCst) { - self.wg.pending.fetch_add(1, Ordering::SeqCst); - self.wg.queue.send(HandshakeJob::New(self.pk)); - } - } +pub struct Wireguard { + inner: Arc>, } -pub enum HandshakeJob { - Message(Vec, E), - New(PublicKey), -} +pub struct WaitCounter(StdMutex, Condvar); -impl fmt::Display for WireguardInner { +impl fmt::Display for Wireguard { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "wireguard({:x})", self.id) } } -impl Deref for Wireguard { - type Target = Arc>; +impl Deref for Wireguard { + type Target = WireguardInner; fn deref(&self) -> &Self::Target { - &self.state + &self.inner } } -pub struct Wireguard { - runner: Runner, - state: Arc>, +impl Clone for Wireguard { + fn clone(&self) -> Self { + Wireguard { + inner: self.inner.clone(), + } + } } -/* Returns the padded length of a message: - * - * # Arguments - * - * - `size` : Size of unpadded message - * - `mtu` : Maximum transmission unit of the device - * - * # Returns - * - * The padded length (always less than or equal to the MTU) - */ -#[inline(always)] -const fn padding(size: usize, mtu: usize) -> usize { - #[inline(always)] - const fn min(a: usize, b: usize) -> usize { - let m = (a < b) as usize; - a * m + (1 - m) * b +impl WaitCounter { + pub fn wait(&self) { + let mut nread = self.0.lock().unwrap(); + while *nread > 0 { + nread = self.1.wait(nread).unwrap(); + } + } + + fn new() -> Self { + Self(StdMutex::new(0), Condvar::new()) + } + + fn decrease(&self) { + let mut nread = self.0.lock().unwrap(); + assert!(*nread > 0); + *nread -= 1; + if *nread == 0 { + self.1.notify_all(); + } + } + + fn increase(&self) { + *self.0.lock().unwrap() += 1; } - let pad = MESSAGE_PADDING_MULTIPLE; - min(mtu, size + (pad - size % pad) % pad) } -impl Wireguard { +impl Wireguard { /// Brings the WireGuard device down. /// Usually called when the associated interface is brought down. /// @@ -193,7 +136,7 @@ impl Wireguard { } // set mtu - self.state.mtu.store(0, Ordering::Relaxed); + self.mtu.store(0, Ordering::Relaxed); // avoid tranmission from router self.router.down(); @@ -213,7 +156,7 @@ impl Wireguard { let mut enabled = self.enabled.write(); // set mtu - self.state.mtu.store(mtu, Ordering::Relaxed); + self.mtu.store(mtu, Ordering::Relaxed); // check if already up if *enabled { @@ -232,25 +175,21 @@ impl Wireguard { } pub fn clear_peers(&self) { - self.state.peers.write().clear(); + self.peers.write().clear(); } pub fn remove_peer(&self, pk: &PublicKey) { if self.handshake.write().remove(pk).is_ok() { - self.state.peers.write().remove(pk.as_bytes()); + self.peers.write().remove(pk.as_bytes()); } } pub fn lookup_peer(&self, pk: &PublicKey) -> Option> { - self.state - .peers - .read() - .get(pk.as_bytes()) - .map(|p| p.clone()) + self.peers.read().get(pk.as_bytes()).map(|p| p.clone()) } pub fn list_peers(&self) -> Vec> { - let peers = self.state.peers.read(); + let peers = self.peers.read(); let mut list = Vec::with_capacity(peers.len()); for (k, v) in peers.iter() { debug_assert!(k == v.pk.as_bytes()); @@ -274,14 +213,14 @@ impl Wireguard { } pub fn set_psk(&self, pk: PublicKey, psk: [u8; 32]) -> bool { - self.state.handshake.write().set_psk(pk, psk).is_ok() + self.handshake.write().set_psk(pk, psk).is_ok() } pub fn get_psk(&self, pk: &PublicKey) -> Option<[u8; 32]> { - self.state.handshake.read().get_psk(pk).ok() + self.handshake.read().get_psk(pk).ok() } pub fn add_peer(&self, pk: PublicKey) -> bool { - if self.state.peers.read().contains_key(pk.as_bytes()) { + if self.peers.read().contains_key(pk.as_bytes()) { return false; } @@ -289,28 +228,28 @@ impl Wireguard { let state = Arc::new(PeerInner { id: rng.gen(), pk, - wg: self.state.clone(), + wg: self.clone(), walltime_last_handshake: Mutex::new(None), last_handshake_sent: Mutex::new(Instant::now() - TIME_HORIZON), handshake_queued: AtomicBool::new(false), rx_bytes: AtomicU64::new(0), tx_bytes: AtomicU64::new(0), - timers: RwLock::new(Timers::dummy(&self.runner)), + timers: RwLock::new(Timers::dummy(&*self.runner.lock())), }); // create a router peer - let router = Arc::new(self.state.router.new_peer(state.clone())); + let router = Arc::new(self.router.new_peer(state.clone())); // form WireGuard peer let peer = Peer { router, state }; // finally, add the peer to the wireguard device - let mut peers = self.state.peers.write(); + let mut peers = self.peers.write(); match peers.entry(*pk.as_bytes()) { Entry::Occupied(_) => false, Entry::Vacant(vacancy) => { // check that the public key does not cause conflict with the private key of the device - let ok_pk = self.state.handshake.write().add(pk).is_ok(); + let ok_pk = self.handshake.write().add(pk).is_ok(); if !ok_pk { return false; } @@ -324,7 +263,7 @@ impl Wireguard { * This is in fact the only place where the write lock is ever taken. * TODO: Consider the ease of using atomic pointers instead. */ - *peer.timers.write() = Timers::new(&self.runner, *enabled, peer.clone()); + *peer.timers.write() = Timers::new(&*self.runner.lock(), *enabled, peer.clone()); // insert into peer map (takes ownership and ensures that the peer is not dropped) vacancy.insert(peer); @@ -339,140 +278,33 @@ impl Wireguard { /// Any previous reader thread is stopped by closing the previous reader, /// which unblocks the thread and causes an error on reader.read pub fn add_udp_reader(&self, reader: B::Reader) { - let wg = self.state.clone(); + let wg = self.clone(); thread::spawn(move || { - let mut last_under_load = - Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000); - - loop { - // create vector big enough for any message given current MTU - let mtu = wg.mtu.load(Ordering::Relaxed); - let size = mtu + handshake::MAX_HANDSHAKE_MSG_SIZE; - let mut msg: Vec = vec![0; size]; - - // read UDP packet into vector - let (size, src) = match reader.read(&mut msg) { - Err(e) => { - debug!("Bind reader closed with {}", e); - return; - } - Ok(v) => v, - }; - msg.truncate(size); - - // TODO: start device down - if mtu == 0 { - continue; - } - - // message type de-multiplexer - if msg.len() < std::mem::size_of::() { - continue; - } - match LittleEndian::read_u32(&msg[..]) { - handshake::TYPE_COOKIE_REPLY - | handshake::TYPE_INITIATION - | handshake::TYPE_RESPONSE => { - debug!("{} : reader, received handshake message", wg); - - // add one to pending - let pending = wg.pending.fetch_add(1, Ordering::SeqCst); - - // update under_load flag - if pending > THRESHOLD_UNDER_LOAD { - debug!("{} : reader, set under load (pending = {})", wg, pending); - last_under_load = Instant::now(); - wg.under_load.store(true, Ordering::SeqCst); - } else if last_under_load.elapsed() > DURATION_UNDER_LOAD { - debug!("{} : reader, clear under load", wg); - wg.under_load.store(false, Ordering::SeqCst); - } - - // add to handshake queue - wg.queue.send(HandshakeJob::Message(msg, src)); - } - router::TYPE_TRANSPORT => { - debug!("{} : reader, received transport message", wg); - - // transport message - let _ = wg.router.recv(src, msg).map_err(|e| { - debug!("Failed to handle incoming transport message: {}", e); - }); - } - _ => (), - } - } + udp_worker(&wg, reader); }); } pub fn set_writer(&self, writer: B::Writer) { // TODO: Consider unifying these and avoid Clone requirement on writer - *self.state.send.write() = Some(writer.clone()); - self.state.router.set_outbound_writer(writer); + *self.send.write() = Some(writer.clone()); + self.router.set_outbound_writer(writer); } pub fn add_tun_reader(&self, reader: T::Reader) { - fn worker(wg: &Arc>, reader: T::Reader) { - loop { - // create vector big enough for any transport message (based on MTU) - let mtu = wg.mtu.load(Ordering::Relaxed); - let size = mtu + router::SIZE_MESSAGE_PREFIX + 1; - let mut msg: Vec = vec![0; size + router::CAPACITY_MESSAGE_POSTFIX]; - - // read a new IP packet - let payload = match reader.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX) { - Ok(payload) => payload, - Err(e) => { - debug!("TUN worker, failed to read from tun device: {}", e); - break; - } - }; - debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu); - - // check if device is down - if mtu == 0 { - continue; - } - - // truncate padding - let padded = padding(payload, mtu); - log::trace!( - "TUN worker, payload length = {}, padded length = {}", - payload, - padded - ); - msg.truncate(router::SIZE_MESSAGE_PREFIX + padded); - debug_assert!(padded <= mtu); - debug_assert_eq!( - if padded < mtu { - (msg.len() - router::SIZE_MESSAGE_PREFIX) % MESSAGE_PADDING_MULTIPLE - } else { - 0 - }, - 0 - ); - - // crypt-key route - let e = wg.router.send(msg); - debug!("TUN worker, router returned {:?}", e); - } - } - - // start a thread for every reader - let wg = self.state.clone(); + let wg = self.clone(); // increment reader count wg.tun_readers.increase(); // start worker thread::spawn(move || { - worker(&wg, reader); + tun_worker(&wg, reader); wg.tun_readers.decrease(); }); } - pub fn wait(&self) -> WaitHandle { - self.state.tun_readers.clone() + pub fn wait(&self) { + self.tun_readers.wait(); } pub fn new(writer: T::Writer) -> Wireguard { @@ -482,143 +314,33 @@ impl Wireguard { // create device state let mut rng = OsRng::new().unwrap(); - // handshake queue + // create handshake queue let (tx, mut rxs) = ParallelQueue::new(cpus, 128); - let wg = Arc::new(WireguardInner { - enabled: RwLock::new(false), - tun_readers: WaitHandle::new(), - id: rng.gen(), - mtu: AtomicUsize::new(0), - peers: RwLock::new(HashMap::new()), - send: RwLock::new(None), - router: router::Device::new(num_cpus::get(), writer), // router owns the writing half - pending: AtomicUsize::new(0), - handshake: RwLock::new(handshake::Device::new()), - under_load: AtomicBool::new(false), - queue: tx, - }); + + // create arc to state + let wg = Wireguard { + inner: Arc::new(WireguardInner { + enabled: RwLock::new(false), + tun_readers: WaitCounter::new(), + id: rng.gen(), + mtu: AtomicUsize::new(0), + peers: RwLock::new(HashMap::new()), + last_under_load: AtomicUsize::new(0), // TODO + send: RwLock::new(None), + router: router::Device::new(num_cpus::get(), writer), // router owns the writing half + pending: AtomicUsize::new(0), + handshake: RwLock::new(handshake::Device::new()), + runner: Mutex::new(Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY)), + queue: tx, + }), + }; // start handshake workers while let Some(rx) = rxs.pop() { let wg = wg.clone(); - thread::spawn(move || { - debug!("{} : handshake worker, started", wg); - - // prepare OsRng instance for this thread - let mut rng = OsRng::new().expect("Unable to obtain a CSPRNG"); - - // process elements from the handshake queue - for job in rx { - // decrement pending pakcets (under_load) - let job: HandshakeJob = job; - wg.pending.fetch_sub(1, Ordering::SeqCst); - - // demultiplex staged handshake jobs and handshake messages - match job { - HandshakeJob::Message(msg, src) => { - // feed message to handshake device - let src_validate = (&src).into_address(); // TODO avoid - - // process message - let device = wg.handshake.read(); - match device.process( - &mut rng, - &msg[..], - if wg.under_load.load(Ordering::Relaxed) { - debug!("{} : handshake worker, under load", wg); - Some(&src_validate) - } else { - None - }, - ) { - Ok((pk, resp, keypair)) => { - // send response (might be cookie reply or handshake response) - let mut resp_len: u64 = 0; - if let Some(msg) = resp { - resp_len = msg.len() as u64; - let send: &Option = &*wg.send.read(); - if let Some(writer) = send.as_ref() { - debug!( - "{} : handshake worker, send response ({} bytes)", - wg, resp_len - ); - let _ = writer.write(&msg[..], &src).map_err(|e| { - debug!( - "{} : handshake worker, failed to send response, error = {}", - wg, - e - ) - }); - } - } - - // update peer state - if let Some(pk) = pk { - // authenticated handshake packet received - if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { - // add to rx_bytes and tx_bytes - let req_len = msg.len() as u64; - peer.rx_bytes.fetch_add(req_len, Ordering::Relaxed); - peer.tx_bytes.fetch_add(resp_len, Ordering::Relaxed); - - // update endpoint - peer.router.set_endpoint(src); - - if resp_len > 0 { - // update timers after sending handshake response - debug!("{} : handshake worker, handshake response sent", wg); - peer.state.sent_handshake_response(); - } else { - // update timers after receiving handshake response - debug!("{} : handshake worker, handshake response was received", wg); - peer.state.timers_handshake_complete(); - } - - // add any new keypair to peer - keypair.map(|kp| { - debug!( - "{} : handshake worker, new keypair for {}", - wg, peer - ); - - // this means that a handshake response was processed or sent - peer.timers_session_derived(); - - // free any unused ids - for id in peer.router.add_keypair(kp) { - device.release(id); - } - }); - } - } - } - Err(e) => debug!("{} : handshake worker, error = {:?}", wg, e), - } - } - HandshakeJob::New(pk) => { - if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { - debug!( - "{} : handshake worker, new handshake requested for {}", - wg, peer - ); - let device = wg.handshake.read(); - let _ = device.begin(&mut rng, &peer.pk).map(|msg| { - let _ = peer.router.send(&msg[..]).map_err(|e| { - debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e) - }); - peer.state.sent_handshake_initiation(); - }); - peer.handshake_queued.store(false, Ordering::SeqCst); - } - } - } - } - }); + thread::spawn(move || handshake_worker(&wg, rx)); } - Wireguard { - state: wg, - runner: Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY), - } + wg } } diff --git a/src/wireguard/workers.rs b/src/wireguard/workers.rs new file mode 100644 index 0000000..b65f49a --- /dev/null +++ b/src/wireguard/workers.rs @@ -0,0 +1,280 @@ +use std::sync::atomic::Ordering; +use std::time::Instant; + +use byteorder::{ByteOrder, LittleEndian}; +use crossbeam_channel::Receiver; +use log::debug; +use rand::rngs::OsRng; +use x25519_dalek::PublicKey; + +// IO traits +use super::Endpoint; + +use super::tun::Reader as TunReader; +use super::tun::Tun; + +use super::udp::Reader as UDPReader; +use super::udp::Writer as UDPWriter; +use super::udp::UDP; + +// constants +use super::constants::{ + DURATION_UNDER_LOAD, MESSAGE_PADDING_MULTIPLE, THRESHOLD_UNDER_LOAD, TIME_HORIZON, +}; +use super::handshake::MAX_HANDSHAKE_MSG_SIZE; +use super::handshake::{TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE}; +use super::router::{CAPACITY_MESSAGE_POSTFIX, SIZE_MESSAGE_PREFIX, TYPE_TRANSPORT}; + +use super::Wireguard; + +pub enum HandshakeJob { + Message(Vec, E), + New(PublicKey), +} + +/* Returns the padded length of a message: + * + * # Arguments + * + * - `size` : Size of unpadded message + * - `mtu` : Maximum transmission unit of the device + * + * # Returns + * + * The padded length (always less than or equal to the MTU) + */ +#[inline(always)] +const fn padding(size: usize, mtu: usize) -> usize { + #[inline(always)] + const fn min(a: usize, b: usize) -> usize { + let m = (a < b) as usize; + a * m + (1 - m) * b + } + let pad = MESSAGE_PADDING_MULTIPLE; + min(mtu, size + (pad - size % pad) % pad) +} + +pub fn tun_worker(wg: &Wireguard, reader: T::Reader) { + loop { + // create vector big enough for any transport message (based on MTU) + let mtu = wg.mtu.load(Ordering::Relaxed); + let size = mtu + SIZE_MESSAGE_PREFIX + 1; + let mut msg: Vec = vec![0; size + CAPACITY_MESSAGE_POSTFIX]; + + // read a new IP packet + let payload = match reader.read(&mut msg[..], SIZE_MESSAGE_PREFIX) { + Ok(payload) => payload, + Err(e) => { + debug!("TUN worker, failed to read from tun device: {}", e); + break; + } + }; + debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu); + + // check if device is down + if mtu == 0 { + continue; + } + + // truncate padding + let padded = padding(payload, mtu); + log::trace!( + "TUN worker, payload length = {}, padded length = {}", + payload, + padded + ); + msg.truncate(SIZE_MESSAGE_PREFIX + padded); + debug_assert!(padded <= mtu); + debug_assert_eq!( + if padded < mtu { + (msg.len() - SIZE_MESSAGE_PREFIX) % MESSAGE_PADDING_MULTIPLE + } else { + 0 + }, + 0 + ); + + // crypt-key route + let e = wg.router.send(msg); + debug!("TUN worker, router returned {:?}", e); + } +} + +pub fn udp_worker(wg: &Wireguard, reader: B::Reader) { + let mut last_under_load = Instant::now() - TIME_HORIZON; + + loop { + // create vector big enough for any message given current MTU + let mtu = wg.mtu.load(Ordering::Relaxed); + let size = mtu + MAX_HANDSHAKE_MSG_SIZE; + let mut msg: Vec = vec![0; size]; + + // read UDP packet into vector + let (size, src) = match reader.read(&mut msg) { + Err(e) => { + debug!("Bind reader closed with {}", e); + return; + } + Ok(v) => v, + }; + msg.truncate(size); + + // TODO: start device down + if mtu == 0 { + continue; + } + + // message type de-multiplexer + if msg.len() < std::mem::size_of::() { + continue; + } + match LittleEndian::read_u32(&msg[..]) { + TYPE_COOKIE_REPLY | TYPE_INITIATION | TYPE_RESPONSE => { + debug!("{} : reader, received handshake message", wg); + + // add one to pending + let pending = wg.pending.fetch_add(1, Ordering::SeqCst); + + // update under_load flag + if pending > THRESHOLD_UNDER_LOAD { + debug!("{} : reader, set under load (pending = {})", wg, pending); + last_under_load = Instant::now(); + } else if last_under_load.elapsed() > DURATION_UNDER_LOAD { + debug!("{} : reader, clear under load", wg); + } + + // add to handshake queue + wg.queue.send(HandshakeJob::Message(msg, src)); + } + TYPE_TRANSPORT => { + debug!("{} : reader, received transport message", wg); + + // transport message + let _ = wg.router.recv(src, msg).map_err(|e| { + debug!("Failed to handle incoming transport message: {}", e); + }); + } + _ => (), + } + } +} + +pub fn handshake_worker( + wg: &Wireguard, + rx: Receiver>, +) { + debug!("{} : handshake worker, started", wg); + + // prepare OsRng instance for this thread + let mut rng = OsRng::new().expect("Unable to obtain a CSPRNG"); + + // process elements from the handshake queue + for job in rx { + // decrement pending pakcets (under_load) + let job: HandshakeJob = job; + wg.pending.fetch_sub(1, Ordering::SeqCst); + + // demultiplex staged handshake jobs and handshake messages + match job { + HandshakeJob::Message(msg, src) => { + // feed message to handshake device + let src_validate = (&src).into_address(); // TODO avoid + + // process message + let device = wg.handshake.read(); + match device.process( + &mut rng, + &msg[..], + None, + /* + if wg.under_load.load(Ordering::Relaxed) { + debug!("{} : handshake worker, under load", wg); + Some(&src_validate) + } else { + None + } + */ + ) { + Ok((pk, resp, keypair)) => { + // send response (might be cookie reply or handshake response) + let mut resp_len: u64 = 0; + if let Some(msg) = resp { + resp_len = msg.len() as u64; + let send: &Option = &*wg.send.read(); + if let Some(writer) = send.as_ref() { + debug!( + "{} : handshake worker, send response ({} bytes)", + wg, resp_len + ); + let _ = writer.write(&msg[..], &src).map_err(|e| { + debug!( + "{} : handshake worker, failed to send response, error = {}", + wg, + e + ) + }); + } + } + + // update peer state + if let Some(pk) = pk { + // authenticated handshake packet received + if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { + // add to rx_bytes and tx_bytes + let req_len = msg.len() as u64; + peer.rx_bytes.fetch_add(req_len, Ordering::Relaxed); + peer.tx_bytes.fetch_add(resp_len, Ordering::Relaxed); + + // update endpoint + peer.router.set_endpoint(src); + + if resp_len > 0 { + // update timers after sending handshake response + debug!("{} : handshake worker, handshake response sent", wg); + peer.state.sent_handshake_response(); + } else { + // update timers after receiving handshake response + debug!( + "{} : handshake worker, handshake response was received", + wg + ); + peer.state.timers_handshake_complete(); + } + + // add any new keypair to peer + keypair.map(|kp| { + debug!("{} : handshake worker, new keypair for {}", wg, peer); + + // this means that a handshake response was processed or sent + peer.timers_session_derived(); + + // free any unused ids + for id in peer.router.add_keypair(kp) { + device.release(id); + } + }); + } + } + } + Err(e) => debug!("{} : handshake worker, error = {:?}", wg, e), + } + } + HandshakeJob::New(pk) => { + if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { + debug!( + "{} : handshake worker, new handshake requested for {}", + wg, peer + ); + let device = wg.handshake.read(); + let _ = device.begin(&mut rng, &peer.pk).map(|msg| { + let _ = peer.router.send(&msg[..]).map_err(|e| { + debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e) + }); + peer.state.sent_handshake_initiation(); + }); + peer.handshake_queued.store(false, Ordering::SeqCst); + } + } + } + } +} -- cgit v1.2.3-59-g8ed1b