aboutsummaryrefslogtreecommitdiffstats
path: root/src/wireguard/router/device.rs
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-12-03 21:49:08 +0100
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-12-03 21:49:08 +0100
commit5a7f762d6ce6b5bbdbd10f5966adc909597f37d6 (patch)
treeb53fa0c1ee02c1e211d6cf94c6ba0334135ec42e /src/wireguard/router/device.rs
parentClose socket fd after getmtu ioctl (diff)
downloadwireguard-rs-5a7f762d6ce6b5bbdbd10f5966adc909597f37d6.tar.xz
wireguard-rs-5a7f762d6ce6b5bbdbd10f5966adc909597f37d6.zip
Moving away from peer threads
Diffstat (limited to 'src/wireguard/router/device.rs')
-rw-r--r--src/wireguard/router/device.rs141
1 files changed, 108 insertions, 33 deletions
diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs
index 621010b..88eeae1 100644
--- a/src/wireguard/router/device.rs
+++ b/src/wireguard/router/device.rs
@@ -1,7 +1,8 @@
use std::collections::HashMap;
+use std::ops::Deref;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::mpsc::sync_channel;
-use std::sync::mpsc::SyncSender;
+use std::sync::mpsc::{Receiver, SyncSender};
use std::sync::Arc;
use std::thread;
use std::time::Instant;
@@ -11,18 +12,61 @@ use spin::{Mutex, RwLock};
use zerocopy::LayoutVerified;
use super::anti_replay::AntiReplay;
-use super::constants::*;
+use super::pool::Job;
+
+use super::inbound;
+use super::outbound;
use super::messages::{TransportHeader, TYPE_TRANSPORT};
-use super::peer::{new_peer, Peer, PeerInner};
+use super::peer::{new_peer, Peer, PeerHandle};
use super::types::{Callbacks, RouterError};
-use super::workers::{worker_parallel, JobParallel};
use super::SIZE_MESSAGE_PREFIX;
use super::route::RoutingTable;
use super::super::{tun, udp, Endpoint, KeyPair};
+pub struct ParallelQueue<T> {
+ next: AtomicUsize, // next round-robin index
+ queues: Vec<Mutex<SyncSender<T>>>, // work queues (1 per thread)
+}
+
+impl<T> ParallelQueue<T> {
+ fn new(queues: usize) -> (Vec<Receiver<T>>, Self) {
+ let mut rxs = vec![];
+ let mut txs = vec![];
+
+ for _ in 0..queues {
+ let (tx, rx) = sync_channel(128);
+ txs.push(Mutex::new(tx));
+ rxs.push(rx);
+ }
+
+ (
+ rxs,
+ ParallelQueue {
+ next: AtomicUsize::new(0),
+ queues: txs,
+ },
+ )
+ }
+
+ pub fn send(&self, v: T) {
+ let len = self.queues.len();
+ let idx = self.next.fetch_add(1, Ordering::SeqCst);
+ let que = self.queues[idx % len].lock();
+ que.send(v).unwrap();
+ }
+
+ pub fn close(&self) {
+ for i in 0..self.queues.len() {
+ let (tx, _) = sync_channel(0);
+ let queue = &self.queues[i];
+ *queue.lock() = tx;
+ }
+ }
+}
+
pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
// inbound writer (TUN)
pub inbound: T,
@@ -32,11 +76,11 @@ pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer
// routing
pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state
- pub table: RoutingTable<PeerInner<E, C, T, B>>,
+ pub table: RoutingTable<Peer<E, C, T, B>>,
// work queues
- pub queue_next: AtomicUsize, // next round-robin index
- pub queues: Mutex<Vec<SyncSender<JobParallel>>>, // work queues (1 per thread)
+ pub outbound_queue: ParallelQueue<Job<Peer<E, C, T, B>, outbound::Outbound>>,
+ pub inbound_queue: ParallelQueue<Job<Peer<E, C, T, B>, inbound::Inbound<E, C, T, B>>>,
}
pub struct EncryptionState {
@@ -49,24 +93,53 @@ pub struct DecryptionState<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Wr
pub keypair: Arc<KeyPair>,
pub confirmed: AtomicBool,
pub protector: Mutex<AntiReplay>,
- pub peer: Arc<PeerInner<E, C, T, B>>,
+ pub peer: Peer<E, C, T, B>,
pub death: Instant, // time when the key can no longer be used for decryption
}
pub struct Device<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
- state: Arc<DeviceInner<E, C, T, B>>, // reference to device state
+ inner: Arc<DeviceInner<E, C, T, B>>,
+}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone for Device<E, C, T, B> {
+ fn clone(&self) -> Self {
+ Device {
+ inner: self.inner.clone(),
+ }
+ }
+}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PartialEq
+ for Device<E, C, T, B>
+{
+ fn eq(&self, other: &Self) -> bool {
+ Arc::ptr_eq(&self.inner, &other.inner)
+ }
+}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Eq for Device<E, C, T, B> {}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref for Device<E, C, T, B> {
+ type Target = DeviceInner<E, C, T, B>;
+ fn deref(&self) -> &Self::Target {
+ &self.inner
+ }
+}
+
+pub struct DeviceHandle<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
+ state: Device<E, C, T, B>, // reference to device state
handles: Vec<thread::JoinHandle<()>>, // join handles for workers
}
-impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Device<E, C, T, B> {
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop
+ for DeviceHandle<E, C, T, B>
+{
fn drop(&mut self) {
debug!("router: dropping device");
- // drop all queues
- {
- let mut queues = self.state.queues.lock();
- while queues.pop().is_some() {}
- }
+ // close worker queues
+ self.state.outbound_queue.close();
+ self.state.inbound_queue.close();
// join all worker threads
while match self.handles.pop() {
@@ -82,14 +155,16 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Devi
}
}
-impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C, T, B> {
- pub fn new(num_workers: usize, tun: T) -> Device<E, C, T, B> {
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<E, C, T, B> {
+ pub fn new(num_workers: usize, tun: T) -> DeviceHandle<E, C, T, B> {
// allocate shared device state
+ let (mut outrx, outbound_queue) = ParallelQueue::new(num_workers);
+ let (mut inrx, inbound_queue) = ParallelQueue::new(num_workers);
let inner = DeviceInner {
inbound: tun,
+ inbound_queue,
outbound: RwLock::new((true, None)),
- queues: Mutex::new(Vec::with_capacity(num_workers)),
- queue_next: AtomicUsize::new(0),
+ outbound_queue,
recv: RwLock::new(HashMap::new()),
table: RoutingTable::new(),
};
@@ -97,14 +172,20 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C,
// start worker threads
let mut threads = Vec::with_capacity(num_workers);
for _ in 0..num_workers {
- let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE);
- inner.queues.lock().push(tx);
- threads.push(thread::spawn(move || worker_parallel(rx)));
+ let rx = inrx.pop().unwrap();
+ threads.push(thread::spawn(move || inbound::worker(rx)));
+ }
+
+ for _ in 0..num_workers {
+ let rx = outrx.pop().unwrap();
+ threads.push(thread::spawn(move || outbound::worker(rx)));
}
// return exported device handle
- Device {
- state: Arc::new(inner),
+ DeviceHandle {
+ state: Device {
+ inner: Arc::new(inner),
+ },
handles: threads,
}
}
@@ -131,7 +212,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C,
/// # Returns
///
/// A atomic ref. counted peer (with liftime matching the device)
- pub fn new_peer(&self, opaque: C::Opaque) -> Peer<E, C, T, B> {
+ pub fn new_peer(&self, opaque: C::Opaque) -> PeerHandle<E, C, T, B> {
new_peer(self.state.clone(), opaque)
}
@@ -160,10 +241,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C,
// schedule for encryption and transmission to peer
if let Some(job) = peer.send_job(msg, true) {
- // add job to worker queue
- let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst);
- let queues = self.state.queues.lock();
- queues[idx % queues.len()].send(job).unwrap();
+ self.state.outbound_queue.send(job);
}
Ok(())
@@ -209,10 +287,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C,
// schedule for decryption and TUN write
if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) {
- // add job to worker queue
- let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst);
- let queues = self.state.queues.lock();
- queues[idx % queues.len()].send(job).unwrap();
+ self.state.inbound_queue.send(job);
}
Ok(())