From f228b6f98b141940a3302d4cd1978f56f5edb13e Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Mon, 25 Nov 2019 13:33:00 +0100 Subject: Enable up/down from configuration interface --- src/configuration/config.rs | 169 +++++++++++++++++++++++++++++------------- src/configuration/uapi/set.rs | 2 +- src/main.rs | 53 +++++++++---- src/platform/dummy/tun.rs | 5 +- src/platform/linux/tun.rs | 5 +- src/platform/tun.rs | 3 +- src/wireguard/router/tests.rs | 8 +- src/wireguard/tests.rs | 8 +- src/wireguard/wireguard.rs | 12 +-- 9 files changed, 180 insertions(+), 85 deletions(-) (limited to 'src') diff --git a/src/configuration/config.rs b/src/configuration/config.rs index 6ab173c..e7d1ba5 100644 --- a/src/configuration/config.rs +++ b/src/configuration/config.rs @@ -1,7 +1,8 @@ -use spin::Mutex; use std::net::{IpAddr, SocketAddr}; use std::sync::atomic::Ordering; +use std::sync::{Arc, Mutex, MutexGuard}; use std::time::{Duration, SystemTime}; + use x25519_dalek::{PublicKey, StaticSecret}; use super::udp::Owner; @@ -23,27 +24,57 @@ pub struct PeerState { pub allowed_ips: Vec<(IpAddr, u32)>, pub endpoint: Option, pub persistent_keepalive_interval: u64, - pub preshared_key: [u8; 32], // 0^32 is the "default value" + pub preshared_key: [u8; 32], // 0^32 is the "default value" (though treated like any other psk) } -pub struct WireguardConfig { +pub struct WireguardConfig(Arc>>); + +struct State { + port: u16, + bind: Option, + fwmark: Option, +} + +struct Inner { wireguard: Wireguard, - fwmark: Mutex>, - network: Mutex>, + port: u16, + bind: Option, + fwmark: Option, +} + +impl WireguardConfig { + fn lock(&self) -> MutexGuard> { + self.0.lock().unwrap() + } } impl WireguardConfig { pub fn new(wg: Wireguard) -> WireguardConfig { - WireguardConfig { + WireguardConfig(Arc::new(Mutex::new(Inner { wireguard: wg, - fwmark: Mutex::new(None), - network: Mutex::new(None), - } + port: 0, + bind: None, + fwmark: None, + }))) + } +} + +impl Clone for WireguardConfig { + fn clone(&self) -> Self { + WireguardConfig(self.0.clone()) } } /// Exposed configuration interface pub trait Configuration { + fn up(&self, mtu: usize); + + fn down(&self); + + fn start_listener(&self) -> Result<(), ConfigError>; + + fn stop_listener(&self) -> Result<(), ConfigError>; + /// Updates the private key of the device /// /// # Arguments @@ -65,7 +96,7 @@ pub trait Configuration { /// An integer indicating the protocol version fn get_protocol_version(&self) -> usize; - fn set_listen_port(&self, port: Option) -> Result<(), ConfigError>; + fn set_listen_port(&self, port: u16) -> Result<(), ConfigError>; /// Set the firewall mark (or similar, depending on platform) /// @@ -171,19 +202,24 @@ pub trait Configuration { } impl Configuration for WireguardConfig { + fn up(&self, mtu: usize) { + self.lock().wireguard.up(mtu); + } + + fn down(&self) { + self.lock().wireguard.down(); + } + fn get_fwmark(&self) -> Option { - self.network - .lock() - .as_ref() - .and_then(|bind| bind.get_fwmark()) + self.lock().bind.as_ref().and_then(|own| own.get_fwmark()) } fn set_private_key(&self, sk: Option) { - self.wireguard.set_key(sk) + self.lock().wireguard.set_key(sk) } fn get_private_key(&self) -> Option { - self.wireguard.get_sk() + self.lock().wireguard.get_sk() } fn get_protocol_version(&self) -> usize { @@ -191,49 +227,75 @@ impl Configuration for WireguardConfig { } fn get_listen_port(&self) -> Option { - let bind = self.network.lock(); - log::trace!("Config, Get listen port, bound: {}", bind.is_some()); - bind.as_ref().map(|bind| bind.get_port()) + let st = self.lock(); + log::trace!("Config, Get listen port, bound: {}", st.bind.is_some()); + st.bind.as_ref().map(|bind| bind.get_port()) } - fn set_listen_port(&self, port: Option) -> Result<(), ConfigError> { - log::trace!("Config, Set listen port: {:?}", port); + fn stop_listener(&self) -> Result<(), ConfigError> { + self.lock().bind = None; + Ok(()) + } - let mut bind = self.network.lock(); + fn start_listener(&self) -> Result<(), ConfigError> { + let mut cfg = self.lock(); - // close the current listener - *bind = None; + // check if already listening + if cfg.bind.is_some() { + return Ok(()); + } - // bind to new port - if let Some(port) = port { - // create new listener - let (mut readers, writer, mut owner) = match B::bind(port) { - Ok(r) => r, - Err(_) => { - return Err(ConfigError::FailedToBind); - } - }; + // create new listener + let (mut readers, writer, mut owner) = match B::bind(cfg.port) { + Ok(r) => r, + Err(_) => { + return Err(ConfigError::FailedToBind); + } + }; - // set fwmark - let _ = owner.set_fwmark(*self.fwmark.lock()); // TODO: handle + // set fwmark + let _ = owner.set_fwmark(cfg.fwmark); // TODO: handle - // add readers/writer to wireguard - self.wireguard.set_writer(writer); - while let Some(reader) = readers.pop() { - self.wireguard.add_reader(reader); - } + // set writer on wireguard + cfg.wireguard.set_writer(writer); - // create new UDP state - *bind = Some(owner); + // add readers + while let Some(reader) = readers.pop() { + cfg.wireguard.add_reader(reader); } + // create new UDP state + cfg.bind = Some(owner); Ok(()) } + fn set_listen_port(&self, port: u16) -> Result<(), ConfigError> { + log::trace!("Config, Set listen port: {:?}", port); + + // update port + let listen: bool = { + let mut cfg = self.lock(); + cfg.port = port; + if cfg.bind.is_some() { + cfg.bind = None; + true + } else { + false + } + }; + + // restart listener if bound + if listen { + self.start_listener() + } else { + Ok(()) + } + } + fn set_fwmark(&self, mark: Option) -> Result<(), ConfigError> { log::trace!("Config, Set fwmark: {:?}", mark); - match self.network.lock().as_mut() { + match self.lock().bind.as_mut() { Some(bind) => { bind.set_fwmark(mark).unwrap(); // TODO: handle Ok(()) @@ -243,47 +305,48 @@ impl Configuration for WireguardConfig { } fn replace_peers(&self) { - self.wireguard.clear_peers(); + self.lock().wireguard.clear_peers(); } fn remove_peer(&self, peer: &PublicKey) { - self.wireguard.remove_peer(peer); + self.lock().wireguard.remove_peer(peer); } fn add_peer(&self, peer: &PublicKey) -> bool { - self.wireguard.add_peer(*peer) + self.lock().wireguard.add_peer(*peer) } fn set_preshared_key(&self, peer: &PublicKey, psk: [u8; 32]) { - self.wireguard.set_psk(*peer, psk); + self.lock().wireguard.set_psk(*peer, psk); } fn set_endpoint(&self, peer: &PublicKey, addr: SocketAddr) { - if let Some(peer) = self.wireguard.lookup_peer(peer) { + if let Some(peer) = self.lock().wireguard.lookup_peer(peer) { peer.router.set_endpoint(B::Endpoint::from_address(addr)); } } fn set_persistent_keepalive_interval(&self, peer: &PublicKey, secs: u64) { - if let Some(peer) = self.wireguard.lookup_peer(peer) { + if let Some(peer) = self.lock().wireguard.lookup_peer(peer) { peer.set_persistent_keepalive_interval(secs); } } fn replace_allowed_ips(&self, peer: &PublicKey) { - if let Some(peer) = self.wireguard.lookup_peer(peer) { + if let Some(peer) = self.lock().wireguard.lookup_peer(peer) { peer.router.remove_allowed_ips(); } } fn add_allowed_ip(&self, peer: &PublicKey, ip: IpAddr, masklen: u32) { - if let Some(peer) = self.wireguard.lookup_peer(peer) { + if let Some(peer) = self.lock().wireguard.lookup_peer(peer) { peer.router.add_allowed_ip(ip, masklen); } } fn get_peers(&self) -> Vec { - let peers = self.wireguard.list_peers(); + let cfg = self.lock(); + let peers = cfg.wireguard.list_peers(); let mut state = Vec::with_capacity(peers.len()); for p in peers { @@ -295,7 +358,7 @@ impl Configuration for WireguardConfig { Some((duration.as_secs(), duration.subsec_nanos() as u64)) }); - if let Some(psk) = self.wireguard.get_psk(&p.pk) { + if let Some(psk) = cfg.wireguard.get_psk(&p.pk) { // extract state into PeerState state.push(PeerState { preshared_key: psk, diff --git a/src/configuration/uapi/set.rs b/src/configuration/uapi/set.rs index b44ee1c..e110692 100644 --- a/src/configuration/uapi/set.rs +++ b/src/configuration/uapi/set.rs @@ -116,7 +116,7 @@ impl<'a, C: Configuration> LineParser<'a, C> { // opt: set listen port "listen_port" => match value.parse() { Ok(port) => { - self.config.set_listen_port(Some(port))?; + self.config.set_listen_port(port)?; Ok(()) } Err(_) => Err(ConfigError::InvalidPortNumber), diff --git a/src/main.rs b/src/main.rs index c566f81..1a9650b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,6 +4,7 @@ use log; use daemonize::Daemonize; + use std::env; use std::process::exit; use std::thread; @@ -12,7 +13,9 @@ mod configuration; mod platform; mod wireguard; -use platform::tun::PlatformTun; +use configuration::Configuration; + +use platform::tun::{PlatformTun, Status}; use platform::uapi::{BindUAPI, PlatformUAPI}; use platform::*; @@ -81,34 +84,56 @@ fn main() { // create WireGuard device let wg: wireguard::Wireguard = wireguard::Wireguard::new(readers, writer); - wg.set_mtu(1420); + // wrap in configuration interface + let cfg = configuration::WireguardConfig::new(wg); // start Tun event thread - /* { - let wg = wg.clone(); + let cfg = cfg.clone(); let mut status = status; thread::spawn(move || loop { match status.event() { - Err(_) => break, + Err(e) => { + log::info!("Tun device error {}", e); + exit(0); + } Ok(tun::TunEvent::Up(mtu)) => { - wg.mtu.store(mtu, Ordering::Relaxed); + log::info!("Tun up (mtu = {})", mtu); + + // bring the wireguard device up + cfg.up(mtu); + + // start listening on UDP + let _ = cfg + .start_listener() + .map_err(|e| log::info!("Failed to start UDP listener: {}", e)); + } + Ok(tun::TunEvent::Down) => { + log::info!("Tun down"); + + // set wireguard device down + cfg.down(); + + // close UDP listener + let _ = cfg + .stop_listener() + .map_err(|e| log::info!("Failed to stop UDP listener {}", e)); } - Ok(tun::TunEvent::Down) => {} } }); } - */ - // handle TUN updates up/down - - // wrap in configuration interface and start UAPI server - let cfg = configuration::WireguardConfig::new(wg); + // start UAPI server loop { match uapi.connect() { - Ok(mut stream) => configuration::uapi::handle(&mut stream, &cfg), + Ok(mut stream) => { + let cfg = cfg.clone(); + thread::spawn(move || { + configuration::uapi::handle(&mut stream, &cfg); + }); + } Err(err) => { - log::info!("UAPI error: {:}", err); + log::info!("UAPI error: {}", err); break; } } diff --git a/src/platform/dummy/tun.rs b/src/platform/dummy/tun.rs index 6ddf7d5..5d13628 100644 --- a/src/platform/dummy/tun.rs +++ b/src/platform/dummy/tun.rs @@ -150,7 +150,6 @@ impl Status for TunStatus { impl Tun for TunTest { type Writer = TunWriter; type Reader = TunReader; - type Status = TunStatus; type Error = TunError; } @@ -167,7 +166,7 @@ impl TunFakeIO { } impl TunTest { - pub fn create(mtu: usize, store: bool) -> (TunFakeIO, TunReader, TunWriter, TunStatus) { + pub fn create(store: bool) -> (TunFakeIO, TunReader, TunWriter, TunStatus) { let (tx1, rx1) = if store { sync_channel(32) } else { @@ -200,6 +199,8 @@ impl TunTest { } impl PlatformTun for TunTest { + type Status = TunStatus; + fn create(_name: &str) -> Result<(Vec, Self::Writer, Self::Status), Self::Error> { Err(TunError::Disconnected) } diff --git a/src/platform/linux/tun.rs b/src/platform/linux/tun.rs index 604fad9..82eb469 100644 --- a/src/platform/linux/tun.rs +++ b/src/platform/linux/tun.rs @@ -87,10 +87,12 @@ impl Reader for LinuxTunReader { type Error = LinuxTunError; fn read(&self, buf: &mut [u8], offset: usize) -> Result { + /* debug_assert!( offset < buf.len(), "There is no space for the body of the read" ); + */ let n: isize = unsafe { read(self.fd, buf[offset..].as_mut_ptr() as _, buf.len() - offset) }; if n < 0 { @@ -132,10 +134,11 @@ impl Tun for LinuxTun { type Error = LinuxTunError; type Reader = LinuxTunReader; type Writer = LinuxTunWriter; - type Status = LinuxTunStatus; } impl PlatformTun for LinuxTun { + type Status = LinuxTunStatus; + fn create(name: &str) -> Result<(Vec, Self::Writer, Self::Status), Self::Error> { // construct request struct let mut req = Ifreq { diff --git a/src/platform/tun.rs b/src/platform/tun.rs index fda17fd..801754e 100644 --- a/src/platform/tun.rs +++ b/src/platform/tun.rs @@ -51,11 +51,12 @@ pub trait Reader: Send + 'static { pub trait Tun: Send + Sync + 'static { type Writer: Writer; type Reader: Reader; - type Status: Status; type Error: Error; } /// On some platforms the application can create the TUN device itself. pub trait PlatformTun: Tun { + type Status: Status; + fn create(name: &str) -> Result<(Vec, Self::Writer, Self::Status), Self::Error>; } diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs index 2d6bb63..d96dc90 100644 --- a/src/wireguard/router/tests.rs +++ b/src/wireguard/router/tests.rs @@ -139,7 +139,7 @@ mod tests { } // create device - let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false); + let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(false); let router: Device<_, BencherCallbacks, dummy::TunWriter, dummy::VoidBind> = Device::new(num_cpus::get(), tun_writer); @@ -169,7 +169,7 @@ mod tests { init(); // create device - let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false); + let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(false); let router: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer); router.set_outbound_writer(dummy::VoidBind::new()); @@ -315,8 +315,8 @@ mod tests { dummy::PairBind::pair(); // create matching device - let (_fake, _, tun_writer1, _) = dummy::TunTest::create(1500, false); - let (_fake, _, tun_writer2, _) = dummy::TunTest::create(1500, false); + let (_fake, _, tun_writer1, _) = dummy::TunTest::create(false); + let (_fake, _, tun_writer2, _) = dummy::TunTest::create(false); let router1: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer1); router1.set_outbound_writer(bind_writer1); diff --git a/src/wireguard/tests.rs b/src/wireguard/tests.rs index 8217d72..7a18005 100644 --- a/src/wireguard/tests.rs +++ b/src/wireguard/tests.rs @@ -84,17 +84,17 @@ fn test_pure_wireguard() { // create WG instances for dummy TUN devices - let (fake1, tun_reader1, tun_writer1, _) = dummy::TunTest::create(1500, true); + let (fake1, tun_reader1, tun_writer1, _) = dummy::TunTest::create(true); let wg1: Wireguard = Wireguard::new(vec![tun_reader1], tun_writer1); - wg1.set_mtu(1500); + wg1.up(1500); - let (fake2, tun_reader2, tun_writer2, _) = dummy::TunTest::create(1500, true); + let (fake2, tun_reader2, tun_writer2, _) = dummy::TunTest::create(true); let wg2: Wireguard = Wireguard::new(vec![tun_reader2], tun_writer2); - wg2.set_mtu(1500); + wg2.up(1500); // create pair bind to connect the interfaces "over the internet" diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index 41f6857..61f6428 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -147,6 +147,9 @@ impl Wireguard { // ensure exclusive access (to avoid race with "up" call) let peers = self.peers.write(); + // set mtu + self.state.mtu.store(0, Ordering::Relaxed); + // avoid tranmission from router self.router.down(); @@ -158,10 +161,13 @@ impl Wireguard { /// Brings the WireGuard device up. /// Usually called when the associated interface is brought up. - pub fn up(&self) { + pub fn up(&self, mtu: usize) { // ensure exclusive access (to avoid race with "down" call) let peers = self.peers.write(); + // set mtu + self.state.mtu.store(mtu, Ordering::Relaxed); + // enable tranmission from router self.router.up(); @@ -338,10 +344,6 @@ impl Wireguard { }); } - pub fn set_mtu(&self, mtu: usize) { - self.mtu.store(mtu, Ordering::Relaxed); - } - pub fn set_writer(&self, writer: B::Writer) { // TODO: Consider unifying these and avoid Clone requirement on writer *self.state.send.write() = Some(writer.clone()); -- cgit v1.2.3-59-g8ed1b