diff options
author | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-12-03 21:49:08 +0100 |
---|---|---|
committer | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-12-03 21:49:08 +0100 |
commit | 5a7f762d6ce6b5bbdbd10f5966adc909597f37d6 (patch) | |
tree | b53fa0c1ee02c1e211d6cf94c6ba0334135ec42e /src/wireguard/router/device.rs | |
parent | Close socket fd after getmtu ioctl (diff) | |
download | wireguard-rs-5a7f762d6ce6b5bbdbd10f5966adc909597f37d6.tar.xz wireguard-rs-5a7f762d6ce6b5bbdbd10f5966adc909597f37d6.zip |
Moving away from peer threads
Diffstat (limited to 'src/wireguard/router/device.rs')
-rw-r--r-- | src/wireguard/router/device.rs | 141 |
1 files changed, 108 insertions, 33 deletions
diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs index 621010b..88eeae1 100644 --- a/src/wireguard/router/device.rs +++ b/src/wireguard/router/device.rs @@ -1,7 +1,8 @@ use std::collections::HashMap; +use std::ops::Deref; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::mpsc::sync_channel; -use std::sync::mpsc::SyncSender; +use std::sync::mpsc::{Receiver, SyncSender}; use std::sync::Arc; use std::thread; use std::time::Instant; @@ -11,18 +12,61 @@ use spin::{Mutex, RwLock}; use zerocopy::LayoutVerified; use super::anti_replay::AntiReplay; -use super::constants::*; +use super::pool::Job; + +use super::inbound; +use super::outbound; use super::messages::{TransportHeader, TYPE_TRANSPORT}; -use super::peer::{new_peer, Peer, PeerInner}; +use super::peer::{new_peer, Peer, PeerHandle}; use super::types::{Callbacks, RouterError}; -use super::workers::{worker_parallel, JobParallel}; use super::SIZE_MESSAGE_PREFIX; use super::route::RoutingTable; use super::super::{tun, udp, Endpoint, KeyPair}; +pub struct ParallelQueue<T> { + next: AtomicUsize, // next round-robin index + queues: Vec<Mutex<SyncSender<T>>>, // work queues (1 per thread) +} + +impl<T> ParallelQueue<T> { + fn new(queues: usize) -> (Vec<Receiver<T>>, Self) { + let mut rxs = vec![]; + let mut txs = vec![]; + + for _ in 0..queues { + let (tx, rx) = sync_channel(128); + txs.push(Mutex::new(tx)); + rxs.push(rx); + } + + ( + rxs, + ParallelQueue { + next: AtomicUsize::new(0), + queues: txs, + }, + ) + } + + pub fn send(&self, v: T) { + let len = self.queues.len(); + let idx = self.next.fetch_add(1, Ordering::SeqCst); + let que = self.queues[idx % len].lock(); + que.send(v).unwrap(); + } + + pub fn close(&self) { + for i in 0..self.queues.len() { + let (tx, _) = sync_channel(0); + let queue = &self.queues[i]; + *queue.lock() = tx; + } + } +} + pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { // inbound writer (TUN) pub inbound: T, @@ -32,11 +76,11 @@ pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer // routing pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state - pub table: RoutingTable<PeerInner<E, C, T, B>>, + pub table: RoutingTable<Peer<E, C, T, B>>, // work queues - pub queue_next: AtomicUsize, // next round-robin index - pub queues: Mutex<Vec<SyncSender<JobParallel>>>, // work queues (1 per thread) + pub outbound_queue: ParallelQueue<Job<Peer<E, C, T, B>, outbound::Outbound>>, + pub inbound_queue: ParallelQueue<Job<Peer<E, C, T, B>, inbound::Inbound<E, C, T, B>>>, } pub struct EncryptionState { @@ -49,24 +93,53 @@ pub struct DecryptionState<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Wr pub keypair: Arc<KeyPair>, pub confirmed: AtomicBool, pub protector: Mutex<AntiReplay>, - pub peer: Arc<PeerInner<E, C, T, B>>, + pub peer: Peer<E, C, T, B>, pub death: Instant, // time when the key can no longer be used for decryption } pub struct Device<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { - state: Arc<DeviceInner<E, C, T, B>>, // reference to device state + inner: Arc<DeviceInner<E, C, T, B>>, +} + +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone for Device<E, C, T, B> { + fn clone(&self) -> Self { + Device { + inner: self.inner.clone(), + } + } +} + +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PartialEq + for Device<E, C, T, B> +{ + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.inner, &other.inner) + } +} + +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Eq for Device<E, C, T, B> {} + +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref for Device<E, C, T, B> { + type Target = DeviceInner<E, C, T, B>; + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +pub struct DeviceHandle<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { + state: Device<E, C, T, B>, // reference to device state handles: Vec<thread::JoinHandle<()>>, // join handles for workers } -impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Device<E, C, T, B> { +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop + for DeviceHandle<E, C, T, B> +{ fn drop(&mut self) { debug!("router: dropping device"); - // drop all queues - { - let mut queues = self.state.queues.lock(); - while queues.pop().is_some() {} - } + // close worker queues + self.state.outbound_queue.close(); + self.state.inbound_queue.close(); // join all worker threads while match self.handles.pop() { @@ -82,14 +155,16 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Devi } } -impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C, T, B> { - pub fn new(num_workers: usize, tun: T) -> Device<E, C, T, B> { +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<E, C, T, B> { + pub fn new(num_workers: usize, tun: T) -> DeviceHandle<E, C, T, B> { // allocate shared device state + let (mut outrx, outbound_queue) = ParallelQueue::new(num_workers); + let (mut inrx, inbound_queue) = ParallelQueue::new(num_workers); let inner = DeviceInner { inbound: tun, + inbound_queue, outbound: RwLock::new((true, None)), - queues: Mutex::new(Vec::with_capacity(num_workers)), - queue_next: AtomicUsize::new(0), + outbound_queue, recv: RwLock::new(HashMap::new()), table: RoutingTable::new(), }; @@ -97,14 +172,20 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C, // start worker threads let mut threads = Vec::with_capacity(num_workers); for _ in 0..num_workers { - let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE); - inner.queues.lock().push(tx); - threads.push(thread::spawn(move || worker_parallel(rx))); + let rx = inrx.pop().unwrap(); + threads.push(thread::spawn(move || inbound::worker(rx))); + } + + for _ in 0..num_workers { + let rx = outrx.pop().unwrap(); + threads.push(thread::spawn(move || outbound::worker(rx))); } // return exported device handle - Device { - state: Arc::new(inner), + DeviceHandle { + state: Device { + inner: Arc::new(inner), + }, handles: threads, } } @@ -131,7 +212,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C, /// # Returns /// /// A atomic ref. counted peer (with liftime matching the device) - pub fn new_peer(&self, opaque: C::Opaque) -> Peer<E, C, T, B> { + pub fn new_peer(&self, opaque: C::Opaque) -> PeerHandle<E, C, T, B> { new_peer(self.state.clone(), opaque) } @@ -160,10 +241,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C, // schedule for encryption and transmission to peer if let Some(job) = peer.send_job(msg, true) { - // add job to worker queue - let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst); - let queues = self.state.queues.lock(); - queues[idx % queues.len()].send(job).unwrap(); + self.state.outbound_queue.send(job); } Ok(()) @@ -209,10 +287,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C, // schedule for decryption and TUN write if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) { - // add job to worker queue - let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst); - let queues = self.state.queues.lock(); - queues[idx % queues.len()].send(job).unwrap(); + self.state.inbound_queue.send(job); } Ok(()) |