From 761c46064d7510303f08cde27c9e13b07293f3af Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Wed, 9 Oct 2019 15:08:26 +0200 Subject: Restructure IO traits. --- src/config.rs | 183 ++++++++++++++++++++++++++++++++++++++++++++ src/handshake/device.rs | 9 +++ src/main.rs | 7 +- src/router/device.rs | 55 +++++++------ src/router/peer.rs | 65 ++++++++-------- src/router/tests.rs | 46 ++++++----- src/router/types.rs | 2 + src/router/workers.rs | 45 ++++++----- src/timers.rs | 8 +- src/types/bind.rs | 79 ++++--------------- src/types/dummy.rs | 172 +++++++++++++++++++++++++++-------------- src/types/endpoint.rs | 2 +- src/types/mod.rs | 8 +- src/types/tun.rs | 43 +++++++---- src/types/udp.rs | 29 ------- src/wireguard.rs | 200 +++++++++++++++++++++++++++++++----------------- 16 files changed, 610 insertions(+), 343 deletions(-) create mode 100644 src/config.rs delete mode 100644 src/types/udp.rs (limited to 'src') diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..60faf43 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,183 @@ +use std::error::Error; +use std::net::{IpAddr, SocketAddr}; +use x25519_dalek::{PublicKey, StaticSecret}; + +use crate::wireguard::Wireguard; +use crate::types::{Bind, Endpoint, Tun}; + +/// +/// The goal of the configuration interface is, among others, +/// to hide the IO implementations (over which the WG device is generic), +/// from the configuration and UAPI code. + +/// Describes a snapshot of the state of a peer +pub struct PeerState { + rx_bytes: u64, + tx_bytes: u64, + last_handshake_time_sec: u64, + last_handshake_time_nsec: u64, + public_key: PublicKey, + allowed_ips: Vec<(IpAddr, u32)>, +} + +pub enum ConfigError { + NoSuchPeer +} + +impl ConfigError { + + fn errno(&self) -> i32 { + match self { + NoSuchPeer => 1, + } + } +} + +/// Exposed configuration interface +pub trait Configuration { + /// Updates the private key of the device + /// + /// # Arguments + /// + /// - `sk`: The new private key (or None, if the private key should be cleared) + fn set_private_key(&self, sk: Option); + + /// Returns the private key of the device + /// + /// # Returns + /// + /// The private if set, otherwise None. + fn get_private_key(&self) -> Option; + + /// Returns the protocol version of the device + /// + /// # Returns + /// + /// An integer indicating the protocol version + fn get_protocol_version(&self) -> usize; + + fn set_listen_port(&self, port: u16) -> Option; + + /// Set the firewall mark (or similar, depending on platform) + /// + /// # Arguments + /// + /// - `mark`: The fwmark value + /// + /// # Returns + /// + /// An error if this operation is not supported by the underlying + /// "bind" implementation. + fn set_fwmark(&self, mark: Option) -> Option; + + /// Removes all peers from the device + fn replace_peers(&self); + + /// Remove the peer from the + /// + /// # Arguments + /// + /// - `peer`: The public key of the peer to remove + /// + /// # Returns + /// + /// If the peer does not exists this operation is a noop + fn remove_peer(&self, peer: PublicKey); + + /// Adds a new peer to the device + /// + /// # Arguments + /// + /// - `peer`: The public key of the peer to add + /// + /// # Returns + /// + /// A bool indicating if the peer was added. + /// + /// If the peer already exists this operation is a noop + fn add_peer(&self, peer: PublicKey) -> bool; + + /// Update the psk of a peer + /// + /// # Arguments + /// + /// - `peer`: The public key of the peer + /// - `psk`: The new psk or None if the psk should be unset + /// + /// # Returns + /// + /// An error if no such peer exists + fn set_preshared_key(&self, peer: PublicKey, psk: Option<[u8; 32]>) -> Option; + + /// Update the endpoint of the + /// + /// # Arguments + /// + /// - `peer': The public key of the peer + /// - `psk` + fn set_endpoint(&self, peer: PublicKey, addr: SocketAddr) -> Option; + + /// Update the endpoint of the + /// + /// # Arguments + /// + /// - `peer': The public key of the peer + /// - `psk` + fn set_persistent_keepalive_interval(&self, peer: PublicKey) -> Option; + + /// Remove all allowed IPs from the peer + /// + /// # Arguments + /// + /// - `peer': The public key of the peer + /// + /// # Returns + /// + /// An error if no such peer exists + fn replace_allowed_ips(&self, peer: PublicKey) -> Option; + + /// Add a new allowed subnet to the peer + /// + /// # Arguments + /// + /// - `peer`: The public key of the peer + /// - `ip`: Subnet mask + /// - `masklen`: + /// + /// # Returns + /// + /// An error if the peer does not exist + /// + /// # Note: + /// + /// The API must itself sanitize the (ip, masklen) set: + /// The ip should be masked to remove any set bits right of the first "masklen" bits. + fn add_allowed_ip(&self, peer: PublicKey, ip: IpAddr, masklen: u32) -> Option; + + /// Returns the state of all peers + /// + /// # Returns + /// + /// A list of structures describing the state of each peer + fn get_peers(&self) -> Vec; +} + +impl Configuration for Wireguard { + + fn set_private_key(&self, sk : Option) { + self.set_key(sk) + } + + fn get_private_key(&self) -> Option { + self.get_sk() + } + + fn get_protocol_version(&self) -> usize { + 1 + } + + fn set_listen_port(&self, port : u16) -> Option { + + } + +} \ No newline at end of file diff --git a/src/handshake/device.rs b/src/handshake/device.rs index 6178831..6a55f6e 100644 --- a/src/handshake/device.rs +++ b/src/handshake/device.rs @@ -76,6 +76,15 @@ impl Device { } } + /// Return the secret key of the device + /// + /// # Returns + /// + /// A secret key (x25519 scalar) + pub fn get_sk(&self) -> StaticSecret { + StaticSecret::from(self.sk.to_bytes()) + } + /// Add a new public key to the state machine /// To remove public keys, you must create a new machine instance /// diff --git a/src/main.rs b/src/main.rs index 6133884..7a31119 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,6 +5,7 @@ extern crate jemallocator; #[global_allocator] static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc; +// mod config; mod constants; mod handshake; mod router; @@ -14,7 +15,8 @@ mod wireguard; #[cfg(test)] mod tests { - use crate::types::{dummy, Bind}; + use crate::types::tun::Tun; + use crate::types::{bind, dummy, tun}; use crate::wireguard::Wireguard; use std::thread; @@ -27,7 +29,8 @@ mod tests { #[test] fn test_pure_wireguard() { init(); - let wg = Wireguard::new(dummy::TunTest::new(), dummy::VoidBind::new()); + let (reader, writer, mtu) = dummy::TunTest::create("name").unwrap(); + let wg: Wireguard = Wireguard::new(reader, writer, mtu); thread::sleep(Duration::from_millis(500)); } } diff --git a/src/router/device.rs b/src/router/device.rs index d126959..989c2c2 100644 --- a/src/router/device.rs +++ b/src/router/device.rs @@ -17,21 +17,23 @@ use super::constants::*; use super::ip::*; use super::messages::{TransportHeader, TYPE_TRANSPORT}; use super::peer::{new_peer, Peer, PeerInner}; -use super::types::{Callbacks, Opaque, RouterError}; +use super::types::{Callbacks, RouterError}; use super::workers::{worker_parallel, JobParallel, Operation}; use super::SIZE_MESSAGE_PREFIX; -use super::super::types::{Bind, KeyPair, Tun}; +use super::super::types::{KeyPair, Endpoint, bind, tun}; -pub struct DeviceInner { - // IO & timer callbacks - pub tun: T, - pub bind: B, +pub struct DeviceInner> { + // inbound writer (TUN) + pub inbound: T, + + // outbound writer (Bind) + pub outbound: RwLock>, // routing - pub recv: RwLock>>>, // receiver id -> decryption state - pub ipv4: RwLock>>>, // ipv4 cryptkey routing - pub ipv6: RwLock>>>, // ipv6 cryptkey routing + pub recv: RwLock>>>, // receiver id -> decryption state + pub ipv4: RwLock>>>, // ipv4 cryptkey routing + pub ipv6: RwLock>>>, // ipv6 cryptkey routing // work queues pub queue_next: AtomicUsize, // next round-robin index @@ -45,20 +47,20 @@ pub struct EncryptionState { pub death: Instant, // (birth + reject-after-time - keepalive-timeout - rekey-timeout) } -pub struct DecryptionState { +pub struct DecryptionState> { pub keypair: Arc, pub confirmed: AtomicBool, pub protector: Mutex, - pub peer: Arc>, + pub peer: Arc>, pub death: Instant, // time when the key can no longer be used for decryption } -pub struct Device { - state: Arc>, // reference to device state +pub struct Device> { + state: Arc>, // reference to device state handles: Vec>, // join handles for workers } -impl Drop for Device { +impl> Drop for Device { fn drop(&mut self) { debug!("router: dropping device"); @@ -83,10 +85,10 @@ impl Drop for Device { } #[inline(always)] -fn get_route( - device: &Arc>, +fn get_route>( + device: &Arc>, packet: &[u8], -) -> Option>> { +) -> Option>> { // ensure version access within bounds if packet.len() < 1 { return None; @@ -122,12 +124,12 @@ fn get_route( } } -impl Device { - pub fn new(num_workers: usize, tun: T, bind: B) -> Device { +impl> Device { + pub fn new(num_workers: usize, tun: T) -> Device { // allocate shared device state let mut inner = DeviceInner { - tun, - bind, + inbound: tun, + outbound: RwLock::new(None), queues: Mutex::new(Vec::with_capacity(num_workers)), queue_next: AtomicUsize::new(0), recv: RwLock::new(HashMap::new()), @@ -159,7 +161,7 @@ impl Device { /// # Returns /// /// A atomic ref. counted peer (with liftime matching the device) - pub fn new_peer(&self, opaque: C::Opaque) -> Peer { + pub fn new_peer(&self, opaque: C::Opaque) -> Peer { new_peer(self.state.clone(), opaque) } @@ -199,7 +201,7 @@ impl Device { /// # Returns /// /// - pub fn recv(&self, src: B::Endpoint, msg: Vec) -> Result<(), RouterError> { + pub fn recv(&self, src: E, msg: Vec) -> Result<(), RouterError> { // parse / cast let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) { Some(v) => v, @@ -231,4 +233,11 @@ impl Device { Ok(()) } + + /// Set outbound writer + /// + /// + pub fn set_outbound_writer(&self, new : B) { + *self.state.outbound.write() = Some(new); + } } diff --git a/src/router/peer.rs b/src/router/peer.rs index 86723bb..189904c 100644 --- a/src/router/peer.rs +++ b/src/router/peer.rs @@ -14,7 +14,7 @@ use treebitmap::IpLookupTable; use zerocopy::LayoutVerified; use super::super::constants::*; -use super::super::types::{Bind, Endpoint, KeyPair, Tun}; +use super::super::types::{Endpoint, KeyPair, bind, tun}; use super::anti_replay::AntiReplay; use super::device::DecryptionState; @@ -39,28 +39,28 @@ pub struct KeyWheel { retired: Vec, // retired ids } -pub struct PeerInner { - pub device: Arc>, +pub struct PeerInner> { + pub device: Arc>, pub opaque: C::Opaque, pub outbound: Mutex>, - pub inbound: Mutex>>, + pub inbound: Mutex>>, pub staged_packets: Mutex; MAX_STAGED_PACKETS], Wrapping>>, pub keys: Mutex, pub ekey: Mutex>, - pub endpoint: Mutex>, + pub endpoint: Mutex>, } -pub struct Peer { - state: Arc>, +pub struct Peer> { + state: Arc>, thread_outbound: Option>, thread_inbound: Option>, } -fn treebit_list( - peer: &Arc>, - table: &spin::RwLock>>>, - callback: Box E>, -) -> Vec +fn treebit_list>( + peer: &Arc>, + table: &spin::RwLock>>>, + callback: Box R>, +) -> Vec where A: Address, { @@ -74,9 +74,9 @@ where res } -fn treebit_remove( - peer: &Peer, - table: &spin::RwLock>>>, +fn treebit_remove>( + peer: &Peer, + table: &spin::RwLock>>>, ) { let mut m = table.write(); @@ -107,8 +107,8 @@ impl EncryptionState { } } -impl DecryptionState { - fn new(peer: &Arc>, keypair: &Arc) -> DecryptionState { +impl> DecryptionState { + fn new(peer: &Arc>, keypair: &Arc) -> DecryptionState { DecryptionState { confirmed: AtomicBool::new(keypair.initiator), keypair: keypair.clone(), @@ -119,7 +119,7 @@ impl DecryptionState { } } -impl Drop for Peer { +impl> Drop for Peer { fn drop(&mut self) { let peer = &self.state; @@ -167,10 +167,10 @@ impl Drop for Peer { } } -pub fn new_peer( - device: Arc>, +pub fn new_peer>( + device: Arc>, opaque: C::Opaque, -) -> Peer { +) -> Peer { let (out_tx, out_rx) = sync_channel(128); let (in_tx, in_rx) = sync_channel(128); @@ -215,7 +215,7 @@ pub fn new_peer( } } -impl PeerInner { +impl> PeerInner { fn send_staged(&self) -> bool { debug!("peer.send_staged"); let mut sent = false; @@ -286,8 +286,8 @@ impl PeerInner { pub fn recv_job( &self, - src: B::Endpoint, - dec: Arc>, + src: E, + dec: Arc>, mut msg: Vec, ) -> Option { let (tx, rx) = oneshot(); @@ -370,7 +370,7 @@ impl PeerInner { } } -impl Peer { +impl> Peer { /// Set the endpoint of the peer /// /// # Arguments @@ -381,9 +381,9 @@ impl Peer { /// /// This API still permits support for the "sticky socket" behavior, /// as sockets should be "unsticked" when manually updating the endpoint - pub fn set_endpoint(&self, address: SocketAddr) { + pub fn set_endpoint(&self, endpoint: E) { debug!("peer.set_endpoint"); - *self.state.endpoint.lock() = Some(B::Endpoint::from_address(address)); + *self.state.endpoint.lock() = Some(endpoint); } /// Returns the current endpoint of the peer (for configuration) @@ -591,11 +591,12 @@ impl Peer { debug!("peer.send"); let inner = &self.state; match inner.endpoint.lock().as_ref() { - Some(endpoint) => inner - .device - .bind - .send(msg, endpoint) - .map_err(|_| RouterError::SendError), + Some(endpoint) => inner.device + .outbound + .read() + .as_ref() + .ok_or(RouterError::SendError) + .and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError) ), None => Err(RouterError::NoEndpoint), } } diff --git a/src/router/tests.rs b/src/router/tests.rs index f42e1f6..3b6b941 100644 --- a/src/router/tests.rs +++ b/src/router/tests.rs @@ -1,18 +1,18 @@ -use std::error::Error; -use std::fmt; -use std::net::{IpAddr, SocketAddr}; +use std::net::IpAddr; use std::sync::atomic::Ordering; -use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; use std::sync::Arc; use std::sync::Mutex; use std::thread; -use std::time::{Duration, Instant}; +use std::time::Duration; use num_cpus; use pnet::packet::ipv4::MutableIpv4Packet; use pnet::packet::ipv6::MutableIpv6Packet; -use super::super::types::{dummy, Bind, Endpoint, Key, KeyPair, Tun}; +use super::super::types::bind::*; +use super::super::types::tun::*; +use super::super::types::*; + use super::{Callbacks, Device, SIZE_MESSAGE_PREFIX}; extern crate test; @@ -145,8 +145,9 @@ mod tests { } // create device - let router: Device = - Device::new(num_cpus::get(), dummy::TunTest {}, dummy::VoidBind::new()); + let (_reader, tun_writer, _mtu) = dummy::TunTest::create("name").unwrap(); + let router: Device<_, BencherCallbacks, dummy::TunTest, dummy::VoidBind> = + Device::new(num_cpus::get(), tun_writer); // add new peer let opaque = Arc::new(AtomicUsize::new(0)); @@ -174,8 +175,9 @@ mod tests { init(); // create device - let router: Device = - Device::new(1, dummy::TunTest::new(), dummy::VoidBind::new()); + let (_reader, tun_writer, _mtu) = dummy::TunTest::create("name").unwrap(); + let router: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer); + router.set_outbound_writer(dummy::VoidBind::new()); let tests = vec![ ("192.168.1.0", 24, "192.168.1.20", true), @@ -315,12 +317,18 @@ mod tests { ]; for (stage, p1, p2) in tests.iter() { - // create matching devices - let (bind1, bind2) = dummy::PairBind::pair(); - let router1: Device = - Device::new(1, dummy::TunTest::new(), bind1.clone()); - let router2: Device = - Device::new(1, dummy::TunTest::new(), bind2.clone()); + let ((bind_reader1, bind_writer1), (bind_reader2, bind_writer2)) = + dummy::PairBind::pair(); + + // create matching device + let (tun_writer1, _, _) = dummy::TunTest::create("tun1").unwrap(); + let (tun_writer2, _, _) = dummy::TunTest::create("tun1").unwrap(); + + let router1: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer1); + router1.set_outbound_writer(bind_writer1); + + let router2: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer2); + router2.set_outbound_writer(bind_writer2); // prepare opaque values for tracing callbacks @@ -339,7 +347,7 @@ mod tests { let peer2 = router2.new_peer(opaq2.clone()); let mask: IpAddr = mask.parse().unwrap(); peer2.add_subnet(mask, *len); - peer2.set_endpoint("127.0.0.1:8080".parse().unwrap()); + peer2.set_endpoint(dummy::UnitEndpoint::new()); if *stage { // stage a packet which can be used for confirmation (in place of a keepalive) @@ -372,7 +380,7 @@ mod tests { // read confirming message received by the other end ("across the internet") let mut buf = vec![0u8; 2048]; - let (len, from) = bind1.recv(&mut buf).unwrap(); + let (len, from) = bind_reader1.read(&mut buf).unwrap(); buf.truncate(len); router1.recv(from, buf).unwrap(); @@ -411,7 +419,7 @@ mod tests { // receive ("across the internet") on the other end let mut buf = vec![0u8; 2048]; - let (len, from) = bind2.recv(&mut buf).unwrap(); + let (len, from) = bind_reader2.read(&mut buf).unwrap(); buf.truncate(len); router2.recv(from, buf).unwrap(); diff --git a/src/router/types.rs b/src/router/types.rs index b7c3ae0..4a72c27 100644 --- a/src/router/types.rs +++ b/src/router/types.rs @@ -1,6 +1,8 @@ use std::error::Error; use std::fmt; +use super::super::types::Endpoint; + pub trait Opaque: Send + Sync + 'static {} impl Opaque for T where T: Send + Sync + 'static {} diff --git a/src/router/workers.rs b/src/router/workers.rs index 6710816..2e89bb0 100644 --- a/src/router/workers.rs +++ b/src/router/workers.rs @@ -17,7 +17,7 @@ use super::messages::{TransportHeader, TYPE_TRANSPORT}; use super::peer::PeerInner; use super::types::Callbacks; -use super::super::types::{Bind, Tun}; +use super::super::types::{Endpoint, tun, bind}; use super::ip::*; const SIZE_TAG: usize = 16; @@ -38,18 +38,18 @@ pub struct JobBuffer { pub type JobParallel = (oneshot::Sender, JobBuffer); #[allow(type_alias_bounds)] -pub type JobInbound = ( - Arc>, - B::Endpoint, +pub type JobInbound> = ( + Arc>, + E, oneshot::Receiver, ); pub type JobOutbound = oneshot::Receiver; #[inline(always)] -fn check_route( - device: &Arc>, - peer: &Arc>, +fn check_route>( + device: &Arc>, + peer: &Arc>, packet: &[u8], ) -> Option { match packet[0] >> 4 { @@ -93,10 +93,10 @@ fn check_route( } } -pub fn worker_inbound( - device: Arc>, // related device - peer: Arc>, // related peer - receiver: Receiver>, +pub fn worker_inbound>( + device: Arc>, // related device + peer: Arc>, // related peer + receiver: Receiver>, ) { loop { // fetch job @@ -153,7 +153,7 @@ pub fn worker_inbound( if let Some(inner_len) = check_route(&device, &peer, &packet[..length]) { debug_assert!(inner_len <= length, "should be validated"); if inner_len <= length { - sent = match device.tun.write(&packet[..inner_len]) { + sent = match device.inbound.write(&packet[..inner_len]) { Err(e) => { debug!("failed to write inbound packet to TUN: {:?}", e); false @@ -176,9 +176,9 @@ pub fn worker_inbound( } } -pub fn worker_outbound( - device: Arc>, // related device - peer: Arc>, // related peer +pub fn worker_outbound>( + device: Arc>, // related device + peer: Arc>, // related peer receiver: Receiver, ) { loop { @@ -198,12 +198,17 @@ pub fn worker_outbound( if buf.okay { // write to UDP bind let xmit = if let Some(dst) = peer.endpoint.lock().as_ref() { - match device.bind.send(&buf.msg[..], dst) { - Err(e) => { - debug!("failed to send outbound packet: {:?}", e); - false + let send : &Option = &*device.outbound.read(); + if let Some(writer) = send.as_ref() { + match writer.write(&buf.msg[..], dst) { + Err(e) => { + debug!("failed to send outbound packet: {:?}", e); + false + } + Ok(_) => true, } - Ok(_) => true, + } else { + false } } else { false diff --git a/src/timers.rs b/src/timers.rs index 303fd35..67ece06 100644 --- a/src/timers.rs +++ b/src/timers.rs @@ -7,7 +7,7 @@ use hjul::{Runner, Timer}; use crate::constants::*; use crate::router::Callbacks; -use crate::types::{Bind, Tun}; +use crate::types::{tun, bind}; use crate::wireguard::{Peer, PeerInner}; pub struct Timers { @@ -23,8 +23,8 @@ pub struct Timers { impl Timers { pub fn new(runner: &Runner, peer: Peer) -> Timers where - T: Tun, - B: Bind, + T: tun::Tun, + B: bind::Bind, { // create a timer instance for the provided peer Timers { @@ -103,7 +103,7 @@ impl Timers { pub struct Events(PhantomData<(T, B)>); -impl Callbacks for Events { +impl Callbacks for Events { type Opaque = Arc>; fn send(peer: &Self::Opaque, size: usize, data: bool, sent: bool) { diff --git a/src/types/bind.rs b/src/types/bind.rs index 62adbbb..fcc38c8 100644 --- a/src/types/bind.rs +++ b/src/types/bind.rs @@ -1,73 +1,28 @@ use super::Endpoint; -use std::error; +use std::error::Error; -/// Traits representing the "internet facing" end of the VPN. -/// -/// In practice this is a UDP socket (but the router interface is agnostic). -/// Often these traits will be implemented on the same type. +pub trait Reader: Send + Sync { + type Error: Error; -/// Bind interface provided to the router code -pub trait RouterBind: Send + Sync { - type Error: error::Error; - type Endpoint: Endpoint; + fn read(&self, buf: &mut [u8]) -> Result<(usize, E), Self::Error>; +} - /// Receive a buffer on the bind - /// - /// # Arguments - /// - /// - `buf`, buffer for storing the packet. If the buffer is too short, the packet should just be truncated. - /// - /// # Note - /// - /// The size of the buffer is derieved from the MTU of the Tun device. - fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error>; +pub trait Writer: Send + Sync + Clone + 'static { + type Error: Error; - /// Send a buffer to the endpoint - /// - /// # Arguments - /// - /// - `buf`, packet src buffer (in practice the body of a UDP datagram) - /// - `dst`, destination endpoint (in practice, src: (ip, port) + dst: (ip, port) for sticky sockets) - /// - /// # Returns - /// - /// The unit type or an error if transmission failed - fn send(&self, buf: &[u8], dst: &Self::Endpoint) -> Result<(), Self::Error>; + fn write(&self, buf: &[u8], dst: &E) -> Result<(), Self::Error>; } -/// Bind interface provided for configuration (setting / getting the port) -pub trait ConfigBind { - type Error: error::Error; - - /// Return a new (unbound) instance of a configuration bind - fn new() -> Self; +pub trait Bind: Send + Sync + 'static { + type Error: Error; + type Endpoint: Endpoint; - /// Updates the port of the bind - /// - /// # Arguments - /// - /// - `port`, the new port to bind to. 0 means any available port. - /// - /// # Returns - /// - /// The unit type or an error, if binding fails - fn set_port(&self, port: u16) -> Result<(), Self::Error>; + /* Until Rust gets type equality constraints these have to be generic */ + type Writer: Writer; + type Reader: Reader; - /// Returns the current port of the bind - fn get_port(&self) -> Option; + /* Used to close the reader/writer when binding to a new port */ + type Closer; - /// Set the mark (e.g. on Linus this is the fwmark) on the bind - /// - /// # Arguments - /// - /// - `mark`, the mark to set - /// - /// # Note - /// - /// The mark should be retained accross calls to `set_port`. - /// - /// # Returns - /// - /// The unit type or an error, if the operation fails due to permission errors - fn set_mark(&self, mark: u16) -> Result<(), Self::Error>; + fn bind(port: u16) -> Result<(Self::Reader, Self::Writer, Self::Closer, u16), Self::Error>; } diff --git a/src/types/dummy.rs b/src/types/dummy.rs index e15abb0..40a3bdd 100644 --- a/src/types/dummy.rs +++ b/src/types/dummy.rs @@ -5,8 +5,9 @@ use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; use std::sync::Arc; use std::sync::Mutex; use std::time::Instant; +use std::marker; -use super::{Bind, Endpoint, Key, KeyPair, Tun}; +use super::*; /* This submodule provides pure/dummy implementations of the IO interfaces * for use in unit tests thoughout the project. @@ -72,104 +73,103 @@ impl Endpoint for UnitEndpoint { } } +impl UnitEndpoint { + pub fn new() -> UnitEndpoint { + UnitEndpoint{} + } +} + +/* */ + #[derive(Clone, Copy)] pub struct TunTest {} -impl Tun for TunTest { +impl tun::Reader for TunTest { type Error = TunError; + fn read(&self, _buf: &mut [u8], _offset: usize) -> Result { + Ok(0) + } +} + +impl tun::MTU for TunTest { fn mtu(&self) -> usize { 1500 } +} - fn read(&self, _buf: &mut [u8], _offset: usize) -> Result { - Ok(0) - } +impl tun::Writer for TunTest { + type Error = TunError; fn write(&self, _src: &[u8]) -> Result<(), Self::Error> { Ok(()) } } +impl tun::Tun for TunTest { + type Writer = TunTest; + type Reader = TunTest; + type MTU = TunTest; + type Error = TunError; +} + impl TunTest { - pub fn new() -> TunTest { - TunTest {} + pub fn create(_name: &str) -> Result<(TunTest, TunTest, TunTest), TunError> { + Ok((TunTest {},TunTest {}, TunTest{})) } } -/* Bind implemenentations */ +/* Void Bind */ #[derive(Clone, Copy)] pub struct VoidBind {} -impl Bind for VoidBind { +impl bind::Reader for VoidBind { type Error = BindError; - type Endpoint = UnitEndpoint; - fn new() -> VoidBind { - VoidBind {} - } - - fn set_port(&self, _port: u16) -> Result<(), Self::Error> { - Ok(()) - } - - fn get_port(&self) -> Option { - None - } - - fn recv(&self, _buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> { + fn read(&self, _buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> { Ok((0, UnitEndpoint {})) } - - fn send(&self, _buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> { - Ok(()) - } } -#[derive(Clone)] -pub struct PairBind { - send: Arc>>>, - recv: Arc>>>, -} +impl bind::Writer for VoidBind { + type Error = BindError; -impl PairBind { - pub fn pair() -> (PairBind, PairBind) { - let (tx1, rx1) = sync_channel(128); - let (tx2, rx2) = sync_channel(128); - ( - PairBind { - send: Arc::new(Mutex::new(tx1)), - recv: Arc::new(Mutex::new(rx2)), - }, - PairBind { - send: Arc::new(Mutex::new(tx2)), - recv: Arc::new(Mutex::new(rx1)), - }, - ) + fn write(&self, _buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> { + Ok(()) } } -impl Bind for PairBind { +impl bind::Bind for VoidBind { type Error = BindError; type Endpoint = UnitEndpoint; - fn new() -> PairBind { - PairBind { - send: Arc::new(Mutex::new(sync_channel(0).0)), - recv: Arc::new(Mutex::new(sync_channel(0).1)), - } - } + type Reader = VoidBind; + type Writer = VoidBind; + type Closer = (); - fn set_port(&self, _port: u16) -> Result<(), Self::Error> { - Ok(()) + fn bind(_ : u16) -> Result<(Self::Reader, Self::Writer, Self::Closer, u16), Self::Error> { + Ok((VoidBind{}, VoidBind{}, (), 2600)) } +} - fn get_port(&self) -> Option { - None +impl VoidBind { + pub fn new() -> VoidBind { + VoidBind{} } +} - fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> { +/* Pair Bind */ + +#[derive(Clone)] +pub struct PairReader { + recv: Arc>>>, + _marker: marker::PhantomData, +} + +impl bind::Reader for PairReader { + type Error = BindError; + fn read(&self, buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> { let vec = self .recv .lock() @@ -180,8 +180,11 @@ impl Bind for PairBind { buf[..len].copy_from_slice(&vec[..]); Ok((vec.len(), UnitEndpoint {})) } +} - fn send(&self, buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> { +impl bind::Writer for PairWriter { + type Error = BindError; + fn write(&self, buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> { let owned = buf.to_owned(); match self.send.lock().unwrap().send(owned) { Err(_) => Err(BindError::Disconnected), @@ -190,6 +193,57 @@ impl Bind for PairBind { } } +#[derive(Clone)] +pub struct PairWriter { + send: Arc>>>, + _marker: marker::PhantomData, +} + +#[derive(Clone)] +pub struct PairBind {} + +impl PairBind { + pub fn pair() -> ((PairReader, PairWriter), (PairReader, PairWriter)) { + let (tx1, rx1) = sync_channel(128); + let (tx2, rx2) = sync_channel(128); + ( + ( + PairReader{ + + recv: Arc::new(Mutex::new(rx1)), + _marker: marker::PhantomData + }, + PairWriter{ + send: Arc::new(Mutex::new(tx2)), + _marker: marker::PhantomData + } + ), + ( + PairReader{ + recv: Arc::new(Mutex::new(rx2)), + _marker: marker::PhantomData + }, + PairWriter{ + send: Arc::new(Mutex::new(tx1)), + _marker: marker::PhantomData + } + ), + ) + } +} + +impl bind::Bind for PairBind { + type Closer = (); + type Error = BindError; + type Endpoint = UnitEndpoint; + type Reader = PairReader; + type Writer = PairWriter; + + fn bind(_port: u16) -> Result<(Self::Reader, Self::Writer, Self::Closer, u16), Self::Error> { + Err(BindError::Disconnected) + } +} + pub fn keypair(initiator: bool) -> KeyPair { let k1 = Key { key: [0x53u8; 32], diff --git a/src/types/endpoint.rs b/src/types/endpoint.rs index 261203f..74796aa 100644 --- a/src/types/endpoint.rs +++ b/src/types/endpoint.rs @@ -1,6 +1,6 @@ use std::net::SocketAddr; -pub trait Endpoint: Send { +pub trait Endpoint: Send + 'static { fn from_address(addr: SocketAddr) -> Self; fn into_address(&self) -> SocketAddr; } diff --git a/src/types/mod.rs b/src/types/mod.rs index 07ca44d..e0725f3 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,12 +1,10 @@ mod endpoint; mod keys; -mod tun; -mod udp; +pub mod tun; +pub mod bind; #[cfg(test)] pub mod dummy; pub use endpoint::Endpoint; -pub use keys::{Key, KeyPair}; -pub use tun::Tun; -pub use udp::Bind; +pub use keys::{Key, KeyPair}; \ No newline at end of file diff --git a/src/types/tun.rs b/src/types/tun.rs index fc8044a..2ba16ff 100644 --- a/src/types/tun.rs +++ b/src/types/tun.rs @@ -1,18 +1,22 @@ -use std::error; +use std::error::Error; -pub trait Tun: Send + Sync + Clone + 'static { - type Error: error::Error; +pub trait Writer: Send + Sync + 'static { + type Error: Error; - /// Returns the MTU of the device + /// Receive a cryptkey routed IP packet /// - /// This function needs to be efficient (called for every read). - /// The goto implementation strategy is to .load an atomic variable, - /// then use e.g. netlink to update the variable in a separate thread. + /// # Arguments + /// + /// - src: Buffer containing the IP packet to be written /// /// # Returns /// - /// The MTU of the interface in bytes - fn mtu(&self) -> usize; + /// Unit type or an error + fn write(&self, src: &[u8]) -> Result<(), Self::Error>; +} + +pub trait Reader: Send + 'static { + type Error: Error; /// Reads an IP packet into dst[offset:] from the tunnel device /// @@ -29,15 +33,24 @@ pub trait Tun: Send + Sync + Clone + 'static { /// /// The size of the IP packet (ignoring the header) or an std::error::Error instance: fn read(&self, buf: &mut [u8], offset: usize) -> Result; +} - /// Writes an IP packet to the tunnel device - /// - /// # Arguments +pub trait MTU: Send + Sync + Clone + 'static { + /// Returns the MTU of the device /// - /// - src: Buffer containing the IP packet to be written + /// This function needs to be efficient (called for every read). + /// The goto implementation strategy is to .load an atomic variable, + /// then use e.g. netlink to update the variable in a separate thread. /// /// # Returns /// - /// Unit type or an error - fn write(&self, src: &[u8]) -> Result<(), Self::Error>; + /// The MTU of the interface in bytes + fn mtu(&self) -> usize; +} + +pub trait Tun: Send + Sync + 'static { + type Writer: Writer; + type Reader: Reader; + type MTU: MTU; + type Error: Error; } diff --git a/src/types/udp.rs b/src/types/udp.rs deleted file mode 100644 index 943bf94..0000000 --- a/src/types/udp.rs +++ /dev/null @@ -1,29 +0,0 @@ -use super::Endpoint; -use std::error; - -/* Often times an a file descriptor in an atomic might suffice. - */ -pub trait Bind: Send + Sync + Clone + 'static { - type Error: error::Error + Send; - type Endpoint: Endpoint; - - fn new() -> Self; - - /// Updates the port of the Bind - /// - /// # Arguments - /// - /// - port, The new port to bind to. 0 means any available port. - /// - /// # Returns - /// - /// The unit type or an error, if binding fails - fn set_port(&self, port: u16) -> Result<(), Self::Error>; - - /// Returns the current port of the bind - fn get_port(&self) -> Option; - - fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error>; - - fn send(&self, buf: &[u8], dst: &Self::Endpoint) -> Result<(), Self::Error>; -} diff --git a/src/wireguard.rs b/src/wireguard.rs index ea600d0..ba81f47 100644 --- a/src/wireguard.rs +++ b/src/wireguard.rs @@ -2,11 +2,13 @@ use crate::constants::*; use crate::handshake; use crate::router; use crate::timers::{Events, Timers}; -use crate::types::{Bind, Endpoint, Tun}; + +use crate::types::Endpoint; +use crate::types::tun::{Tun, Reader, MTU}; +use crate::types::bind::{Bind, Writer}; use hjul::Runner; -use std::cmp; use std::ops::Deref; use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; @@ -27,12 +29,20 @@ const SIZE_HANDSHAKE_QUEUE: usize = 128; const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4; const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000); -#[derive(Clone)] pub struct Peer { - pub router: Arc, T, B>>, + pub router: Arc, T::Writer, B::Writer>>, pub state: Arc>, } +impl Clone for Peer { + fn clone(&self) -> Peer { + Peer{ + router: self.router.clone(), + state: self.state.clone() + } + } +} + pub struct PeerInner { pub keepalive: AtomicUsize, // keepalive interval pub rx_bytes: AtomicU64, @@ -66,20 +76,22 @@ pub enum HandshakeJob { } struct WireguardInner { + // provides access to the MTU value of the tun device + // (otherwise owned solely by the router and a dedicated read IO thread) + mtu: T::MTU, + send: RwLock>, + // identify and configuration map peers: RwLock>>, // cryptkey router - router: router::Device, T, B>, + router: router::Device, T::Writer, B::Writer>, // handshake related state handshake: RwLock, under_load: AtomicBool, pending: AtomicUsize, // num of pending handshake packets in queue queue: Mutex>>, - - // IO - bind: B, } pub struct Wireguard { @@ -87,6 +99,17 @@ pub struct Wireguard { state: Arc>, } +/* Returns the padded length of a message: + * + * # Arguments + * + * - `size` : Size of unpadded message + * - `mtu` : Maximum transmission unit of the device + * + * # Returns + * + * The padded length (always less than or equal to the MTU) + */ #[inline(always)] const fn padding(size: usize, mtu: usize) -> usize { #[inline(always)] @@ -114,6 +137,15 @@ impl Wireguard { } } + pub fn get_sk(&self) -> Option { + let mut handshake = self.state.handshake.read(); + if handshake.active { + Some(handshake.device.get_sk()) + } else { + None + } + } + pub fn new_peer(&self, pk: PublicKey) -> Peer { let state = Arc::new(PeerInner { pk, @@ -137,20 +169,92 @@ impl Wireguard { peer } - pub fn new(tun: T, bind: B) -> Wireguard { + pub fn new_bind( + reader: B::Reader, + writer: B::Writer, + closer: B::Closer + ) { + + // drop existing closer + + + // swap IO thread for new reader + + + // start UDP read IO thread + + /* + { + let wg = wg.clone(); + let mtu = mtu.clone(); + thread::spawn(move || { + let mut last_under_load = + Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000); + + loop { + // create vector big enough for any message given current MTU + let size = mtu.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE; + let mut msg: Vec = Vec::with_capacity(size); + msg.resize(size, 0); + + // read UDP packet into vector + let (size, src) = reader.read(&mut msg).unwrap(); // TODO handle error + msg.truncate(size); + + // message type de-multiplexer + if msg.len() < std::mem::size_of::() { + continue; + } + match LittleEndian::read_u32(&msg[..]) { + handshake::TYPE_COOKIE_REPLY + | handshake::TYPE_INITIATION + | handshake::TYPE_RESPONSE => { + // update under_load flag + if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD { + last_under_load = Instant::now(); + wg.under_load.store(true, Ordering::SeqCst); + } else if last_under_load.elapsed() > DURATION_UNDER_LOAD { + wg.under_load.store(false, Ordering::SeqCst); + } + + wg.queue + .lock() + .send(HandshakeJob::Message(msg, src)) + .unwrap(); + } + router::TYPE_TRANSPORT => { + // transport message + let _ = wg.router.recv(src, msg); + } + _ => (), + } + } + }); + } + */ + + + } + + pub fn new( + reader: T::Reader, + writer: T::Writer, + mtu: T::MTU, + ) -> Wireguard { // create device state let mut rng = OsRng::new().unwrap(); let (tx, rx): (Sender>, _) = bounded(SIZE_HANDSHAKE_QUEUE); let wg = Arc::new(WireguardInner { + mtu: mtu.clone(), peers: RwLock::new(HashMap::new()), - router: router::Device::new(num_cpus::get(), tun.clone(), bind.clone()), + send: RwLock::new(None), + router: router::Device::new(num_cpus::get(), writer), // router owns the writing half pending: AtomicUsize::new(0), handshake: RwLock::new(Handshake { device: handshake::Device::new(StaticSecret::new(&mut rng)), active: false, }), under_load: AtomicBool::new(false), - bind: bind.clone(), queue: Mutex::new(tx), }); @@ -158,7 +262,6 @@ impl Wireguard { for _ in 0..num_cpus::get() { let wg = wg.clone(); let rx = rx.clone(); - let bind = bind.clone(); thread::spawn(move || { // prepare OsRng instance for this thread let mut rng = OsRng::new().unwrap(); @@ -189,19 +292,22 @@ impl Wireguard { Ok((pk, msg, keypair)) => { // send response if let Some(msg) = msg { - let _ = bind.send(&msg[..], &src).map_err(|e| { - debug!( - "handshake worker, failed to send response, error = {:?}", - e - ) - }); + let send : &Option = &*wg.send.read(); + if let Some(writer) = send.as_ref() { + let _ = writer.write(&msg[..], &src).map_err(|e| { + debug!( + "handshake worker, failed to send response, error = {:?}", + e + ) + }); + } } // update timers if let Some(pk) = pk { if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { - // update endpoint (DISCUSS: right semantics?) - peer.router.set_endpoint(src_validate); + // update endpoint + peer.router.set_endpoint(src); // add keypair to peer and free any unused ids if let Some(keypair) = keypair { @@ -227,68 +333,18 @@ impl Wireguard { }); } - // start UDP read IO thread - { - let wg = wg.clone(); - let tun = tun.clone(); - let bind = bind.clone(); - thread::spawn(move || { - let mut last_under_load = - Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000); - - loop { - // create vector big enough for any message given current MTU - let size = tun.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE; - let mut msg: Vec = Vec::with_capacity(size); - msg.resize(size, 0); - - // read UDP packet into vector - let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error - msg.truncate(size); - - // message type de-multiplexer - if msg.len() < std::mem::size_of::() { - continue; - } - match LittleEndian::read_u32(&msg[..]) { - handshake::TYPE_COOKIE_REPLY - | handshake::TYPE_INITIATION - | handshake::TYPE_RESPONSE => { - // update under_load flag - if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD { - last_under_load = Instant::now(); - wg.under_load.store(true, Ordering::SeqCst); - } else if last_under_load.elapsed() > DURATION_UNDER_LOAD { - wg.under_load.store(false, Ordering::SeqCst); - } - - wg.queue - .lock() - .send(HandshakeJob::Message(msg, src)) - .unwrap(); - } - router::TYPE_TRANSPORT => { - // transport message - let _ = wg.router.recv(src, msg); - } - _ => (), - } - } - }); - } - // start TUN read IO thread { let wg = wg.clone(); thread::spawn(move || loop { // create vector big enough for any transport message (based on MTU) - let mtu = tun.mtu(); + let mtu = mtu.mtu(); let size = mtu + router::SIZE_MESSAGE_PREFIX; let mut msg: Vec = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX); msg.resize(size, 0); // read a new IP packet - let payload = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap(); + let payload = reader.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap(); debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu); // truncate padding -- cgit v1.2.3-59-g8ed1b