diff options
-rw-r--r-- | Cargo.lock | 20 | ||||
-rw-r--r-- | Cargo.toml | 1 | ||||
-rw-r--r-- | src/wireguard/mod.rs | 2 | ||||
-rw-r--r-- | src/wireguard/peer.rs | 3 | ||||
-rw-r--r-- | src/wireguard/queue.rs | 64 | ||||
-rw-r--r-- | src/wireguard/router/device.rs | 6 | ||||
-rw-r--r-- | src/wireguard/router/mod.rs | 4 | ||||
-rw-r--r-- | src/wireguard/router/queue.rs | 46 | ||||
-rw-r--r-- | src/wireguard/wireguard.rs | 30 |
9 files changed, 84 insertions, 92 deletions
@@ -166,23 +166,6 @@ dependencies = [ ] [[package]] -name = "crossbeam-channel" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "crossbeam-utils 0.6.6 (registry+https://github.com/rust-lang/crates.io-index)", -] - -[[package]] -name = "crossbeam-utils" -version = "0.6.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "cfg-if 0.1.9 (registry+https://github.com/rust-lang/crates.io-index)", - "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", -] - -[[package]] name = "crypto-mac" version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1204,7 +1187,6 @@ dependencies = [ "byteorder 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)", "chacha20poly1305 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", "clear_on_drop 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)", - "crossbeam-channel 0.3.9 (registry+https://github.com/rust-lang/crates.io-index)", "daemonize 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)", "digest 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)", "env_logger 0.6.2 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1295,8 +1277,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum chacha20poly1305 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "40cd3ddeae0b0ea7fe848a06e4fbf3f02463648b9395bd1139368ce42b44543e" "checksum clear_on_drop 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "97276801e127ffb46b66ce23f35cc96bd454fa311294bced4bbace7baa8b1d17" "checksum cloudabi 0.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "ddfc5b9aa5d4507acaf872de71051dfd0e309860e88966e1051e462a077aac4f" -"checksum crossbeam-channel 0.3.9 (registry+https://github.com/rust-lang/crates.io-index)" = "c8ec7fcd21571dc78f96cc96243cab8d8f035247c3efd16c687be154c3fa9efa" -"checksum crossbeam-utils 0.6.6 (registry+https://github.com/rust-lang/crates.io-index)" = "04973fa96e96579258a5091af6003abde64af786b860f18622b82e026cca60e6" "checksum crypto-mac 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "4434400df11d95d556bac068ddfedd482915eb18fe8bea89bc80b6e4b1c179e5" "checksum curve25519-dalek 1.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "8b7dcd30ba50cdf88b55b033456138b7c0ac4afdc436d82e1b79f370f24cc66d" "checksum daemonize 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)" = "70c24513e34f53b640819f0ac9f705b673fcf4006d7aab8778bee72ebfc89815" @@ -18,7 +18,6 @@ byteorder = "1.3.1" digest = "0.8.0" arraydeque = "0.4.5" treebitmap = "^0.4" -crossbeam-channel = "0.3.9" hjul = "0.2.1" ring = "0.16.7" chacha20poly1305 = "^0.1" diff --git a/src/wireguard/mod.rs b/src/wireguard/mod.rs index 711aa2b..f899359 100644 --- a/src/wireguard/mod.rs +++ b/src/wireguard/mod.rs @@ -5,6 +5,7 @@ mod wireguard; mod endpoint; mod handshake; mod peer; +mod queue; mod router; mod types; @@ -23,4 +24,3 @@ use super::platform::dummy; use super::platform::{tun, udp, Endpoint}; use peer::PeerInner; use types::KeyPair; -use wireguard::HandshakeJob; diff --git a/src/wireguard/peer.rs b/src/wireguard/peer.rs index 04622fd..448db96 100644 --- a/src/wireguard/peer.rs +++ b/src/wireguard/peer.rs @@ -1,6 +1,5 @@ use super::router; use super::timers::{Events, Timers}; -use super::HandshakeJob; use super::tun::Tun; use super::udp::UDP; @@ -14,7 +13,6 @@ use std::time::{Instant, SystemTime}; use spin::{Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard}; -use crossbeam_channel::Sender; use x25519_dalek::PublicKey; pub struct Peer<T: Tun, B: UDP> { @@ -33,7 +31,6 @@ pub struct PeerInner<T: Tun, B: UDP> { pub walltime_last_handshake: Mutex<Option<SystemTime>>, pub last_handshake_sent: Mutex<Instant>, // instant for last handshake pub handshake_queued: AtomicBool, // is a handshake job currently queued for the peer? - pub queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, // handshake queue // stats and configuration pub pk: PublicKey, // public key, DISCUSS: avoid this. TODO: remove diff --git a/src/wireguard/queue.rs b/src/wireguard/queue.rs new file mode 100644 index 0000000..a0fcf03 --- /dev/null +++ b/src/wireguard/queue.rs @@ -0,0 +1,64 @@ +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::mpsc::sync_channel; +use std::sync::mpsc::{Receiver, SyncSender}; +use std::sync::Mutex; + +/// A simple parallel queue used to pass work to a worker pool. +/// +/// Unlike e.g. the crossbeam multi-producer multi-consumer queue +/// the ParallelQueue offers fewer features and instead improves speed: +/// +/// The crossbeam channel ensures that elements are consumed +/// even if not every Receiver is being read from. +/// This is not ensured by ParallelQueue. +pub struct ParallelQueue<T> { + next: AtomicUsize, // next round-robin index + queues: Vec<Mutex<Option<SyncSender<T>>>>, // work queues (1 per thread) +} + +impl<T> ParallelQueue<T> { + /// Create a new ParallelQueue instance + /// + /// # Arguments + /// + /// - `queues`: number of readers/writers + /// - `capacity`: capacity of each internal queue + /// + pub fn new(queues: usize, capacity: usize) -> (Self, Vec<Receiver<T>>) { + let mut rxs = vec![]; + let mut txs = vec![]; + + for _ in 0..queues { + let (tx, rx) = sync_channel(capacity); + txs.push(Mutex::new(Some(tx))); + rxs.push(rx); + } + + ( + ParallelQueue { + next: AtomicUsize::new(0), + queues: txs, + }, + rxs, + ) + } + + pub fn send(&self, v: T) { + let len = self.queues.len(); + let idx = self.next.fetch_add(1, Ordering::SeqCst); + match self.queues[idx % len].lock().unwrap().as_ref() { + Some(que) => { + // TODO: consider best way to propergate Result + let _ = que.send(v); + } + _ => (), + } + } + + pub fn close(&self) { + for i in 0..self.queues.len() { + let queue = &self.queues[i]; + *queue.lock().unwrap() = None; + } + } +} diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs index febea45..1d3b743 100644 --- a/src/wireguard/router/device.rs +++ b/src/wireguard/router/device.rs @@ -24,7 +24,7 @@ use super::route::RoutingTable; use super::runq::RunQueue; use super::super::{tun, udp, Endpoint, KeyPair}; -use super::queue::ParallelQueue; +use super::ParallelQueue; pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { // inbound writer (TUN) @@ -125,8 +125,8 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<E, C, T, B> { pub fn new(num_workers: usize, tun: T) -> DeviceHandle<E, C, T, B> { // allocate shared device state - let (mut outrx, queue_outbound) = ParallelQueue::new(num_workers); - let (mut inrx, queue_inbound) = ParallelQueue::new(num_workers); + let (queue_outbound, mut outrx) = ParallelQueue::new(num_workers, 128); + let (queue_inbound, mut inrx) = ParallelQueue::new(num_workers, 128); let device = Device { inner: Arc::new(DeviceInner { inbound: tun, diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs index 49a4f96..8238d32 100644 --- a/src/wireguard/router/mod.rs +++ b/src/wireguard/router/mod.rs @@ -7,13 +7,10 @@ mod messages; mod outbound; mod peer; mod pool; -mod queue; mod route; mod runq; mod types; -// mod workers; - #[cfg(test)] mod tests; @@ -21,6 +18,7 @@ use messages::TransportHeader; use std::mem; use super::constants::REJECT_AFTER_MESSAGES; +use super::queue::ParallelQueue; use super::types::*; use super::{tun, udp, Endpoint}; diff --git a/src/wireguard/router/queue.rs b/src/wireguard/router/queue.rs deleted file mode 100644 index 5d0165c..0000000 --- a/src/wireguard/router/queue.rs +++ /dev/null @@ -1,46 +0,0 @@ -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::mpsc::sync_channel; -use std::sync::mpsc::{Receiver, SyncSender}; - -use spin::Mutex; - -pub struct ParallelQueue<T> { - next: AtomicUsize, // next round-robin index - queues: Vec<Mutex<SyncSender<T>>>, // work queues (1 per thread) -} - -impl<T> ParallelQueue<T> { - pub fn new(queues: usize) -> (Vec<Receiver<T>>, Self) { - let mut rxs = vec![]; - let mut txs = vec![]; - - for _ in 0..queues { - let (tx, rx) = sync_channel(128); - txs.push(Mutex::new(tx)); - rxs.push(rx); - } - - ( - rxs, - ParallelQueue { - next: AtomicUsize::new(0), - queues: txs, - }, - ) - } - - pub fn send(&self, v: T) { - let len = self.queues.len(); - let idx = self.next.fetch_add(1, Ordering::SeqCst); - let que = self.queues[idx % len].lock(); - que.send(v).unwrap(); - } - - pub fn close(&self) { - for i in 0..self.queues.len() { - let (tx, _) = sync_channel(0); - let queue = &self.queues[i]; - *queue.lock() = tx; - } - } -} diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index 2b0e779..d0c0e53 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -4,6 +4,7 @@ use super::router; use super::timers::{Events, Timers}; use super::{Peer, PeerInner}; +use super::queue::ParallelQueue; use super::tun; use super::tun::Reader as TunReader; @@ -35,7 +36,6 @@ use rand::Rng; use spin::{Mutex, RwLock}; use byteorder::{ByteOrder, LittleEndian}; -use crossbeam_channel::{bounded, Sender}; use x25519_dalek::{PublicKey, StaticSecret}; const SIZE_HANDSHAKE_QUEUE: usize = 128; @@ -99,7 +99,7 @@ pub struct WireguardInner<T: tun::Tun, B: udp::UDP> { handshake: RwLock<handshake::Device>, under_load: AtomicBool, pending: AtomicUsize, // num of pending handshake packets in queue - queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, + queue: ParallelQueue<HandshakeJob<B::Endpoint>>, } impl<T: tun::Tun, B: udp::UDP> PeerInner<T, B> { @@ -123,7 +123,7 @@ impl<T: tun::Tun, B: udp::UDP> PeerInner<T, B> { if !self.handshake_queued.swap(true, Ordering::SeqCst) { self.wg.pending.fetch_add(1, Ordering::SeqCst); - self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap(); + self.wg.queue.send(HandshakeJob::New(self.pk)); } } } @@ -288,7 +288,6 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> { walltime_last_handshake: Mutex::new(None), last_handshake_sent: Mutex::new(Instant::now() - TIME_HORIZON), handshake_queued: AtomicBool::new(false), - queue: Mutex::new(self.state.queue.lock().clone()), rx_bytes: AtomicU64::new(0), tx_bytes: AtomicU64::new(0), timers: RwLock::new(Timers::dummy(&self.runner)), @@ -386,10 +385,7 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> { } // add to handshake queue - wg.queue - .lock() - .send(HandshakeJob::Message(msg, src)) - .unwrap(); + wg.queue.send(HandshakeJob::Message(msg, src)); } router::TYPE_TRANSPORT => { debug!("{} : reader, received transport message", wg); @@ -477,12 +473,14 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> { } pub fn new(writer: T::Writer) -> Wireguard<T, B> { + // workers equal to number of physical cores + let cpus = num_cpus::get(); + // create device state let mut rng = OsRng::new().unwrap(); // handshake queue - let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE); - + let (tx, mut rxs) = ParallelQueue::new(cpus, 128); let wg = Arc::new(WireguardInner { enabled: RwLock::new(false), tun_readers: WaitHandle::new(), @@ -494,13 +492,12 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> { pending: AtomicUsize::new(0), handshake: RwLock::new(handshake::Device::new()), under_load: AtomicBool::new(false), - queue: Mutex::new(tx), + queue: tx, }); // start handshake workers - for _ in 0..num_cpus::get() { + while let Some(rx) = rxs.pop() { let wg = wg.clone(); - let rx = rx.clone(); thread::spawn(move || { debug!("{} : handshake worker, started", wg); @@ -509,16 +506,18 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> { // process elements from the handshake queue for job in rx { - // decrement pending + // decrement pending pakcets (under_load) + let job: HandshakeJob<B::Endpoint> = job; wg.pending.fetch_sub(1, Ordering::SeqCst); - let device = wg.handshake.read(); + // demultiplex staged handshake jobs and handshake messages match job { HandshakeJob::Message(msg, src) => { // feed message to handshake device let src_validate = (&src).into_address(); // TODO avoid // process message + let device = wg.handshake.read(); match device.process( &mut rng, &msg[..], @@ -599,6 +598,7 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> { "{} : handshake worker, new handshake requested for {}", wg, peer ); + let device = wg.handshake.read(); let _ = device.begin(&mut rng, &peer.pk).map(|msg| { let _ = peer.router.send(&msg[..]).map_err(|e| { debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e) |