From 3fa928b3158ce33a57e9ba2c1913485eb409ff4b Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Wed, 23 Oct 2019 10:32:18 +0200 Subject: Work on platform specific code (Linux) --- src/wireguard/config.rs | 186 -------------------------------------- src/wireguard/handshake/device.rs | 6 +- src/wireguard/mod.rs | 13 +-- src/wireguard/types/dummy.rs | 54 +++++++---- src/wireguard/types/endpoint.rs | 2 +- src/wireguard/types/mod.rs | 5 +- src/wireguard/wireguard.rs | 50 ++++++++-- 7 files changed, 93 insertions(+), 223 deletions(-) delete mode 100644 src/wireguard/config.rs (limited to 'src/wireguard') diff --git a/src/wireguard/config.rs b/src/wireguard/config.rs deleted file mode 100644 index 0f2953d..0000000 --- a/src/wireguard/config.rs +++ /dev/null @@ -1,186 +0,0 @@ -use std::net::{IpAddr, SocketAddr}; -use x25519_dalek::{PublicKey, StaticSecret}; - -use super::wireguard::Wireguard; -use super::types::bind::Bind; -use super::types::tun::Tun; - -/// The goal of the configuration interface is, among others, -/// to hide the IO implementations (over which the WG device is generic), -/// from the configuration and UAPI code. - -/// Describes a snapshot of the state of a peer -pub struct PeerState { - rx_bytes: u64, - tx_bytes: u64, - last_handshake_time_sec: u64, - last_handshake_time_nsec: u64, - public_key: PublicKey, - allowed_ips: Vec<(IpAddr, u32)>, -} - -pub enum ConfigError { - NoSuchPeer -} - -impl ConfigError { - - fn errno(&self) -> i32 { - match self { - NoSuchPeer => 1, - } - } -} - -/// Exposed configuration interface -pub trait Configuration { - /// Updates the private key of the device - /// - /// # Arguments - /// - /// - `sk`: The new private key (or None, if the private key should be cleared) - fn set_private_key(&self, sk: Option); - - /// Returns the private key of the device - /// - /// # Returns - /// - /// The private if set, otherwise None. - fn get_private_key(&self) -> Option; - - /// Returns the protocol version of the device - /// - /// # Returns - /// - /// An integer indicating the protocol version - fn get_protocol_version(&self) -> usize; - - fn set_listen_port(&self, port: u16) -> Option; - - /// Set the firewall mark (or similar, depending on platform) - /// - /// # Arguments - /// - /// - `mark`: The fwmark value - /// - /// # Returns - /// - /// An error if this operation is not supported by the underlying - /// "bind" implementation. - fn set_fwmark(&self, mark: Option) -> Option; - - /// Removes all peers from the device - fn replace_peers(&self); - - /// Remove the peer from the - /// - /// # Arguments - /// - /// - `peer`: The public key of the peer to remove - /// - /// # Returns - /// - /// If the peer does not exists this operation is a noop - fn remove_peer(&self, peer: PublicKey); - - /// Adds a new peer to the device - /// - /// # Arguments - /// - /// - `peer`: The public key of the peer to add - /// - /// # Returns - /// - /// A bool indicating if the peer was added. - /// - /// If the peer already exists this operation is a noop - fn add_peer(&self, peer: PublicKey) -> bool; - - /// Update the psk of a peer - /// - /// # Arguments - /// - /// - `peer`: The public key of the peer - /// - `psk`: The new psk or None if the psk should be unset - /// - /// # Returns - /// - /// An error if no such peer exists - fn set_preshared_key(&self, peer: PublicKey, psk: Option<[u8; 32]>) -> Option; - - /// Update the endpoint of the - /// - /// # Arguments - /// - /// - `peer': The public key of the peer - /// - `psk` - fn set_endpoint(&self, peer: PublicKey, addr: SocketAddr) -> Option; - - /// Update the endpoint of the - /// - /// # Arguments - /// - /// - `peer': The public key of the peer - /// - `psk` - fn set_persistent_keepalive_interval(&self, peer: PublicKey) -> Option; - - /// Remove all allowed IPs from the peer - /// - /// # Arguments - /// - /// - `peer': The public key of the peer - /// - /// # Returns - /// - /// An error if no such peer exists - fn replace_allowed_ips(&self, peer: PublicKey) -> Option; - - /// Add a new allowed subnet to the peer - /// - /// # Arguments - /// - /// - `peer`: The public key of the peer - /// - `ip`: Subnet mask - /// - `masklen`: - /// - /// # Returns - /// - /// An error if the peer does not exist - /// - /// # Note: - /// - /// The API must itself sanitize the (ip, masklen) set: - /// The ip should be masked to remove any set bits right of the first "masklen" bits. - fn add_allowed_ip(&self, peer: PublicKey, ip: IpAddr, masklen: u32) -> Option; - - /// Returns the state of all peers - /// - /// # Returns - /// - /// A list of structures describing the state of each peer - fn get_peers(&self) -> Vec; -} - -impl Configuration for Wireguard { - - fn set_private_key(&self, sk : Option) { - self.set_key(sk) - } - - fn get_private_key(&self) -> Option { - self.get_sk() - } - - fn get_protocol_version(&self) -> usize { - 1 - } - - fn set_listen_port(&self, port : u16) -> Option { - None - } - - fn set_fwmark(&self, mark: Option) -> Option { - None - } - -} \ No newline at end of file diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs index 6a55f6e..c2e3a6e 100644 --- a/src/wireguard/handshake/device.rs +++ b/src/wireguard/handshake/device.rs @@ -77,9 +77,9 @@ impl Device { } /// Return the secret key of the device - /// + /// /// # Returns - /// + /// /// A secret key (x25519 scalar) pub fn get_sk(&self) -> StaticSecret { StaticSecret::from(self.sk.to_bytes()) @@ -95,7 +95,7 @@ impl Device { pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> { // check that the pk is not added twice if let Some(_) = self.pk_map.get(pk.as_bytes()) { - return Err(ConfigError::new("Duplicate public key")); + return Ok(()); }; // check that the pk is not that of the device diff --git a/src/wireguard/mod.rs b/src/wireguard/mod.rs index 9417e57..563a22f 100644 --- a/src/wireguard/mod.rs +++ b/src/wireguard/mod.rs @@ -1,7 +1,6 @@ -mod wireguard; -// mod config; mod constants; mod timers; +mod wireguard; mod handshake; mod router; @@ -12,12 +11,14 @@ mod tests; /// The WireGuard sub-module contains a pure, configurable implementation of WireGuard. /// The implementation is generic over: -/// +/// /// - TUN type, specifying how packets are received on the interface side: a reader/writer and MTU reporting interface. /// - Bind type, specifying how WireGuard messages are sent/received from the internet and what constitutes an "endpoint" - -pub use wireguard::{Wireguard, Peer}; +pub use wireguard::{Peer, Wireguard}; pub use types::bind; pub use types::tun; -pub use types::Endpoint; \ No newline at end of file +pub use types::Endpoint; + +#[cfg(test)] +pub use types::dummy; diff --git a/src/wireguard/types/dummy.rs b/src/wireguard/types/dummy.rs index 2403c9b..384f123 100644 --- a/src/wireguard/types/dummy.rs +++ b/src/wireguard/types/dummy.rs @@ -2,11 +2,11 @@ use std::error::Error; use std::fmt; use std::marker; use std::net::SocketAddr; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; use std::sync::Arc; use std::sync::Mutex; use std::time::Instant; -use std::sync::atomic::{Ordering, AtomicUsize}; use super::*; @@ -43,7 +43,7 @@ impl fmt::Display for BindError { #[derive(Debug)] pub enum TunError { - Disconnected + Disconnected, } impl Error for TunError { @@ -76,7 +76,7 @@ impl Endpoint for UnitEndpoint { "127.0.0.1:8080".parse().unwrap() } - fn clear_src(&self) {} + fn clear_src(&mut self) {} } impl UnitEndpoint { @@ -92,21 +92,21 @@ pub struct TunTest {} pub struct TunFakeIO { store: bool, tx: SyncSender>, - rx: Receiver> + rx: Receiver>, } pub struct TunReader { - rx: Receiver> + rx: Receiver>, } pub struct TunWriter { store: bool, - tx: Mutex>> + tx: Mutex>>, } #[derive(Clone)] pub struct TunMTU { - mtu: Arc + mtu: Arc, } impl tun::Reader for TunReader { @@ -118,7 +118,7 @@ impl tun::Reader for TunReader { buf[offset..].copy_from_slice(&m[..]); Ok(m.len()) } - Err(_) => Err(TunError::Disconnected) + Err(_) => Err(TunError::Disconnected), } } } @@ -131,7 +131,7 @@ impl tun::Writer for TunWriter { let m = src.to_owned(); match self.tx.lock().unwrap().send(m) { Ok(_) => Ok(()), - Err(_) => Err(TunError::Disconnected) + Err(_) => Err(TunError::Disconnected), } } else { Ok(()) @@ -153,7 +153,7 @@ impl tun::Tun for TunTest { } impl TunFakeIO { - pub fn write(&self, msg : Vec) { + pub fn write(&self, msg: Vec) { if self.store { self.tx.send(msg).unwrap(); } @@ -165,15 +165,31 @@ impl TunFakeIO { } impl TunTest { - pub fn create(mtu : usize, store: bool) -> (TunFakeIO, TunReader, TunWriter, TunMTU) { - - let (tx1, rx1) = if store { sync_channel(32) } else { sync_channel(1) }; - let (tx2, rx2) = if store { sync_channel(32) } else { sync_channel(1) }; - - let fake = TunFakeIO{tx: tx1, rx: rx2, store}; - let reader = TunReader{rx : rx1}; - let writer = TunWriter{tx : Mutex::new(tx2), store}; - let mtu = TunMTU{mtu : Arc::new(AtomicUsize::new(mtu))}; + pub fn create(mtu: usize, store: bool) -> (TunFakeIO, TunReader, TunWriter, TunMTU) { + let (tx1, rx1) = if store { + sync_channel(32) + } else { + sync_channel(1) + }; + let (tx2, rx2) = if store { + sync_channel(32) + } else { + sync_channel(1) + }; + + let fake = TunFakeIO { + tx: tx1, + rx: rx2, + store, + }; + let reader = TunReader { rx: rx1 }; + let writer = TunWriter { + tx: Mutex::new(tx2), + store, + }; + let mtu = TunMTU { + mtu: Arc::new(AtomicUsize::new(mtu)), + }; (fake, reader, writer, mtu) } diff --git a/src/wireguard/types/endpoint.rs b/src/wireguard/types/endpoint.rs index f4f93da..4702aab 100644 --- a/src/wireguard/types/endpoint.rs +++ b/src/wireguard/types/endpoint.rs @@ -3,5 +3,5 @@ use std::net::SocketAddr; pub trait Endpoint: Send + 'static { fn from_address(addr: SocketAddr) -> Self; fn into_address(&self) -> SocketAddr; - fn clear_src(&self); + fn clear_src(&mut self); } diff --git a/src/wireguard/types/mod.rs b/src/wireguard/types/mod.rs index e0725f3..20a1238 100644 --- a/src/wireguard/types/mod.rs +++ b/src/wireguard/types/mod.rs @@ -1,10 +1,11 @@ mod endpoint; mod keys; -pub mod tun; + pub mod bind; +pub mod tun; #[cfg(test)] pub mod dummy; pub use endpoint::Endpoint; -pub use keys::{Key, KeyPair}; \ No newline at end of file +pub use keys::{Key, KeyPair}; diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index 1363c27..9bcac0a 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -54,7 +54,7 @@ pub struct PeerInner { pub handshake_queued: AtomicBool, pub queue: Mutex>>, // handshake queue - pub pk: PublicKey, // DISCUSS: Change layout in handshake module (adopt pattern of router), to avoid this. + pub pk: PublicKey, // DISCUSS: Change layout in handshake module (adopt pattern of router), to avoid this. TODO: remove pub timers: RwLock, // } @@ -99,7 +99,7 @@ pub enum HandshakeJob { New(PublicKey), } -struct WireguardInner { +pub struct WireguardInner { // provides access to the MTU value of the tun device // (otherwise owned solely by the router and a dedicated read IO thread) mtu: T::MTU, @@ -118,9 +118,21 @@ struct WireguardInner { queue: Mutex>>, } +#[derive(Clone)] +pub struct WireguardHandle { + inner: Arc>, +} + +impl Deref for WireguardHandle { + type Target = Arc>; + fn deref(&self) -> &Self::Target { + &self.inner + } +} + pub struct Wireguard { runner: Runner, - state: Arc>, + state: WireguardHandle, } /* Returns the padded length of a message: @@ -146,6 +158,24 @@ const fn padding(size: usize, mtu: usize) -> usize { } impl Wireguard { + pub fn clear_peers(&self) { + self.state.peers.write().clear(); + } + + pub fn remove_peer(&self, pk: PublicKey) { + self.state.peers.write().remove(pk.as_bytes()); + } + + pub fn list_peers(&self) -> Vec> { + let peers = self.state.peers.read(); + let mut list = Vec::with_capacity(peers.len()); + for (k, v) in peers.iter() { + debug_assert!(k == v.pk.as_bytes()); + list.push(v.clone()); + } + list + } + pub fn set_key(&self, sk: Option) { let mut handshake = self.state.handshake.write(); match sk { @@ -170,7 +200,7 @@ impl Wireguard { } } - pub fn new_peer(&self, pk: PublicKey) -> Peer { + pub fn new_peer(&self, pk: PublicKey) { let state = Arc::new(PeerInner { pk, last_handshake: Mutex::new(SystemTime::UNIX_EPOCH), @@ -182,8 +212,13 @@ impl Wireguard { timers: RwLock::new(Timers::dummy(&self.runner)), }); + // create a router peer let router = Arc::new(self.state.router.new_peer(state.clone())); + // add to the handshake device + self.state.handshake.write().device.add(pk).unwrap(); // TODO: handle adding of public key for interface + + // form WireGuard peer let peer = Peer { router, state }; /* The need for dummy timers arises from the chicken-egg @@ -193,7 +228,10 @@ impl Wireguard { * TODO: Consider the ease of using atomic pointers instead. */ *peer.timers.write() = Timers::new(&self.runner, peer.clone()); - peer + + // finally, add the peer to the wireguard device + let mut peers = self.state.peers.write(); + peers.entry(*pk.as_bytes()).or_insert(peer); } /* Begin consuming messages from the reader. @@ -417,7 +455,7 @@ impl Wireguard { } Wireguard { - state: wg, + state: WireguardHandle { inner: wg }, runner: Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY), } } -- cgit v1.2.3-59-g8ed1b