From 3bff078e3f1c59454d8db14e5dc7603e6fdbeaba Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sun, 24 Nov 2019 18:41:43 +0100 Subject: Make IO traits suitable for Tun events (up/down) --- src/wireguard/mod.rs | 2 +- src/wireguard/peer.rs | 16 ++++++------- src/wireguard/router/device.rs | 12 +++++----- src/wireguard/router/peer.rs | 18 +++++++-------- src/wireguard/router/tests.rs | 2 +- src/wireguard/router/workers.rs | 8 +++---- src/wireguard/tests.rs | 14 +++++++----- src/wireguard/timers.rs | 8 +++---- src/wireguard/wireguard.rs | 50 ++++++++++++++++++++++++++++------------- 9 files changed, 77 insertions(+), 53 deletions(-) (limited to 'src/wireguard') diff --git a/src/wireguard/mod.rs b/src/wireguard/mod.rs index 79feed7..711aa2b 100644 --- a/src/wireguard/mod.rs +++ b/src/wireguard/mod.rs @@ -20,7 +20,7 @@ pub use types::dummy_keypair; #[cfg(test)] use super::platform::dummy; -use super::platform::{bind, tun, Endpoint}; +use super::platform::{tun, udp, Endpoint}; use peer::PeerInner; use types::KeyPair; use wireguard::HandshakeJob; diff --git a/src/wireguard/peer.rs b/src/wireguard/peer.rs index 92844b6..5bcd070 100644 --- a/src/wireguard/peer.rs +++ b/src/wireguard/peer.rs @@ -2,8 +2,8 @@ use super::router; use super::timers::{Events, Timers}; use super::HandshakeJob; -use super::bind::Bind; use super::tun::Tun; +use super::udp::UDP; use super::wireguard::WireguardInner; use std::fmt; @@ -17,12 +17,12 @@ use spin::{Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard}; use crossbeam_channel::Sender; use x25519_dalek::PublicKey; -pub struct Peer { +pub struct Peer { pub router: Arc, T::Writer, B::Writer>>, pub state: Arc>, } -pub struct PeerInner { +pub struct PeerInner { // internal id (for logging) pub id: u64, @@ -44,7 +44,7 @@ pub struct PeerInner { pub timers: RwLock, } -impl Clone for Peer { +impl Clone for Peer { fn clone(&self) -> Peer { Peer { router: self.router.clone(), @@ -53,7 +53,7 @@ impl Clone for Peer { } } -impl PeerInner { +impl PeerInner { #[inline(always)] pub fn timers(&self) -> RwLockReadGuard { self.timers.read() @@ -65,20 +65,20 @@ impl PeerInner { } } -impl fmt::Display for Peer { +impl fmt::Display for Peer { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "peer(id = {})", self.id) } } -impl Deref for Peer { +impl Deref for Peer { type Target = PeerInner; fn deref(&self) -> &Self::Target { &self.state } } -impl Peer { +impl Peer { /// Bring the peer down. Causing: /// /// - Timers to be stopped and disabled. diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs index 34273d5..621010b 100644 --- a/src/wireguard/router/device.rs +++ b/src/wireguard/router/device.rs @@ -21,9 +21,9 @@ use super::SIZE_MESSAGE_PREFIX; use super::route::RoutingTable; -use super::super::{bind, tun, Endpoint, KeyPair}; +use super::super::{tun, udp, Endpoint, KeyPair}; -pub struct DeviceInner> { +pub struct DeviceInner> { // inbound writer (TUN) pub inbound: T, @@ -45,7 +45,7 @@ pub struct EncryptionState { pub death: Instant, // (birth + reject-after-time - keepalive-timeout - rekey-timeout) } -pub struct DecryptionState> { +pub struct DecryptionState> { pub keypair: Arc, pub confirmed: AtomicBool, pub protector: Mutex, @@ -53,12 +53,12 @@ pub struct DecryptionState> { +pub struct Device> { state: Arc>, // reference to device state handles: Vec>, // join handles for workers } -impl> Drop for Device { +impl> Drop for Device { fn drop(&mut self) { debug!("router: dropping device"); @@ -82,7 +82,7 @@ impl> Drop for Dev } } -impl> Device { +impl> Device { pub fn new(num_workers: usize, tun: T) -> Device { // allocate shared device state let inner = DeviceInner { diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs index c09e786..fff4dfc 100644 --- a/src/wireguard/router/peer.rs +++ b/src/wireguard/router/peer.rs @@ -12,7 +12,7 @@ use log::debug; use spin::Mutex; use super::super::constants::*; -use super::super::{bind, tun, Endpoint, KeyPair}; +use super::super::{tun, udp, Endpoint, KeyPair}; use super::anti_replay::AntiReplay; use super::device::DecryptionState; @@ -36,7 +36,7 @@ pub struct KeyWheel { retired: Vec, // retired ids } -pub struct PeerInner> { +pub struct PeerInner> { pub device: Arc>, pub opaque: C::Opaque, pub outbound: Mutex>, @@ -47,13 +47,13 @@ pub struct PeerInner>, } -pub struct Peer> { +pub struct Peer> { state: Arc>, thread_outbound: Option>, thread_inbound: Option>, } -impl> Deref for Peer { +impl> Deref for Peer { type Target = Arc>; fn deref(&self) -> &Self::Target { @@ -71,7 +71,7 @@ impl EncryptionState { } } -impl> DecryptionState { +impl> DecryptionState { fn new( peer: &Arc>, keypair: &Arc, @@ -86,7 +86,7 @@ impl> DecryptionSt } } -impl> Drop for Peer { +impl> Drop for Peer { fn drop(&mut self) { let peer = &self.state; @@ -133,7 +133,7 @@ impl> Drop for Pee } } -pub fn new_peer>( +pub fn new_peer>( device: Arc>, opaque: C::Opaque, ) -> Peer { @@ -180,7 +180,7 @@ pub fn new_peer>( } } -impl> PeerInner { +impl> PeerInner { /// Send a raw message to the peer (used for handshake messages) /// /// # Arguments @@ -352,7 +352,7 @@ impl> PeerInner> Peer { +impl> Peer { /// Set the endpoint of the peer /// /// # Arguments diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs index 24c1b56..2d6bb63 100644 --- a/src/wireguard/router/tests.rs +++ b/src/wireguard/router/tests.rs @@ -7,10 +7,10 @@ use std::time::Duration; use num_cpus; -use super::super::bind::*; use super::super::dummy; use super::super::dummy_keypair; use super::super::tests::make_packet_dst; +use super::super::udp::*; use super::KeyPair; use super::SIZE_MESSAGE_PREFIX; use super::{Callbacks, Device}; diff --git a/src/wireguard/router/workers.rs b/src/wireguard/router/workers.rs index cd8015b..3ed6311 100644 --- a/src/wireguard/router/workers.rs +++ b/src/wireguard/router/workers.rs @@ -19,7 +19,7 @@ use super::types::Callbacks; use super::REJECT_AFTER_MESSAGES; use super::super::types::KeyPair; -use super::super::{bind, tun, Endpoint}; +use super::super::{tun, udp, Endpoint}; pub const SIZE_TAG: usize = 16; @@ -40,7 +40,7 @@ pub enum JobParallel { } #[allow(type_alias_bounds)] -pub type JobInbound> = ( +pub type JobInbound> = ( Arc>, E, oneshot::Receiver>, @@ -50,7 +50,7 @@ pub type JobOutbound = oneshot::Receiver; /* TODO: Replace with run-queue */ -pub fn worker_inbound>( +pub fn worker_inbound>( device: Arc>, // related device peer: Arc>, // related peer receiver: Receiver>, @@ -137,7 +137,7 @@ pub fn worker_inbound>( +pub fn worker_outbound>( peer: Arc>, receiver: Receiver, ) { diff --git a/src/wireguard/tests.rs b/src/wireguard/tests.rs index 83ef594..8217d72 100644 --- a/src/wireguard/tests.rs +++ b/src/wireguard/tests.rs @@ -1,5 +1,5 @@ use super::wireguard::Wireguard; -use super::{bind, dummy, tun}; +use super::{dummy, tun, udp}; use std::net::IpAddr; use std::thread; @@ -84,13 +84,17 @@ fn test_pure_wireguard() { // create WG instances for dummy TUN devices - let (fake1, tun_reader1, tun_writer1, mtu1) = dummy::TunTest::create(1500, true); + let (fake1, tun_reader1, tun_writer1, _) = dummy::TunTest::create(1500, true); let wg1: Wireguard = - Wireguard::new(vec![tun_reader1], tun_writer1, mtu1); + Wireguard::new(vec![tun_reader1], tun_writer1); - let (fake2, tun_reader2, tun_writer2, mtu2) = dummy::TunTest::create(1500, true); + wg1.set_mtu(1500); + + let (fake2, tun_reader2, tun_writer2, _) = dummy::TunTest::create(1500, true); let wg2: Wireguard = - Wireguard::new(vec![tun_reader2], tun_writer2, mtu2); + Wireguard::new(vec![tun_reader2], tun_writer2); + + wg2.set_mtu(1500); // create pair bind to connect the interfaces "over the internet" diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs index 5eb69dc..18f49bf 100644 --- a/src/wireguard/timers.rs +++ b/src/wireguard/timers.rs @@ -9,7 +9,7 @@ use hjul::{Runner, Timer}; use super::constants::*; use super::router::{message_data_len, Callbacks}; use super::{Peer, PeerInner}; -use super::{bind, tun}; +use super::{udp, tun}; use super::types::KeyPair; pub struct Timers { @@ -35,7 +35,7 @@ impl Timers { } } -impl PeerInner { +impl PeerInner { pub fn get_keepalive_interval(&self) -> u64 { self.timers().keepalive_interval @@ -224,7 +224,7 @@ impl Timers { pub fn new(runner: &Runner, peer: Peer) -> Timers where T: tun::Tun, - B: bind::Bind, + B: udp::UDP, { // create a timer instance for the provided peer Timers { @@ -335,7 +335,7 @@ impl Timers { pub struct Events(PhantomData<(T, B)>); -impl Callbacks for Events { +impl Callbacks for Events { type Opaque = Arc>; /* Called after the router encrypts a transport message destined for the peer. diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index eb43512..41f6857 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -4,9 +4,13 @@ use super::router; use super::timers::{Events, Timers}; use super::{Peer, PeerInner}; -use super::bind::Reader as BindReader; -use super::bind::{Bind, Writer}; -use super::tun::{Reader, Tun, MTU}; +use super::tun; +use super::tun::Reader as TunReader; + +use super::udp; +use super::udp::Reader as UDPReader; +use super::udp::Writer as UDPWriter; + use super::Endpoint; use hjul::Runner; @@ -34,13 +38,15 @@ const SIZE_HANDSHAKE_QUEUE: usize = 128; const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4; const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000); -pub struct WireguardInner { +pub struct WireguardInner { // identifier (for logging) id: u32, start: Instant, + // current MTU + mtu: AtomicUsize, + // provides access to the MTU value of the tun device - mtu: T::MTU, send: RwLock>, // identity and configuration map @@ -56,7 +62,7 @@ pub struct WireguardInner { queue: Mutex>>, } -impl PeerInner { +impl PeerInner { /* Queue a handshake request for the parallel workers * (if one does not already exist) * @@ -87,20 +93,20 @@ pub enum HandshakeJob { New(PublicKey), } -impl fmt::Display for WireguardInner { +impl fmt::Display for WireguardInner { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "wireguard({:x})", self.id) } } -impl Deref for Wireguard { +impl Deref for Wireguard { type Target = Arc>; fn deref(&self) -> &Self::Target { &self.state } } -pub struct Wireguard { +pub struct Wireguard { runner: Runner, state: Arc>, } @@ -127,7 +133,7 @@ const fn padding(size: usize, mtu: usize) -> usize { min(mtu, size + (pad - size % pad) % pad) } -impl Wireguard { +impl Wireguard { /// Brings the WireGuard device down. /// Usually called when the associated interface is brought down. /// @@ -269,7 +275,8 @@ impl Wireguard { loop { // create vector big enough for any message given current MTU - let size = wg.mtu.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE; + let mtu = wg.mtu.load(Ordering::Relaxed); + let size = mtu + handshake::MAX_HANDSHAKE_MSG_SIZE; let mut msg: Vec = Vec::with_capacity(size); msg.resize(size, 0); @@ -283,6 +290,11 @@ impl Wireguard { }; msg.truncate(size); + // TODO: start device down + if mtu == 0 { + continue; + } + // message type de-multiplexer if msg.len() < std::mem::size_of::() { continue; @@ -326,13 +338,17 @@ impl Wireguard { }); } + pub fn set_mtu(&self, mtu: usize) { + self.mtu.store(mtu, Ordering::Relaxed); + } + pub fn set_writer(&self, writer: B::Writer) { // TODO: Consider unifying these and avoid Clone requirement on writer *self.state.send.write() = Some(writer.clone()); self.state.router.set_outbound_writer(writer); } - pub fn new(mut readers: Vec, writer: T::Writer, mtu: T::MTU) -> Wireguard { + pub fn new(mut readers: Vec, writer: T::Writer) -> Wireguard { // create device state let mut rng = OsRng::new().unwrap(); @@ -342,7 +358,7 @@ impl Wireguard { let wg = Arc::new(WireguardInner { start: Instant::now(), id: rng.gen(), - mtu: mtu.clone(), + mtu: AtomicUsize::new(0), peers: RwLock::new(HashMap::new()), send: RwLock::new(None), router: router::Device::new(num_cpus::get(), writer), // router owns the writing half @@ -475,10 +491,9 @@ impl Wireguard { ); while let Some(reader) = readers.pop() { let wg = wg.clone(); - let mtu = mtu.clone(); thread::spawn(move || loop { // create vector big enough for any transport message (based on MTU) - let mtu = mtu.mtu(); + let mtu = wg.mtu.load(Ordering::Relaxed); let size = mtu + router::SIZE_MESSAGE_PREFIX; let mut msg: Vec = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX); msg.resize(size, 0); @@ -493,6 +508,11 @@ impl Wireguard { }; debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu); + // TODO: start device down + if mtu == 0 { + continue; + } + // truncate padding let padded = padding(payload, mtu); log::trace!( -- cgit v1.2.3-59-g8ed1b