From 05710c455f1c759cf9bc1a1eaa6307fe564f15cc Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Fri, 15 Nov 2019 15:32:36 +0100 Subject: Update UAPI semantics for remove --- src/configuration/config.rs | 36 ++++----- src/configuration/mod.rs | 2 +- src/configuration/uapi/get.rs | 10 ++- src/configuration/uapi/mod.rs | 29 ++++--- src/configuration/uapi/set.rs | 156 ++++++++++++++++++++++---------------- src/main.rs | 38 +++++++--- src/platform/bind.rs | 2 +- src/platform/dummy/bind.rs | 2 +- src/platform/dummy/tun.rs | 2 +- src/platform/linux/mod.rs | 6 +- src/platform/linux/tun.rs | 2 +- src/platform/linux/uapi.rs | 31 ++++++++ src/platform/linux/udp.rs | 76 ++++++++++++++++--- src/platform/mod.rs | 5 +- src/platform/tun.rs | 2 +- src/platform/uapi.rs | 16 ++++ src/wireguard/handshake/device.rs | 11 +-- src/wireguard/wireguard.rs | 2 +- 18 files changed, 288 insertions(+), 140 deletions(-) create mode 100644 src/platform/linux/uapi.rs create mode 100644 src/platform/uapi.rs diff --git a/src/configuration/config.rs b/src/configuration/config.rs index b1c0121..50fdfb8 100644 --- a/src/configuration/config.rs +++ b/src/configuration/config.rs @@ -19,16 +19,16 @@ pub struct PeerState { pub last_handshake_time_nsec: u64, pub public_key: PublicKey, pub allowed_ips: Vec<(IpAddr, u32)>, - pub preshared_key: Option<[u8; 32]>, + pub preshared_key: [u8; 32], // 0^32 is the "default value" } -pub struct WireguardConfig { +pub struct WireguardConfig { wireguard: Wireguard, network: Mutex>, } -impl WireguardConfig { - fn new(wg: Wireguard) -> WireguardConfig { +impl WireguardConfig { + pub fn new(wg: Wireguard) -> WireguardConfig { WireguardConfig { wireguard: wg, network: Mutex::new(None), @@ -110,7 +110,7 @@ pub trait Configuration { /// # Returns /// /// An error if no such peer exists - fn set_preshared_key(&self, peer: &PublicKey, psk: Option<[u8; 32]>) -> Option; + fn set_preshared_key(&self, peer: &PublicKey, psk: [u8; 32]) -> Option; /// Update the endpoint of the /// @@ -170,7 +170,7 @@ pub trait Configuration { fn get_fwmark(&self) -> Option; } -impl Configuration for WireguardConfig { +impl Configuration for WireguardConfig { fn get_fwmark(&self) -> Option { self.network .lock() @@ -246,7 +246,7 @@ impl Configuration for WireguardConfig { false } - fn set_preshared_key(&self, peer: &PublicKey, psk: Option<[u8; 32]>) -> Option { + fn set_preshared_key(&self, peer: &PublicKey, psk: [u8; 32]) -> Option { if self.wireguard.set_psk(*peer, psk) { None } else { @@ -308,16 +308,18 @@ impl Configuration for WireguardConfig { .duration_since(SystemTime::UNIX_EPOCH) .unwrap_or(Duration::from_secs(0)); // any time before epoch is mapped to epoch - // extract state into PeerState - state.push(PeerState { - preshared_key: self.wireguard.get_psk(&p.pk), - rx_bytes: p.rx_bytes.load(Ordering::Relaxed), - tx_bytes: p.tx_bytes.load(Ordering::Relaxed), - allowed_ips: p.router.list_allowed_ips(), - last_handshake_time_nsec: last_handshake.subsec_nanos() as u64, - last_handshake_time_sec: last_handshake.as_secs(), - public_key: p.pk, - }) + if let Some(psk) = self.wireguard.get_psk(&p.pk) { + // extract state into PeerState + state.push(PeerState { + preshared_key: psk, + rx_bytes: p.rx_bytes.load(Ordering::Relaxed), + tx_bytes: p.tx_bytes.load(Ordering::Relaxed), + allowed_ips: p.router.list_allowed_ips(), + last_handshake_time_nsec: last_handshake.subsec_nanos() as u64, + last_handshake_time_sec: last_handshake.as_secs(), + public_key: p.pk, + }) + } } state } diff --git a/src/configuration/mod.rs b/src/configuration/mod.rs index 26f0c6e..dc1d93a 100644 --- a/src/configuration/mod.rs +++ b/src/configuration/mod.rs @@ -1,6 +1,6 @@ mod config; mod error; -mod uapi; +pub mod uapi; use super::platform::Endpoint; use super::platform::{bind, tun}; diff --git a/src/configuration/uapi/get.rs b/src/configuration/uapi/get.rs index 9b05421..0874cfc 100644 --- a/src/configuration/uapi/get.rs +++ b/src/configuration/uapi/get.rs @@ -1,6 +1,8 @@ use hex::FromHex; use subtle::ConstantTimeEq; +use log; + use super::Configuration; use std::io; @@ -8,9 +10,11 @@ pub fn serialize(writer: &mut W, config: &C) -> let mut write = |key: &'static str, value: String| { debug_assert!(value.is_ascii()); debug_assert!(key.is_ascii()); + log::trace!("UAPI: return : {} = {}", key, value); writer.write(key.as_ref())?; writer.write(b"=")?; - writer.write(value.as_ref()) + writer.write(value.as_ref())?; + writer.write(b"\n") }; // serialize interface @@ -40,9 +44,7 @@ pub fn serialize(writer: &mut W, config: &C) -> p.last_handshake_time_nsec.to_string(), )?; write("public_key", hex::encode(p.public_key.as_bytes()))?; - if let Some(psk) = p.preshared_key { - write("preshared_key", hex::encode(psk))?; - } + write("preshared_key", hex::encode(p.preshared_key))?; for (ip, cidr) in p.allowed_ips { write("allowed_ip", ip.to_string() + "/" + &cidr.to_string())?; } diff --git a/src/configuration/uapi/mod.rs b/src/configuration/uapi/mod.rs index 117d970..4261e7d 100644 --- a/src/configuration/uapi/mod.rs +++ b/src/configuration/uapi/mod.rs @@ -1,6 +1,7 @@ mod get; mod set; +use log; use std::io::{Read, Write}; use super::{ConfigError, Configuration}; @@ -10,10 +11,9 @@ use set::LineParser; const MAX_LINE_LENGTH: usize = 256; -pub fn process(reader: &mut R, writer: &mut W, config: &C) { - fn operation( - reader: &mut R, - writer: &mut W, +pub fn handle(stream: &mut S, config: &C) { + fn operation( + stream: &mut S, config: &C, ) -> Result<(), ConfigError> { // read string up to maximum length (why is this not in std?) @@ -23,6 +23,7 @@ pub fn process(reader: &mut R, writer: &mut while let Ok(_) = reader.read_exact(&mut m) { let c = m[0] as char; if c == '\n' { + log::trace!("UAPI, line: {}", l); return Ok(l); }; l.push(c); @@ -43,12 +44,16 @@ pub fn process(reader: &mut R, writer: &mut }; // read operation line - match readline(reader)?.as_str() { - "get=1" => serialize(writer, config).map_err(|_| ConfigError::IOError), + match readline(stream)?.as_str() { + "get=1" => { + log::debug!("UAPI, Get operation"); + serialize(stream, config).map_err(|_| ConfigError::IOError) + } "set=1" => { + log::debug!("UAPI, Set operation"); let mut parser = LineParser::new(config); loop { - let ln = readline(reader)?; + let ln = readline(stream)?; if ln == "" { break Ok(()); }; @@ -61,17 +66,17 @@ pub fn process(reader: &mut R, writer: &mut } // process operation - let res = operation(reader, writer, config); - log::debug!("{:?}", res); + let res = operation(stream, config); + log::debug!("UAPI, Result of operation: {:?}", res); // return errno - let _ = writer.write("errno=".as_ref()); - let _ = writer.write( + let _ = stream.write("errno=".as_ref()); + let _ = stream.write( match res { Err(e) => e.errno().to_string(), Ok(()) => "0".to_owned(), } .as_ref(), ); - let _ = writer.write("\n\n".as_ref()); + let _ = stream.write("\n\n".as_ref()); } diff --git a/src/configuration/uapi/set.rs b/src/configuration/uapi/set.rs index 4c2c554..e449edd 100644 --- a/src/configuration/uapi/set.rs +++ b/src/configuration/uapi/set.rs @@ -1,18 +1,27 @@ use hex::FromHex; +use std::net::{IpAddr, SocketAddr}; use subtle::ConstantTimeEq; use x25519_dalek::{PublicKey, StaticSecret}; use super::{ConfigError, Configuration}; -#[derive(Copy, Clone)] enum ParserState { - Peer { - public_key: PublicKey, - update_only: bool, - }, + Peer(ParsedPeer), Interface, } +struct ParsedPeer { + public_key: PublicKey, + update_only: bool, + allowed_ips: Vec<(IpAddr, u32)>, + remove: bool, + preshared_key: Option<[u8; 32]>, + replace_allowed_ips: bool, + persistent_keepalive_interval: Option, + protocol_version: Option, + endpoint: Option, +} + pub struct LineParser<'a, C: Configuration> { config: &'a C, state: ParserState, @@ -28,45 +37,71 @@ impl<'a, C: Configuration> LineParser<'a, C> { fn new_peer(value: &str) -> Result { match <[u8; 32]>::from_hex(value) { - Ok(pk) => Ok(ParserState::Peer { + Ok(pk) => Ok(ParserState::Peer(ParsedPeer { public_key: PublicKey::from(pk), + remove: false, update_only: false, - }), + allowed_ips: vec![], + preshared_key: None, + replace_allowed_ips: false, + persistent_keepalive_interval: None, + protocol_version: None, + endpoint: None, + })), Err(_) => Err(ConfigError::InvalidHexValue), } } pub fn parse_line(&mut self, key: &str, value: &str) -> Result<(), ConfigError> { - // add the peer if not update_only - let flush_peer = |st: ParserState| -> ParserState { - match st { - ParserState::Peer { - public_key, - update_only: false, - } => { - self.config.add_peer(&public_key); - ParserState::Peer { - public_key, - update_only: true, - } + // flush peer updates to configuration + fn flush_peer(config: &C, peer: &ParsedPeer) -> Option { + if peer.remove { + config.remove_peer(&peer.public_key); + return None; + } + + if !peer.update_only { + config.add_peer(&peer.public_key); + } + + for (ip, masklen) in &peer.allowed_ips { + config.add_allowed_ip(&peer.public_key, *ip, *masklen); + } + + if let Some(psk) = peer.preshared_key { + config.set_preshared_key(&peer.public_key, psk); + } + + if let Some(secs) = peer.persistent_keepalive_interval { + config.set_persistent_keepalive_interval(&peer.public_key, secs); + } + + if let Some(version) = peer.protocol_version { + if version == 0 || version > config.get_protocol_version() { + return Some(ConfigError::UnsupportedProtocolVersion); } - _ => st, } + + if let Some(endpoint) = peer.endpoint { + config.set_endpoint(&peer.public_key, endpoint); + }; + + None }; // parse line and update parser state - self.state = match self.state { + match self.state { // configure the interface ParserState::Interface => match key { // opt: set private key "private_key" => match <[u8; 32]>::from_hex(value) { Ok(sk) => { - self.config.set_private_key(if sk == [0u8; 32] { + self.config.set_private_key(if sk.ct_eq(&[0u8; 32]).into() { None } else { Some(StaticSecret::from(sk)) }); - Ok(self.state) + Ok(()) } Err(_) => Err(ConfigError::InvalidHexValue), }, @@ -75,7 +110,7 @@ impl<'a, C: Configuration> LineParser<'a, C> { "listen_port" => match value.parse() { Ok(port) => { self.config.set_listen_port(Some(port)); - Ok(self.state) + Ok(()) } Err(_) => Err(ConfigError::InvalidPortNumber), }, @@ -85,7 +120,7 @@ impl<'a, C: Configuration> LineParser<'a, C> { Ok(fwmark) => { self.config .set_fwmark(if fwmark == 0 { None } else { Some(fwmark) }); - Ok(self.state) + Ok(()) } Err(_) => Err(ConfigError::InvalidFwmark), }, @@ -96,51 +131,47 @@ impl<'a, C: Configuration> LineParser<'a, C> { for p in self.config.get_peers() { self.config.remove_peer(&p.public_key) } - Ok(self.state) + Ok(()) } _ => Err(ConfigError::UnsupportedValue), }, // opt: transition to peer configuration - "public_key" => Self::new_peer(value), + "public_key" => { + self.state = Self::new_peer(value)?; + Ok(()) + } // unknown key _ => Err(ConfigError::InvalidKey), }, // configure peers - ParserState::Peer { public_key, .. } => match key { + ParserState::Peer(ref mut peer) => match key { // opt: new peer "public_key" => { - flush_peer(self.state); - Self::new_peer(value) + flush_peer(self.config, &peer); + self.state = Self::new_peer(value)?; + Ok(()) } // opt: remove peer "remove" => { - self.config.remove_peer(&public_key); - Ok(self.state) + peer.remove = true; + Ok(()) } // opt: update only - "update_only" => Ok(ParserState::Peer { - public_key, - update_only: true, - }), + "update_only" => { + peer.update_only = true; + Ok(()) + } // opt: set preshared key "preshared_key" => match <[u8; 32]>::from_hex(value) { Ok(psk) => { - let st = flush_peer(self.state); - self.config.set_preshared_key( - &public_key, - if psk.ct_eq(&[0u8; 32]).into() { - None - } else { - Some(psk) - }, - ); - Ok(st) + peer.preshared_key = Some(psk); + Ok(()) } Err(_) => Err(ConfigError::InvalidHexValue), }, @@ -148,9 +179,8 @@ impl<'a, C: Configuration> LineParser<'a, C> { // opt: set endpoint "endpoint" => match value.parse() { Ok(endpoint) => { - let st = flush_peer(self.state); - self.config.set_endpoint(&public_key, endpoint); - Ok(st) + peer.endpoint = Some(endpoint); + Ok(()) } Err(_) => Err(ConfigError::InvalidSocketAddr), }, @@ -158,19 +188,17 @@ impl<'a, C: Configuration> LineParser<'a, C> { // opt: set persistent keepalive interval "persistent_keepalive_interval" => match value.parse() { Ok(secs) => { - let st = flush_peer(self.state); - self.config - .set_persistent_keepalive_interval(&public_key, secs); - Ok(st) + peer.persistent_keepalive_interval = Some(secs); + Ok(()) } Err(_) => Err(ConfigError::InvalidKeepaliveInterval), }, // opt replace allowed ips "replace_allowed_ips" => { - let st = flush_peer(self.state); - self.config.replace_allowed_ips(&public_key); - Ok(st) + peer.replace_allowed_ips = true; + peer.allowed_ips.clear(); + Ok(()) } // opt add allowed ips @@ -180,9 +208,8 @@ impl<'a, C: Configuration> LineParser<'a, C> { let cidr = split.next().and_then(|x| x.parse().ok()); match (addr, cidr) { (Some(addr), Some(cidr)) => { - let st = flush_peer(self.state); - self.config.add_allowed_ip(&public_key, addr, cidr); - Ok(st) + peer.allowed_ips.push((addr, cidr)); + Ok(()) } _ => Err(ConfigError::InvalidAllowedIp), } @@ -193,11 +220,8 @@ impl<'a, C: Configuration> LineParser<'a, C> { let parse_res: Result = value.parse(); match parse_res { Ok(version) => { - if version == 0 || version > self.config.get_protocol_version() { - Err(ConfigError::UnsupportedProtocolVersion) - } else { - Ok(self.state) - } + peer.protocol_version = Some(version); + Ok(()) } Err(_) => Err(ConfigError::UnsupportedProtocolVersion), } @@ -206,8 +230,6 @@ impl<'a, C: Configuration> LineParser<'a, C> { // unknown key _ => Err(ConfigError::InvalidKey), }, - }?; - - Ok(()) + } } } diff --git a/src/main.rs b/src/main.rs index 4d34b39..e17a127 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,21 +10,35 @@ mod configuration; mod platform; mod wireguard; -use platform::tun; +use platform::tun::PlatformTun; +use platform::uapi::PlatformUAPI; +use platform::*; -use configuration::WireguardConfig; +use std::sync::Arc; +use std::thread; +use std::time::Duration; fn main() { - /* - let (mut readers, writer, mtu) = platform::TunInstance::create("test").unwrap(); - let wg = wireguard::Wireguard::new(readers, writer, mtu); - */ -} + let name = "wg0"; + + let _ = env_logger::builder().is_test(true).try_init(); + + // create UAPI socket + let uapi = plt::UAPI::bind(name).unwrap(); + + // create TUN device + let (readers, writer, mtu) = plt::Tun::create(name).unwrap(); + + // create WireGuard device + let wg: wireguard::Wireguard = + wireguard::Wireguard::new(readers, writer, mtu); -/* -fn test_wg_configuration() { - let (mut readers, writer, mtu) = platform::dummy:: + // wrap in configuration interface and start UAPI server + let cfg = configuration::WireguardConfig::new(wg); + loop { + let mut stream = uapi.accept().unwrap(); + configuration::uapi::handle(&mut stream.0, &cfg); + } - let wg = wireguard::Wireguard::new(readers, writer, mtu); + thread::sleep(Duration::from_secs(600)); } -*/ diff --git a/src/platform/bind.rs b/src/platform/bind.rs index 1a234c7..1055f37 100644 --- a/src/platform/bind.rs +++ b/src/platform/bind.rs @@ -37,7 +37,7 @@ pub trait Owner: Send { /// On some platforms the application can itself bind to a socket. /// This enables configuration using the UAPI interface. -pub trait Platform: Bind { +pub trait PlatformBind: Bind { type Owner: Owner; /// Bind to a new port, returning the reader/writer and diff --git a/src/platform/dummy/bind.rs b/src/platform/dummy/bind.rs index 2c30486..b42483a 100644 --- a/src/platform/dummy/bind.rs +++ b/src/platform/dummy/bind.rs @@ -216,7 +216,7 @@ impl Owner for VoidOwner { } } -impl Platform for PairBind { +impl PlatformBind for PairBind { type Owner = VoidOwner; fn bind(_port: u16) -> Result<(Vec, Self::Writer, Self::Owner), Self::Error> { Err(BindError::Disconnected) diff --git a/src/platform/dummy/tun.rs b/src/platform/dummy/tun.rs index 185b328..569bf1c 100644 --- a/src/platform/dummy/tun.rs +++ b/src/platform/dummy/tun.rs @@ -192,7 +192,7 @@ impl TunTest { } } -impl Platform for TunTest { +impl PlatformTun for TunTest { fn create(_name: &str) -> Result<(Vec, Self::Writer, Self::MTU), Self::Error> { Err(TunError::Disconnected) } diff --git a/src/platform/linux/mod.rs b/src/platform/linux/mod.rs index 7d6a61c..82731de 100644 --- a/src/platform/linux/mod.rs +++ b/src/platform/linux/mod.rs @@ -1,5 +1,7 @@ mod tun; +mod uapi; mod udp; -pub use tun::LinuxTun; -pub use udp::LinuxBind; +pub use tun::LinuxTun as Tun; +pub use uapi::LinuxUAPI as UAPI; +pub use udp::LinuxBind as Bind; diff --git a/src/platform/linux/tun.rs b/src/platform/linux/tun.rs index 090569a..0bbae81 100644 --- a/src/platform/linux/tun.rs +++ b/src/platform/linux/tun.rs @@ -125,7 +125,7 @@ impl Tun for LinuxTun { type MTU = LinuxTunMTU; } -impl Platform for LinuxTun { +impl PlatformTun for LinuxTun { fn create(name: &str) -> Result<(Vec, Self::Writer, Self::MTU), Self::Error> { // construct request struct let mut req = Ifreq { diff --git a/src/platform/linux/uapi.rs b/src/platform/linux/uapi.rs new file mode 100644 index 0000000..fdf2bf0 --- /dev/null +++ b/src/platform/linux/uapi.rs @@ -0,0 +1,31 @@ +use super::super::uapi::*; + +use std::fs; +use std::io; +use std::os::unix::net::{UnixListener, UnixStream}; + +const SOCK_DIR: &str = "/var/run/wireguard/"; + +pub struct LinuxUAPI {} + +impl PlatformUAPI for LinuxUAPI { + type Error = io::Error; + type Bind = UnixListener; + + fn bind(name: &str) -> Result { + let socket_path = format!("{}{}.sock", SOCK_DIR, name); + let _ = fs::create_dir_all(SOCK_DIR); + let _ = fs::remove_file(&socket_path); + UnixListener::bind(socket_path) + } +} + +impl BindUAPI for UnixListener { + type Stream = UnixStream; + type Error = io::Error; + + fn accept(&self) -> Result { + let (stream, _) = self.accept()?; + Ok(stream) + } +} diff --git a/src/platform/linux/udp.rs b/src/platform/linux/udp.rs index 52e4c45..d3d61b6 100644 --- a/src/platform/linux/udp.rs +++ b/src/platform/linux/udp.rs @@ -1,26 +1,82 @@ use super::super::bind::*; use super::super::Endpoint; -use std::net::SocketAddr; +use std::io; +use std::net::{SocketAddr, UdpSocket}; +use std::sync::Arc; -pub struct LinuxEndpoint {} +#[derive(Clone)] +pub struct LinuxBind(Arc); -pub struct LinuxBind {} +pub struct LinuxOwner(Arc); -impl Endpoint for LinuxEndpoint { +impl Endpoint for SocketAddr { fn clear_src(&mut self) {} fn from_address(addr: SocketAddr) -> Self { - LinuxEndpoint {} + addr } fn into_address(&self) -> SocketAddr { - "127.0.0.1:6060".parse().unwrap() + *self } } -/* -impl Bind for PlatformBind { - type Endpoint = PlatformEndpoint; +impl Reader for LinuxBind { + type Error = io::Error; + + fn read(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error> { + self.0.recv_from(buf) + } +} + +impl Writer for LinuxBind { + type Error = io::Error; + + fn write(&self, buf: &[u8], dst: &SocketAddr) -> Result<(), Self::Error> { + self.0.send_to(buf, dst)?; + Ok(()) + } +} + +impl Owner for LinuxOwner { + type Error = io::Error; + + fn get_port(&self) -> u16 { + 1337 + } + + fn get_fwmark(&self) -> Option { + None + } + + fn set_fwmark(&mut self, value: Option) -> Option { + None + } +} + +impl Drop for LinuxOwner { + fn drop(&mut self) {} +} + +impl Bind for LinuxBind { + type Error = io::Error; + type Endpoint = SocketAddr; + type Reader = LinuxBind; + type Writer = LinuxBind; +} + +impl PlatformBind for LinuxBind { + type Owner = LinuxOwner; + + fn bind(port: u16) -> Result<(Vec, Self::Writer, Self::Owner), Self::Error> { + let socket = UdpSocket::bind(format!("0.0.0.0:{}", port))?; + let socket = Arc::new(socket); + + Ok(( + vec![LinuxBind(socket.clone())], + LinuxBind(socket.clone()), + LinuxOwner(socket), + )) + } } -*/ diff --git a/src/platform/mod.rs b/src/platform/mod.rs index ecd559a..99707e3 100644 --- a/src/platform/mod.rs +++ b/src/platform/mod.rs @@ -2,14 +2,15 @@ mod endpoint; pub mod bind; pub mod tun; +pub mod uapi; pub use endpoint::Endpoint; #[cfg(target_os = "linux")] -mod linux; +pub mod linux; #[cfg(test)] pub mod dummy; #[cfg(target_os = "linux")] -pub use linux::LinuxTun as TunInstance; +pub use linux as plt; diff --git a/src/platform/tun.rs b/src/platform/tun.rs index f49d4af..c92304a 100644 --- a/src/platform/tun.rs +++ b/src/platform/tun.rs @@ -56,6 +56,6 @@ pub trait Tun: Send + Sync + 'static { } /// On some platforms the application can create the TUN device itself. -pub trait Platform: Tun { +pub trait PlatformTun: Tun { fn create(name: &str) -> Result<(Vec, Self::Writer, Self::MTU), Self::Error>; } diff --git a/src/platform/uapi.rs b/src/platform/uapi.rs new file mode 100644 index 0000000..6922a9c --- /dev/null +++ b/src/platform/uapi.rs @@ -0,0 +1,16 @@ +use std::error::Error; +use std::io::{Read, Write}; + +pub trait BindUAPI { + type Stream: Read + Write; + type Error: Error; + + fn accept(&self) -> Result; +} + +pub trait PlatformUAPI { + type Error: Error; + type Bind: BindUAPI; + + fn bind(name: &str) -> Result; +} diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs index f65692c..02e6929 100644 --- a/src/wireguard/handshake/device.rs +++ b/src/wireguard/handshake/device.rs @@ -178,13 +178,10 @@ impl Device { /// # Returns /// /// The call might fail if the public key is not found - pub fn set_psk(&mut self, pk: PublicKey, psk: Option) -> Result<(), ConfigError> { + pub fn set_psk(&mut self, pk: PublicKey, psk: Psk) -> Result<(), ConfigError> { match self.pk_map.get_mut(pk.as_bytes()) { Some(mut peer) => { - peer.psk = match psk { - Some(v) => v, - None => [0u8; 32], - }; + peer.psk = psk; Ok(()) } _ => Err(ConfigError::new("No such public key")), @@ -466,8 +463,8 @@ mod tests { dev1.add(pk2).unwrap(); dev2.add(pk1).unwrap(); - dev1.set_psk(pk2, Some(psk)).unwrap(); - dev2.set_psk(pk1, Some(psk)).unwrap(); + dev1.set_psk(pk2, psk).unwrap(); + dev2.set_psk(pk1, psk).unwrap(); (pk1, dev1, pk2, dev2) } diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index 77be9f8..c0a8d9d 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -201,7 +201,7 @@ impl Wireguard { .map(|sk| StaticSecret::from(sk.to_bytes())) } - pub fn set_psk(&self, pk: PublicKey, psk: Option<[u8; 32]>) -> bool { + pub fn set_psk(&self, pk: PublicKey, psk: [u8; 32]) -> bool { self.state.handshake.write().set_psk(pk, psk).is_ok() } pub fn get_psk(&self, pk: &PublicKey) -> Option<[u8; 32]> { -- cgit v1.2.3-59-g8ed1b