summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/configuration/config.rs105
-rw-r--r--src/configuration/uapi/get.rs5
-rw-r--r--src/configuration/uapi/mod.rs7
-rw-r--r--src/configuration/uapi/set.rs13
-rw-r--r--src/main.rs40
-rw-r--r--src/platform/bind.rs2
-rw-r--r--src/platform/dummy/bind.rs4
-rw-r--r--src/platform/linux/uapi.rs2
-rw-r--r--src/platform/linux/udp.rs6
-rw-r--r--src/platform/uapi.rs2
-rw-r--r--src/wireguard/handshake/device.rs16
-rw-r--r--src/wireguard/handshake/peer.rs1
-rw-r--r--src/wireguard/peer.rs3
-rw-r--r--src/wireguard/timers.rs2
-rw-r--r--src/wireguard/wireguard.rs19
15 files changed, 122 insertions, 105 deletions
diff --git a/src/configuration/config.rs b/src/configuration/config.rs
index 50fdfb8..e50aeb6 100644
--- a/src/configuration/config.rs
+++ b/src/configuration/config.rs
@@ -10,6 +10,9 @@ use bind::Owner;
/// The goal of the configuration interface is, among others,
/// to hide the IO implementations (over which the WG device is generic),
/// from the configuration and UAPI code.
+///
+/// Furthermore it forms the simpler interface for embedding WireGuard in other applications,
+/// and hides the complex types of the implementation from the host application.
/// Describes a snapshot of the state of a peer
pub struct PeerState {
@@ -24,6 +27,7 @@ pub struct PeerState {
pub struct WireguardConfig<T: tun::Tun, B: bind::PlatformBind> {
wireguard: Wireguard<T, B>,
+ fwmark: Mutex<Option<u32>>,
network: Mutex<Option<B::Owner>>,
}
@@ -31,6 +35,7 @@ impl<T: tun::Tun, B: bind::PlatformBind> WireguardConfig<T, B> {
pub fn new(wg: Wireguard<T, B>) -> WireguardConfig<T, B> {
WireguardConfig {
wireguard: wg,
+ fwmark: Mutex::new(None),
network: Mutex::new(None),
}
}
@@ -59,7 +64,7 @@ pub trait Configuration {
/// An integer indicating the protocol version
fn get_protocol_version(&self) -> usize;
- fn set_listen_port(&self, port: Option<u16>) -> Option<ConfigError>;
+ fn set_listen_port(&self, port: Option<u16>) -> Result<(), ConfigError>;
/// Set the firewall mark (or similar, depending on platform)
///
@@ -71,7 +76,7 @@ pub trait Configuration {
///
/// An error if this operation is not supported by the underlying
/// "bind" implementation.
- fn set_fwmark(&self, mark: Option<u32>) -> Option<ConfigError>;
+ fn set_fwmark(&self, mark: Option<u32>) -> Result<(), ConfigError>;
/// Removes all peers from the device
fn replace_peers(&self);
@@ -110,7 +115,7 @@ pub trait Configuration {
/// # Returns
///
/// An error if no such peer exists
- fn set_preshared_key(&self, peer: &PublicKey, psk: [u8; 32]) -> Option<ConfigError>;
+ fn set_preshared_key(&self, peer: &PublicKey, psk: [u8; 32]);
/// Update the endpoint of the
///
@@ -118,7 +123,7 @@ pub trait Configuration {
///
/// - `peer': The public key of the peer
/// - `psk`
- fn set_endpoint(&self, peer: &PublicKey, addr: SocketAddr) -> Option<ConfigError>;
+ fn set_endpoint(&self, peer: &PublicKey, addr: SocketAddr);
/// Update the endpoint of the
///
@@ -126,8 +131,7 @@ pub trait Configuration {
///
/// - `peer': The public key of the peer
/// - `psk`
- fn set_persistent_keepalive_interval(&self, peer: &PublicKey, secs: u64)
- -> Option<ConfigError>;
+ fn set_persistent_keepalive_interval(&self, peer: &PublicKey, secs: u64);
/// Remove all allowed IPs from the peer
///
@@ -138,7 +142,7 @@ pub trait Configuration {
/// # Returns
///
/// An error if no such peer exists
- fn replace_allowed_ips(&self, peer: &PublicKey) -> Option<ConfigError>;
+ fn replace_allowed_ips(&self, peer: &PublicKey);
/// Add a new allowed subnet to the peer
///
@@ -151,12 +155,7 @@ pub trait Configuration {
/// # Returns
///
/// An error if the peer does not exist
- ///
- /// # Note:
- ///
- /// The API must itself sanitize the (ip, masklen) set:
- /// The ip should be masked to remove any set bits right of the first "masklen" bits.
- fn add_allowed_ip(&self, peer: &PublicKey, ip: IpAddr, masklen: u32) -> Option<ConfigError>;
+ fn add_allowed_ip(&self, peer: &PublicKey, ip: IpAddr, masklen: u32);
fn get_listen_port(&self) -> Option<u16>;
@@ -191,10 +190,14 @@ impl<T: tun::Tun, B: bind::PlatformBind> Configuration for WireguardConfig<T, B>
}
fn get_listen_port(&self) -> Option<u16> {
- self.network.lock().as_ref().map(|bind| bind.get_port())
+ let bind = self.network.lock();
+ log::trace!("Config, Get listen port, bound: {}", bind.is_some());
+ bind.as_ref().map(|bind| bind.get_port())
}
- fn set_listen_port(&self, port: Option<u16>) -> Option<ConfigError> {
+ fn set_listen_port(&self, port: Option<u16>) -> Result<(), ConfigError> {
+ log::trace!("Config, Set listen port: {:?}", port);
+
let mut bind = self.network.lock();
// close the current listener
@@ -203,13 +206,16 @@ impl<T: tun::Tun, B: bind::PlatformBind> Configuration for WireguardConfig<T, B>
// bind to new port
if let Some(port) = port {
// create new listener
- let (mut readers, writer, owner) = match B::bind(port) {
+ let (mut readers, writer, mut owner) = match B::bind(port) {
Ok(r) => r,
Err(_) => {
- return Some(ConfigError::FailedToBind);
+ return Err(ConfigError::FailedToBind);
}
};
+ // set fwmark
+ let _ = owner.set_fwmark(*self.fwmark.lock()); // TODO: handle
+
// add readers/writer to wireguard
self.wireguard.set_writer(writer);
while let Some(reader) = readers.pop() {
@@ -220,16 +226,18 @@ impl<T: tun::Tun, B: bind::PlatformBind> Configuration for WireguardConfig<T, B>
*bind = Some(owner);
}
- None
+ Ok(())
}
- fn set_fwmark(&self, mark: Option<u32>) -> Option<ConfigError> {
+ fn set_fwmark(&self, mark: Option<u32>) -> Result<(), ConfigError> {
+ log::trace!("Config, Set fwmark: {:?}", mark);
+
match self.network.lock().as_mut() {
Some(bind) => {
bind.set_fwmark(mark).unwrap(); // TODO: handle
- None
+ Ok(())
}
- None => Some(ConfigError::NotListening),
+ None => Err(ConfigError::NotListening),
}
}
@@ -242,59 +250,34 @@ impl<T: tun::Tun, B: bind::PlatformBind> Configuration for WireguardConfig<T, B>
}
fn add_peer(&self, peer: &PublicKey) -> bool {
- self.wireguard.add_peer(*peer);
- false
+ self.wireguard.add_peer(*peer)
}
- fn set_preshared_key(&self, peer: &PublicKey, psk: [u8; 32]) -> Option<ConfigError> {
- if self.wireguard.set_psk(*peer, psk) {
- None
- } else {
- Some(ConfigError::NoSuchPeer)
- }
+ fn set_preshared_key(&self, peer: &PublicKey, psk: [u8; 32]) {
+ self.wireguard.set_psk(*peer, psk);
}
- fn set_endpoint(&self, peer: &PublicKey, addr: SocketAddr) -> Option<ConfigError> {
- match self.wireguard.lookup_peer(peer) {
- Some(peer) => {
- peer.router.set_endpoint(B::Endpoint::from_address(addr));
- None
- }
- None => Some(ConfigError::NoSuchPeer),
+ fn set_endpoint(&self, peer: &PublicKey, addr: SocketAddr) {
+ if let Some(peer) = self.wireguard.lookup_peer(peer) {
+ peer.router.set_endpoint(B::Endpoint::from_address(addr));
}
}
- fn set_persistent_keepalive_interval(
- &self,
- peer: &PublicKey,
- secs: u64,
- ) -> Option<ConfigError> {
- match self.wireguard.lookup_peer(peer) {
- Some(peer) => {
- peer.set_persistent_keepalive_interval(secs);
- None
- }
- None => Some(ConfigError::NoSuchPeer),
+ fn set_persistent_keepalive_interval(&self, peer: &PublicKey, secs: u64) {
+ if let Some(peer) = self.wireguard.lookup_peer(peer) {
+ peer.set_persistent_keepalive_interval(secs);
}
}
- fn replace_allowed_ips(&self, peer: &PublicKey) -> Option<ConfigError> {
- match self.wireguard.lookup_peer(peer) {
- Some(peer) => {
- peer.router.remove_allowed_ips();
- None
- }
- None => Some(ConfigError::NoSuchPeer),
+ fn replace_allowed_ips(&self, peer: &PublicKey) {
+ if let Some(peer) = self.wireguard.lookup_peer(peer) {
+ peer.router.remove_allowed_ips();
}
}
- fn add_allowed_ip(&self, peer: &PublicKey, ip: IpAddr, masklen: u32) -> Option<ConfigError> {
- match self.wireguard.lookup_peer(peer) {
- Some(peer) => {
- peer.router.add_allowed_ip(ip, masklen);
- None
- }
- None => Some(ConfigError::NoSuchPeer),
+ fn add_allowed_ip(&self, peer: &PublicKey, ip: IpAddr, masklen: u32) {
+ if let Some(peer) = self.wireguard.lookup_peer(peer) {
+ peer.router.add_allowed_ip(ip, masklen);
}
}
diff --git a/src/configuration/uapi/get.rs b/src/configuration/uapi/get.rs
index 0874cfc..43d4735 100644
--- a/src/configuration/uapi/get.rs
+++ b/src/configuration/uapi/get.rs
@@ -1,10 +1,7 @@
-use hex::FromHex;
-use subtle::ConstantTimeEq;
-
use log;
+use std::io;
use super::Configuration;
-use std::io;
pub fn serialize<C: Configuration, W: io::Write>(writer: &mut W, config: &C) -> io::Result<()> {
let mut write = |key: &'static str, value: String| {
diff --git a/src/configuration/uapi/mod.rs b/src/configuration/uapi/mod.rs
index 4261e7d..3cb88c0 100644
--- a/src/configuration/uapi/mod.rs
+++ b/src/configuration/uapi/mod.rs
@@ -55,10 +55,13 @@ pub fn handle<S: Read + Write, C: Configuration>(stream: &mut S, config: &C) {
loop {
let ln = readline(stream)?;
if ln == "" {
+ // end of transcript
+ parser.parse_line("", "")?; // flush final peer
break Ok(());
+ } else {
+ let (k, v) = keypair(ln.as_str())?;
+ parser.parse_line(k, v)?;
};
- let (k, v) = keypair(ln.as_str())?;
- parser.parse_line(k, v)?;
}
}
_ => Err(ConfigError::InvalidOperation),
diff --git a/src/configuration/uapi/set.rs b/src/configuration/uapi/set.rs
index e449edd..882e4a7 100644
--- a/src/configuration/uapi/set.rs
+++ b/src/configuration/uapi/set.rs
@@ -109,7 +109,7 @@ impl<'a, C: Configuration> LineParser<'a, C> {
// opt: set listen port
"listen_port" => match value.parse() {
Ok(port) => {
- self.config.set_listen_port(Some(port));
+ self.config.set_listen_port(Some(port))?;
Ok(())
}
Err(_) => Err(ConfigError::InvalidPortNumber),
@@ -119,7 +119,7 @@ impl<'a, C: Configuration> LineParser<'a, C> {
"fwmark" => match value.parse() {
Ok(fwmark) => {
self.config
- .set_fwmark(if fwmark == 0 { None } else { Some(fwmark) });
+ .set_fwmark(if fwmark == 0 { None } else { Some(fwmark) })?;
Ok(())
}
Err(_) => Err(ConfigError::InvalidFwmark),
@@ -142,6 +142,9 @@ impl<'a, C: Configuration> LineParser<'a, C> {
Ok(())
}
+ // ignore (end of transcript)
+ "" => Ok(()),
+
// unknown key
_ => Err(ConfigError::InvalidKey),
},
@@ -227,6 +230,12 @@ impl<'a, C: Configuration> LineParser<'a, C> {
}
}
+ // flush (used at end of transcipt)
+ "" => {
+ flush_peer(self.config, &peer);
+ Ok(())
+ }
+
// unknown key
_ => Err(ConfigError::InvalidKey),
},
diff --git a/src/main.rs b/src/main.rs
index e17a127..b1762cb 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -10,24 +10,37 @@ mod configuration;
mod platform;
mod wireguard;
+use log;
+
+use std::env;
+
use platform::tun::PlatformTun;
-use platform::uapi::PlatformUAPI;
+use platform::uapi::{BindUAPI, PlatformUAPI};
use platform::*;
-use std::sync::Arc;
-use std::thread;
-use std::time::Duration;
-
fn main() {
- let name = "wg0";
+ let mut name = String::new();
+ let mut foreground = false;
+
+ for arg in env::args() {
+ if arg == "--foreground" || arg == "-f" {
+ foreground = true;
+ } else {
+ name = arg;
+ }
+ }
+
+ if name == "" {
+ return;
+ }
let _ = env_logger::builder().is_test(true).try_init();
// create UAPI socket
- let uapi = plt::UAPI::bind(name).unwrap();
+ let uapi = plt::UAPI::bind(name.as_str()).unwrap();
// create TUN device
- let (readers, writer, mtu) = plt::Tun::create(name).unwrap();
+ let (readers, writer, mtu) = plt::Tun::create(name.as_str()).unwrap();
// create WireGuard device
let wg: wireguard::Wireguard<plt::Tun, plt::Bind> =
@@ -36,9 +49,12 @@ fn main() {
// wrap in configuration interface and start UAPI server
let cfg = configuration::WireguardConfig::new(wg);
loop {
- let mut stream = uapi.accept().unwrap();
- configuration::uapi::handle(&mut stream.0, &cfg);
+ match uapi.connect() {
+ Ok(mut stream) => configuration::uapi::handle(&mut stream, &cfg),
+ Err(err) => {
+ log::info!("UAPI error: {:}", err);
+ break;
+ }
+ }
}
-
- thread::sleep(Duration::from_secs(600));
}
diff --git a/src/platform/bind.rs b/src/platform/bind.rs
index 1055f37..9487dfd 100644
--- a/src/platform/bind.rs
+++ b/src/platform/bind.rs
@@ -32,7 +32,7 @@ pub trait Owner: Send {
fn get_fwmark(&self) -> Option<u32>;
- fn set_fwmark(&mut self, value: Option<u32>) -> Option<Self::Error>;
+ fn set_fwmark(&mut self, value: Option<u32>) -> Result<(), Self::Error>;
}
/// On some platforms the application can itself bind to a socket.
diff --git a/src/platform/dummy/bind.rs b/src/platform/dummy/bind.rs
index b42483a..d69e6a4 100644
--- a/src/platform/dummy/bind.rs
+++ b/src/platform/dummy/bind.rs
@@ -203,8 +203,8 @@ impl Bind for PairBind {
impl Owner for VoidOwner {
type Error = BindError;
- fn set_fwmark(&mut self, _value: Option<u32>) -> Option<Self::Error> {
- None
+ fn set_fwmark(&mut self, _value: Option<u32>) -> Result<(), Self::Error> {
+ Ok(())
}
fn get_port(&self) -> u16 {
diff --git a/src/platform/linux/uapi.rs b/src/platform/linux/uapi.rs
index fdf2bf0..107745a 100644
--- a/src/platform/linux/uapi.rs
+++ b/src/platform/linux/uapi.rs
@@ -24,7 +24,7 @@ impl BindUAPI for UnixListener {
type Stream = UnixStream;
type Error = io::Error;
- fn accept(&self) -> Result<UnixStream, io::Error> {
+ fn connect(&self) -> Result<UnixStream, io::Error> {
let (stream, _) = self.accept()?;
Ok(stream)
}
diff --git a/src/platform/linux/udp.rs b/src/platform/linux/udp.rs
index d3d61b6..a291d1a 100644
--- a/src/platform/linux/udp.rs
+++ b/src/platform/linux/udp.rs
@@ -43,15 +43,15 @@ impl Owner for LinuxOwner {
type Error = io::Error;
fn get_port(&self) -> u16 {
- 1337
+ self.0.local_addr().unwrap().port() // todo handle
}
fn get_fwmark(&self) -> Option<u32> {
None
}
- fn set_fwmark(&mut self, value: Option<u32>) -> Option<Self::Error> {
- None
+ fn set_fwmark(&mut self, _value: Option<u32>) -> Result<(), Self::Error> {
+ Ok(())
}
}
diff --git a/src/platform/uapi.rs b/src/platform/uapi.rs
index 6922a9c..8259f67 100644
--- a/src/platform/uapi.rs
+++ b/src/platform/uapi.rs
@@ -5,7 +5,7 @@ pub trait BindUAPI {
type Stream: Read + Write;
type Error: Error;
- fn accept(&self) -> Result<Self::Stream, Self::Error>;
+ fn connect(&self) -> Result<Self::Stream, Self::Error>;
}
pub trait PlatformUAPI {
diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs
index 02e6929..030c0f8 100644
--- a/src/wireguard/handshake/device.rs
+++ b/src/wireguard/handshake/device.rs
@@ -469,6 +469,10 @@ mod tests {
(pk1, dev1, pk2, dev2)
}
+ fn wait() {
+ thread::sleep(Duration::from_millis(20));
+ }
+
/* Test longest possible handshake interaction (7 messages):
*
* 1. I -> R (initation)
@@ -502,8 +506,8 @@ mod tests {
_ => panic!("unexpected response"),
}
- // avoid initation flood
- thread::sleep(Duration::from_millis(20));
+ // avoid initation flood detection
+ wait();
// 3. device-1 : create second initation
let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
@@ -529,8 +533,8 @@ mod tests {
_ => panic!("unexpected response"),
}
- // avoid initation flood
- thread::sleep(Duration::from_millis(20));
+ // avoid initation flood detection
+ wait();
// 6. device-1 : create third initation
let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
@@ -600,8 +604,8 @@ mod tests {
dev1.release(ks_i.send.id);
dev2.release(ks_r.send.id);
- // to avoid flood detection
- thread::sleep(Duration::from_millis(20));
+ // avoid initation flood detection
+ wait();
}
dev1.remove(pk2).unwrap();
diff --git a/src/wireguard/handshake/peer.rs b/src/wireguard/handshake/peer.rs
index abb36eb..2d69244 100644
--- a/src/wireguard/handshake/peer.rs
+++ b/src/wireguard/handshake/peer.rs
@@ -7,7 +7,6 @@ use generic_array::typenum::U32;
use generic_array::GenericArray;
use x25519_dalek::PublicKey;
-use x25519_dalek::SharedSecret;
use x25519_dalek::StaticSecret;
use clear_on_drop::clear::Clear;
diff --git a/src/wireguard/peer.rs b/src/wireguard/peer.rs
index 4f9d19f..7d95493 100644
--- a/src/wireguard/peer.rs
+++ b/src/wireguard/peer.rs
@@ -1,4 +1,3 @@
-use super::constants::*;
use super::router;
use super::timers::{Events, Timers};
use super::HandshakeJob;
@@ -9,7 +8,7 @@ use super::wireguard::WireguardInner;
use std::fmt;
use std::ops::Deref;
-use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
+use std::sync::atomic::{AtomicBool, AtomicU64};
use std::sync::Arc;
use std::time::{Instant, SystemTime};
diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs
index 33b089f..8f6b3ee 100644
--- a/src/wireguard/timers.rs
+++ b/src/wireguard/timers.rs
@@ -63,7 +63,7 @@ impl<T: tun::Tun, B: bind::Bind> PeerInner<T, B> {
// take a write lock preventing simultaneous "stop_timers" call
let mut timers = self.timers_mut();
- // set flag to renable timer events
+ // set flag to reenable timer events
if timers.enabled {
return;
}
diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs
index c0a8d9d..613c0a8 100644
--- a/src/wireguard/wireguard.rs
+++ b/src/wireguard/wireguard.rs
@@ -18,6 +18,7 @@ use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant, SystemTime};
+use std::collections::hash_map::Entry;
use std::collections::HashMap;
use log::debug;
@@ -208,9 +209,9 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
self.state.handshake.read().get_psk(pk).ok()
}
- pub fn add_peer(&self, pk: PublicKey) {
+ pub fn add_peer(&self, pk: PublicKey) -> bool {
if self.state.peers.read().contains_key(pk.as_bytes()) {
- return;
+ return false;
}
let mut rng = OsRng::new().unwrap();
@@ -243,10 +244,16 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
// finally, add the peer to the wireguard device
let mut peers = self.state.peers.write();
- peers.entry(*pk.as_bytes()).or_insert(peer);
-
- // add to the handshake device
- self.state.handshake.write().add(pk).unwrap(); // TODO: handle adding of public key for interface
+ match peers.entry(*pk.as_bytes()) {
+ Entry::Occupied(_) => false,
+ Entry::Vacant(vacancy) => {
+ let ok_pk = self.state.handshake.write().add(pk).is_ok();
+ if ok_pk {
+ vacancy.insert(peer);
+ }
+ ok_pk
+ }
+ }
}
/// Begin consuming messages from the reader.