summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-12-09 13:21:12 +0100
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-12-09 13:21:12 +0100
commit115fa574a807877594c3b8cf467798fc0524d007 (patch)
tree07d487f3f8ef130d17536fb93b02f937b7bf66ae
parentFixed inbound job bug (add to sequential queue) (diff)
downloadwireguard-rs-115fa574a807877594c3b8cf467798fc0524d007.tar.xz
wireguard-rs-115fa574a807877594c3b8cf467798fc0524d007.zip
Move to run queue
-rw-r--r--src/platform/linux/tun.rs2
-rw-r--r--src/wireguard/router/device.rs101
-rw-r--r--src/wireguard/router/inbound.rs242
-rw-r--r--src/wireguard/router/mod.rs1
-rw-r--r--src/wireguard/router/outbound.rs138
-rw-r--r--src/wireguard/router/peer.rs21
-rw-r--r--src/wireguard/router/pool.rs77
-rw-r--r--src/wireguard/router/runq.rs145
-rw-r--r--src/wireguard/router/tests.rs2
9 files changed, 478 insertions, 251 deletions
diff --git a/src/platform/linux/tun.rs b/src/platform/linux/tun.rs
index 39b9320..fb905b9 100644
--- a/src/platform/linux/tun.rs
+++ b/src/platform/linux/tun.rs
@@ -312,7 +312,7 @@ impl LinuxTunStatus {
Err(LinuxTunError::Closed)
} else {
Ok(LinuxTunStatus {
- events: vec![],
+ events: vec![TunEvent::Up(1500)], // TODO: for testing
index: get_ifindex(&name),
fd,
name,
diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs
index e405446..9bba199 100644
--- a/src/wireguard/router/device.rs
+++ b/src/wireguard/router/device.rs
@@ -20,6 +20,7 @@ use super::peer::{new_peer, Peer, PeerHandle};
use super::types::{Callbacks, RouterError};
use super::SIZE_MESSAGE_PREFIX;
+use super::runq::RunQueue;
use super::route::RoutingTable;
use super::super::{tun, udp, Endpoint, KeyPair};
@@ -37,8 +38,12 @@ pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer
pub table: RoutingTable<Peer<E, C, T, B>>,
// work queues
- pub outbound_queue: ParallelQueue<Job<Peer<E, C, T, B>, outbound::Outbound>>,
- pub inbound_queue: ParallelQueue<Job<Peer<E, C, T, B>, inbound::Inbound<E, C, T, B>>>,
+ pub queue_outbound: ParallelQueue<Job<Peer<E, C, T, B>, outbound::Outbound>>,
+ pub queue_inbound: ParallelQueue<Job<Peer<E, C, T, B>, inbound::Inbound<E, C, T, B>>>,
+
+ // run queues
+ pub run_inbound: RunQueue<Peer<E, C, T, B>>,
+ pub run_outbound: RunQueue<Peer<E, C, T, B>>,
}
pub struct EncryptionState {
@@ -96,8 +101,12 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop
debug!("router: dropping device");
// close worker queues
- self.state.outbound_queue.close();
- self.state.inbound_queue.close();
+ self.state.queue_outbound.close();
+ self.state.queue_inbound.close();
+
+ // close run queues
+ self.state.run_outbound.close();
+ self.state.run_inbound.close();
// join all worker threads
while match self.handles.pop() {
@@ -116,43 +125,73 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<E, C, T, B> {
pub fn new(num_workers: usize, tun: T) -> DeviceHandle<E, C, T, B> {
// allocate shared device state
- let (mut outrx, outbound_queue) = ParallelQueue::new(num_workers);
- let (mut inrx, inbound_queue) = ParallelQueue::new(num_workers);
- let inner = DeviceInner {
- inbound: tun,
- inbound_queue,
- outbound: RwLock::new((true, None)),
- outbound_queue,
- recv: RwLock::new(HashMap::new()),
- table: RoutingTable::new(),
+ let (mut outrx, queue_outbound) = ParallelQueue::new(num_workers);
+ let (mut inrx, queue_inbound) = ParallelQueue::new(num_workers);
+ let device = Device {
+ inner: Arc::new(DeviceInner {
+ inbound: tun,
+ queue_inbound,
+ outbound: RwLock::new((true, None)),
+ queue_outbound,
+ run_inbound: RunQueue::new(),
+ run_outbound: RunQueue::new(),
+ recv: RwLock::new(HashMap::new()),
+ table: RoutingTable::new(),
+ })
};
// start worker threads
let mut threads = Vec::with_capacity(num_workers);
+ // inbound/decryption workers
for _ in 0..num_workers {
- let rx = inrx.pop().unwrap();
- threads.push(thread::spawn(move || {
- log::debug!("inbound router worker started");
- inbound::worker(rx)
- }));
+ // parallel workers (parallel processing)
+ {
+ let device = device.clone();
+ let rx = inrx.pop().unwrap();
+ threads.push(thread::spawn(move || {
+ log::debug!("inbound parallel router worker started");
+ inbound::parallel(device, rx)
+ }));
+ }
+
+ // sequential workers (in-order processing)
+ {
+ let device = device.clone();
+ threads.push(thread::spawn(move || {
+ log::debug!("inbound sequential router worker started");
+ inbound::sequential(device)
+ }));
+ }
}
+ // outbound/encryption workers
for _ in 0..num_workers {
- let rx = outrx.pop().unwrap();
- threads.push(thread::spawn(move || {
- log::debug!("outbound router worker started");
- outbound::worker(rx)
- }));
+ // parallel workers (parallel processing)
+ {
+ let device = device.clone();
+ let rx = outrx.pop().unwrap();
+ threads.push(thread::spawn(move || {
+ log::debug!("outbound parallel router worker started");
+ outbound::parallel(device, rx)
+ }));
+ }
+
+ // sequential workers (in-order processing)
+ {
+ let device = device.clone();
+ threads.push(thread::spawn(move || {
+ log::debug!("outbound sequential router worker started");
+ outbound::sequential(device)
+ }));
+ }
}
- debug_assert_eq!(threads.len(), num_workers * 2);
+ debug_assert_eq!(threads.len(), num_workers * 4);
// return exported device handle
DeviceHandle {
- state: Device {
- inner: Arc::new(inner),
- },
+ state: device,
handles: threads,
}
}
@@ -192,7 +231,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<
pub fn send(&self, msg: Vec<u8>) -> Result<(), RouterError> {
debug_assert!(msg.len() > SIZE_MESSAGE_PREFIX);
log::trace!(
- "Router, outbound packet = {}",
+ "send, packet = {}",
hex::encode(&msg[SIZE_MESSAGE_PREFIX..])
);
@@ -208,7 +247,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<
// schedule for encryption and transmission to peer
if let Some(job) = peer.send_job(msg, true) {
- self.state.outbound_queue.send(job);
+ self.state.queue_outbound.send(job);
}
Ok(())
@@ -225,6 +264,8 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<
///
///
pub fn recv(&self, src: E, msg: Vec<u8>) -> Result<(), RouterError> {
+ log::trace!("receive, src: {}", src.into_address());
+
// parse / cast
let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) {
Some(v) => v,
@@ -255,7 +296,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<
// schedule for decryption and TUN write
if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) {
log::trace!("schedule decryption of transport message");
- self.state.inbound_queue.send(job);
+ self.state.queue_inbound.send(job);
}
Ok(())
}
diff --git a/src/wireguard/router/inbound.rs b/src/wireguard/router/inbound.rs
index 3d47bb7..9b15750 100644
--- a/src/wireguard/router/inbound.rs
+++ b/src/wireguard/router/inbound.rs
@@ -4,6 +4,8 @@ use super::peer::Peer;
use super::pool::*;
use super::types::Callbacks;
use super::{tun, udp, Endpoint};
+use super::device::Device;
+use super::runq::RunQueue;
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
use zerocopy::{AsBytes, LayoutVerified};
@@ -38,139 +40,151 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Inbound<E, C,
}
#[inline(always)]
-fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- peer: &Peer<E, C, T, B>,
- body: &mut Inbound<E, C, T, B>,
+pub fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
+ device: Device<E, C, T, B>,
+ receiver: Receiver<Job<Peer<E, C, T, B>, Inbound<E, C, T, B>>>,
) {
- log::trace!("worker, parallel section, obtained job");
+ // run queue to schedule
+ fn queue<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
+ device: &Device<E, C, T, B>,
+ ) -> &RunQueue<Peer<E, C, T, B>> {
+ &device.run_inbound
+ }
+
+ // parallel work to apply
+ fn work<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
+ peer: &Peer<E, C, T, B>,
+ body: &mut Inbound<E, C, T, B>,
+ ) {
+ log::trace!("worker, parallel section, obtained job");
+
+ // cast to header followed by payload
+ let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
+ match LayoutVerified::new_from_prefix(&mut body.msg[..]) {
+ Some(v) => v,
+ None => {
+ log::debug!("inbound worker: failed to parse message");
+ return;
+ }
+ };
+
+ // authenticate and decrypt payload
+ {
+ // create nonce object
+ let mut nonce = [0u8; 12];
+ debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
+ nonce[4..].copy_from_slice(header.f_counter.as_bytes());
+ let nonce = Nonce::assume_unique_for_key(nonce);
+
+ // do the weird ring AEAD dance
+ let key = LessSafeKey::new(
+ UnboundKey::new(&CHACHA20_POLY1305, &body.state.keypair.recv.key[..]).unwrap(),
+ );
- // cast to header followed by payload
- let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
- match LayoutVerified::new_from_prefix(&mut body.msg[..]) {
- Some(v) => v,
- None => {
- log::debug!("inbound worker: failed to parse message");
- return;
+ // attempt to open (and authenticate) the body
+ match key.open_in_place(nonce, Aad::empty(), packet) {
+ Ok(_) => (),
+ Err(_) => {
+ // fault and return early
+ log::trace!("inbound worker: authentication failure");
+ body.failed = true;
+ return;
+ }
+ }
+ }
+
+ // cryptokey route and strip padding
+ let inner_len = {
+ let length = packet.len() - SIZE_TAG;
+ if length > 0 {
+ peer.device.table.check_route(&peer, &packet[..length])
+ } else {
+ Some(0)
}
};
- // authenticate and decrypt payload
- {
- // create nonce object
- let mut nonce = [0u8; 12];
- debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
- nonce[4..].copy_from_slice(header.f_counter.as_bytes());
- let nonce = Nonce::assume_unique_for_key(nonce);
-
- // do the weird ring AEAD dance
- let key = LessSafeKey::new(
- UnboundKey::new(&CHACHA20_POLY1305, &body.state.keypair.recv.key[..]).unwrap(),
- );
-
- // attempt to open (and authenticate) the body
- match key.open_in_place(nonce, Aad::empty(), packet) {
- Ok(_) => (),
- Err(_) => {
- // fault and return early
- log::trace!("inbound worker: authentication failure");
+ // truncate to remove tag
+ match inner_len {
+ None => {
+ log::trace!("inbound worker: cryptokey routing failed");
body.failed = true;
- return;
+ }
+ Some(len) => {
+ log::trace!(
+ "inbound worker: good route, length = {} {}",
+ len,
+ if len == 0 { "(keepalive)" } else { "" }
+ );
+ body.msg.truncate(mem::size_of::<TransportHeader>() + len);
}
}
}
- // cryptokey route and strip padding
- let inner_len = {
- let length = packet.len() - SIZE_TAG;
- if length > 0 {
- peer.device.table.check_route(&peer, &packet[..length])
- } else {
- Some(0)
- }
- };
-
- // truncate to remove tag
- match inner_len {
- None => {
- log::trace!("inbound worker: cryptokey routing failed");
- body.failed = true;
- }
- Some(len) => {
- log::trace!(
- "inbound worker: good route, length = {} {}",
- len,
- if len == 0 { "(keepalive)" } else { "" }
- );
- body.msg.truncate(mem::size_of::<TransportHeader>() + len);
- }
- }
+ worker_parallel(device, |dev| &dev.run_inbound, receiver, work)
}
#[inline(always)]
-fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- peer: &Peer<E, C, T, B>,
- body: &mut Inbound<E, C, T, B>,
+pub fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
+ device: Device<E, C, T, B>,
) {
- log::trace!("worker, sequential section, obtained job");
-
- // decryption failed, return early
- if body.failed {
- log::trace!("job faulted, remove from queue and ignore");
- return;
- }
-
- // cast transport header
- let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) =
- match LayoutVerified::new_from_prefix(&body.msg[..]) {
- Some(v) => v,
- None => {
- log::debug!("inbound worker: failed to parse message");
- return;
- }
- };
-
- // check for replay
- if !body.state.protector.lock().update(header.f_counter.get()) {
- log::debug!("inbound worker: replay detected");
- return;
- }
+ // sequential work to apply
+ fn work<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
+ peer: &Peer<E, C, T, B>,
+ body: &mut Inbound<E, C, T, B>
+ ) {
+ log::trace!("worker, sequential section, obtained job");
+
+ // decryption failed, return early
+ if body.failed {
+ log::trace!("job faulted, remove from queue and ignore");
+ return;
+ }
- // check for confirms key
- if !body.state.confirmed.swap(true, Ordering::SeqCst) {
- log::debug!("inbound worker: message confirms key");
- peer.confirm_key(&body.state.keypair);
- }
+ // cast transport header
+ let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) =
+ match LayoutVerified::new_from_prefix(&body.msg[..]) {
+ Some(v) => v,
+ None => {
+ log::debug!("inbound worker: failed to parse message");
+ return;
+ }
+ };
+
+ // check for replay
+ if !body.state.protector.lock().update(header.f_counter.get()) {
+ log::debug!("inbound worker: replay detected");
+ return;
+ }
- // update endpoint
- *peer.endpoint.lock() = body.endpoint.take();
+ // check for confirms key
+ if !body.state.confirmed.swap(true, Ordering::SeqCst) {
+ log::debug!("inbound worker: message confirms key");
+ peer.confirm_key(&body.state.keypair);
+ }
- // check if should be written to TUN
- let mut sent = false;
- if packet.len() > 0 {
- sent = match peer.device.inbound.write(&packet[..]) {
- Err(e) => {
- log::debug!("failed to write inbound packet to TUN: {:?}", e);
- false
+ // update endpoint
+ *peer.endpoint.lock() = body.endpoint.take();
+
+ // check if should be written to TUN
+ let mut sent = false;
+ if packet.len() > 0 {
+ sent = match peer.device.inbound.write(&packet[..]) {
+ Err(e) => {
+ log::debug!("failed to write inbound packet to TUN: {:?}", e);
+ false
+ }
+ Ok(_) => true,
}
- Ok(_) => true,
+ } else {
+ log::debug!("inbound worker: received keepalive")
}
- } else {
- log::debug!("inbound worker: received keepalive")
- }
- // trigger callback
- C::recv(&peer.opaque, body.msg.len(), sent, &body.state.keypair);
-}
-
-#[inline(always)]
-fn queue<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- peer: &Peer<E, C, T, B>,
-) -> &InorderQueue<Peer<E, C, T, B>, Inbound<E, C, T, B>> {
- &peer.inbound
-}
+ // trigger callback
+ C::recv(&peer.opaque, body.msg.len(), sent, &body.state.keypair);
+ }
-pub fn worker<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- receiver: Receiver<Job<Peer<E, C, T, B>, Inbound<E, C, T, B>>>,
-) {
- worker_template(receiver, parallel, sequential, queue)
+ // handle message from the peers inbound queue
+ device.run_inbound.run(|peer| {
+ peer.inbound.handle(|body| work(&peer, body));
+ });
}
diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs
index bccb0a9..35efe4c 100644
--- a/src/wireguard/router/mod.rs
+++ b/src/wireguard/router/mod.rs
@@ -10,6 +10,7 @@ mod pool;
mod queue;
mod route;
mod types;
+mod runq;
// mod workers;
diff --git a/src/wireguard/router/outbound.rs b/src/wireguard/router/outbound.rs
index d08637b..6c42d8f 100644
--- a/src/wireguard/router/outbound.rs
+++ b/src/wireguard/router/outbound.rs
@@ -5,6 +5,7 @@ use super::types::Callbacks;
use super::KeyPair;
use super::REJECT_AFTER_MESSAGES;
use super::{tun, udp, Endpoint};
+use super::device::Device;
use std::sync::mpsc::Receiver;
use std::sync::Arc;
@@ -31,78 +32,77 @@ impl Outbound {
}
#[inline(always)]
-fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- _peer: &Peer<E, C, T, B>,
- body: &mut Outbound,
+pub fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
+ device: Device<E, C, T, B>,
+ receiver: Receiver<Job<Peer<E, C, T, B>, Outbound>>,
+
) {
- log::trace!("worker, parallel section, obtained job");
-
- // make space for the tag
- body.msg.extend([0u8; SIZE_TAG].iter());
-
- // cast to header (should never fail)
- let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
- LayoutVerified::new_from_prefix(&mut body.msg[..])
- .expect("earlier code should ensure that there is ample space");
-
- // set header fields
- debug_assert!(
- body.counter < REJECT_AFTER_MESSAGES,
- "should be checked when assigning counters"
- );
- header.f_type.set(TYPE_TRANSPORT);
- header.f_receiver.set(body.keypair.send.id);
- header.f_counter.set(body.counter);
-
- // create a nonce object
- let mut nonce = [0u8; 12];
- debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
- nonce[4..].copy_from_slice(header.f_counter.as_bytes());
- let nonce = Nonce::assume_unique_for_key(nonce);
-
- // do the weird ring AEAD dance
- let key =
- LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &body.keypair.send.key[..]).unwrap());
-
- // encrypt content of transport message in-place
- let end = packet.len() - SIZE_TAG;
- let tag = key
- .seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end])
- .unwrap();
-
- // append tag
- packet[end..].copy_from_slice(tag.as_ref());
-}
+ fn work<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
+ _peer: &Peer<E, C, T, B>,
+ body: &mut Outbound,
+ ) {
+ log::trace!("worker, parallel section, obtained job");
+
+ // make space for the tag
+ body.msg.extend([0u8; SIZE_TAG].iter());
+
+ // cast to header (should never fail)
+ let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
+ LayoutVerified::new_from_prefix(&mut body.msg[..])
+ .expect("earlier code should ensure that there is ample space");
+
+ // set header fields
+ debug_assert!(
+ body.counter < REJECT_AFTER_MESSAGES,
+ "should be checked when assigning counters"
+ );
+ header.f_type.set(TYPE_TRANSPORT);
+ header.f_receiver.set(body.keypair.send.id);
+ header.f_counter.set(body.counter);
+
+ // create a nonce object
+ let mut nonce = [0u8; 12];
+ debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
+ nonce[4..].copy_from_slice(header.f_counter.as_bytes());
+ let nonce = Nonce::assume_unique_for_key(nonce);
+
+ // do the weird ring AEAD dance
+ let key =
+ LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &body.keypair.send.key[..]).unwrap());
+
+ // encrypt content of transport message in-place
+ let end = packet.len() - SIZE_TAG;
+ let tag = key
+ .seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end])
+ .unwrap();
+
+ // append tag
+ packet[end..].copy_from_slice(tag.as_ref());
+ }
-#[inline(always)]
-fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- peer: &Peer<E, C, T, B>,
- body: &mut Outbound,
-) {
- log::trace!("worker, sequential section, obtained job");
-
- // send to peer
- let xmit = peer.send(&body.msg[..]).is_ok();
-
- // trigger callback
- C::send(
- &peer.opaque,
- body.msg.len(),
- xmit,
- &body.keypair,
- body.counter,
- );
+ worker_parallel(device, |dev| &dev.run_outbound, receiver, work);
}
-#[inline(always)]
-pub fn queue<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- peer: &Peer<E, C, T, B>,
-) -> &InorderQueue<Peer<E, C, T, B>, Outbound> {
- &peer.outbound
-}
-pub fn worker<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- receiver: Receiver<Job<Peer<E, C, T, B>, Outbound>>,
+#[inline(always)]
+pub fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
+ device: Device<E, C, T, B>,
) {
- worker_template(receiver, parallel, sequential, queue)
-}
+ device.run_outbound.run(|peer| {
+ peer.outbound.handle(|body| {
+ log::trace!("worker, sequential section, obtained job");
+
+ // send to peer
+ let xmit = peer.send(&body.msg[..]).is_ok();
+
+ // trigger callback
+ C::send(
+ &peer.opaque,
+ body.msg.len(),
+ xmit,
+ &body.keypair,
+ body.counter,
+ );
+ });
+ });
+} \ No newline at end of file
diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs
index 40442a8..a00ce1a 100644
--- a/src/wireguard/router/peer.rs
+++ b/src/wireguard/router/peer.rs
@@ -20,6 +20,7 @@ use super::messages::TransportHeader;
use super::constants::*;
use super::types::{Callbacks, RouterError};
use super::SIZE_MESSAGE_PREFIX;
+use super::runq::ToKey;
// worker pool related
use super::inbound::Inbound;
@@ -56,14 +57,28 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone for Pee
}
}
+/* Equality of peers is defined as pointer equality
+ * the atomic reference counted pointer.
+ */
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PartialEq for Peer<E, C, T, B> {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.inner, &other.inner)
}
}
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ToKey for Peer<E, C, T, B> {
+ type Key = usize;
+ fn to_key(&self) -> usize {
+ Arc::downgrade(&self.inner).into_raw() as usize
+ }
+}
+
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Eq for Peer<E, C, T, B> {}
+/* A peer is transparently dereferenced to the inner type
+ *
+ */
+
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref for Peer<E, C, T, B> {
type Target = PeerInner<E, C, T, B>;
fn deref(&self) -> &Self::Target {
@@ -71,6 +86,10 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref for Pee
}
}
+
+/* A peer handle is a specially designated peer pointer
+ * which removes the peer from the device when dropped.
+ */
pub struct PeerHandle<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
peer: Peer<E, C, T, B>,
}
@@ -227,7 +246,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
log::debug!("peer.send_raw");
match self.send_job(msg, false) {
Some(job) => {
- self.device.outbound_queue.send(job);
+ self.device.queue_outbound.send(job);
debug!("send_raw: got obtained send_job");
true
}
diff --git a/src/wireguard/router/pool.rs b/src/wireguard/router/pool.rs
index 9c72372..98b1144 100644
--- a/src/wireguard/router/pool.rs
+++ b/src/wireguard/router/pool.rs
@@ -2,6 +2,9 @@ use arraydeque::ArrayDeque;
use spin::{Mutex, MutexGuard};
use std::sync::mpsc::Receiver;
use std::sync::Arc;
+use std::mem;
+
+use super::runq::{RunQueue, ToKey};
const INORDER_QUEUE_SIZE: usize = 64;
@@ -60,51 +63,53 @@ impl<P, B> InorderQueue<P, B> {
}
#[inline(always)]
- pub fn handle<F: Fn(&mut InnerJob<P, B>)>(&self, f: F) {
+ pub fn handle<F: Fn(&mut B)>(&self, f: F) {
// take the mutex
let mut queue = self.queue.lock();
- // handle all complete messages
- while queue
- .pop_front()
- .and_then(|j| {
- // check if job is complete
- let ret = if let Some(mut guard) = j.complete() {
- f(&mut *guard);
- false
- } else {
- true
- };
-
- // return job to cyclic buffer if not complete
- if ret {
- let _res = queue.push_front(j);
- debug_assert!(_res.is_ok());
- None
- } else {
- // add job back to pool
- Some(())
+ loop {
+ // attempt to extract front element
+ let front = queue.pop_front();
+ let elem = match front {
+ Some(elem) => elem,
+ _ => {
+ return;
}
- })
- .is_some()
- {}
+ };
+
+ // apply function if job complete
+ let ret = if let Some(mut guard) = elem.complete() {
+ mem::drop(queue);
+ f(&mut guard.body);
+ queue = self.queue.lock();
+ false
+ } else {
+ true
+ };
+
+ // job not complete yet, return job to front
+ if ret {
+ queue.push_front(elem).unwrap();
+ return;
+ }
+ }
}
}
/// Allows easy construction of a semi-parallel worker.
/// Applicable for both decryption and encryption workers.
#[inline(always)]
-pub fn worker_template<
- P, // represents a peer (atomic reference counted pointer)
+pub fn worker_parallel<
+ P : ToKey, // represents a peer (atomic reference counted pointer)
B, // inner body type (message buffer, key material, ...)
+ D, // device
W: Fn(&P, &mut B),
- S: Fn(&P, &mut B),
- Q: Fn(&P) -> &InorderQueue<P, B>,
+ Q: Fn(&D) -> &RunQueue<P>,
>(
- receiver: Receiver<Job<P, B>>, // receiever for new jobs
- work_parallel: W, // perform parallel / out-of-order work on peer
- work_sequential: S, // perform sequential work on peer
- queue: Q, // resolve a peer to an inorder queue
+ device: D,
+ queue: Q,
+ receiver: Receiver<Job<P, B>>,
+ work: W,
) {
log::trace!("router worker started");
loop {
@@ -123,11 +128,11 @@ pub fn worker_template<
let peer = job.peer.take().unwrap();
// process job
- work_parallel(&peer, &mut job.body);
+ work(&peer, &mut job.body);
peer
};
-
+
// process inorder jobs for peer
- queue(&peer).handle(|j| work_sequential(&peer, &mut j.body));
+ queue(&device).insert(peer);
}
-}
+} \ No newline at end of file
diff --git a/src/wireguard/router/runq.rs b/src/wireguard/router/runq.rs
new file mode 100644
index 0000000..6d96490
--- /dev/null
+++ b/src/wireguard/router/runq.rs
@@ -0,0 +1,145 @@
+use std::mem;
+use std::sync::{Condvar, Mutex};
+use std::hash::Hash;
+
+use std::collections::hash_map::Entry;
+use std::collections::HashMap;
+use std::collections::VecDeque;
+
+pub trait ToKey {
+ type Key: Hash + Eq;
+ fn to_key(&self) -> Self::Key;
+}
+
+pub struct RunQueue<T : ToKey> {
+ cvar: Condvar,
+ inner: Mutex<Inner<T>>,
+}
+
+struct Inner<T : ToKey> {
+ stop: bool,
+ queue: VecDeque<T>,
+ members: HashMap<T::Key, usize>,
+}
+
+impl<T : ToKey> RunQueue<T> {
+ pub fn close(&self) {
+ let mut inner = self.inner.lock().unwrap();
+ inner.stop = true;
+ self.cvar.notify_all();
+ }
+
+ pub fn new() -> RunQueue<T> {
+ RunQueue {
+ cvar: Condvar::new(),
+ inner: Mutex::new(Inner {
+ stop:false,
+ queue: VecDeque::new(),
+ members: HashMap::new(),
+ }),
+ }
+ }
+
+ pub fn insert(&self, v: T) {
+ let key = v.to_key();
+ let mut inner = self.inner.lock().unwrap();
+ match inner.members.entry(key) {
+ Entry::Occupied(mut elem) => {
+ *elem.get_mut() += 1;
+ }
+ Entry::Vacant(spot) => {
+ // add entry to back of queue
+ spot.insert(0);
+ inner.queue.push_back(v);
+
+ // wake a thread
+ self.cvar.notify_one();
+ }
+ }
+ }
+
+ pub fn run<F: Fn(&T) -> ()>(&self, f: F) {
+ let mut inner = self.inner.lock().unwrap();
+ loop {
+ // fetch next element
+ let elem = loop {
+ // run-queue closed
+ if inner.stop {
+ return;
+ }
+
+ // try to pop from queue
+ match inner.queue.pop_front() {
+ Some(elem) => {
+ break elem;
+ }
+ None => (),
+ };
+
+ // wait for an element to be inserted
+ inner = self.cvar.wait(inner).unwrap();
+ };
+
+ // fetch current request number
+ let key = elem.to_key();
+ let old_n = *inner.members.get(&key).unwrap();
+ mem::drop(inner); // drop guard
+
+ // handle element
+ f(&elem);
+
+ // retake lock and check if should be added back to queue
+ inner = self.inner.lock().unwrap();
+ match inner.members.entry(key) {
+ Entry::Occupied(occ) => {
+ if *occ.get() == old_n {
+ // no new requests since last, remove entry.
+ occ.remove();
+ } else {
+ // new requests, reschedule.
+ inner.queue.push_back(elem);
+ }
+ }
+ Entry::Vacant(_) => {
+ unreachable!();
+ }
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::thread;
+ use std::sync::Arc;
+ use std::time::Duration;
+
+ /*
+ #[test]
+ fn test_wait() {
+ let queue: Arc<RunQueue<usize>> = Arc::new(RunQueue::new());
+
+ {
+ let queue = queue.clone();
+ thread::spawn(move || {
+ queue.run(|e| {
+ println!("t0 {}", e);
+ thread::sleep(Duration::from_millis(100));
+ })
+ });
+ }
+
+ {
+ let queue = queue.clone();
+ thread::spawn(move || {
+ queue.run(|e| {
+ println!("t1 {}", e);
+ thread::sleep(Duration::from_millis(100));
+ })
+ });
+ }
+
+ }
+ */
+} \ No newline at end of file
diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs
index 1f500c0..fe1fbbe 100644
--- a/src/wireguard/router/tests.rs
+++ b/src/wireguard/router/tests.rs
@@ -273,6 +273,8 @@ mod tests {
}
}
}
+
+ println!("Test complete, drop device");
}
#[test]