aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-11-15 15:32:36 +0100
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-11-15 15:32:36 +0100
commit05710c455f1c759cf9bc1a1eaa6307fe564f15cc (patch)
treec671d703d0db93a9bc8f27d0e2b02d0422120352
parentInitial version of full UAPI parser (diff)
downloadwireguard-rs-05710c455f1c759cf9bc1a1eaa6307fe564f15cc.tar.xz
wireguard-rs-05710c455f1c759cf9bc1a1eaa6307fe564f15cc.zip
Update UAPI semantics for remove
Diffstat (limited to '')
-rw-r--r--src/configuration/config.rs36
-rw-r--r--src/configuration/mod.rs2
-rw-r--r--src/configuration/uapi/get.rs10
-rw-r--r--src/configuration/uapi/mod.rs29
-rw-r--r--src/configuration/uapi/set.rs156
-rw-r--r--src/main.rs38
-rw-r--r--src/platform/bind.rs2
-rw-r--r--src/platform/dummy/bind.rs2
-rw-r--r--src/platform/dummy/tun.rs2
-rw-r--r--src/platform/linux/mod.rs6
-rw-r--r--src/platform/linux/tun.rs2
-rw-r--r--src/platform/linux/uapi.rs31
-rw-r--r--src/platform/linux/udp.rs76
-rw-r--r--src/platform/mod.rs5
-rw-r--r--src/platform/tun.rs2
-rw-r--r--src/platform/uapi.rs16
-rw-r--r--src/wireguard/handshake/device.rs11
-rw-r--r--src/wireguard/wireguard.rs2
18 files changed, 288 insertions, 140 deletions
diff --git a/src/configuration/config.rs b/src/configuration/config.rs
index b1c0121..50fdfb8 100644
--- a/src/configuration/config.rs
+++ b/src/configuration/config.rs
@@ -19,16 +19,16 @@ pub struct PeerState {
pub last_handshake_time_nsec: u64,
pub public_key: PublicKey,
pub allowed_ips: Vec<(IpAddr, u32)>,
- pub preshared_key: Option<[u8; 32]>,
+ pub preshared_key: [u8; 32], // 0^32 is the "default value"
}
-pub struct WireguardConfig<T: tun::Tun, B: bind::Platform> {
+pub struct WireguardConfig<T: tun::Tun, B: bind::PlatformBind> {
wireguard: Wireguard<T, B>,
network: Mutex<Option<B::Owner>>,
}
-impl<T: tun::Tun, B: bind::Platform> WireguardConfig<T, B> {
- fn new(wg: Wireguard<T, B>) -> WireguardConfig<T, B> {
+impl<T: tun::Tun, B: bind::PlatformBind> WireguardConfig<T, B> {
+ pub fn new(wg: Wireguard<T, B>) -> WireguardConfig<T, B> {
WireguardConfig {
wireguard: wg,
network: Mutex::new(None),
@@ -110,7 +110,7 @@ pub trait Configuration {
/// # Returns
///
/// An error if no such peer exists
- fn set_preshared_key(&self, peer: &PublicKey, psk: Option<[u8; 32]>) -> Option<ConfigError>;
+ fn set_preshared_key(&self, peer: &PublicKey, psk: [u8; 32]) -> Option<ConfigError>;
/// Update the endpoint of the
///
@@ -170,7 +170,7 @@ pub trait Configuration {
fn get_fwmark(&self) -> Option<u32>;
}
-impl<T: tun::Tun, B: bind::Platform> Configuration for WireguardConfig<T, B> {
+impl<T: tun::Tun, B: bind::PlatformBind> Configuration for WireguardConfig<T, B> {
fn get_fwmark(&self) -> Option<u32> {
self.network
.lock()
@@ -246,7 +246,7 @@ impl<T: tun::Tun, B: bind::Platform> Configuration for WireguardConfig<T, B> {
false
}
- fn set_preshared_key(&self, peer: &PublicKey, psk: Option<[u8; 32]>) -> Option<ConfigError> {
+ fn set_preshared_key(&self, peer: &PublicKey, psk: [u8; 32]) -> Option<ConfigError> {
if self.wireguard.set_psk(*peer, psk) {
None
} else {
@@ -308,16 +308,18 @@ impl<T: tun::Tun, B: bind::Platform> Configuration for WireguardConfig<T, B> {
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or(Duration::from_secs(0)); // any time before epoch is mapped to epoch
- // extract state into PeerState
- state.push(PeerState {
- preshared_key: self.wireguard.get_psk(&p.pk),
- rx_bytes: p.rx_bytes.load(Ordering::Relaxed),
- tx_bytes: p.tx_bytes.load(Ordering::Relaxed),
- allowed_ips: p.router.list_allowed_ips(),
- last_handshake_time_nsec: last_handshake.subsec_nanos() as u64,
- last_handshake_time_sec: last_handshake.as_secs(),
- public_key: p.pk,
- })
+ if let Some(psk) = self.wireguard.get_psk(&p.pk) {
+ // extract state into PeerState
+ state.push(PeerState {
+ preshared_key: psk,
+ rx_bytes: p.rx_bytes.load(Ordering::Relaxed),
+ tx_bytes: p.tx_bytes.load(Ordering::Relaxed),
+ allowed_ips: p.router.list_allowed_ips(),
+ last_handshake_time_nsec: last_handshake.subsec_nanos() as u64,
+ last_handshake_time_sec: last_handshake.as_secs(),
+ public_key: p.pk,
+ })
+ }
}
state
}
diff --git a/src/configuration/mod.rs b/src/configuration/mod.rs
index 26f0c6e..dc1d93a 100644
--- a/src/configuration/mod.rs
+++ b/src/configuration/mod.rs
@@ -1,6 +1,6 @@
mod config;
mod error;
-mod uapi;
+pub mod uapi;
use super::platform::Endpoint;
use super::platform::{bind, tun};
diff --git a/src/configuration/uapi/get.rs b/src/configuration/uapi/get.rs
index 9b05421..0874cfc 100644
--- a/src/configuration/uapi/get.rs
+++ b/src/configuration/uapi/get.rs
@@ -1,6 +1,8 @@
use hex::FromHex;
use subtle::ConstantTimeEq;
+use log;
+
use super::Configuration;
use std::io;
@@ -8,9 +10,11 @@ pub fn serialize<C: Configuration, W: io::Write>(writer: &mut W, config: &C) ->
let mut write = |key: &'static str, value: String| {
debug_assert!(value.is_ascii());
debug_assert!(key.is_ascii());
+ log::trace!("UAPI: return : {} = {}", key, value);
writer.write(key.as_ref())?;
writer.write(b"=")?;
- writer.write(value.as_ref())
+ writer.write(value.as_ref())?;
+ writer.write(b"\n")
};
// serialize interface
@@ -40,9 +44,7 @@ pub fn serialize<C: Configuration, W: io::Write>(writer: &mut W, config: &C) ->
p.last_handshake_time_nsec.to_string(),
)?;
write("public_key", hex::encode(p.public_key.as_bytes()))?;
- if let Some(psk) = p.preshared_key {
- write("preshared_key", hex::encode(psk))?;
- }
+ write("preshared_key", hex::encode(p.preshared_key))?;
for (ip, cidr) in p.allowed_ips {
write("allowed_ip", ip.to_string() + "/" + &cidr.to_string())?;
}
diff --git a/src/configuration/uapi/mod.rs b/src/configuration/uapi/mod.rs
index 117d970..4261e7d 100644
--- a/src/configuration/uapi/mod.rs
+++ b/src/configuration/uapi/mod.rs
@@ -1,6 +1,7 @@
mod get;
mod set;
+use log;
use std::io::{Read, Write};
use super::{ConfigError, Configuration};
@@ -10,10 +11,9 @@ use set::LineParser;
const MAX_LINE_LENGTH: usize = 256;
-pub fn process<R: Read, W: Write, C: Configuration>(reader: &mut R, writer: &mut W, config: &C) {
- fn operation<R: Read, W: Write, C: Configuration>(
- reader: &mut R,
- writer: &mut W,
+pub fn handle<S: Read + Write, C: Configuration>(stream: &mut S, config: &C) {
+ fn operation<S: Read + Write, C: Configuration>(
+ stream: &mut S,
config: &C,
) -> Result<(), ConfigError> {
// read string up to maximum length (why is this not in std?)
@@ -23,6 +23,7 @@ pub fn process<R: Read, W: Write, C: Configuration>(reader: &mut R, writer: &mut
while let Ok(_) = reader.read_exact(&mut m) {
let c = m[0] as char;
if c == '\n' {
+ log::trace!("UAPI, line: {}", l);
return Ok(l);
};
l.push(c);
@@ -43,12 +44,16 @@ pub fn process<R: Read, W: Write, C: Configuration>(reader: &mut R, writer: &mut
};
// read operation line
- match readline(reader)?.as_str() {
- "get=1" => serialize(writer, config).map_err(|_| ConfigError::IOError),
+ match readline(stream)?.as_str() {
+ "get=1" => {
+ log::debug!("UAPI, Get operation");
+ serialize(stream, config).map_err(|_| ConfigError::IOError)
+ }
"set=1" => {
+ log::debug!("UAPI, Set operation");
let mut parser = LineParser::new(config);
loop {
- let ln = readline(reader)?;
+ let ln = readline(stream)?;
if ln == "" {
break Ok(());
};
@@ -61,17 +66,17 @@ pub fn process<R: Read, W: Write, C: Configuration>(reader: &mut R, writer: &mut
}
// process operation
- let res = operation(reader, writer, config);
- log::debug!("{:?}", res);
+ let res = operation(stream, config);
+ log::debug!("UAPI, Result of operation: {:?}", res);
// return errno
- let _ = writer.write("errno=".as_ref());
- let _ = writer.write(
+ let _ = stream.write("errno=".as_ref());
+ let _ = stream.write(
match res {
Err(e) => e.errno().to_string(),
Ok(()) => "0".to_owned(),
}
.as_ref(),
);
- let _ = writer.write("\n\n".as_ref());
+ let _ = stream.write("\n\n".as_ref());
}
diff --git a/src/configuration/uapi/set.rs b/src/configuration/uapi/set.rs
index 4c2c554..e449edd 100644
--- a/src/configuration/uapi/set.rs
+++ b/src/configuration/uapi/set.rs
@@ -1,18 +1,27 @@
use hex::FromHex;
+use std::net::{IpAddr, SocketAddr};
use subtle::ConstantTimeEq;
use x25519_dalek::{PublicKey, StaticSecret};
use super::{ConfigError, Configuration};
-#[derive(Copy, Clone)]
enum ParserState {
- Peer {
- public_key: PublicKey,
- update_only: bool,
- },
+ Peer(ParsedPeer),
Interface,
}
+struct ParsedPeer {
+ public_key: PublicKey,
+ update_only: bool,
+ allowed_ips: Vec<(IpAddr, u32)>,
+ remove: bool,
+ preshared_key: Option<[u8; 32]>,
+ replace_allowed_ips: bool,
+ persistent_keepalive_interval: Option<u64>,
+ protocol_version: Option<usize>,
+ endpoint: Option<SocketAddr>,
+}
+
pub struct LineParser<'a, C: Configuration> {
config: &'a C,
state: ParserState,
@@ -28,45 +37,71 @@ impl<'a, C: Configuration> LineParser<'a, C> {
fn new_peer(value: &str) -> Result<ParserState, ConfigError> {
match <[u8; 32]>::from_hex(value) {
- Ok(pk) => Ok(ParserState::Peer {
+ Ok(pk) => Ok(ParserState::Peer(ParsedPeer {
public_key: PublicKey::from(pk),
+ remove: false,
update_only: false,
- }),
+ allowed_ips: vec![],
+ preshared_key: None,
+ replace_allowed_ips: false,
+ persistent_keepalive_interval: None,
+ protocol_version: None,
+ endpoint: None,
+ })),
Err(_) => Err(ConfigError::InvalidHexValue),
}
}
pub fn parse_line(&mut self, key: &str, value: &str) -> Result<(), ConfigError> {
- // add the peer if not update_only
- let flush_peer = |st: ParserState| -> ParserState {
- match st {
- ParserState::Peer {
- public_key,
- update_only: false,
- } => {
- self.config.add_peer(&public_key);
- ParserState::Peer {
- public_key,
- update_only: true,
- }
+ // flush peer updates to configuration
+ fn flush_peer<C: Configuration>(config: &C, peer: &ParsedPeer) -> Option<ConfigError> {
+ if peer.remove {
+ config.remove_peer(&peer.public_key);
+ return None;
+ }
+
+ if !peer.update_only {
+ config.add_peer(&peer.public_key);
+ }
+
+ for (ip, masklen) in &peer.allowed_ips {
+ config.add_allowed_ip(&peer.public_key, *ip, *masklen);
+ }
+
+ if let Some(psk) = peer.preshared_key {
+ config.set_preshared_key(&peer.public_key, psk);
+ }
+
+ if let Some(secs) = peer.persistent_keepalive_interval {
+ config.set_persistent_keepalive_interval(&peer.public_key, secs);
+ }
+
+ if let Some(version) = peer.protocol_version {
+ if version == 0 || version > config.get_protocol_version() {
+ return Some(ConfigError::UnsupportedProtocolVersion);
}
- _ => st,
}
+
+ if let Some(endpoint) = peer.endpoint {
+ config.set_endpoint(&peer.public_key, endpoint);
+ };
+
+ None
};
// parse line and update parser state
- self.state = match self.state {
+ match self.state {
// configure the interface
ParserState::Interface => match key {
// opt: set private key
"private_key" => match <[u8; 32]>::from_hex(value) {
Ok(sk) => {
- self.config.set_private_key(if sk == [0u8; 32] {
+ self.config.set_private_key(if sk.ct_eq(&[0u8; 32]).into() {
None
} else {
Some(StaticSecret::from(sk))
});
- Ok(self.state)
+ Ok(())
}
Err(_) => Err(ConfigError::InvalidHexValue),
},
@@ -75,7 +110,7 @@ impl<'a, C: Configuration> LineParser<'a, C> {
"listen_port" => match value.parse() {
Ok(port) => {
self.config.set_listen_port(Some(port));
- Ok(self.state)
+ Ok(())
}
Err(_) => Err(ConfigError::InvalidPortNumber),
},
@@ -85,7 +120,7 @@ impl<'a, C: Configuration> LineParser<'a, C> {
Ok(fwmark) => {
self.config
.set_fwmark(if fwmark == 0 { None } else { Some(fwmark) });
- Ok(self.state)
+ Ok(())
}
Err(_) => Err(ConfigError::InvalidFwmark),
},
@@ -96,51 +131,47 @@ impl<'a, C: Configuration> LineParser<'a, C> {
for p in self.config.get_peers() {
self.config.remove_peer(&p.public_key)
}
- Ok(self.state)
+ Ok(())
}
_ => Err(ConfigError::UnsupportedValue),
},
// opt: transition to peer configuration
- "public_key" => Self::new_peer(value),
+ "public_key" => {
+ self.state = Self::new_peer(value)?;
+ Ok(())
+ }
// unknown key
_ => Err(ConfigError::InvalidKey),
},
// configure peers
- ParserState::Peer { public_key, .. } => match key {
+ ParserState::Peer(ref mut peer) => match key {
// opt: new peer
"public_key" => {
- flush_peer(self.state);
- Self::new_peer(value)
+ flush_peer(self.config, &peer);
+ self.state = Self::new_peer(value)?;
+ Ok(())
}
// opt: remove peer
"remove" => {
- self.config.remove_peer(&public_key);
- Ok(self.state)
+ peer.remove = true;
+ Ok(())
}
// opt: update only
- "update_only" => Ok(ParserState::Peer {
- public_key,
- update_only: true,
- }),
+ "update_only" => {
+ peer.update_only = true;
+ Ok(())
+ }
// opt: set preshared key
"preshared_key" => match <[u8; 32]>::from_hex(value) {
Ok(psk) => {
- let st = flush_peer(self.state);
- self.config.set_preshared_key(
- &public_key,
- if psk.ct_eq(&[0u8; 32]).into() {
- None
- } else {
- Some(psk)
- },
- );
- Ok(st)
+ peer.preshared_key = Some(psk);
+ Ok(())
}
Err(_) => Err(ConfigError::InvalidHexValue),
},
@@ -148,9 +179,8 @@ impl<'a, C: Configuration> LineParser<'a, C> {
// opt: set endpoint
"endpoint" => match value.parse() {
Ok(endpoint) => {
- let st = flush_peer(self.state);
- self.config.set_endpoint(&public_key, endpoint);
- Ok(st)
+ peer.endpoint = Some(endpoint);
+ Ok(())
}
Err(_) => Err(ConfigError::InvalidSocketAddr),
},
@@ -158,19 +188,17 @@ impl<'a, C: Configuration> LineParser<'a, C> {
// opt: set persistent keepalive interval
"persistent_keepalive_interval" => match value.parse() {
Ok(secs) => {
- let st = flush_peer(self.state);
- self.config
- .set_persistent_keepalive_interval(&public_key, secs);
- Ok(st)
+ peer.persistent_keepalive_interval = Some(secs);
+ Ok(())
}
Err(_) => Err(ConfigError::InvalidKeepaliveInterval),
},
// opt replace allowed ips
"replace_allowed_ips" => {
- let st = flush_peer(self.state);
- self.config.replace_allowed_ips(&public_key);
- Ok(st)
+ peer.replace_allowed_ips = true;
+ peer.allowed_ips.clear();
+ Ok(())
}
// opt add allowed ips
@@ -180,9 +208,8 @@ impl<'a, C: Configuration> LineParser<'a, C> {
let cidr = split.next().and_then(|x| x.parse().ok());
match (addr, cidr) {
(Some(addr), Some(cidr)) => {
- let st = flush_peer(self.state);
- self.config.add_allowed_ip(&public_key, addr, cidr);
- Ok(st)
+ peer.allowed_ips.push((addr, cidr));
+ Ok(())
}
_ => Err(ConfigError::InvalidAllowedIp),
}
@@ -193,11 +220,8 @@ impl<'a, C: Configuration> LineParser<'a, C> {
let parse_res: Result<usize, _> = value.parse();
match parse_res {
Ok(version) => {
- if version == 0 || version > self.config.get_protocol_version() {
- Err(ConfigError::UnsupportedProtocolVersion)
- } else {
- Ok(self.state)
- }
+ peer.protocol_version = Some(version);
+ Ok(())
}
Err(_) => Err(ConfigError::UnsupportedProtocolVersion),
}
@@ -206,8 +230,6 @@ impl<'a, C: Configuration> LineParser<'a, C> {
// unknown key
_ => Err(ConfigError::InvalidKey),
},
- }?;
-
- Ok(())
+ }
}
}
diff --git a/src/main.rs b/src/main.rs
index 4d34b39..e17a127 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -10,21 +10,35 @@ mod configuration;
mod platform;
mod wireguard;
-use platform::tun;
+use platform::tun::PlatformTun;
+use platform::uapi::PlatformUAPI;
+use platform::*;
-use configuration::WireguardConfig;
+use std::sync::Arc;
+use std::thread;
+use std::time::Duration;
fn main() {
- /*
- let (mut readers, writer, mtu) = platform::TunInstance::create("test").unwrap();
- let wg = wireguard::Wireguard::new(readers, writer, mtu);
- */
-}
+ let name = "wg0";
+
+ let _ = env_logger::builder().is_test(true).try_init();
+
+ // create UAPI socket
+ let uapi = plt::UAPI::bind(name).unwrap();
+
+ // create TUN device
+ let (readers, writer, mtu) = plt::Tun::create(name).unwrap();
+
+ // create WireGuard device
+ let wg: wireguard::Wireguard<plt::Tun, plt::Bind> =
+ wireguard::Wireguard::new(readers, writer, mtu);
-/*
-fn test_wg_configuration() {
- let (mut readers, writer, mtu) = platform::dummy::
+ // wrap in configuration interface and start UAPI server
+ let cfg = configuration::WireguardConfig::new(wg);
+ loop {
+ let mut stream = uapi.accept().unwrap();
+ configuration::uapi::handle(&mut stream.0, &cfg);
+ }
- let wg = wireguard::Wireguard::new(readers, writer, mtu);
+ thread::sleep(Duration::from_secs(600));
}
-*/
diff --git a/src/platform/bind.rs b/src/platform/bind.rs
index 1a234c7..1055f37 100644
--- a/src/platform/bind.rs
+++ b/src/platform/bind.rs
@@ -37,7 +37,7 @@ pub trait Owner: Send {
/// On some platforms the application can itself bind to a socket.
/// This enables configuration using the UAPI interface.
-pub trait Platform: Bind {
+pub trait PlatformBind: Bind {
type Owner: Owner;
/// Bind to a new port, returning the reader/writer and
diff --git a/src/platform/dummy/bind.rs b/src/platform/dummy/bind.rs
index 2c30486..b42483a 100644
--- a/src/platform/dummy/bind.rs
+++ b/src/platform/dummy/bind.rs
@@ -216,7 +216,7 @@ impl Owner for VoidOwner {
}
}
-impl Platform for PairBind {
+impl PlatformBind for PairBind {
type Owner = VoidOwner;
fn bind(_port: u16) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Owner), Self::Error> {
Err(BindError::Disconnected)
diff --git a/src/platform/dummy/tun.rs b/src/platform/dummy/tun.rs
index 185b328..569bf1c 100644
--- a/src/platform/dummy/tun.rs
+++ b/src/platform/dummy/tun.rs
@@ -192,7 +192,7 @@ impl TunTest {
}
}
-impl Platform for TunTest {
+impl PlatformTun for TunTest {
fn create(_name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::MTU), Self::Error> {
Err(TunError::Disconnected)
}
diff --git a/src/platform/linux/mod.rs b/src/platform/linux/mod.rs
index 7d6a61c..82731de 100644
--- a/src/platform/linux/mod.rs
+++ b/src/platform/linux/mod.rs
@@ -1,5 +1,7 @@
mod tun;
+mod uapi;
mod udp;
-pub use tun::LinuxTun;
-pub use udp::LinuxBind;
+pub use tun::LinuxTun as Tun;
+pub use uapi::LinuxUAPI as UAPI;
+pub use udp::LinuxBind as Bind;
diff --git a/src/platform/linux/tun.rs b/src/platform/linux/tun.rs
index 090569a..0bbae81 100644
--- a/src/platform/linux/tun.rs
+++ b/src/platform/linux/tun.rs
@@ -125,7 +125,7 @@ impl Tun for LinuxTun {
type MTU = LinuxTunMTU;
}
-impl Platform for LinuxTun {
+impl PlatformTun for LinuxTun {
fn create(name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::MTU), Self::Error> {
// construct request struct
let mut req = Ifreq {
diff --git a/src/platform/linux/uapi.rs b/src/platform/linux/uapi.rs
new file mode 100644
index 0000000..fdf2bf0
--- /dev/null
+++ b/src/platform/linux/uapi.rs
@@ -0,0 +1,31 @@
+use super::super::uapi::*;
+
+use std::fs;
+use std::io;
+use std::os::unix::net::{UnixListener, UnixStream};
+
+const SOCK_DIR: &str = "/var/run/wireguard/";
+
+pub struct LinuxUAPI {}
+
+impl PlatformUAPI for LinuxUAPI {
+ type Error = io::Error;
+ type Bind = UnixListener;
+
+ fn bind(name: &str) -> Result<UnixListener, io::Error> {
+ let socket_path = format!("{}{}.sock", SOCK_DIR, name);
+ let _ = fs::create_dir_all(SOCK_DIR);
+ let _ = fs::remove_file(&socket_path);
+ UnixListener::bind(socket_path)
+ }
+}
+
+impl BindUAPI for UnixListener {
+ type Stream = UnixStream;
+ type Error = io::Error;
+
+ fn accept(&self) -> Result<UnixStream, io::Error> {
+ let (stream, _) = self.accept()?;
+ Ok(stream)
+ }
+}
diff --git a/src/platform/linux/udp.rs b/src/platform/linux/udp.rs
index 52e4c45..d3d61b6 100644
--- a/src/platform/linux/udp.rs
+++ b/src/platform/linux/udp.rs
@@ -1,26 +1,82 @@
use super::super::bind::*;
use super::super::Endpoint;
-use std::net::SocketAddr;
+use std::io;
+use std::net::{SocketAddr, UdpSocket};
+use std::sync::Arc;
-pub struct LinuxEndpoint {}
+#[derive(Clone)]
+pub struct LinuxBind(Arc<UdpSocket>);
-pub struct LinuxBind {}
+pub struct LinuxOwner(Arc<UdpSocket>);
-impl Endpoint for LinuxEndpoint {
+impl Endpoint for SocketAddr {
fn clear_src(&mut self) {}
fn from_address(addr: SocketAddr) -> Self {
- LinuxEndpoint {}
+ addr
}
fn into_address(&self) -> SocketAddr {
- "127.0.0.1:6060".parse().unwrap()
+ *self
}
}
-/*
-impl Bind for PlatformBind {
- type Endpoint = PlatformEndpoint;
+impl Reader<SocketAddr> for LinuxBind {
+ type Error = io::Error;
+
+ fn read(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error> {
+ self.0.recv_from(buf)
+ }
+}
+
+impl Writer<SocketAddr> for LinuxBind {
+ type Error = io::Error;
+
+ fn write(&self, buf: &[u8], dst: &SocketAddr) -> Result<(), Self::Error> {
+ self.0.send_to(buf, dst)?;
+ Ok(())
+ }
+}
+
+impl Owner for LinuxOwner {
+ type Error = io::Error;
+
+ fn get_port(&self) -> u16 {
+ 1337
+ }
+
+ fn get_fwmark(&self) -> Option<u32> {
+ None
+ }
+
+ fn set_fwmark(&mut self, value: Option<u32>) -> Option<Self::Error> {
+ None
+ }
+}
+
+impl Drop for LinuxOwner {
+ fn drop(&mut self) {}
+}
+
+impl Bind for LinuxBind {
+ type Error = io::Error;
+ type Endpoint = SocketAddr;
+ type Reader = LinuxBind;
+ type Writer = LinuxBind;
+}
+
+impl PlatformBind for LinuxBind {
+ type Owner = LinuxOwner;
+
+ fn bind(port: u16) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Owner), Self::Error> {
+ let socket = UdpSocket::bind(format!("0.0.0.0:{}", port))?;
+ let socket = Arc::new(socket);
+
+ Ok((
+ vec![LinuxBind(socket.clone())],
+ LinuxBind(socket.clone()),
+ LinuxOwner(socket),
+ ))
+ }
}
-*/
diff --git a/src/platform/mod.rs b/src/platform/mod.rs
index ecd559a..99707e3 100644
--- a/src/platform/mod.rs
+++ b/src/platform/mod.rs
@@ -2,14 +2,15 @@ mod endpoint;
pub mod bind;
pub mod tun;
+pub mod uapi;
pub use endpoint::Endpoint;
#[cfg(target_os = "linux")]
-mod linux;
+pub mod linux;
#[cfg(test)]
pub mod dummy;
#[cfg(target_os = "linux")]
-pub use linux::LinuxTun as TunInstance;
+pub use linux as plt;
diff --git a/src/platform/tun.rs b/src/platform/tun.rs
index f49d4af..c92304a 100644
--- a/src/platform/tun.rs
+++ b/src/platform/tun.rs
@@ -56,6 +56,6 @@ pub trait Tun: Send + Sync + 'static {
}
/// On some platforms the application can create the TUN device itself.
-pub trait Platform: Tun {
+pub trait PlatformTun: Tun {
fn create(name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::MTU), Self::Error>;
}
diff --git a/src/platform/uapi.rs b/src/platform/uapi.rs
new file mode 100644
index 0000000..6922a9c
--- /dev/null
+++ b/src/platform/uapi.rs
@@ -0,0 +1,16 @@
+use std::error::Error;
+use std::io::{Read, Write};
+
+pub trait BindUAPI {
+ type Stream: Read + Write;
+ type Error: Error;
+
+ fn accept(&self) -> Result<Self::Stream, Self::Error>;
+}
+
+pub trait PlatformUAPI {
+ type Error: Error;
+ type Bind: BindUAPI;
+
+ fn bind(name: &str) -> Result<Self::Bind, Self::Error>;
+}
diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs
index f65692c..02e6929 100644
--- a/src/wireguard/handshake/device.rs
+++ b/src/wireguard/handshake/device.rs
@@ -178,13 +178,10 @@ impl Device {
/// # Returns
///
/// The call might fail if the public key is not found
- pub fn set_psk(&mut self, pk: PublicKey, psk: Option<Psk>) -> Result<(), ConfigError> {
+ pub fn set_psk(&mut self, pk: PublicKey, psk: Psk) -> Result<(), ConfigError> {
match self.pk_map.get_mut(pk.as_bytes()) {
Some(mut peer) => {
- peer.psk = match psk {
- Some(v) => v,
- None => [0u8; 32],
- };
+ peer.psk = psk;
Ok(())
}
_ => Err(ConfigError::new("No such public key")),
@@ -466,8 +463,8 @@ mod tests {
dev1.add(pk2).unwrap();
dev2.add(pk1).unwrap();
- dev1.set_psk(pk2, Some(psk)).unwrap();
- dev2.set_psk(pk1, Some(psk)).unwrap();
+ dev1.set_psk(pk2, psk).unwrap();
+ dev2.set_psk(pk1, psk).unwrap();
(pk1, dev1, pk2, dev2)
}
diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs
index 77be9f8..c0a8d9d 100644
--- a/src/wireguard/wireguard.rs
+++ b/src/wireguard/wireguard.rs
@@ -201,7 +201,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
.map(|sk| StaticSecret::from(sk.to_bytes()))
}
- pub fn set_psk(&self, pk: PublicKey, psk: Option<[u8; 32]>) -> bool {
+ pub fn set_psk(&self, pk: PublicKey, psk: [u8; 32]) -> bool {
self.state.handshake.write().set_psk(pk, psk).is_ok()
}
pub fn get_psk(&self, pk: &PublicKey) -> Option<[u8; 32]> {