diff options
Diffstat (limited to '')
-rw-r--r-- | src/configuration/config.rs | 8 | ||||
-rw-r--r-- | src/configuration/mod.rs | 2 | ||||
-rw-r--r-- | src/main.rs | 25 | ||||
-rw-r--r-- | src/platform/dummy/bind.rs | 8 | ||||
-rw-r--r-- | src/platform/dummy/tun.rs | 35 | ||||
-rw-r--r-- | src/platform/linux/mod.rs | 2 | ||||
-rw-r--r-- | src/platform/linux/tun.rs | 42 | ||||
-rw-r--r-- | src/platform/linux/udp.rs | 24 | ||||
-rw-r--r-- | src/platform/mod.rs | 2 | ||||
-rw-r--r-- | src/platform/tun.rs | 30 | ||||
-rw-r--r-- | src/platform/udp.rs (renamed from src/platform/bind.rs) | 4 | ||||
-rw-r--r-- | src/wireguard/mod.rs | 2 | ||||
-rw-r--r-- | src/wireguard/peer.rs | 16 | ||||
-rw-r--r-- | src/wireguard/router/device.rs | 12 | ||||
-rw-r--r-- | src/wireguard/router/peer.rs | 18 | ||||
-rw-r--r-- | src/wireguard/router/tests.rs | 2 | ||||
-rw-r--r-- | src/wireguard/router/workers.rs | 8 | ||||
-rw-r--r-- | src/wireguard/tests.rs | 14 | ||||
-rw-r--r-- | src/wireguard/timers.rs | 8 | ||||
-rw-r--r-- | src/wireguard/wireguard.rs | 50 |
20 files changed, 186 insertions, 126 deletions
diff --git a/src/configuration/config.rs b/src/configuration/config.rs index 2a149ee..6ab173c 100644 --- a/src/configuration/config.rs +++ b/src/configuration/config.rs @@ -4,8 +4,8 @@ use std::sync::atomic::Ordering; use std::time::{Duration, SystemTime}; use x25519_dalek::{PublicKey, StaticSecret}; +use super::udp::Owner; use super::*; -use bind::Owner; /// The goal of the configuration interface is, among others, /// to hide the IO implementations (over which the WG device is generic), @@ -26,13 +26,13 @@ pub struct PeerState { pub preshared_key: [u8; 32], // 0^32 is the "default value" } -pub struct WireguardConfig<T: tun::Tun, B: bind::PlatformBind> { +pub struct WireguardConfig<T: tun::Tun, B: udp::PlatformUDP> { wireguard: Wireguard<T, B>, fwmark: Mutex<Option<u32>>, network: Mutex<Option<B::Owner>>, } -impl<T: tun::Tun, B: bind::PlatformBind> WireguardConfig<T, B> { +impl<T: tun::Tun, B: udp::PlatformUDP> WireguardConfig<T, B> { pub fn new(wg: Wireguard<T, B>) -> WireguardConfig<T, B> { WireguardConfig { wireguard: wg, @@ -170,7 +170,7 @@ pub trait Configuration { fn get_fwmark(&self) -> Option<u32>; } -impl<T: tun::Tun, B: bind::PlatformBind> Configuration for WireguardConfig<T, B> { +impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireguardConfig<T, B> { fn get_fwmark(&self) -> Option<u32> { self.network .lock() diff --git a/src/configuration/mod.rs b/src/configuration/mod.rs index dc1d93a..d7524d9 100644 --- a/src/configuration/mod.rs +++ b/src/configuration/mod.rs @@ -3,7 +3,7 @@ mod error; pub mod uapi; use super::platform::Endpoint; -use super::platform::{bind, tun}; +use super::platform::{tun, udp}; use super::wireguard::Wireguard; pub use error::ConfigError; diff --git a/src/main.rs b/src/main.rs index aa02321..c566f81 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,6 +6,7 @@ use log; use daemonize::Daemonize; use std::env; use std::process::exit; +use std::thread; mod configuration; mod platform; @@ -52,7 +53,7 @@ fn main() { }); // create TUN device - let (readers, writer, mtu) = plt::Tun::create(name.as_str()).unwrap_or_else(|e| { + let (readers, writer, status) = plt::Tun::create(name.as_str()).unwrap_or_else(|e| { eprintln!("Failed to create TUN device: {}", e); exit(-3); }); @@ -78,8 +79,26 @@ fn main() { if drop_privileges {} // create WireGuard device - let wg: wireguard::Wireguard<plt::Tun, plt::Bind> = - wireguard::Wireguard::new(readers, writer, mtu); + let wg: wireguard::Wireguard<plt::Tun, plt::UDP> = wireguard::Wireguard::new(readers, writer); + + wg.set_mtu(1420); + + // start Tun event thread + /* + { + let wg = wg.clone(); + let mut status = status; + thread::spawn(move || loop { + match status.event() { + Err(_) => break, + Ok(tun::TunEvent::Up(mtu)) => { + wg.mtu.store(mtu, Ordering::Relaxed); + } + Ok(tun::TunEvent::Down) => {} + } + }); + } + */ // handle TUN updates up/down diff --git a/src/platform/dummy/bind.rs b/src/platform/dummy/bind.rs index d69e6a4..3146af8 100644 --- a/src/platform/dummy/bind.rs +++ b/src/platform/dummy/bind.rs @@ -11,7 +11,7 @@ use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; use std::sync::Arc; use std::sync::Mutex; -use super::super::bind::*; +use super::super::udp::*; use super::UnitEndpoint; @@ -82,7 +82,7 @@ impl Writer<UnitEndpoint> for VoidBind { } } -impl Bind for VoidBind { +impl UDP for VoidBind { type Error = BindError; type Endpoint = UnitEndpoint; @@ -193,7 +193,7 @@ impl PairBind { } } -impl Bind for PairBind { +impl UDP for PairBind { type Error = BindError; type Endpoint = UnitEndpoint; type Reader = PairReader<Self::Endpoint>; @@ -216,7 +216,7 @@ impl Owner for VoidOwner { } } -impl PlatformBind for PairBind { +impl PlatformUDP for PairBind { type Owner = VoidOwner; fn bind(_port: u16) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Owner), Self::Error> { Err(BindError::Disconnected) diff --git a/src/platform/dummy/tun.rs b/src/platform/dummy/tun.rs index 569bf1c..6ddf7d5 100644 --- a/src/platform/dummy/tun.rs +++ b/src/platform/dummy/tun.rs @@ -10,6 +10,8 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; use std::sync::Arc; use std::sync::Mutex; +use std::thread; +use std::time::Duration; use super::super::tun::*; @@ -83,9 +85,8 @@ pub struct TunWriter { tx: Mutex<SyncSender<Vec<u8>>>, } -#[derive(Clone)] -pub struct TunMTU { - mtu: Arc<AtomicUsize>, +pub struct TunStatus { + first: bool, } impl Reader for TunReader { @@ -131,16 +132,25 @@ impl Writer for TunWriter { } } -impl MTU for TunMTU { - fn mtu(&self) -> usize { - self.mtu.load(Ordering::Acquire) +impl Status for TunStatus { + type Error = TunError; + + fn event(&mut self) -> Result<TunEvent, Self::Error> { + if self.first { + self.first = false; + return Ok(TunEvent::Up(1420)); + } + + loop { + thread::sleep(Duration::from_secs(60 * 60)); + } } } impl Tun for TunTest { type Writer = TunWriter; type Reader = TunReader; - type MTU = TunMTU; + type Status = TunStatus; type Error = TunError; } @@ -157,7 +167,7 @@ impl TunFakeIO { } impl TunTest { - pub fn create(mtu: usize, store: bool) -> (TunFakeIO, TunReader, TunWriter, TunMTU) { + pub fn create(mtu: usize, store: bool) -> (TunFakeIO, TunReader, TunWriter, TunStatus) { let (tx1, rx1) = if store { sync_channel(32) } else { @@ -184,16 +194,13 @@ impl TunTest { tx: Mutex::new(tx2), store, }; - let mtu = TunMTU { - mtu: Arc::new(AtomicUsize::new(mtu)), - }; - - (fake, reader, writer, mtu) + let status = TunStatus { first: true }; + (fake, reader, writer, status) } } impl PlatformTun for TunTest { - fn create(_name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::MTU), Self::Error> { + fn create(_name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Status), Self::Error> { Err(TunError::Disconnected) } } diff --git a/src/platform/linux/mod.rs b/src/platform/linux/mod.rs index 82731de..d28391e 100644 --- a/src/platform/linux/mod.rs +++ b/src/platform/linux/mod.rs @@ -4,4 +4,4 @@ mod udp; pub use tun::LinuxTun as Tun; pub use uapi::LinuxUAPI as UAPI; -pub use udp::LinuxBind as Bind; +pub use udp::LinuxUDP as UDP; diff --git a/src/platform/linux/tun.rs b/src/platform/linux/tun.rs index 0bbae81..604fad9 100644 --- a/src/platform/linux/tun.rs +++ b/src/platform/linux/tun.rs @@ -6,8 +6,8 @@ use std::error::Error; use std::fmt; use std::os::raw::c_short; use std::os::unix::io::RawFd; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; +use std::thread; +use std::time::Duration; const IFNAMSIZ: usize = 16; const TUNSETIFF: u64 = 0x4004_54ca; @@ -30,7 +30,9 @@ struct Ifreq { _pad: [u8; 64], } -pub struct LinuxTun {} +pub struct LinuxTun { + events: Vec<TunEvent>, +} pub struct LinuxTunReader { fd: RawFd, @@ -44,8 +46,8 @@ pub struct LinuxTunWriter { * announcing an MTU update for the interface */ #[derive(Clone)] -pub struct LinuxTunMTU { - value: Arc<AtomicUsize>, +pub struct LinuxTunStatus { + first: bool, } #[derive(Debug)] @@ -81,13 +83,6 @@ impl Error for LinuxTunError { } } -impl MTU for LinuxTunMTU { - #[inline(always)] - fn mtu(&self) -> usize { - self.value.load(Ordering::Relaxed) - } -} - impl Reader for LinuxTunReader { type Error = LinuxTunError; @@ -118,15 +113,30 @@ impl Writer for LinuxTunWriter { } } +impl Status for LinuxTunStatus { + type Error = LinuxTunError; + + fn event(&mut self) -> Result<TunEvent, Self::Error> { + if self.first { + self.first = false; + return Ok(TunEvent::Up(1420)); + } + + loop { + thread::sleep(Duration::from_secs(60 * 60)); + } + } +} + impl Tun for LinuxTun { type Error = LinuxTunError; type Reader = LinuxTunReader; type Writer = LinuxTunWriter; - type MTU = LinuxTunMTU; + type Status = LinuxTunStatus; } impl PlatformTun for LinuxTun { - fn create(name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::MTU), Self::Error> { + fn create(name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Status), Self::Error> { // construct request struct let mut req = Ifreq { name: [0u8; libc::IFNAMSIZ], @@ -157,9 +167,7 @@ impl PlatformTun for LinuxTun { Ok(( vec![LinuxTunReader { fd }], // TODO: enable multi-queue for Linux LinuxTunWriter { fd }, - LinuxTunMTU { - value: Arc::new(AtomicUsize::new(1500)), // TODO: fetch and update - }, + LinuxTunStatus { first: true }, )) } } diff --git a/src/platform/linux/udp.rs b/src/platform/linux/udp.rs index a291d1a..f871bce 100644 --- a/src/platform/linux/udp.rs +++ b/src/platform/linux/udp.rs @@ -1,4 +1,4 @@ -use super::super::bind::*; +use super::super::udp::*; use super::super::Endpoint; use std::io; @@ -6,7 +6,7 @@ use std::net::{SocketAddr, UdpSocket}; use std::sync::Arc; #[derive(Clone)] -pub struct LinuxBind(Arc<UdpSocket>); +pub struct LinuxUDP(Arc<UdpSocket>); pub struct LinuxOwner(Arc<UdpSocket>); @@ -22,7 +22,7 @@ impl Endpoint for SocketAddr { } } -impl Reader<SocketAddr> for LinuxBind { +impl Reader<SocketAddr> for LinuxUDP { type Error = io::Error; fn read(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error> { @@ -30,7 +30,7 @@ impl Reader<SocketAddr> for LinuxBind { } } -impl Writer<SocketAddr> for LinuxBind { +impl Writer<SocketAddr> for LinuxUDP { type Error = io::Error; fn write(&self, buf: &[u8], dst: &SocketAddr) -> Result<(), Self::Error> { @@ -56,17 +56,19 @@ impl Owner for LinuxOwner { } impl Drop for LinuxOwner { - fn drop(&mut self) {} + fn drop(&mut self) { + // TODO: close udp bind + } } -impl Bind for LinuxBind { +impl UDP for LinuxUDP { type Error = io::Error; type Endpoint = SocketAddr; - type Reader = LinuxBind; - type Writer = LinuxBind; + type Reader = Self; + type Writer = Self; } -impl PlatformBind for LinuxBind { +impl PlatformUDP for LinuxUDP { type Owner = LinuxOwner; fn bind(port: u16) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Owner), Self::Error> { @@ -74,8 +76,8 @@ impl PlatformBind for LinuxBind { let socket = Arc::new(socket); Ok(( - vec![LinuxBind(socket.clone())], - LinuxBind(socket.clone()), + vec![LinuxUDP(socket.clone())], + LinuxUDP(socket.clone()), LinuxOwner(socket), )) } diff --git a/src/platform/mod.rs b/src/platform/mod.rs index 99707e3..6b8fa0e 100644 --- a/src/platform/mod.rs +++ b/src/platform/mod.rs @@ -1,8 +1,8 @@ mod endpoint; -pub mod bind; pub mod tun; pub mod uapi; +pub mod udp; pub use endpoint::Endpoint; diff --git a/src/platform/tun.rs b/src/platform/tun.rs index c92304a..fda17fd 100644 --- a/src/platform/tun.rs +++ b/src/platform/tun.rs @@ -1,5 +1,18 @@ use std::error::Error; +pub enum TunEvent { + Up(usize), // interface is up (supply MTU) + Down, // interface is down +} + +pub trait Status: Send + 'static { + type Error: Error; + + /// Returns status updates for the interface + /// When the status is unchanged the method blocks + fn event(&mut self) -> Result<TunEvent, Self::Error>; +} + pub trait Writer: Send + Sync + 'static { type Error: Error; @@ -35,27 +48,14 @@ pub trait Reader: Send + 'static { fn read(&self, buf: &mut [u8], offset: usize) -> Result<usize, Self::Error>; } -pub trait MTU: Send + Sync + Clone + 'static { - /// Returns the MTU of the device - /// - /// This function needs to be efficient (called for every read). - /// The goto implementation strategy is to .load an atomic variable, - /// then use e.g. netlink to update the variable in a separate thread. - /// - /// # Returns - /// - /// The MTU of the interface in bytes - fn mtu(&self) -> usize; -} - pub trait Tun: Send + Sync + 'static { type Writer: Writer; type Reader: Reader; - type MTU: MTU; + type Status: Status; type Error: Error; } /// On some platforms the application can create the TUN device itself. pub trait PlatformTun: Tun { - fn create(name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::MTU), Self::Error>; + fn create(name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Status), Self::Error>; } diff --git a/src/platform/bind.rs b/src/platform/udp.rs index 9487dfd..3671229 100644 --- a/src/platform/bind.rs +++ b/src/platform/udp.rs @@ -13,7 +13,7 @@ pub trait Writer<E: Endpoint>: Send + Sync + Clone + 'static { fn write(&self, buf: &[u8], dst: &E) -> Result<(), Self::Error>; } -pub trait Bind: Send + Sync + 'static { +pub trait UDP: Send + Sync + 'static { type Error: Error; type Endpoint: Endpoint; @@ -37,7 +37,7 @@ pub trait Owner: Send { /// On some platforms the application can itself bind to a socket. /// This enables configuration using the UAPI interface. -pub trait PlatformBind: Bind { +pub trait PlatformUDP: UDP { type Owner: Owner; /// Bind to a new port, returning the reader/writer and diff --git a/src/wireguard/mod.rs b/src/wireguard/mod.rs index 79feed7..711aa2b 100644 --- a/src/wireguard/mod.rs +++ b/src/wireguard/mod.rs @@ -20,7 +20,7 @@ pub use types::dummy_keypair; #[cfg(test)] use super::platform::dummy; -use super::platform::{bind, tun, Endpoint}; +use super::platform::{tun, udp, Endpoint}; use peer::PeerInner; use types::KeyPair; use wireguard::HandshakeJob; diff --git a/src/wireguard/peer.rs b/src/wireguard/peer.rs index 92844b6..5bcd070 100644 --- a/src/wireguard/peer.rs +++ b/src/wireguard/peer.rs @@ -2,8 +2,8 @@ use super::router; use super::timers::{Events, Timers}; use super::HandshakeJob; -use super::bind::Bind; use super::tun::Tun; +use super::udp::UDP; use super::wireguard::WireguardInner; use std::fmt; @@ -17,12 +17,12 @@ use spin::{Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard}; use crossbeam_channel::Sender; use x25519_dalek::PublicKey; -pub struct Peer<T: Tun, B: Bind> { +pub struct Peer<T: Tun, B: UDP> { pub router: Arc<router::Peer<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>, pub state: Arc<PeerInner<T, B>>, } -pub struct PeerInner<T: Tun, B: Bind> { +pub struct PeerInner<T: Tun, B: UDP> { // internal id (for logging) pub id: u64, @@ -44,7 +44,7 @@ pub struct PeerInner<T: Tun, B: Bind> { pub timers: RwLock<Timers>, } -impl<T: Tun, B: Bind> Clone for Peer<T, B> { +impl<T: Tun, B: UDP> Clone for Peer<T, B> { fn clone(&self) -> Peer<T, B> { Peer { router: self.router.clone(), @@ -53,7 +53,7 @@ impl<T: Tun, B: Bind> Clone for Peer<T, B> { } } -impl<T: Tun, B: Bind> PeerInner<T, B> { +impl<T: Tun, B: UDP> PeerInner<T, B> { #[inline(always)] pub fn timers(&self) -> RwLockReadGuard<Timers> { self.timers.read() @@ -65,20 +65,20 @@ impl<T: Tun, B: Bind> PeerInner<T, B> { } } -impl<T: Tun, B: Bind> fmt::Display for Peer<T, B> { +impl<T: Tun, B: UDP> fmt::Display for Peer<T, B> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "peer(id = {})", self.id) } } -impl<T: Tun, B: Bind> Deref for Peer<T, B> { +impl<T: Tun, B: UDP> Deref for Peer<T, B> { type Target = PeerInner<T, B>; fn deref(&self) -> &Self::Target { &self.state } } -impl<T: Tun, B: Bind> Peer<T, B> { +impl<T: Tun, B: UDP> Peer<T, B> { /// Bring the peer down. Causing: /// /// - Timers to be stopped and disabled. diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs index 34273d5..621010b 100644 --- a/src/wireguard/router/device.rs +++ b/src/wireguard/router/device.rs @@ -21,9 +21,9 @@ use super::SIZE_MESSAGE_PREFIX; use super::route::RoutingTable; -use super::super::{bind, tun, Endpoint, KeyPair}; +use super::super::{tun, udp, Endpoint, KeyPair}; -pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { +pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { // inbound writer (TUN) pub inbound: T, @@ -45,7 +45,7 @@ pub struct EncryptionState { pub death: Instant, // (birth + reject-after-time - keepalive-timeout - rekey-timeout) } -pub struct DecryptionState<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { +pub struct DecryptionState<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { pub keypair: Arc<KeyPair>, pub confirmed: AtomicBool, pub protector: Mutex<AntiReplay>, @@ -53,12 +53,12 @@ pub struct DecryptionState<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::W pub death: Instant, // time when the key can no longer be used for decryption } -pub struct Device<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { +pub struct Device<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { state: Arc<DeviceInner<E, C, T, B>>, // reference to device state handles: Vec<thread::JoinHandle<()>>, // join handles for workers } -impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Device<E, C, T, B> { +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Device<E, C, T, B> { fn drop(&mut self) { debug!("router: dropping device"); @@ -82,7 +82,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Dev } } -impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C, T, B> { +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C, T, B> { pub fn new(num_workers: usize, tun: T) -> Device<E, C, T, B> { // allocate shared device state let inner = DeviceInner { diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs index c09e786..fff4dfc 100644 --- a/src/wireguard/router/peer.rs +++ b/src/wireguard/router/peer.rs @@ -12,7 +12,7 @@ use log::debug; use spin::Mutex; use super::super::constants::*; -use super::super::{bind, tun, Endpoint, KeyPair}; +use super::super::{tun, udp, Endpoint, KeyPair}; use super::anti_replay::AntiReplay; use super::device::DecryptionState; @@ -36,7 +36,7 @@ pub struct KeyWheel { retired: Vec<u32>, // retired ids } -pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { +pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { pub device: Arc<DeviceInner<E, C, T, B>>, pub opaque: C::Opaque, pub outbound: Mutex<SyncSender<JobOutbound>>, @@ -47,13 +47,13 @@ pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer< pub endpoint: Mutex<Option<E>>, } -pub struct Peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { +pub struct Peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { state: Arc<PeerInner<E, C, T, B>>, thread_outbound: Option<thread::JoinHandle<()>>, thread_inbound: Option<thread::JoinHandle<()>>, } -impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Deref for Peer<E, C, T, B> { +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref for Peer<E, C, T, B> { type Target = Arc<PeerInner<E, C, T, B>>; fn deref(&self) -> &Self::Target { @@ -71,7 +71,7 @@ impl EncryptionState { } } -impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> DecryptionState<E, C, T, B> { +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DecryptionState<E, C, T, B> { fn new( peer: &Arc<PeerInner<E, C, T, B>>, keypair: &Arc<KeyPair>, @@ -86,7 +86,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> DecryptionSt } } -impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Peer<E, C, T, B> { +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Peer<E, C, T, B> { fn drop(&mut self) { let peer = &self.state; @@ -133,7 +133,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Pee } } -pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( +pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( device: Arc<DeviceInner<E, C, T, B>>, opaque: C::Opaque, ) -> Peer<E, C, T, B> { @@ -180,7 +180,7 @@ pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( } } -impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> PeerInner<E, C, T, B> { +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E, C, T, B> { /// Send a raw message to the peer (used for handshake messages) /// /// # Arguments @@ -352,7 +352,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> PeerInner<E, } } -impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T, B> { +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, B> { /// Set the endpoint of the peer /// /// # Arguments diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs index 24c1b56..2d6bb63 100644 --- a/src/wireguard/router/tests.rs +++ b/src/wireguard/router/tests.rs @@ -7,10 +7,10 @@ use std::time::Duration; use num_cpus; -use super::super::bind::*; use super::super::dummy; use super::super::dummy_keypair; use super::super::tests::make_packet_dst; +use super::super::udp::*; use super::KeyPair; use super::SIZE_MESSAGE_PREFIX; use super::{Callbacks, Device}; diff --git a/src/wireguard/router/workers.rs b/src/wireguard/router/workers.rs index cd8015b..3ed6311 100644 --- a/src/wireguard/router/workers.rs +++ b/src/wireguard/router/workers.rs @@ -19,7 +19,7 @@ use super::types::Callbacks; use super::REJECT_AFTER_MESSAGES; use super::super::types::KeyPair; -use super::super::{bind, tun, Endpoint}; +use super::super::{tun, udp, Endpoint}; pub const SIZE_TAG: usize = 16; @@ -40,7 +40,7 @@ pub enum JobParallel { } #[allow(type_alias_bounds)] -pub type JobInbound<E, C, T, B: bind::Writer<E>> = ( +pub type JobInbound<E, C, T, B: udp::Writer<E>> = ( Arc<DecryptionState<E, C, T, B>>, E, oneshot::Receiver<Option<JobDecryption>>, @@ -50,7 +50,7 @@ pub type JobOutbound = oneshot::Receiver<JobEncryption>; /* TODO: Replace with run-queue */ -pub fn worker_inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( +pub fn worker_inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( device: Arc<DeviceInner<E, C, T, B>>, // related device peer: Arc<PeerInner<E, C, T, B>>, // related peer receiver: Receiver<JobInbound<E, C, T, B>>, @@ -137,7 +137,7 @@ pub fn worker_inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer /* TODO: Replace with run-queue */ -pub fn worker_outbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( +pub fn worker_outbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( peer: Arc<PeerInner<E, C, T, B>>, receiver: Receiver<JobOutbound>, ) { diff --git a/src/wireguard/tests.rs b/src/wireguard/tests.rs index 83ef594..8217d72 100644 --- a/src/wireguard/tests.rs +++ b/src/wireguard/tests.rs @@ -1,5 +1,5 @@ use super::wireguard::Wireguard; -use super::{bind, dummy, tun}; +use super::{dummy, tun, udp}; use std::net::IpAddr; use std::thread; @@ -84,13 +84,17 @@ fn test_pure_wireguard() { // create WG instances for dummy TUN devices - let (fake1, tun_reader1, tun_writer1, mtu1) = dummy::TunTest::create(1500, true); + let (fake1, tun_reader1, tun_writer1, _) = dummy::TunTest::create(1500, true); let wg1: Wireguard<dummy::TunTest, dummy::PairBind> = - Wireguard::new(vec![tun_reader1], tun_writer1, mtu1); + Wireguard::new(vec![tun_reader1], tun_writer1); - let (fake2, tun_reader2, tun_writer2, mtu2) = dummy::TunTest::create(1500, true); + wg1.set_mtu(1500); + + let (fake2, tun_reader2, tun_writer2, _) = dummy::TunTest::create(1500, true); let wg2: Wireguard<dummy::TunTest, dummy::PairBind> = - Wireguard::new(vec![tun_reader2], tun_writer2, mtu2); + Wireguard::new(vec![tun_reader2], tun_writer2); + + wg2.set_mtu(1500); // create pair bind to connect the interfaces "over the internet" diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs index 5eb69dc..18f49bf 100644 --- a/src/wireguard/timers.rs +++ b/src/wireguard/timers.rs @@ -9,7 +9,7 @@ use hjul::{Runner, Timer}; use super::constants::*; use super::router::{message_data_len, Callbacks}; use super::{Peer, PeerInner}; -use super::{bind, tun}; +use super::{udp, tun}; use super::types::KeyPair; pub struct Timers { @@ -35,7 +35,7 @@ impl Timers { } } -impl<T: tun::Tun, B: bind::Bind> PeerInner<T, B> { +impl<T: tun::Tun, B: udp::UDP> PeerInner<T, B> { pub fn get_keepalive_interval(&self) -> u64 { self.timers().keepalive_interval @@ -224,7 +224,7 @@ impl Timers { pub fn new<T, B>(runner: &Runner, peer: Peer<T, B>) -> Timers where T: tun::Tun, - B: bind::Bind, + B: udp::UDP, { // create a timer instance for the provided peer Timers { @@ -335,7 +335,7 @@ impl Timers { pub struct Events<T, B>(PhantomData<(T, B)>); -impl<T: tun::Tun, B: bind::Bind> Callbacks for Events<T, B> { +impl<T: tun::Tun, B: udp::UDP> Callbacks for Events<T, B> { type Opaque = Arc<PeerInner<T, B>>; /* Called after the router encrypts a transport message destined for the peer. diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index eb43512..41f6857 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -4,9 +4,13 @@ use super::router; use super::timers::{Events, Timers}; use super::{Peer, PeerInner}; -use super::bind::Reader as BindReader; -use super::bind::{Bind, Writer}; -use super::tun::{Reader, Tun, MTU}; +use super::tun; +use super::tun::Reader as TunReader; + +use super::udp; +use super::udp::Reader as UDPReader; +use super::udp::Writer as UDPWriter; + use super::Endpoint; use hjul::Runner; @@ -34,13 +38,15 @@ const SIZE_HANDSHAKE_QUEUE: usize = 128; const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4; const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000); -pub struct WireguardInner<T: Tun, B: Bind> { +pub struct WireguardInner<T: tun::Tun, B: udp::UDP> { // identifier (for logging) id: u32, start: Instant, + // current MTU + mtu: AtomicUsize, + // provides access to the MTU value of the tun device - mtu: T::MTU, send: RwLock<Option<B::Writer>>, // identity and configuration map @@ -56,7 +62,7 @@ pub struct WireguardInner<T: Tun, B: Bind> { queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, } -impl<T: Tun, B: Bind> PeerInner<T, B> { +impl<T: tun::Tun, B: udp::UDP> PeerInner<T, B> { /* Queue a handshake request for the parallel workers * (if one does not already exist) * @@ -87,20 +93,20 @@ pub enum HandshakeJob<E> { New(PublicKey), } -impl<T: Tun, B: Bind> fmt::Display for WireguardInner<T, B> { +impl<T: tun::Tun, B: udp::UDP> fmt::Display for WireguardInner<T, B> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "wireguard({:x})", self.id) } } -impl<T: Tun, B: Bind> Deref for Wireguard<T, B> { +impl<T: tun::Tun, B: udp::UDP> Deref for Wireguard<T, B> { type Target = Arc<WireguardInner<T, B>>; fn deref(&self) -> &Self::Target { &self.state } } -pub struct Wireguard<T: Tun, B: Bind> { +pub struct Wireguard<T: tun::Tun, B: udp::UDP> { runner: Runner, state: Arc<WireguardInner<T, B>>, } @@ -127,7 +133,7 @@ const fn padding(size: usize, mtu: usize) -> usize { min(mtu, size + (pad - size % pad) % pad) } -impl<T: Tun, B: Bind> Wireguard<T, B> { +impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> { /// Brings the WireGuard device down. /// Usually called when the associated interface is brought down. /// @@ -269,7 +275,8 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { loop { // create vector big enough for any message given current MTU - let size = wg.mtu.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE; + let mtu = wg.mtu.load(Ordering::Relaxed); + let size = mtu + handshake::MAX_HANDSHAKE_MSG_SIZE; let mut msg: Vec<u8> = Vec::with_capacity(size); msg.resize(size, 0); @@ -283,6 +290,11 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { }; msg.truncate(size); + // TODO: start device down + if mtu == 0 { + continue; + } + // message type de-multiplexer if msg.len() < std::mem::size_of::<u32>() { continue; @@ -326,13 +338,17 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { }); } + pub fn set_mtu(&self, mtu: usize) { + self.mtu.store(mtu, Ordering::Relaxed); + } + pub fn set_writer(&self, writer: B::Writer) { // TODO: Consider unifying these and avoid Clone requirement on writer *self.state.send.write() = Some(writer.clone()); self.state.router.set_outbound_writer(writer); } - pub fn new(mut readers: Vec<T::Reader>, writer: T::Writer, mtu: T::MTU) -> Wireguard<T, B> { + pub fn new(mut readers: Vec<T::Reader>, writer: T::Writer) -> Wireguard<T, B> { // create device state let mut rng = OsRng::new().unwrap(); @@ -342,7 +358,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { let wg = Arc::new(WireguardInner { start: Instant::now(), id: rng.gen(), - mtu: mtu.clone(), + mtu: AtomicUsize::new(0), peers: RwLock::new(HashMap::new()), send: RwLock::new(None), router: router::Device::new(num_cpus::get(), writer), // router owns the writing half @@ -475,10 +491,9 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { ); while let Some(reader) = readers.pop() { let wg = wg.clone(); - let mtu = mtu.clone(); thread::spawn(move || loop { // create vector big enough for any transport message (based on MTU) - let mtu = mtu.mtu(); + let mtu = wg.mtu.load(Ordering::Relaxed); let size = mtu + router::SIZE_MESSAGE_PREFIX; let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX); msg.resize(size, 0); @@ -493,6 +508,11 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { }; debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu); + // TODO: start device down + if mtu == 0 { + continue; + } + // truncate padding let padded = padding(payload, mtu); log::trace!( |