diff options
author | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-11-25 13:33:00 +0100 |
---|---|---|
committer | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-11-25 13:33:00 +0100 |
commit | f228b6f98b141940a3302d4cd1978f56f5edb13e (patch) | |
tree | f0486d26f494ce7f5d507205aa5cd475d05385c1 /src/configuration/config.rs | |
parent | Make IO traits suitable for Tun events (up/down) (diff) | |
download | wireguard-rs-f228b6f98b141940a3302d4cd1978f56f5edb13e.tar.xz wireguard-rs-f228b6f98b141940a3302d4cd1978f56f5edb13e.zip |
Enable up/down from configuration interface
Diffstat (limited to 'src/configuration/config.rs')
-rw-r--r-- | src/configuration/config.rs | 169 |
1 files changed, 116 insertions, 53 deletions
diff --git a/src/configuration/config.rs b/src/configuration/config.rs index 6ab173c..e7d1ba5 100644 --- a/src/configuration/config.rs +++ b/src/configuration/config.rs @@ -1,7 +1,8 @@ -use spin::Mutex; use std::net::{IpAddr, SocketAddr}; use std::sync::atomic::Ordering; +use std::sync::{Arc, Mutex, MutexGuard}; use std::time::{Duration, SystemTime}; + use x25519_dalek::{PublicKey, StaticSecret}; use super::udp::Owner; @@ -23,27 +24,57 @@ pub struct PeerState { pub allowed_ips: Vec<(IpAddr, u32)>, pub endpoint: Option<SocketAddr>, pub persistent_keepalive_interval: u64, - pub preshared_key: [u8; 32], // 0^32 is the "default value" + pub preshared_key: [u8; 32], // 0^32 is the "default value" (though treated like any other psk) } -pub struct WireguardConfig<T: tun::Tun, B: udp::PlatformUDP> { +pub struct WireguardConfig<T: tun::Tun, B: udp::PlatformUDP>(Arc<Mutex<Inner<T, B>>>); + +struct State<B: udp::PlatformUDP> { + port: u16, + bind: Option<B::Owner>, + fwmark: Option<u32>, +} + +struct Inner<T: tun::Tun, B: udp::PlatformUDP> { wireguard: Wireguard<T, B>, - fwmark: Mutex<Option<u32>>, - network: Mutex<Option<B::Owner>>, + port: u16, + bind: Option<B::Owner>, + fwmark: Option<u32>, +} + +impl<T: tun::Tun, B: udp::PlatformUDP> WireguardConfig<T, B> { + fn lock(&self) -> MutexGuard<Inner<T, B>> { + self.0.lock().unwrap() + } } impl<T: tun::Tun, B: udp::PlatformUDP> WireguardConfig<T, B> { pub fn new(wg: Wireguard<T, B>) -> WireguardConfig<T, B> { - WireguardConfig { + WireguardConfig(Arc::new(Mutex::new(Inner { wireguard: wg, - fwmark: Mutex::new(None), - network: Mutex::new(None), - } + port: 0, + bind: None, + fwmark: None, + }))) + } +} + +impl<T: tun::Tun, B: udp::PlatformUDP> Clone for WireguardConfig<T, B> { + fn clone(&self) -> Self { + WireguardConfig(self.0.clone()) } } /// Exposed configuration interface pub trait Configuration { + fn up(&self, mtu: usize); + + fn down(&self); + + fn start_listener(&self) -> Result<(), ConfigError>; + + fn stop_listener(&self) -> Result<(), ConfigError>; + /// Updates the private key of the device /// /// # Arguments @@ -65,7 +96,7 @@ pub trait Configuration { /// An integer indicating the protocol version fn get_protocol_version(&self) -> usize; - fn set_listen_port(&self, port: Option<u16>) -> Result<(), ConfigError>; + fn set_listen_port(&self, port: u16) -> Result<(), ConfigError>; /// Set the firewall mark (or similar, depending on platform) /// @@ -171,19 +202,24 @@ pub trait Configuration { } impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireguardConfig<T, B> { + fn up(&self, mtu: usize) { + self.lock().wireguard.up(mtu); + } + + fn down(&self) { + self.lock().wireguard.down(); + } + fn get_fwmark(&self) -> Option<u32> { - self.network - .lock() - .as_ref() - .and_then(|bind| bind.get_fwmark()) + self.lock().bind.as_ref().and_then(|own| own.get_fwmark()) } fn set_private_key(&self, sk: Option<StaticSecret>) { - self.wireguard.set_key(sk) + self.lock().wireguard.set_key(sk) } fn get_private_key(&self) -> Option<StaticSecret> { - self.wireguard.get_sk() + self.lock().wireguard.get_sk() } fn get_protocol_version(&self) -> usize { @@ -191,49 +227,75 @@ impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireguardConfig<T, B> { } fn get_listen_port(&self) -> Option<u16> { - let bind = self.network.lock(); - log::trace!("Config, Get listen port, bound: {}", bind.is_some()); - bind.as_ref().map(|bind| bind.get_port()) + let st = self.lock(); + log::trace!("Config, Get listen port, bound: {}", st.bind.is_some()); + st.bind.as_ref().map(|bind| bind.get_port()) } - fn set_listen_port(&self, port: Option<u16>) -> Result<(), ConfigError> { - log::trace!("Config, Set listen port: {:?}", port); + fn stop_listener(&self) -> Result<(), ConfigError> { + self.lock().bind = None; + Ok(()) + } - let mut bind = self.network.lock(); + fn start_listener(&self) -> Result<(), ConfigError> { + let mut cfg = self.lock(); - // close the current listener - *bind = None; + // check if already listening + if cfg.bind.is_some() { + return Ok(()); + } - // bind to new port - if let Some(port) = port { - // create new listener - let (mut readers, writer, mut owner) = match B::bind(port) { - Ok(r) => r, - Err(_) => { - return Err(ConfigError::FailedToBind); - } - }; + // create new listener + let (mut readers, writer, mut owner) = match B::bind(cfg.port) { + Ok(r) => r, + Err(_) => { + return Err(ConfigError::FailedToBind); + } + }; - // set fwmark - let _ = owner.set_fwmark(*self.fwmark.lock()); // TODO: handle + // set fwmark + let _ = owner.set_fwmark(cfg.fwmark); // TODO: handle - // add readers/writer to wireguard - self.wireguard.set_writer(writer); - while let Some(reader) = readers.pop() { - self.wireguard.add_reader(reader); - } + // set writer on wireguard + cfg.wireguard.set_writer(writer); - // create new UDP state - *bind = Some(owner); + // add readers + while let Some(reader) = readers.pop() { + cfg.wireguard.add_reader(reader); } + // create new UDP state + cfg.bind = Some(owner); Ok(()) } + fn set_listen_port(&self, port: u16) -> Result<(), ConfigError> { + log::trace!("Config, Set listen port: {:?}", port); + + // update port + let listen: bool = { + let mut cfg = self.lock(); + cfg.port = port; + if cfg.bind.is_some() { + cfg.bind = None; + true + } else { + false + } + }; + + // restart listener if bound + if listen { + self.start_listener() + } else { + Ok(()) + } + } + fn set_fwmark(&self, mark: Option<u32>) -> Result<(), ConfigError> { log::trace!("Config, Set fwmark: {:?}", mark); - match self.network.lock().as_mut() { + match self.lock().bind.as_mut() { Some(bind) => { bind.set_fwmark(mark).unwrap(); // TODO: handle Ok(()) @@ -243,47 +305,48 @@ impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireguardConfig<T, B> { } fn replace_peers(&self) { - self.wireguard.clear_peers(); + self.lock().wireguard.clear_peers(); } fn remove_peer(&self, peer: &PublicKey) { - self.wireguard.remove_peer(peer); + self.lock().wireguard.remove_peer(peer); } fn add_peer(&self, peer: &PublicKey) -> bool { - self.wireguard.add_peer(*peer) + self.lock().wireguard.add_peer(*peer) } fn set_preshared_key(&self, peer: &PublicKey, psk: [u8; 32]) { - self.wireguard.set_psk(*peer, psk); + self.lock().wireguard.set_psk(*peer, psk); } fn set_endpoint(&self, peer: &PublicKey, addr: SocketAddr) { - if let Some(peer) = self.wireguard.lookup_peer(peer) { + if let Some(peer) = self.lock().wireguard.lookup_peer(peer) { peer.router.set_endpoint(B::Endpoint::from_address(addr)); } } fn set_persistent_keepalive_interval(&self, peer: &PublicKey, secs: u64) { - if let Some(peer) = self.wireguard.lookup_peer(peer) { + if let Some(peer) = self.lock().wireguard.lookup_peer(peer) { peer.set_persistent_keepalive_interval(secs); } } fn replace_allowed_ips(&self, peer: &PublicKey) { - if let Some(peer) = self.wireguard.lookup_peer(peer) { + if let Some(peer) = self.lock().wireguard.lookup_peer(peer) { peer.router.remove_allowed_ips(); } } fn add_allowed_ip(&self, peer: &PublicKey, ip: IpAddr, masklen: u32) { - if let Some(peer) = self.wireguard.lookup_peer(peer) { + if let Some(peer) = self.lock().wireguard.lookup_peer(peer) { peer.router.add_allowed_ip(ip, masklen); } } fn get_peers(&self) -> Vec<PeerState> { - let peers = self.wireguard.list_peers(); + let cfg = self.lock(); + let peers = cfg.wireguard.list_peers(); let mut state = Vec::with_capacity(peers.len()); for p in peers { @@ -295,7 +358,7 @@ impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireguardConfig<T, B> { Some((duration.as_secs(), duration.subsec_nanos() as u64)) }); - if let Some(psk) = self.wireguard.get_psk(&p.pk) { + if let Some(psk) = cfg.wireguard.get_psk(&p.pk) { // extract state into PeerState state.push(PeerState { preshared_key: psk, |