aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2020-02-16 18:12:43 +0100
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2020-02-16 18:12:43 +0100
commit106c5e8b5c865c8396f824f4f5aa14d1bf0952b1 (patch)
tree68101553c62d301921b84776a9e18fc627c7a731 /src
parentWork on reducing context switches (diff)
downloadwireguard-rs-106c5e8b5c865c8396f824f4f5aa14d1bf0952b1.tar.xz
wireguard-rs-106c5e8b5c865c8396f824f4f5aa14d1bf0952b1.zip
Work on router optimizationsrouter
Diffstat (limited to 'src')
-rw-r--r--src/wireguard/queue.rs3
-rw-r--r--src/wireguard/router/constants.rs4
-rw-r--r--src/wireguard/router/device.rs100
-rw-r--r--src/wireguard/router/inbound.rs190
-rw-r--r--src/wireguard/router/mod.rs5
-rw-r--r--src/wireguard/router/outbound.rs110
-rw-r--r--src/wireguard/router/peer.rs192
-rw-r--r--src/wireguard/router/pool.rs164
-rw-r--r--src/wireguard/router/queue.rs92
-rw-r--r--src/wireguard/router/receive.rs184
-rw-r--r--src/wireguard/router/runq.rs164
-rw-r--r--src/wireguard/router/send.rs95
-rw-r--r--src/wireguard/router/worker.rs30
-rw-r--r--src/wireguard/router/workers.rs257
-rw-r--r--src/wireguard/wireguard.rs2
15 files changed, 350 insertions, 1242 deletions
diff --git a/src/wireguard/queue.rs b/src/wireguard/queue.rs
index 75b9104..eea1ccf 100644
--- a/src/wireguard/queue.rs
+++ b/src/wireguard/queue.rs
@@ -1,6 +1,7 @@
-use crossbeam_channel::{bounded, Receiver, Sender};
use std::sync::Mutex;
+use crossbeam_channel::{bounded, Receiver, Sender};
+
pub struct ParallelQueue<T> {
queue: Mutex<Option<Sender<T>>>,
}
diff --git a/src/wireguard/router/constants.rs b/src/wireguard/router/constants.rs
index af76299..f083811 100644
--- a/src/wireguard/router/constants.rs
+++ b/src/wireguard/router/constants.rs
@@ -4,6 +4,6 @@ pub const MAX_QUEUED_PACKETS: usize = 1024;
// performance constants
-pub const PARALLEL_QUEUE_SIZE: usize = MAX_QUEUED_PACKETS;
+pub const PARALLEL_QUEUE_SIZE: usize = 4 * MAX_QUEUED_PACKETS;
+
pub const INORDER_QUEUE_SIZE: usize = MAX_QUEUED_PACKETS;
-pub const MAX_INORDER_CONSUME: usize = INORDER_QUEUE_SIZE;
diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs
index f903a8e..b8e3821 100644
--- a/src/wireguard/router/device.rs
+++ b/src/wireguard/router/device.rs
@@ -10,19 +10,16 @@ use spin::{Mutex, RwLock};
use zerocopy::LayoutVerified;
use super::anti_replay::AntiReplay;
-use super::pool::Job;
use super::constants::PARALLEL_QUEUE_SIZE;
-use super::inbound;
-use super::outbound;
-
use super::messages::{TransportHeader, TYPE_TRANSPORT};
use super::peer::{new_peer, Peer, PeerHandle};
use super::types::{Callbacks, RouterError};
use super::SIZE_MESSAGE_PREFIX;
+use super::receive::ReceiveJob;
use super::route::RoutingTable;
-use super::runq::RunQueue;
+use super::worker::{worker, JobUnion};
use super::super::{tun, udp, Endpoint, KeyPair};
use super::ParallelQueue;
@@ -38,13 +35,8 @@ pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer
pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state
pub table: RoutingTable<Peer<E, C, T, B>>,
- // work queues
- pub queue_outbound: ParallelQueue<Job<Peer<E, C, T, B>, outbound::Outbound>>,
- pub queue_inbound: ParallelQueue<Job<Peer<E, C, T, B>, inbound::Inbound<E, C, T, B>>>,
-
- // run queues
- pub run_inbound: RunQueue<Peer<E, C, T, B>>,
- pub run_outbound: RunQueue<Peer<E, C, T, B>>,
+ // work queue
+ pub work: ParallelQueue<JobUnion<E, C, T, B>>,
}
pub struct EncryptionState {
@@ -101,13 +93,8 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop
fn drop(&mut self) {
debug!("router: dropping device");
- // close worker queues
- self.state.queue_outbound.close();
- self.state.queue_inbound.close();
-
- // close run queues
- self.state.run_outbound.close();
- self.state.run_inbound.close();
+ // close worker queue
+ self.state.work.close();
// join all worker threads
while match self.handles.pop() {
@@ -118,24 +105,17 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop
}
_ => false,
} {}
-
- debug!("router: device dropped");
}
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<E, C, T, B> {
pub fn new(num_workers: usize, tun: T) -> DeviceHandle<E, C, T, B> {
- // allocate shared device state
- let (queue_outbound, mut outrx) = ParallelQueue::new(num_workers, PARALLEL_QUEUE_SIZE);
- let (queue_inbound, mut inrx) = ParallelQueue::new(num_workers, PARALLEL_QUEUE_SIZE);
+ let (work, mut consumers) = ParallelQueue::new(num_workers, PARALLEL_QUEUE_SIZE);
let device = Device {
inner: Arc::new(DeviceInner {
+ work,
inbound: tun,
- queue_inbound,
outbound: RwLock::new((true, None)),
- queue_outbound,
- run_inbound: RunQueue::new(),
- run_outbound: RunQueue::new(),
recv: RwLock::new(HashMap::new()),
table: RoutingTable::new(),
}),
@@ -143,52 +123,10 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<
// start worker threads
let mut threads = Vec::with_capacity(num_workers);
-
- // inbound/decryption workers
- for _ in 0..num_workers {
- // parallel workers (parallel processing)
- {
- let device = device.clone();
- let rx = inrx.pop().unwrap();
- threads.push(thread::spawn(move || {
- log::debug!("inbound parallel router worker started");
- inbound::parallel(device, rx)
- }));
- }
-
- // sequential workers (in-order processing)
- {
- let device = device.clone();
- threads.push(thread::spawn(move || {
- log::debug!("inbound sequential router worker started");
- inbound::sequential(device)
- }));
- }
+ while let Some(rx) = consumers.pop() {
+ threads.push(thread::spawn(move || worker(rx)));
}
-
- // outbound/encryption workers
- for _ in 0..num_workers {
- // parallel workers (parallel processing)
- {
- let device = device.clone();
- let rx = outrx.pop().unwrap();
- threads.push(thread::spawn(move || {
- log::debug!("outbound parallel router worker started");
- outbound::parallel(device, rx)
- }));
- }
-
- // sequential workers (in-order processing)
- {
- let device = device.clone();
- threads.push(thread::spawn(move || {
- log::debug!("outbound sequential router worker started");
- outbound::sequential(device)
- }));
- }
- }
-
- debug_assert_eq!(threads.len(), num_workers * 4);
+ debug_assert_eq!(threads.len(), num_workers);
// return exported device handle
DeviceHandle {
@@ -250,10 +188,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<
.ok_or(RouterError::NoCryptoKeyRoute)?;
// schedule for encryption and transmission to peer
- if let Some(job) = peer.send_job(msg, true) {
- self.state.queue_outbound.send(job);
- }
-
+ peer.send(msg, true);
Ok(())
}
@@ -297,10 +232,13 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<
.get(&header.f_receiver.get())
.ok_or(RouterError::UnknownReceiverId)?;
- // schedule for decryption and TUN write
- if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) {
- log::trace!("schedule decryption of transport message");
- self.state.queue_inbound.send(job);
+ // create inbound job
+ let job = ReceiveJob::new(msg, dec.clone(), src);
+
+ // 1. add to sequential queue (drop if full)
+ // 2. then add to parallel work queue (wait if full)
+ if dec.peer.inbound.push(job.clone()) {
+ self.state.work.send(JobUnion::Inbound(job));
}
Ok(())
}
diff --git a/src/wireguard/router/inbound.rs b/src/wireguard/router/inbound.rs
deleted file mode 100644
index dc2c44e..0000000
--- a/src/wireguard/router/inbound.rs
+++ /dev/null
@@ -1,190 +0,0 @@
-use std::mem;
-use std::sync::atomic::Ordering;
-use std::sync::Arc;
-
-use crossbeam_channel::Receiver;
-use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
-use zerocopy::{AsBytes, LayoutVerified};
-
-use super::constants::MAX_INORDER_CONSUME;
-use super::device::DecryptionState;
-use super::device::Device;
-use super::messages::TransportHeader;
-use super::peer::Peer;
-use super::pool::*;
-use super::types::Callbacks;
-use super::{tun, udp, Endpoint};
-use super::{REJECT_AFTER_MESSAGES, SIZE_TAG};
-
-pub struct Inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
- msg: Vec<u8>,
- failed: bool,
- state: Arc<DecryptionState<E, C, T, B>>,
- endpoint: Option<E>,
-}
-
-impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Inbound<E, C, T, B> {
- pub fn new(
- msg: Vec<u8>,
- state: Arc<DecryptionState<E, C, T, B>>,
- endpoint: E,
- ) -> Inbound<E, C, T, B> {
- Inbound {
- msg,
- state,
- failed: false,
- endpoint: Some(endpoint),
- }
- }
-}
-
-#[inline(always)]
-pub fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- device: Device<E, C, T, B>,
- receiver: Receiver<Job<Peer<E, C, T, B>, Inbound<E, C, T, B>>>,
-) {
- // parallel work to apply
- #[inline(always)]
- fn work<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- peer: &Peer<E, C, T, B>,
- body: &mut Inbound<E, C, T, B>,
- ) {
- log::trace!("worker, parallel section, obtained job");
-
- // cast to header followed by payload
- let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
- match LayoutVerified::new_from_prefix(&mut body.msg[..]) {
- Some(v) => v,
- None => {
- log::debug!("inbound worker: failed to parse message");
- return;
- }
- };
-
- // authenticate and decrypt payload
- {
- // create nonce object
- let mut nonce = [0u8; 12];
- debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
- nonce[4..].copy_from_slice(header.f_counter.as_bytes());
- let nonce = Nonce::assume_unique_for_key(nonce);
-
- // do the weird ring AEAD dance
- let key = LessSafeKey::new(
- UnboundKey::new(&CHACHA20_POLY1305, &body.state.keypair.recv.key[..]).unwrap(),
- );
-
- // attempt to open (and authenticate) the body
- match key.open_in_place(nonce, Aad::empty(), packet) {
- Ok(_) => (),
- Err(_) => {
- // fault and return early
- log::trace!("inbound worker: authentication failure");
- body.failed = true;
- return;
- }
- }
- }
-
- // check that counter not after reject
- if header.f_counter.get() >= REJECT_AFTER_MESSAGES {
- body.failed = true;
- return;
- }
-
- // cryptokey route and strip padding
- let inner_len = {
- let length = packet.len() - SIZE_TAG;
- if length > 0 {
- peer.device.table.check_route(&peer, &packet[..length])
- } else {
- Some(0)
- }
- };
-
- // truncate to remove tag
- match inner_len {
- None => {
- log::trace!("inbound worker: cryptokey routing failed");
- body.failed = true;
- }
- Some(len) => {
- log::trace!(
- "inbound worker: good route, length = {} {}",
- len,
- if len == 0 { "(keepalive)" } else { "" }
- );
- body.msg.truncate(mem::size_of::<TransportHeader>() + len);
- }
- }
- }
-
- worker_parallel(device, |dev| &dev.run_inbound, receiver, work)
-}
-
-#[inline(always)]
-pub fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- device: Device<E, C, T, B>,
-) {
- // sequential work to apply
- fn work<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- peer: &Peer<E, C, T, B>,
- body: &mut Inbound<E, C, T, B>,
- ) {
- log::trace!("worker, sequential section, obtained job");
-
- // decryption failed, return early
- if body.failed {
- log::trace!("job faulted, remove from queue and ignore");
- return;
- }
-
- // cast transport header
- let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) =
- match LayoutVerified::new_from_prefix(&body.msg[..]) {
- Some(v) => v,
- None => {
- log::debug!("inbound worker: failed to parse message");
- return;
- }
- };
-
- // check for replay
- if !body.state.protector.lock().update(header.f_counter.get()) {
- log::debug!("inbound worker: replay detected");
- return;
- }
-
- // check for confirms key
- if !body.state.confirmed.swap(true, Ordering::SeqCst) {
- log::debug!("inbound worker: message confirms key");
- peer.confirm_key(&body.state.keypair);
- }
-
- // update endpoint
- *peer.endpoint.lock() = body.endpoint.take();
-
- // check if should be written to TUN
- let mut sent = false;
- if packet.len() > 0 {
- sent = match peer.device.inbound.write(&packet[..]) {
- Err(e) => {
- log::debug!("failed to write inbound packet to TUN: {:?}", e);
- false
- }
- Ok(_) => true,
- }
- } else {
- log::debug!("inbound worker: received keepalive")
- }
-
- // trigger callback
- C::recv(&peer.opaque, body.msg.len(), sent, &body.state.keypair);
- }
-
- // handle message from the peers inbound queue
- device.run_inbound.run(|peer| {
- peer.inbound
- .handle(|body| work(&peer, body), MAX_INORDER_CONSUME)
- });
-}
diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs
index ec5cc63..699c621 100644
--- a/src/wireguard/router/mod.rs
+++ b/src/wireguard/router/mod.rs
@@ -1,14 +1,10 @@
mod anti_replay;
mod constants;
mod device;
-mod inbound;
mod ip;
mod messages;
-mod outbound;
mod peer;
-mod pool;
mod route;
-mod runq;
mod types;
mod queue;
@@ -25,7 +21,6 @@ use std::mem;
use super::constants::REJECT_AFTER_MESSAGES;
use super::queue::ParallelQueue;
use super::types::*;
-use super::{tun, udp, Endpoint};
pub const SIZE_TAG: usize = 16;
pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::<TransportHeader>();
diff --git a/src/wireguard/router/outbound.rs b/src/wireguard/router/outbound.rs
deleted file mode 100644
index 1edb2fb..0000000
--- a/src/wireguard/router/outbound.rs
+++ /dev/null
@@ -1,110 +0,0 @@
-use std::sync::Arc;
-
-use crossbeam_channel::Receiver;
-use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
-use zerocopy::{AsBytes, LayoutVerified};
-
-use super::constants::MAX_INORDER_CONSUME;
-use super::device::Device;
-use super::messages::{TransportHeader, TYPE_TRANSPORT};
-use super::peer::Peer;
-use super::pool::*;
-use super::types::Callbacks;
-use super::KeyPair;
-use super::{tun, udp, Endpoint};
-use super::{REJECT_AFTER_MESSAGES, SIZE_TAG};
-
-pub struct Outbound {
- msg: Vec<u8>,
- keypair: Arc<KeyPair>,
- counter: u64,
-}
-
-impl Outbound {
- pub fn new(msg: Vec<u8>, keypair: Arc<KeyPair>, counter: u64) -> Outbound {
- Outbound {
- msg,
- keypair,
- counter,
- }
- }
-}
-
-#[inline(always)]
-pub fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- device: Device<E, C, T, B>,
- receiver: Receiver<Job<Peer<E, C, T, B>, Outbound>>,
-) {
- #[inline(always)]
- fn work<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- _peer: &Peer<E, C, T, B>,
- body: &mut Outbound,
- ) {
- log::trace!("worker, parallel section, obtained job");
-
- // make space for the tag
- body.msg.extend([0u8; SIZE_TAG].iter());
-
- // cast to header (should never fail)
- let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
- LayoutVerified::new_from_prefix(&mut body.msg[..])
- .expect("earlier code should ensure that there is ample space");
-
- // set header fields
- debug_assert!(
- body.counter < REJECT_AFTER_MESSAGES,
- "should be checked when assigning counters"
- );
- header.f_type.set(TYPE_TRANSPORT);
- header.f_receiver.set(body.keypair.send.id);
- header.f_counter.set(body.counter);
-
- // create a nonce object
- let mut nonce = [0u8; 12];
- debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
- nonce[4..].copy_from_slice(header.f_counter.as_bytes());
- let nonce = Nonce::assume_unique_for_key(nonce);
-
- // do the weird ring AEAD dance
- let key = LessSafeKey::new(
- UnboundKey::new(&CHACHA20_POLY1305, &body.keypair.send.key[..]).unwrap(),
- );
-
- // encrypt content of transport message in-place
- let end = packet.len() - SIZE_TAG;
- let tag = key
- .seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end])
- .unwrap();
-
- // append tag
- packet[end..].copy_from_slice(tag.as_ref());
- }
-
- worker_parallel(device, |dev| &dev.run_outbound, receiver, work);
-}
-
-#[inline(always)]
-pub fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- device: Device<E, C, T, B>,
-) {
- device.run_outbound.run(|peer| {
- peer.outbound.handle(
- |body| {
- log::trace!("worker, sequential section, obtained job");
-
- // send to peer
- let xmit = peer.send(&body.msg[..]).is_ok();
-
- // trigger callback
- C::send(
- &peer.opaque,
- body.msg.len(),
- xmit,
- &body.keypair,
- body.counter,
- );
- },
- MAX_INORDER_CONSUME,
- )
- });
-}
diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs
index 7312bc7..710cf32 100644
--- a/src/wireguard/router/peer.rs
+++ b/src/wireguard/router/peer.rs
@@ -1,13 +1,3 @@
-use std::mem;
-use std::net::{IpAddr, SocketAddr};
-use std::ops::Deref;
-use std::sync::atomic::AtomicBool;
-use std::sync::Arc;
-
-use arraydeque::{ArrayDeque, Wrapping};
-use log::debug;
-use spin::Mutex;
-
use super::super::constants::*;
use super::super::{tun, udp, Endpoint, KeyPair};
@@ -15,20 +5,25 @@ use super::anti_replay::AntiReplay;
use super::device::DecryptionState;
use super::device::Device;
use super::device::EncryptionState;
-use super::messages::TransportHeader;
use super::constants::*;
-use super::runq::ToKey;
use super::types::{Callbacks, RouterError};
use super::SIZE_MESSAGE_PREFIX;
-// worker pool related
-use super::inbound::Inbound;
-use super::outbound::Outbound;
use super::queue::Queue;
-
-use super::send::SendJob;
use super::receive::ReceiveJob;
+use super::send::SendJob;
+use super::worker::JobUnion;
+
+use std::mem;
+use std::net::{IpAddr, SocketAddr};
+use std::ops::Deref;
+use std::sync::atomic::AtomicBool;
+use std::sync::Arc;
+
+use arraydeque::{ArrayDeque, Wrapping};
+use log::debug;
+use spin::Mutex;
pub struct KeyWheel {
next: Option<Arc<KeyPair>>, // next key state (unconfirmed)
@@ -44,7 +39,7 @@ pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E
pub inbound: Queue<ReceiveJob<E, C, T, B>>,
pub staged_packets: Mutex<ArrayDeque<[Vec<u8>; MAX_QUEUED_PACKETS], Wrapping>>,
pub keys: Mutex<KeyWheel>,
- pub ekey: Mutex<Option<EncryptionState>>,
+ pub enc_key: Mutex<Option<EncryptionState>>,
pub endpoint: Mutex<Option<E>>,
}
@@ -69,13 +64,6 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PartialEq for
}
}
-impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ToKey for Peer<E, C, T, B> {
- type Key = usize;
- fn to_key(&self) -> usize {
- Arc::downgrade(&self.inner).into_raw() as usize
- }
-}
-
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Eq for Peer<E, C, T, B> {}
/* A peer is transparently dereferenced to the inner type
@@ -157,7 +145,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Peer
keys.current = None;
keys.previous = None;
- *peer.ekey.lock() = None;
+ *peer.enc_key.lock() = None;
*peer.endpoint.lock() = None;
debug!("peer dropped & removed from device");
@@ -175,9 +163,9 @@ pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
inner: Arc::new(PeerInner {
opaque,
device,
- inbound: InorderQueue::new(),
- outbound: InorderQueue::new(),
- ekey: spin::Mutex::new(None),
+ inbound: Queue::new(),
+ outbound: Queue::new(),
+ enc_key: spin::Mutex::new(None),
endpoint: spin::Mutex::new(None),
keys: spin::Mutex::new(KeyWheel {
next: None,
@@ -203,7 +191,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E,
/// # Returns
///
/// Unit if packet was sent, or an error indicating why sending failed
- pub fn send(&self, msg: &[u8]) -> Result<(), RouterError> {
+ pub fn send_raw(&self, msg: &[u8]) -> Result<(), RouterError> {
debug!("peer.send");
// send to endpoint (if known)
@@ -226,6 +214,57 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E,
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, B> {
+ /// Encrypt and send a message to the peer
+ ///
+ /// Arguments:
+ ///
+ /// - `msg` : A padded vector holding the message (allows in-place construction of the transport header)
+ /// - `stage`: Should the message be staged if no key is available
+ ///
+ pub(super) fn send(&self, msg: Vec<u8>, stage: bool) {
+ // check if key available
+ let (job, need_key) = {
+ let mut enc_key = self.enc_key.lock();
+ match enc_key.as_mut() {
+ None => {
+ if stage {
+ self.staged_packets.lock().push_back(msg);
+ };
+ (None, true)
+ }
+ Some(mut state) => {
+ // avoid integer overflow in nonce
+ if state.nonce >= REJECT_AFTER_MESSAGES - 1 {
+ *enc_key = None;
+ if stage {
+ self.staged_packets.lock().push_back(msg);
+ }
+ (None, true)
+ } else {
+ debug!("encryption state available, nonce = {}", state.nonce);
+ let job =
+ SendJob::new(msg, state.nonce, state.keypair.clone(), self.clone());
+ if self.outbound.push(job.clone()) {
+ state.nonce += 1;
+ (Some(job), false)
+ } else {
+ (None, false)
+ }
+ }
+ }
+ }
+ };
+
+ if need_key {
+ debug_assert!(job.is_none());
+ C::need_key(&self.opaque);
+ };
+
+ if let Some(job) = job {
+ self.device.work.send(JobUnion::Outbound(job))
+ }
+ }
+
// Transmit all staged packets
fn send_staged(&self) -> bool {
debug!("peer.send_staged");
@@ -235,28 +274,14 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
match staged.pop_front() {
Some(msg) => {
sent = true;
- self.send_raw(msg);
+ self.send(msg, false);
}
None => break sent,
}
}
}
- // Treat the msg as the payload of a transport message
- // Unlike device.send, peer.send_raw does not buffer messages when a key is not available.
- fn send_raw(&self, msg: Vec<u8>) -> bool {
- log::debug!("peer.send_raw");
- match self.send_job(msg, false) {
- Some(job) => {
- self.device.queue_outbound.send(job);
- debug!("send_raw: got obtained send_job");
- true
- }
- None => false,
- }
- }
-
- pub fn confirm_key(&self, keypair: &Arc<KeyPair>) {
+ pub(super) fn confirm_key(&self, keypair: &Arc<KeyPair>) {
debug!("peer.confirm_key");
{
// take lock and check keypair = keys.next
@@ -284,68 +309,12 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
C::key_confirmed(&self.opaque);
// set new key for encryption
- *self.ekey.lock() = ekey;
+ *self.enc_key.lock() = ekey;
}
// start transmission of staged packets
self.send_staged();
}
-
- pub fn send_job(&self, msg: Vec<u8>, stage: bool) -> Option<SendJob<E, C, T, B>> {
- debug!("peer.send_job");
- debug_assert!(
- msg.len() >= mem::size_of::<TransportHeader>(),
- "received TUN message with size: {:}",
- msg.len()
- );
-
- // check if has key
- let (keypair, counter) = {
- let keypair = {
- // TODO: consider using atomic ptr for ekey state
- let mut ekey = self.ekey.lock();
- match ekey.as_mut() {
- None => None,
- Some(mut state) => {
- // avoid integer overflow in nonce
- if state.nonce >= REJECT_AFTER_MESSAGES - 1 {
- *ekey = None;
- None
- } else {
- debug!("encryption state available, nonce = {}", state.nonce);
- let counter = state.nonce;
- state.nonce += 1;
-
- SendJob::new(
- msg,
- state.nonce,
- state.keypair.clone(),
- self.clone()
- );
-
- Some((state.keypair.clone(), counter))
- }
- }
- }
- };
-
- // If not suitable key was found:
- // 1. Stage packet for later transmission
- // 2. Request new key
- if keypair.is_none() && stage {
- self.staged_packets.lock().push_back(msg);
- C::need_key(&self.opaque);
- return None;
- };
-
- keypair
- }?;
-
- // add job to in-order queue and return sender to device for inclusion in worker pool
- let job = Job::new(self.clone(), Outbound::new(msg, keypair, counter));
- self.outbound.send(job.clone());
- Some(job)
- }
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, C, T, B> {
@@ -397,7 +366,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E,
}
// clear encryption state
- *self.peer.ekey.lock() = None;
+ *self.peer.enc_key.lock() = None;
}
pub fn down(&self) {
@@ -434,7 +403,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E,
// update key-wheel
if new.initiator {
// start using key for encryption
- *self.peer.ekey.lock() = Some(EncryptionState::new(&new));
+ *self.peer.enc_key.lock() = Some(EncryptionState::new(&new));
// move current into previous
keys.previous = keys.current.as_ref().map(|v| v.clone());
@@ -468,16 +437,13 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E,
// schedule confirmation
if initiator {
- debug_assert!(self.peer.ekey.lock().is_some());
+ debug_assert!(self.peer.enc_key.lock().is_some());
debug!("peer.add_keypair: is initiator, must confirm the key");
// attempt to confirm using staged packets
if !self.peer.send_staged() {
// fall back to keepalive packet
- let ok = self.send_keepalive();
- debug!(
- "peer.add_keypair: keepalive for confirmation, sent = {}",
- ok
- );
+ self.send_keepalive();
+ debug!("peer.add_keypair: keepalive for confirmation",);
}
debug!("peer.add_keypair: key attempted confirmed");
}
@@ -489,9 +455,9 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E,
release
}
- pub fn send_keepalive(&self) -> bool {
+ pub fn send_keepalive(&self) {
debug!("peer.send_keepalive");
- self.peer.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX])
+ self.peer.send(vec![0u8; SIZE_MESSAGE_PREFIX], false)
}
/// Map a subnet to the peer
diff --git a/src/wireguard/router/pool.rs b/src/wireguard/router/pool.rs
deleted file mode 100644
index 3fc0026..0000000
--- a/src/wireguard/router/pool.rs
+++ /dev/null
@@ -1,164 +0,0 @@
-use std::mem;
-use std::sync::Arc;
-
-use arraydeque::ArrayDeque;
-use crossbeam_channel::Receiver;
-use spin::{Mutex, MutexGuard};
-
-use super::constants::INORDER_QUEUE_SIZE;
-use super::runq::{RunQueue, ToKey};
-
-pub struct InnerJob<P, B> {
- // peer (used by worker to schedule/handle inorder queue),
- // when the peer is None, the job is complete
- peer: Option<P>,
- pub body: B,
-}
-
-pub struct Job<P, B> {
- inner: Arc<Mutex<InnerJob<P, B>>>,
-}
-
-impl<P, B> Clone for Job<P, B> {
- fn clone(&self) -> Job<P, B> {
- Job {
- inner: self.inner.clone(),
- }
- }
-}
-
-impl<P, B> Job<P, B> {
- pub fn new(peer: P, body: B) -> Job<P, B> {
- Job {
- inner: Arc::new(Mutex::new(InnerJob {
- peer: Some(peer),
- body,
- })),
- }
- }
-}
-
-impl<P, B> Job<P, B> {
- /// Returns a mutex guard to the inner job if complete
- pub fn complete(&self) -> Option<MutexGuard<InnerJob<P, B>>> {
- self.inner
- .try_lock()
- .and_then(|m| if m.peer.is_none() { Some(m) } else { None })
- }
-}
-
-pub struct InorderQueue<P, B> {
- queue: Mutex<ArrayDeque<[Job<P, B>; INORDER_QUEUE_SIZE]>>,
-}
-
-impl<P, B> InorderQueue<P, B> {
- pub fn new() -> InorderQueue<P, B> {
- InorderQueue {
- queue: Mutex::new(ArrayDeque::new()),
- }
- }
-
- /// Add a new job to the in-order queue
- ///
- /// # Arguments
- ///
- /// - `job`: The job added to the back of the queue
- ///
- /// # Returns
- ///
- /// True if the element was added,
- /// false to indicate that the queue is full.
- pub fn send(&self, job: Job<P, B>) -> bool {
- self.queue.lock().push_back(job).is_ok()
- }
-
- /// Consume completed jobs from the in-order queue
- ///
- /// # Arguments
- ///
- /// - `f`: function to apply to the body of each jobof each job.
- /// - `limit`: maximum number of jobs to handle before returning
- ///
- /// # Returns
- ///
- /// A boolean indicating if the limit was reached:
- /// true indicating that the limit was reached,
- /// while false implies that the queue is empty or an uncompleted job was reached.
- #[inline(always)]
- pub fn handle<F: Fn(&mut B)>(&self, f: F, mut limit: usize) -> bool {
- // take the mutex
- let mut queue = self.queue.lock();
-
- while limit > 0 {
- // attempt to extract front element
- let front = queue.pop_front();
- let elem = match front {
- Some(elem) => elem,
- _ => {
- return false;
- }
- };
-
- // apply function if job complete
- let ret = if let Some(mut guard) = elem.complete() {
- mem::drop(queue);
- f(&mut guard.body);
- queue = self.queue.lock();
- false
- } else {
- true
- };
-
- // job not complete yet, return job to front
- if ret {
- queue.push_front(elem).unwrap();
- return false;
- }
- limit -= 1;
- }
-
- // did not complete all jobs
- true
- }
-}
-
-/// Allows easy construction of a parallel worker.
-/// Applicable for both decryption and encryption workers.
-#[inline(always)]
-pub fn worker_parallel<
- P: ToKey, // represents a peer (atomic reference counted pointer)
- B, // inner body type (message buffer, key material, ...)
- D, // device
- W: Fn(&P, &mut B),
- Q: Fn(&D) -> &RunQueue<P>,
->(
- device: D,
- queue: Q,
- receiver: Receiver<Job<P, B>>,
- work: W,
-) {
- log::trace!("router worker started");
- loop {
- // handle new job
- let peer = {
- // get next job
- let job = match receiver.recv() {
- Ok(job) => job,
- _ => return,
- };
-
- // lock the job
- let mut job = job.inner.lock();
-
- // take the peer from the job
- let peer = job.peer.take().unwrap();
-
- // process job
- work(&peer, &mut job.body);
- peer
- };
-
- // process inorder jobs for peer
- queue(&device).insert(peer);
- }
-}
diff --git a/src/wireguard/router/queue.rs b/src/wireguard/router/queue.rs
index 045fd51..ec4492e 100644
--- a/src/wireguard/router/queue.rs
+++ b/src/wireguard/router/queue.rs
@@ -4,29 +4,36 @@ use spin::Mutex;
use std::mem;
use std::sync::atomic::{AtomicUsize, Ordering};
-const QUEUE_SIZE: usize = 1024;
-
-pub trait Job: Sized {
- fn queue(&self) -> &Queue<Self>;
+use super::constants::INORDER_QUEUE_SIZE;
+pub trait SequentialJob {
fn is_ready(&self) -> bool;
- fn parallel_work(&self);
-
fn sequential_work(self);
}
+pub trait ParallelJob: Sized + SequentialJob {
+ fn queue(&self) -> &Queue<Self>;
+
+ fn parallel_work(&self);
+}
-pub struct Queue<J: Job> {
+pub struct Queue<J: SequentialJob> {
contenders: AtomicUsize,
- queue: Mutex<ArrayDeque<[J; QUEUE_SIZE]>>,
+ queue: Mutex<ArrayDeque<[J; INORDER_QUEUE_SIZE]>>,
+
+ #[cfg(debug)]
+ _flag: Mutex<()>,
}
-impl<J: Job> Queue<J> {
+impl<J: SequentialJob> Queue<J> {
pub fn new() -> Queue<J> {
Queue {
contenders: AtomicUsize::new(0),
queue: Mutex::new(ArrayDeque::new()),
+
+ #[cfg(debug)]
+ _flag: Mutex::new(()),
}
}
@@ -36,14 +43,22 @@ impl<J: Job> Queue<J> {
pub fn consume(&self) {
// check if we are the first contender
- let pos = self.contenders.fetch_add(1, Ordering::Acquire);
+ let pos = self.contenders.fetch_add(1, Ordering::SeqCst);
if pos > 0 {
- assert!(pos < usize::max_value(), "contenders overflow");
+ assert!(usize::max_value() > pos, "contenders overflow");
+ return;
}
// enter the critical section
let mut contenders = 1; // myself
while contenders > 0 {
+ // check soundness in debug builds
+ #[cfg(debug)]
+ let _flag = self
+ ._flag
+ .try_lock()
+ .expect("contenders should ensure mutual exclusion");
+
// handle every ready element
loop {
let mut queue = self.queue.lock();
@@ -69,8 +84,61 @@ impl<J: Job> Queue<J> {
job.sequential_work();
}
+ #[cfg(debug)]
+ mem::drop(_flag);
+
// decrease contenders
- contenders = self.contenders.fetch_sub(contenders, Ordering::Acquire) - contenders;
+ contenders = self.contenders.fetch_sub(contenders, Ordering::SeqCst) - contenders;
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use std::sync::Arc;
+ use std::thread;
+
+ use rand::thread_rng;
+ use rand::Rng;
+
+ struct TestJob {}
+
+ impl SequentialJob for TestJob {
+ fn is_ready(&self) -> bool {
+ true
+ }
+
+ fn sequential_work(self) {}
+ }
+
+ /* Fuzz the Queue */
+ #[test]
+ fn test_queue() {
+ fn hammer(queue: &Arc<Queue<TestJob>>) {
+ let mut rng = thread_rng();
+ for _ in 0..1_000_000 {
+ if rng.gen() {
+ queue.push(TestJob {});
+ } else {
+ queue.consume();
+ }
+ }
}
+
+ let queue = Arc::new(Queue::new());
+
+ // repeatedly apply operations randomly from concurrent threads
+ let other = {
+ let queue = queue.clone();
+ thread::spawn(move || hammer(&queue))
+ };
+ hammer(&queue);
+
+ // wait, consume and check empty
+ other.join().unwrap();
+ queue.consume();
+ assert_eq!(queue.queue.lock().len(), 0);
}
}
diff --git a/src/wireguard/router/receive.rs b/src/wireguard/router/receive.rs
index 53890e3..c5fe3da 100644
--- a/src/wireguard/router/receive.rs
+++ b/src/wireguard/router/receive.rs
@@ -1,21 +1,18 @@
-use super::queue::{Job, Queue};
-use super::KeyPair;
+use super::device::DecryptionState;
+use super::messages::TransportHeader;
+use super::queue::{ParallelJob, Queue, SequentialJob};
use super::types::Callbacks;
-use super::peer::Peer;
use super::{REJECT_AFTER_MESSAGES, SIZE_TAG};
-use super::messages::{TransportHeader, TYPE_TRANSPORT};
-use super::device::DecryptionState;
use super::super::{tun, udp, Endpoint};
+use std::mem;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
-use std::mem;
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
-use zerocopy::{AsBytes, LayoutVerified};
use spin::Mutex;
-
+use zerocopy::{AsBytes, LayoutVerified};
struct Inner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
ready: AtomicBool,
@@ -23,49 +20,49 @@ struct Inner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
state: Arc<DecryptionState<E, C, T, B>>, // decryption state (keys and replay protector)
}
-pub struct ReceiveJob<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
- inner: Arc<Inner<E, C, T, B>>,
+pub struct ReceiveJob<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
+ Arc<Inner<E, C, T, B>>,
+);
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone
+ for ReceiveJob<E, C, T, B>
+{
+ fn clone(&self) -> ReceiveJob<E, C, T, B> {
+ ReceiveJob(self.0.clone())
+ }
}
-impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ReceiveJob<E, C, T, B> {
- fn new(buffer: Vec<u8>, state: Arc<DecryptionState<E, C, T, B>>, endpoint: E) -> Option<ReceiveJob<E, C, T, B>> {
- // create job
- let inner = Arc::new(Inner{
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ReceiveJob<E, C, T, B> {
+ pub fn new(
+ buffer: Vec<u8>,
+ state: Arc<DecryptionState<E, C, T, B>>,
+ endpoint: E,
+ ) -> ReceiveJob<E, C, T, B> {
+ ReceiveJob(Arc::new(Inner {
ready: AtomicBool::new(false),
buffer: Mutex::new((Some(endpoint), buffer)),
- state
- });
-
- // attempt to add to queue
- if state.peer.inbound.push(ReceiveJob{ inner: inner.clone()}) {
- Some(ReceiveJob{inner})
- } else {
- None
- }
-
+ state,
+ }))
}
}
-impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Job for ReceiveJob<E, C, T, B> {
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ParallelJob
+ for ReceiveJob<E, C, T, B>
+{
fn queue(&self) -> &Queue<Self> {
- &self.inner.state.peer.inbound
- }
-
- fn is_ready(&self) -> bool {
- self.inner.ready.load(Ordering::Acquire)
+ &self.0.state.peer.inbound
}
fn parallel_work(&self) {
// TODO: refactor
// decrypt
{
- let job = &self.inner;
+ let job = &self.0;
let peer = &job.state.peer;
let mut msg = job.buffer.lock();
-
- let failed = || {
- // cast to header followed by payload
- let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
+
+ // cast to header followed by payload
+ let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
match LayoutVerified::new_from_prefix(&mut msg.1[..]) {
Some(v) => v,
None => {
@@ -74,73 +71,81 @@ impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Job for Rece
}
};
- // authenticate and decrypt payload
- {
- // create nonce object
- let mut nonce = [0u8; 12];
- debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
- nonce[4..].copy_from_slice(header.f_counter.as_bytes());
- let nonce = Nonce::assume_unique_for_key(nonce);
-
- // do the weird ring AEAD dance
- let key = LessSafeKey::new(
- UnboundKey::new(&CHACHA20_POLY1305, &job.state.keypair.recv.key[..]).unwrap(),
- );
-
- // attempt to open (and authenticate) the body
- match key.open_in_place(nonce, Aad::empty(), packet) {
- Ok(_) => (),
- Err(_) => {
- // fault and return early
- log::trace!("inbound worker: authentication failure");
- msg.1.truncate(0);
- return;
- }
+ // authenticate and decrypt payload
+ {
+ // create nonce object
+ let mut nonce = [0u8; 12];
+ debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
+ nonce[4..].copy_from_slice(header.f_counter.as_bytes());
+ let nonce = Nonce::assume_unique_for_key(nonce);
+
+ // do the weird ring AEAD dance
+ let key = LessSafeKey::new(
+ UnboundKey::new(&CHACHA20_POLY1305, &job.state.keypair.recv.key[..]).unwrap(),
+ );
+
+ // attempt to open (and authenticate) the body
+ match key.open_in_place(nonce, Aad::empty(), packet) {
+ Ok(_) => (),
+ Err(_) => {
+ // fault and return early
+ log::trace!("inbound worker: authentication failure");
+ msg.1.truncate(0);
+ return;
}
}
+ }
- // check that counter not after reject
- if header.f_counter.get() >= REJECT_AFTER_MESSAGES {
- msg.1.truncate(0);
- return;
- }
-
- // cryptokey route and strip padding
- let inner_len = {
- let length = packet.len() - SIZE_TAG;
- if length > 0 {
- peer.device.table.check_route(&peer, &packet[..length])
- } else {
- Some(0)
- }
- };
+ // check that counter not after reject
+ if header.f_counter.get() >= REJECT_AFTER_MESSAGES {
+ msg.1.truncate(0);
+ return;
+ }
- // truncate to remove tag
- match inner_len {
- None => {
- log::trace!("inbound worker: cryptokey routing failed");
- msg.1.truncate(0);
- }
- Some(len) => {
- log::trace!(
- "inbound worker: good route, length = {} {}",
- len,
- if len == 0 { "(keepalive)" } else { "" }
- );
- msg.1.truncate(mem::size_of::<TransportHeader>() + len);
- }
+ // cryptokey route and strip padding
+ let inner_len = {
+ let length = packet.len() - SIZE_TAG;
+ if length > 0 {
+ peer.device.table.check_route(&peer, &packet[..length])
+ } else {
+ Some(0)
}
};
+
+ // truncate to remove tag
+ match inner_len {
+ None => {
+ log::trace!("inbound worker: cryptokey routing failed");
+ msg.1.truncate(0);
+ }
+ Some(len) => {
+ log::trace!(
+ "inbound worker: good route, length = {} {}",
+ len,
+ if len == 0 { "(keepalive)" } else { "" }
+ );
+ msg.1.truncate(mem::size_of::<TransportHeader>() + len);
+ }
+ }
}
// mark ready
- self.inner.ready.store(true, Ordering::Release);
+ self.0.ready.store(true, Ordering::Release);
+ }
+}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> SequentialJob
+ for ReceiveJob<E, C, T, B>
+{
+ fn is_ready(&self) -> bool {
+ self.0.ready.load(Ordering::Acquire)
}
fn sequential_work(self) {
- let job = &self.inner;
+ let job = &self.0;
let peer = &job.state.peer;
let mut msg = job.buffer.lock();
+ let endpoint = msg.0.take();
// cast transport header
let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) =
@@ -165,7 +170,7 @@ impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Job for Rece
}
// update endpoint
- *peer.endpoint.lock() = msg.0.take();
+ *peer.endpoint.lock() = endpoint;
// check if should be written to TUN
let mut sent = false;
@@ -184,5 +189,4 @@ impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Job for Rece
// trigger callback
C::recv(&peer.opaque, msg.1.len(), sent, &job.state.keypair);
}
-
}
diff --git a/src/wireguard/router/runq.rs b/src/wireguard/router/runq.rs
deleted file mode 100644
index 44e11a1..0000000
--- a/src/wireguard/router/runq.rs
+++ /dev/null
@@ -1,164 +0,0 @@
-use std::hash::Hash;
-use std::mem;
-use std::sync::{Condvar, Mutex};
-
-use std::collections::hash_map::Entry;
-use std::collections::HashMap;
-use std::collections::VecDeque;
-
-pub trait ToKey {
- type Key: Hash + Eq;
- fn to_key(&self) -> Self::Key;
-}
-
-pub struct RunQueue<T: ToKey> {
- cvar: Condvar,
- inner: Mutex<Inner<T>>,
-}
-
-struct Inner<T: ToKey> {
- stop: bool,
- queue: VecDeque<T>,
- members: HashMap<T::Key, usize>,
-}
-
-impl<T: ToKey> RunQueue<T> {
- pub fn close(&self) {
- let mut inner = self.inner.lock().unwrap();
- inner.stop = true;
- self.cvar.notify_all();
- }
-
- pub fn new() -> RunQueue<T> {
- RunQueue {
- cvar: Condvar::new(),
- inner: Mutex::new(Inner {
- stop: false,
- queue: VecDeque::new(),
- members: HashMap::new(),
- }),
- }
- }
-
- pub fn insert(&self, v: T) {
- let key = v.to_key();
- let mut inner = self.inner.lock().unwrap();
- match inner.members.entry(key) {
- Entry::Occupied(mut elem) => {
- *elem.get_mut() += 1;
- }
- Entry::Vacant(spot) => {
- // add entry to back of queue
- spot.insert(0);
- inner.queue.push_back(v);
-
- // wake a thread
- self.cvar.notify_one();
- }
- }
- }
-
- /// Run (consume from) the run queue using the provided function.
- /// The function should return wheter the given element should be rescheduled.
- ///
- /// # Arguments
- ///
- /// - `f` : function to apply to every element
- ///
- /// # Note
- ///
- /// The function f may be called again even when the element was not inserted back in to the
- /// queue since the last applciation and no rescheduling was requested.
- ///
- /// This happens then the function handles all work for T,
- /// but T is added to the run queue while the function is running.
- pub fn run<F: Fn(&T) -> bool>(&self, f: F) {
- let mut inner = self.inner.lock().unwrap();
- loop {
- // fetch next element
- let elem = loop {
- // run-queue closed
- if inner.stop {
- return;
- }
-
- // try to pop from queue
- match inner.queue.pop_front() {
- Some(elem) => {
- break elem;
- }
- None => (),
- };
-
- // wait for an element to be inserted
- inner = self.cvar.wait(inner).unwrap();
- };
-
- // fetch current request number
- let key = elem.to_key();
- let old_n = *inner.members.get(&key).unwrap();
- mem::drop(inner); // drop guard
-
- // handle element
- let rerun = f(&elem);
-
- // if the function requested a re-run add the element to the back of the queue
- inner = self.inner.lock().unwrap();
- if rerun {
- inner.queue.push_back(elem);
- continue;
- }
-
- // otherwise check if new requests have come in since we ran the function
- match inner.members.entry(key) {
- Entry::Occupied(occ) => {
- if *occ.get() == old_n {
- // no new requests since last, remove entry.
- occ.remove();
- } else {
- // new requests, reschedule.
- inner.queue.push_back(elem);
- }
- }
- Entry::Vacant(_) => {
- unreachable!();
- }
- }
- }
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use std::thread;
- use std::time::Duration;
-
- /*
- #[test]
- fn test_wait() {
- let queue: Arc<RunQueue<usize>> = Arc::new(RunQueue::new());
-
- {
- let queue = queue.clone();
- thread::spawn(move || {
- queue.run(|e| {
- println!("t0 {}", e);
- thread::sleep(Duration::from_millis(100));
- })
- });
- }
-
- {
- let queue = queue.clone();
- thread::spawn(move || {
- queue.run(|e| {
- println!("t1 {}", e);
- thread::sleep(Duration::from_millis(100));
- })
- });
- }
-
- }
- */
-}
diff --git a/src/wireguard/router/send.rs b/src/wireguard/router/send.rs
index 2bd4abd..8e41796 100644
--- a/src/wireguard/router/send.rs
+++ b/src/wireguard/router/send.rs
@@ -1,4 +1,4 @@
-use super::queue::{Job, Queue};
+use super::queue::{SequentialJob, ParallelJob, Queue};
use super::KeyPair;
use super::types::Callbacks;
use super::peer::Peer;
@@ -22,8 +22,14 @@ struct Inner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
peer: Peer<E, C, T, B>,
}
-pub struct SendJob<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
- inner: Arc<Inner<E, C, T, B>>,
+pub struct SendJob<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> (
+ Arc<Inner<E, C, T, B>>
+);
+
+impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone for SendJob<E, C, T, B> {
+ fn clone(&self) -> SendJob<E, C, T, B> {
+ SendJob(self.0.clone())
+ }
}
impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> SendJob<E, C, T, B> {
@@ -32,32 +38,53 @@ impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> SendJob<E, C
counter: u64,
keypair: Arc<KeyPair>,
peer: Peer<E, C, T, B>
- ) -> Option<SendJob<E, C, T, B>> {
- // create job
- let inner = Arc::new(Inner{
+ ) -> SendJob<E, C, T, B> {
+ SendJob(Arc::new(Inner{
buffer: Mutex::new(buffer),
counter,
keypair,
peer,
ready: AtomicBool::new(false)
- });
-
- // attempt to add to queue
- if peer.outbound.push(SendJob{ inner: inner.clone()}) {
- Some(SendJob{inner})
- } else {
- None
- }
+ }))
}
}
-impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Job for SendJob<E, C, T, B> {
- fn queue(&self) -> &Queue<Self> {
- &self.inner.peer.outbound
+impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> SequentialJob for SendJob<E, C, T, B> {
+
+ fn is_ready(&self) -> bool {
+ self.0.ready.load(Ordering::Acquire)
}
- fn is_ready(&self) -> bool {
- self.inner.ready.load(Ordering::Acquire)
+ fn sequential_work(self) {
+ debug_assert_eq!(
+ self.is_ready(),
+ true,
+ "doing sequential work
+ on an incomplete job"
+ );
+ log::trace!("processing sequential send job");
+
+ // send to peer
+ let job = &self.0;
+ let msg = job.buffer.lock();
+ let xmit = job.peer.send_raw(&msg[..]).is_ok();
+
+ // trigger callback (for timers)
+ C::send(
+ &job.peer.opaque,
+ msg.len(),
+ xmit,
+ &job.keypair,
+ job.counter,
+ );
+ }
+}
+
+
+impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ParallelJob for SendJob<E, C, T, B> {
+
+ fn queue(&self) -> &Queue<Self> {
+ &self.0.peer.outbound
}
fn parallel_work(&self) {
@@ -71,7 +98,7 @@ impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Job for Send
// encrypt body
{
// make space for the tag
- let job = &*self.inner;
+ let job = &*self.0;
let mut msg = job.buffer.lock();
msg.extend([0u8; SIZE_TAG].iter());
@@ -111,30 +138,6 @@ impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Job for Send
}
// mark ready
- self.inner.ready.store(true, Ordering::Release);
- }
-
- fn sequential_work(self) {
- debug_assert_eq!(
- self.is_ready(),
- true,
- "doing sequential work
- on an incomplete job"
- );
- log::trace!("processing sequential send job");
-
- // send to peer
- let job = &self.inner;
- let msg = job.buffer.lock();
- let xmit = job.peer.send(&msg[..]).is_ok();
-
- // trigger callback (for timers)
- C::send(
- &job.peer.opaque,
- msg.len(),
- xmit,
- &job.keypair,
- job.counter,
- );
+ self.0.ready.store(true, Ordering::Release);
}
-}
+} \ No newline at end of file
diff --git a/src/wireguard/router/worker.rs b/src/wireguard/router/worker.rs
index d95050e..bbb644c 100644
--- a/src/wireguard/router/worker.rs
+++ b/src/wireguard/router/worker.rs
@@ -1,13 +1,31 @@
-use super::Device;
-
use super::super::{tun, udp, Endpoint};
use super::types::Callbacks;
-use super::receive::ReceieveJob;
+use super::queue::ParallelJob;
+use super::receive::ReceiveJob;
use super::send::SendJob;
-fn worker<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- device: Device<E, C, T, B>,
+use crossbeam_channel::Receiver;
+
+pub enum JobUnion<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
+ Outbound(SendJob<E, C, T, B>),
+ Inbound(ReceiveJob<E, C, T, B>),
+}
+
+pub fn worker<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
+ receiver: Receiver<JobUnion<E, C, T, B>>,
) {
- // fetch job
+ loop {
+ match receiver.recv() {
+ Err(_) => break,
+ Ok(JobUnion::Inbound(job)) => {
+ job.parallel_work();
+ job.queue().consume();
+ }
+ Ok(JobUnion::Outbound(job)) => {
+ job.parallel_work();
+ job.queue().consume();
+ }
+ }
+ }
}
diff --git a/src/wireguard/router/workers.rs b/src/wireguard/router/workers.rs
deleted file mode 100644
index 8ddc136..0000000
--- a/src/wireguard/router/workers.rs
+++ /dev/null
@@ -1,257 +0,0 @@
-use std::sync::Arc;
-
-use log::{debug, trace};
-
-use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
-
-use crossbeam_channel::Receiver;
-use std::sync::atomic::Ordering;
-use zerocopy::{AsBytes, LayoutVerified};
-
-use super::device::{DecryptionState, DeviceInner};
-use super::messages::{TransportHeader, TYPE_TRANSPORT};
-use super::peer::PeerInner;
-use super::types::Callbacks;
-
-use super::REJECT_AFTER_MESSAGES;
-
-use super::super::types::KeyPair;
-use super::super::{tun, udp, Endpoint};
-
-pub const SIZE_TAG: usize = 16;
-
-pub struct JobEncryption {
- pub msg: Vec<u8>,
- pub keypair: Arc<KeyPair>,
- pub counter: u64,
-}
-
-pub struct JobDecryption {
- pub msg: Vec<u8>,
- pub keypair: Arc<KeyPair>,
-}
-
-pub enum JobParallel {
- Encryption(oneshot::Sender<JobEncryption>, JobEncryption),
- Decryption(oneshot::Sender<Option<JobDecryption>>, JobDecryption),
-}
-
-#[allow(type_alias_bounds)]
-pub type JobInbound<E, C, T, B: udp::Writer<E>> = (
- Arc<DecryptionState<E, C, T, B>>,
- E,
- oneshot::Receiver<Option<JobDecryption>>,
-);
-
-pub type JobOutbound = oneshot::Receiver<JobEncryption>;
-
-/* TODO: Replace with run-queue
- */
-pub fn worker_inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- device: Arc<DeviceInner<E, C, T, B>>, // related device
- peer: Arc<PeerInner<E, C, T, B>>, // related peer
- receiver: Receiver<JobInbound<E, C, T, B>>,
-) {
- loop {
- // fetch job
- let (state, endpoint, rx) = match receiver.recv() {
- Ok(v) => v,
- _ => {
- return;
- }
- };
- debug!("inbound worker: obtained job");
-
- // wait for job to complete
- let _ = rx
- .map(|buf| {
- debug!("inbound worker: job complete");
- if let Some(buf) = buf {
- // cast transport header
- let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) =
- match LayoutVerified::new_from_prefix(&buf.msg[..]) {
- Some(v) => v,
- None => {
- debug!("inbound worker: failed to parse message");
- return;
- }
- };
-
- debug_assert!(
- packet.len() >= CHACHA20_POLY1305.tag_len(),
- "this should be checked earlier in the pipeline (decryption should fail)"
- );
-
- // check for replay
- if !state.protector.lock().update(header.f_counter.get()) {
- debug!("inbound worker: replay detected");
- return;
- }
-
- // check for confirms key
- if !state.confirmed.swap(true, Ordering::SeqCst) {
- debug!("inbound worker: message confirms key");
- peer.confirm_key(&state.keypair);
- }
-
- // update endpoint
- *peer.endpoint.lock() = Some(endpoint);
-
- // calculate length of IP packet + padding
- let length = packet.len() - SIZE_TAG;
- debug!("inbound worker: plaintext length = {}", length);
-
- // check if should be written to TUN
- let mut sent = false;
- if length > 0 {
- if let Some(inner_len) = device.table.check_route(&peer, &packet[..length])
- {
- // TODO: Consider moving the cryptkey route check to parallel decryption worker
- debug_assert!(inner_len <= length, "should be validated earlier");
- if inner_len <= length {
- sent = match device.inbound.write(&packet[..inner_len]) {
- Err(e) => {
- debug!("failed to write inbound packet to TUN: {:?}", e);
- false
- }
- Ok(_) => true,
- }
- }
- }
- } else {
- debug!("inbound worker: received keepalive")
- }
-
- // trigger callback
- C::recv(&peer.opaque, buf.msg.len(), sent, &buf.keypair);
- } else {
- debug!("inbound worker: authentication failure")
- }
- })
- .wait();
- }
-}
-
-
-pub fn worker_outbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- peer: Arc<PeerInner<E, C, T, B>>,
- receiver: Receiver<JobOutbound>,
-) {
- loop {
- // fetch job
- let rx = match receiver.recv() {
- Ok(v) => v,
- _ => {
- return;
- }
- };
- debug!("outbound worker: obtained job");
-
- // wait for job to complete
- let _ = rx
- .map(|buf| {
- debug!("outbound worker: job complete");
-
- // send to peer
- let xmit = peer.send(&buf.msg[..]).is_ok();
-
- // trigger callback
- C::send(&peer.opaque, buf.msg.len(), xmit, &buf.keypair, buf.counter);
- })
- .wait();
- }
-}
-
-pub fn worker_parallel(receiver: Receiver<JobParallel>) {
- loop {
- // fetch next job
- let job = match receiver.recv() {
- Err(_) => {
- return;
- }
- Ok(val) => val,
- };
- trace!("parallel worker: obtained job");
-
- // handle job
- match job {
- JobParallel::Encryption(tx, mut job) => {
- job.msg.extend([0u8; SIZE_TAG].iter());
-
- // cast to header (should never fail)
- let (mut header, body): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
- LayoutVerified::new_from_prefix(&mut job.msg[..])
- .expect("earlier code should ensure that there is ample space");
-
- // set header fields
- debug_assert!(
- job.counter < REJECT_AFTER_MESSAGES,
- "should be checked when assigning counters"
- );
- header.f_type.set(TYPE_TRANSPORT);
- header.f_receiver.set(job.keypair.send.id);
- header.f_counter.set(job.counter);
-
- // create a nonce object
- let mut nonce = [0u8; 12];
- debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
- nonce[4..].copy_from_slice(header.f_counter.as_bytes());
- let nonce = Nonce::assume_unique_for_key(nonce);
-
- // do the weird ring AEAD dance
- let key = LessSafeKey::new(
- UnboundKey::new(&CHACHA20_POLY1305, &job.keypair.send.key[..]).unwrap(),
- );
-
- // encrypt content of transport message in-place
- let end = body.len() - SIZE_TAG;
- let tag = key
- .seal_in_place_separate_tag(nonce, Aad::empty(), &mut body[..end])
- .unwrap();
-
- // append tag
- body[end..].copy_from_slice(tag.as_ref());
-
- // pass ownership
- let _ = tx.send(job);
- }
- JobParallel::Decryption(tx, mut job) => {
- // cast to header (could fail)
- let layout: Option<(LayoutVerified<&mut [u8], TransportHeader>, &mut [u8])> =
- LayoutVerified::new_from_prefix(&mut job.msg[..]);
-
- let _ = tx.send(match layout {
- Some((header, body)) => {
- debug_assert_eq!(
- header.f_type.get(),
- TYPE_TRANSPORT,
- "type and reserved bits should be checked by message de-multiplexer"
- );
- if header.f_counter.get() < REJECT_AFTER_MESSAGES {
- // create a nonce object
- let mut nonce = [0u8; 12];
- debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
- nonce[4..].copy_from_slice(header.f_counter.as_bytes());
- let nonce = Nonce::assume_unique_for_key(nonce);
-
- // do the weird ring AEAD dance
- let key = LessSafeKey::new(
- UnboundKey::new(&CHACHA20_POLY1305, &job.keypair.recv.key[..])
- .unwrap(),
- );
-
- // attempt to open (and authenticate) the body
- match key.open_in_place(nonce, Aad::empty(), body) {
- Ok(_) => Some(job),
- Err(_) => None,
- }
- } else {
- None
- }
- }
- None => None,
- });
- }
- }
- }
-}
diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs
index 45b1fcb..94e240d 100644
--- a/src/wireguard/wireguard.rs
+++ b/src/wireguard/wireguard.rs
@@ -603,7 +603,7 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
);
let device = wg.handshake.read();
let _ = device.begin(&mut rng, &peer.pk).map(|msg| {
- let _ = peer.router.send(&msg[..]).map_err(|e| {
+ let _ = peer.router.send_raw(&msg[..]).map_err(|e| {
debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e)
});
peer.state.sent_handshake_initiation();