aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2020-02-16 20:25:31 +0100
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2020-02-16 20:25:31 +0100
commitead75828cdaa5253e57b5792b51e3d99a4a78ea0 (patch)
tree97fcba5fe19efcb52c0e25cebe4ec359c0d503c8 /src
parentFixed EINVAL on read4/6 from invalid namelen (diff)
downloadwireguard-rs-ead75828cdaa5253e57b5792b51e3d99a4a78ea0.tar.xz
wireguard-rs-ead75828cdaa5253e57b5792b51e3d99a4a78ea0.zip
Simplified router code
Diffstat (limited to 'src')
-rw-r--r--src/wireguard/router/constants.rs4
-rw-r--r--src/wireguard/router/device.rs125
-rw-r--r--src/wireguard/router/inbound.rs190
-rw-r--r--src/wireguard/router/mod.rs10
-rw-r--r--src/wireguard/router/outbound.rs110
-rw-r--r--src/wireguard/router/peer.rs206
-rw-r--r--src/wireguard/router/pool.rs164
-rw-r--r--src/wireguard/router/queue.rs144
-rw-r--r--src/wireguard/router/receive.rs192
-rw-r--r--src/wireguard/router/runq.rs129
-rw-r--r--src/wireguard/router/send.rs143
-rw-r--r--src/wireguard/router/tests.rs7
-rw-r--r--src/wireguard/router/worker.rs31
-rw-r--r--src/wireguard/timers.rs4
-rw-r--r--src/wireguard/workers.rs5
15 files changed, 638 insertions, 826 deletions
diff --git a/src/wireguard/router/constants.rs b/src/wireguard/router/constants.rs
index af76299..f083811 100644
--- a/src/wireguard/router/constants.rs
+++ b/src/wireguard/router/constants.rs
@@ -4,6 +4,6 @@ pub const MAX_QUEUED_PACKETS: usize = 1024;
// performance constants
-pub const PARALLEL_QUEUE_SIZE: usize = MAX_QUEUED_PACKETS;
+pub const PARALLEL_QUEUE_SIZE: usize = 4 * MAX_QUEUED_PACKETS;
+
pub const INORDER_QUEUE_SIZE: usize = MAX_QUEUED_PACKETS;
-pub const MAX_INORDER_CONSUME: usize = INORDER_QUEUE_SIZE;
diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs
index 96b7d82..9d78178 100644
--- a/src/wireguard/router/device.rs
+++ b/src/wireguard/router/device.rs
@@ -10,19 +10,16 @@ use spin::{Mutex, RwLock};
use zerocopy::LayoutVerified;
use super::anti_replay::AntiReplay;
-use super::pool::Job;
use super::constants::PARALLEL_QUEUE_SIZE;
-use super::inbound;
-use super::outbound;
-
use super::messages::{TransportHeader, TYPE_TRANSPORT};
use super::peer::{new_peer, Peer, PeerHandle};
use super::types::{Callbacks, RouterError};
use super::SIZE_MESSAGE_PREFIX;
+use super::receive::ReceiveJob;
use super::route::RoutingTable;
-use super::runq::RunQueue;
+use super::worker::{worker, JobUnion};
use super::super::{tun, udp, Endpoint, KeyPair};
use super::ParallelQueue;
@@ -38,13 +35,8 @@ pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer
pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state
pub table: RoutingTable<Peer<E, C, T, B>>,
- // work queues
- pub queue_outbound: ParallelQueue<Job<Peer<E, C, T, B>, outbound::Outbound>>,
- pub queue_inbound: ParallelQueue<Job<Peer<E, C, T, B>, inbound::Inbound<E, C, T, B>>>,
-
- // run queues
- pub run_inbound: RunQueue<Peer<E, C, T, B>>,
- pub run_outbound: RunQueue<Peer<E, C, T, B>>,
+ // work queue
+ pub work: ParallelQueue<JobUnion<E, C, T, B>>,
}
pub struct EncryptionState {
@@ -101,13 +93,8 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop
fn drop(&mut self) {
debug!("router: dropping device");
- // close worker queues
- self.state.queue_outbound.close();
- self.state.queue_inbound.close();
-
- // close run queues
- self.state.run_outbound.close();
- self.state.run_inbound.close();
+ // close worker queue
+ self.state.work.close();
// join all worker threads
while match self.handles.pop() {
@@ -118,77 +105,28 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop
}
_ => false,
} {}
-
- debug!("router: device dropped");
}
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<E, C, T, B> {
pub fn new(num_workers: usize, tun: T) -> DeviceHandle<E, C, T, B> {
- // allocate shared device state
- let (queue_outbound, mut outrx) = ParallelQueue::new(num_workers, PARALLEL_QUEUE_SIZE);
- let (queue_inbound, mut inrx) = ParallelQueue::new(num_workers, PARALLEL_QUEUE_SIZE);
+ let (work, mut consumers) = ParallelQueue::new(num_workers, PARALLEL_QUEUE_SIZE);
let device = Device {
inner: Arc::new(DeviceInner {
+ work,
inbound: tun,
- queue_inbound,
outbound: RwLock::new((true, None)),
- queue_outbound,
- run_inbound: RunQueue::new(),
- run_outbound: RunQueue::new(),
recv: RwLock::new(HashMap::new()),
table: RoutingTable::new(),
}),
};
// start worker threads
- let mut threads = Vec::with_capacity(4 * num_workers);
-
- // inbound/decryption workers
- for _ in 0..num_workers {
- // parallel workers (parallel processing)
- {
- let device = device.clone();
- let rx = inrx.pop().unwrap();
- threads.push(thread::spawn(move || {
- log::debug!("inbound parallel router worker started");
- inbound::parallel(device, rx)
- }));
- }
-
- // sequential workers (in-order processing)
- {
- let device = device.clone();
- threads.push(thread::spawn(move || {
- log::debug!("inbound sequential router worker started");
- inbound::sequential(device)
- }));
- }
- }
-
- // outbound/encryption workers
- for _ in 0..num_workers {
- // parallel workers (parallel processing)
- {
- let device = device.clone();
- let rx = outrx.pop().unwrap();
- threads.push(thread::spawn(move || {
- log::debug!("outbound parallel router worker started");
- outbound::parallel(device, rx)
- }));
- }
-
- // sequential workers (in-order processing)
- {
- let device = device.clone();
- threads.push(thread::spawn(move || {
- log::debug!("outbound sequential router worker started");
- outbound::sequential(device)
- }));
- }
+ let mut threads = Vec::with_capacity(num_workers);
+ while let Some(rx) = consumers.pop() {
+ threads.push(thread::spawn(move || worker(rx)));
}
-
- debug_assert_eq!(threads.len(), num_workers * 4);
+ debug_assert_eq!(threads.len(), num_workers);
// return exported device handle
DeviceHandle {
@@ -197,6 +135,16 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<
}
}
+ pub fn send_raw(&self, msg : &[u8], dst: &mut E) -> Result<(), B::Error> {
+ let bind = self.state.outbound.read();
+ if bind.0 {
+ if let Some(bind) = bind.1.as_ref() {
+ return bind.write(msg, dst);
+ }
+ }
+ return Ok(())
+ }
+
/// Brings the router down.
/// When the router is brought down it:
/// - Prevents transmission of outbound messages.
@@ -250,10 +198,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<
.ok_or(RouterError::NoCryptoKeyRoute)?;
// schedule for encryption and transmission to peer
- if let Some(job) = peer.send_job(msg, true) {
- self.state.queue_outbound.send(job);
- }
-
+ peer.send(msg, true);
Ok(())
}
@@ -297,10 +242,13 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<
.get(&header.f_receiver.get())
.ok_or(RouterError::UnknownReceiverId)?;
- // schedule for decryption and TUN write
- if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) {
- log::trace!("schedule decryption of transport message");
- self.state.queue_inbound.send(job);
+ // create inbound job
+ let job = ReceiveJob::new(msg, dec.clone(), src);
+
+ // 1. add to sequential queue (drop if full)
+ // 2. then add to parallel work queue (wait if full)
+ if dec.peer.inbound.push(job.clone()) {
+ self.state.work.send(JobUnion::Inbound(job));
}
Ok(())
}
@@ -311,17 +259,4 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<
pub fn set_outbound_writer(&self, new: B) {
self.state.outbound.write().1 = Some(new);
}
-
- pub fn write(&self, msg: &[u8], endpoint: &mut E) -> Result<(), RouterError> {
- let outbound = self.state.outbound.read();
- if outbound.0 {
- outbound
- .1
- .as_ref()
- .ok_or(RouterError::SendError)
- .and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError))
- } else {
- Ok(())
- }
- }
}
diff --git a/src/wireguard/router/inbound.rs b/src/wireguard/router/inbound.rs
deleted file mode 100644
index dc2c44e..0000000
--- a/src/wireguard/router/inbound.rs
+++ /dev/null
@@ -1,190 +0,0 @@
-use std::mem;
-use std::sync::atomic::Ordering;
-use std::sync::Arc;
-
-use crossbeam_channel::Receiver;
-use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
-use zerocopy::{AsBytes, LayoutVerified};
-
-use super::constants::MAX_INORDER_CONSUME;
-use super::device::DecryptionState;
-use super::device::Device;
-use super::messages::TransportHeader;
-use super::peer::Peer;
-use super::pool::*;
-use super::types::Callbacks;
-use super::{tun, udp, Endpoint};
-use super::{REJECT_AFTER_MESSAGES, SIZE_TAG};
-
-pub struct Inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
- msg: Vec<u8>,
- failed: bool,
- state: Arc<DecryptionState<E, C, T, B>>,
- endpoint: Option<E>,
-}
-
-impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Inbound<E, C, T, B> {
- pub fn new(
- msg: Vec<u8>,
- state: Arc<DecryptionState<E, C, T, B>>,
- endpoint: E,
- ) -> Inbound<E, C, T, B> {
- Inbound {
- msg,
- state,
- failed: false,
- endpoint: Some(endpoint),
- }
- }
-}
-
-#[inline(always)]
-pub fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- device: Device<E, C, T, B>,
- receiver: Receiver<Job<Peer<E, C, T, B>, Inbound<E, C, T, B>>>,
-) {
- // parallel work to apply
- #[inline(always)]
- fn work<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- peer: &Peer<E, C, T, B>,
- body: &mut Inbound<E, C, T, B>,
- ) {
- log::trace!("worker, parallel section, obtained job");
-
- // cast to header followed by payload
- let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
- match LayoutVerified::new_from_prefix(&mut body.msg[..]) {
- Some(v) => v,
- None => {
- log::debug!("inbound worker: failed to parse message");
- return;
- }
- };
-
- // authenticate and decrypt payload
- {
- // create nonce object
- let mut nonce = [0u8; 12];
- debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
- nonce[4..].copy_from_slice(header.f_counter.as_bytes());
- let nonce = Nonce::assume_unique_for_key(nonce);
-
- // do the weird ring AEAD dance
- let key = LessSafeKey::new(
- UnboundKey::new(&CHACHA20_POLY1305, &body.state.keypair.recv.key[..]).unwrap(),
- );
-
- // attempt to open (and authenticate) the body
- match key.open_in_place(nonce, Aad::empty(), packet) {
- Ok(_) => (),
- Err(_) => {
- // fault and return early
- log::trace!("inbound worker: authentication failure");
- body.failed = true;
- return;
- }
- }
- }
-
- // check that counter not after reject
- if header.f_counter.get() >= REJECT_AFTER_MESSAGES {
- body.failed = true;
- return;
- }
-
- // cryptokey route and strip padding
- let inner_len = {
- let length = packet.len() - SIZE_TAG;
- if length > 0 {
- peer.device.table.check_route(&peer, &packet[..length])
- } else {
- Some(0)
- }
- };
-
- // truncate to remove tag
- match inner_len {
- None => {
- log::trace!("inbound worker: cryptokey routing failed");
- body.failed = true;
- }
- Some(len) => {
- log::trace!(
- "inbound worker: good route, length = {} {}",
- len,
- if len == 0 { "(keepalive)" } else { "" }
- );
- body.msg.truncate(mem::size_of::<TransportHeader>() + len);
- }
- }
- }
-
- worker_parallel(device, |dev| &dev.run_inbound, receiver, work)
-}
-
-#[inline(always)]
-pub fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- device: Device<E, C, T, B>,
-) {
- // sequential work to apply
- fn work<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- peer: &Peer<E, C, T, B>,
- body: &mut Inbound<E, C, T, B>,
- ) {
- log::trace!("worker, sequential section, obtained job");
-
- // decryption failed, return early
- if body.failed {
- log::trace!("job faulted, remove from queue and ignore");
- return;
- }
-
- // cast transport header
- let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) =
- match LayoutVerified::new_from_prefix(&body.msg[..]) {
- Some(v) => v,
- None => {
- log::debug!("inbound worker: failed to parse message");
- return;
- }
- };
-
- // check for replay
- if !body.state.protector.lock().update(header.f_counter.get()) {
- log::debug!("inbound worker: replay detected");
- return;
- }
-
- // check for confirms key
- if !body.state.confirmed.swap(true, Ordering::SeqCst) {
- log::debug!("inbound worker: message confirms key");
- peer.confirm_key(&body.state.keypair);
- }
-
- // update endpoint
- *peer.endpoint.lock() = body.endpoint.take();
-
- // check if should be written to TUN
- let mut sent = false;
- if packet.len() > 0 {
- sent = match peer.device.inbound.write(&packet[..]) {
- Err(e) => {
- log::debug!("failed to write inbound packet to TUN: {:?}", e);
- false
- }
- Ok(_) => true,
- }
- } else {
- log::debug!("inbound worker: received keepalive")
- }
-
- // trigger callback
- C::recv(&peer.opaque, body.msg.len(), sent, &body.state.keypair);
- }
-
- // handle message from the peers inbound queue
- device.run_inbound.run(|peer| {
- peer.inbound
- .handle(|body| work(&peer, body), MAX_INORDER_CONSUME)
- });
-}
diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs
index 8238d32..699c621 100644
--- a/src/wireguard/router/mod.rs
+++ b/src/wireguard/router/mod.rs
@@ -1,16 +1,17 @@
mod anti_replay;
mod constants;
mod device;
-mod inbound;
mod ip;
mod messages;
-mod outbound;
mod peer;
-mod pool;
mod route;
-mod runq;
mod types;
+mod queue;
+mod receive;
+mod send;
+mod worker;
+
#[cfg(test)]
mod tests;
@@ -20,7 +21,6 @@ use std::mem;
use super::constants::REJECT_AFTER_MESSAGES;
use super::queue::ParallelQueue;
use super::types::*;
-use super::{tun, udp, Endpoint};
pub const SIZE_TAG: usize = 16;
pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::<TransportHeader>();
diff --git a/src/wireguard/router/outbound.rs b/src/wireguard/router/outbound.rs
deleted file mode 100644
index 1edb2fb..0000000
--- a/src/wireguard/router/outbound.rs
+++ /dev/null
@@ -1,110 +0,0 @@
-use std::sync::Arc;
-
-use crossbeam_channel::Receiver;
-use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
-use zerocopy::{AsBytes, LayoutVerified};
-
-use super::constants::MAX_INORDER_CONSUME;
-use super::device::Device;
-use super::messages::{TransportHeader, TYPE_TRANSPORT};
-use super::peer::Peer;
-use super::pool::*;
-use super::types::Callbacks;
-use super::KeyPair;
-use super::{tun, udp, Endpoint};
-use super::{REJECT_AFTER_MESSAGES, SIZE_TAG};
-
-pub struct Outbound {
- msg: Vec<u8>,
- keypair: Arc<KeyPair>,
- counter: u64,
-}
-
-impl Outbound {
- pub fn new(msg: Vec<u8>, keypair: Arc<KeyPair>, counter: u64) -> Outbound {
- Outbound {
- msg,
- keypair,
- counter,
- }
- }
-}
-
-#[inline(always)]
-pub fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- device: Device<E, C, T, B>,
- receiver: Receiver<Job<Peer<E, C, T, B>, Outbound>>,
-) {
- #[inline(always)]
- fn work<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- _peer: &Peer<E, C, T, B>,
- body: &mut Outbound,
- ) {
- log::trace!("worker, parallel section, obtained job");
-
- // make space for the tag
- body.msg.extend([0u8; SIZE_TAG].iter());
-
- // cast to header (should never fail)
- let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
- LayoutVerified::new_from_prefix(&mut body.msg[..])
- .expect("earlier code should ensure that there is ample space");
-
- // set header fields
- debug_assert!(
- body.counter < REJECT_AFTER_MESSAGES,
- "should be checked when assigning counters"
- );
- header.f_type.set(TYPE_TRANSPORT);
- header.f_receiver.set(body.keypair.send.id);
- header.f_counter.set(body.counter);
-
- // create a nonce object
- let mut nonce = [0u8; 12];
- debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
- nonce[4..].copy_from_slice(header.f_counter.as_bytes());
- let nonce = Nonce::assume_unique_for_key(nonce);
-
- // do the weird ring AEAD dance
- let key = LessSafeKey::new(
- UnboundKey::new(&CHACHA20_POLY1305, &body.keypair.send.key[..]).unwrap(),
- );
-
- // encrypt content of transport message in-place
- let end = packet.len() - SIZE_TAG;
- let tag = key
- .seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end])
- .unwrap();
-
- // append tag
- packet[end..].copy_from_slice(tag.as_ref());
- }
-
- worker_parallel(device, |dev| &dev.run_outbound, receiver, work);
-}
-
-#[inline(always)]
-pub fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- device: Device<E, C, T, B>,
-) {
- device.run_outbound.run(|peer| {
- peer.outbound.handle(
- |body| {
- log::trace!("worker, sequential section, obtained job");
-
- // send to peer
- let xmit = peer.send(&body.msg[..]).is_ok();
-
- // trigger callback
- C::send(
- &peer.opaque,
- body.msg.len(),
- xmit,
- &body.keypair,
- body.counter,
- );
- },
- MAX_INORDER_CONSUME,
- )
- });
-}
diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs
index 8fe2e1c..a20908e 100644
--- a/src/wireguard/router/peer.rs
+++ b/src/wireguard/router/peer.rs
@@ -1,13 +1,3 @@
-use std::mem;
-use std::net::{IpAddr, SocketAddr};
-use std::ops::Deref;
-use std::sync::atomic::AtomicBool;
-use std::sync::Arc;
-
-use arraydeque::{ArrayDeque, Wrapping};
-use log::debug;
-use spin::Mutex;
-
use super::super::constants::*;
use super::super::{tun, udp, Endpoint, KeyPair};
@@ -15,17 +5,25 @@ use super::anti_replay::AntiReplay;
use super::device::DecryptionState;
use super::device::Device;
use super::device::EncryptionState;
-use super::messages::TransportHeader;
use super::constants::*;
-use super::runq::ToKey;
use super::types::{Callbacks, RouterError};
use super::SIZE_MESSAGE_PREFIX;
-// worker pool related
-use super::inbound::Inbound;
-use super::outbound::Outbound;
-use super::pool::{InorderQueue, Job};
+use super::queue::Queue;
+use super::receive::ReceiveJob;
+use super::send::SendJob;
+use super::worker::JobUnion;
+
+use std::mem;
+use std::net::{IpAddr, SocketAddr};
+use std::ops::Deref;
+use std::sync::atomic::AtomicBool;
+use std::sync::Arc;
+
+use arraydeque::{ArrayDeque, Wrapping};
+use log::debug;
+use spin::Mutex;
pub struct KeyWheel {
next: Option<Arc<KeyPair>>, // next key state (unconfirmed)
@@ -37,11 +35,11 @@ pub struct KeyWheel {
pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
pub device: Device<E, C, T, B>,
pub opaque: C::Opaque,
- pub outbound: InorderQueue<Peer<E, C, T, B>, Outbound>,
- pub inbound: InorderQueue<Peer<E, C, T, B>, Inbound<E, C, T, B>>,
+ pub outbound: Queue<SendJob<E, C, T, B>>,
+ pub inbound: Queue<ReceiveJob<E, C, T, B>>,
pub staged_packets: Mutex<ArrayDeque<[Vec<u8>; MAX_QUEUED_PACKETS], Wrapping>>,
pub keys: Mutex<KeyWheel>,
- pub ekey: Mutex<Option<EncryptionState>>,
+ pub enc_key: Mutex<Option<EncryptionState>>,
pub endpoint: Mutex<Option<E>>,
}
@@ -66,13 +64,6 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PartialEq for
}
}
-impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ToKey for Peer<E, C, T, B> {
- type Key = usize;
- fn to_key(&self) -> usize {
- Arc::downgrade(&self.inner).into_raw() as usize
- }
-}
-
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Eq for Peer<E, C, T, B> {}
/* A peer is transparently dereferenced to the inner type
@@ -154,7 +145,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Peer
keys.current = None;
keys.previous = None;
- *peer.ekey.lock() = None;
+ *peer.enc_key.lock() = None;
*peer.endpoint.lock() = None;
debug!("peer dropped & removed from device");
@@ -172,9 +163,9 @@ pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
inner: Arc::new(PeerInner {
opaque,
device,
- inbound: InorderQueue::new(),
- outbound: InorderQueue::new(),
- ekey: spin::Mutex::new(None),
+ inbound: Queue::new(),
+ outbound: Queue::new(),
+ enc_key: spin::Mutex::new(None),
endpoint: spin::Mutex::new(None),
keys: spin::Mutex::new(KeyWheel {
next: None,
@@ -200,7 +191,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E,
/// # Returns
///
/// Unit if packet was sent, or an error indicating why sending failed
- pub fn send(&self, msg: &[u8]) -> Result<(), RouterError> {
+ pub fn send_raw(&self, msg: &[u8]) -> Result<(), RouterError> {
debug!("peer.send");
// send to endpoint (if known)
@@ -223,6 +214,57 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E,
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, B> {
+ /// Encrypt and send a message to the peer
+ ///
+ /// Arguments:
+ ///
+ /// - `msg` : A padded vector holding the message (allows in-place construction of the transport header)
+ /// - `stage`: Should the message be staged if no key is available
+ ///
+ pub(super) fn send(&self, msg: Vec<u8>, stage: bool) {
+ // check if key available
+ let (job, need_key) = {
+ let mut enc_key = self.enc_key.lock();
+ match enc_key.as_mut() {
+ None => {
+ if stage {
+ self.staged_packets.lock().push_back(msg);
+ };
+ (None, true)
+ }
+ Some(mut state) => {
+ // avoid integer overflow in nonce
+ if state.nonce >= REJECT_AFTER_MESSAGES - 1 {
+ *enc_key = None;
+ if stage {
+ self.staged_packets.lock().push_back(msg);
+ }
+ (None, true)
+ } else {
+ debug!("encryption state available, nonce = {}", state.nonce);
+ let job =
+ SendJob::new(msg, state.nonce, state.keypair.clone(), self.clone());
+ if self.outbound.push(job.clone()) {
+ state.nonce += 1;
+ (Some(job), false)
+ } else {
+ (None, false)
+ }
+ }
+ }
+ }
+ };
+
+ if need_key {
+ debug_assert!(job.is_none());
+ C::need_key(&self.opaque);
+ };
+
+ if let Some(job) = job {
+ self.device.work.send(JobUnion::Outbound(job))
+ }
+ }
+
// Transmit all staged packets
fn send_staged(&self) -> bool {
debug!("peer.send_staged");
@@ -232,29 +274,14 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
match staged.pop_front() {
Some(msg) => {
sent = true;
- self.send_raw(msg, false);
+ self.send(msg, false);
}
None => break sent,
}
}
}
- // Treat the msg as the payload of a transport message
- //
- // Returns true if the message was queued for transmission.
- fn send_raw(&self, msg: Vec<u8>, stage: bool) -> bool {
- log::debug!("peer.send_raw");
- match self.send_job(msg, stage) {
- Some(job) => {
- self.device.queue_outbound.send(job);
- debug!("send_raw: got obtained send_job");
- true
- }
- None => false,
- }
- }
-
- pub fn confirm_key(&self, keypair: &Arc<KeyPair>) {
+ pub(super) fn confirm_key(&self, keypair: &Arc<KeyPair>) {
debug!("peer.confirm_key");
{
// take lock and check keypair = keys.next
@@ -282,76 +309,12 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
C::key_confirmed(&self.opaque);
// set new key for encryption
- *self.ekey.lock() = ekey;
+ *self.enc_key.lock() = ekey;
}
// start transmission of staged packets
self.send_staged();
}
-
- pub fn recv_job(
- &self,
- src: E,
- dec: Arc<DecryptionState<E, C, T, B>>,
- msg: Vec<u8>,
- ) -> Option<Job<Self, Inbound<E, C, T, B>>> {
- let job = Job::new(self.clone(), Inbound::new(msg, dec, src));
- self.inbound.send(job.clone());
- Some(job)
- }
-
- pub fn send_job(&self, msg: Vec<u8>, stage: bool) -> Option<Job<Self, Outbound>> {
- debug!(
- "peer.send_job, msg.len() = {}, stage = {}",
- msg.len(),
- stage
- );
- debug_assert!(
- msg.len() >= mem::size_of::<TransportHeader>(),
- "received message with size: {:}",
- msg.len()
- );
-
- // check if has key
- let (keypair, counter) = {
- let keypair = {
- // TODO: consider using atomic ptr for ekey state
- let mut ekey = self.ekey.lock();
- match ekey.as_mut() {
- None => None,
- Some(mut state) => {
- // avoid integer overflow in nonce
- if state.nonce >= REJECT_AFTER_MESSAGES - 1 {
- *ekey = None;
- None
- } else {
- debug!("encryption state available, nonce = {}", state.nonce);
- let counter = state.nonce;
- state.nonce += 1;
- Some((state.keypair.clone(), counter))
- }
- }
- }
- };
-
- // If not suitable key was found:
- // 1. Stage packet for later transmission
- // 2. Request new key
- if keypair.is_none() && stage {
- log::trace!("packet staged");
- self.staged_packets.lock().push_back(msg);
- C::need_key(&self.opaque);
- return None;
- };
-
- keypair
- }?;
-
- // add job to in-order queue and return sender to device for inclusion in worker pool
- let job = Job::new(self.clone(), Outbound::new(msg, keypair, counter));
- self.outbound.send(job.clone());
- Some(job)
- }
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, C, T, B> {
@@ -403,7 +366,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E,
}
// clear encryption state
- *self.peer.ekey.lock() = None;
+ *self.peer.enc_key.lock() = None;
}
pub fn down(&self) {
@@ -440,7 +403,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E,
// update key-wheel
if new.initiator {
// start using key for encryption
- *self.peer.ekey.lock() = Some(EncryptionState::new(&new));
+ *self.peer.enc_key.lock() = Some(EncryptionState::new(&new));
// move current into previous
keys.previous = keys.current.as_ref().map(|v| v.clone());
@@ -474,16 +437,13 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E,
// schedule confirmation
if initiator {
- debug_assert!(self.peer.ekey.lock().is_some());
+ debug_assert!(self.peer.enc_key.lock().is_some());
debug!("peer.add_keypair: is initiator, must confirm the key");
// attempt to confirm using staged packets
if !self.peer.send_staged() {
// fall back to keepalive packet
- let ok = self.send_keepalive();
- debug!(
- "peer.add_keypair: keepalive for confirmation, sent = {}",
- ok
- );
+ self.send_keepalive();
+ debug!("peer.add_keypair: keepalive for confirmation",);
}
debug!("peer.add_keypair: key attempted confirmed");
}
@@ -495,9 +455,9 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E,
release
}
- pub fn send_keepalive(&self) -> bool {
+ pub fn send_keepalive(&self) {
debug!("peer.send_keepalive");
- self.peer.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX], true)
+ self.peer.send(vec![0u8; SIZE_MESSAGE_PREFIX], false)
}
/// Map a subnet to the peer
diff --git a/src/wireguard/router/pool.rs b/src/wireguard/router/pool.rs
deleted file mode 100644
index 3fc0026..0000000
--- a/src/wireguard/router/pool.rs
+++ /dev/null
@@ -1,164 +0,0 @@
-use std::mem;
-use std::sync::Arc;
-
-use arraydeque::ArrayDeque;
-use crossbeam_channel::Receiver;
-use spin::{Mutex, MutexGuard};
-
-use super::constants::INORDER_QUEUE_SIZE;
-use super::runq::{RunQueue, ToKey};
-
-pub struct InnerJob<P, B> {
- // peer (used by worker to schedule/handle inorder queue),
- // when the peer is None, the job is complete
- peer: Option<P>,
- pub body: B,
-}
-
-pub struct Job<P, B> {
- inner: Arc<Mutex<InnerJob<P, B>>>,
-}
-
-impl<P, B> Clone for Job<P, B> {
- fn clone(&self) -> Job<P, B> {
- Job {
- inner: self.inner.clone(),
- }
- }
-}
-
-impl<P, B> Job<P, B> {
- pub fn new(peer: P, body: B) -> Job<P, B> {
- Job {
- inner: Arc::new(Mutex::new(InnerJob {
- peer: Some(peer),
- body,
- })),
- }
- }
-}
-
-impl<P, B> Job<P, B> {
- /// Returns a mutex guard to the inner job if complete
- pub fn complete(&self) -> Option<MutexGuard<InnerJob<P, B>>> {
- self.inner
- .try_lock()
- .and_then(|m| if m.peer.is_none() { Some(m) } else { None })
- }
-}
-
-pub struct InorderQueue<P, B> {
- queue: Mutex<ArrayDeque<[Job<P, B>; INORDER_QUEUE_SIZE]>>,
-}
-
-impl<P, B> InorderQueue<P, B> {
- pub fn new() -> InorderQueue<P, B> {
- InorderQueue {
- queue: Mutex::new(ArrayDeque::new()),
- }
- }
-
- /// Add a new job to the in-order queue
- ///
- /// # Arguments
- ///
- /// - `job`: The job added to the back of the queue
- ///
- /// # Returns
- ///
- /// True if the element was added,
- /// false to indicate that the queue is full.
- pub fn send(&self, job: Job<P, B>) -> bool {
- self.queue.lock().push_back(job).is_ok()
- }
-
- /// Consume completed jobs from the in-order queue
- ///
- /// # Arguments
- ///
- /// - `f`: function to apply to the body of each jobof each job.
- /// - `limit`: maximum number of jobs to handle before returning
- ///
- /// # Returns
- ///
- /// A boolean indicating if the limit was reached:
- /// true indicating that the limit was reached,
- /// while false implies that the queue is empty or an uncompleted job was reached.
- #[inline(always)]
- pub fn handle<F: Fn(&mut B)>(&self, f: F, mut limit: usize) -> bool {
- // take the mutex
- let mut queue = self.queue.lock();
-
- while limit > 0 {
- // attempt to extract front element
- let front = queue.pop_front();
- let elem = match front {
- Some(elem) => elem,
- _ => {
- return false;
- }
- };
-
- // apply function if job complete
- let ret = if let Some(mut guard) = elem.complete() {
- mem::drop(queue);
- f(&mut guard.body);
- queue = self.queue.lock();
- false
- } else {
- true
- };
-
- // job not complete yet, return job to front
- if ret {
- queue.push_front(elem).unwrap();
- return false;
- }
- limit -= 1;
- }
-
- // did not complete all jobs
- true
- }
-}
-
-/// Allows easy construction of a parallel worker.
-/// Applicable for both decryption and encryption workers.
-#[inline(always)]
-pub fn worker_parallel<
- P: ToKey, // represents a peer (atomic reference counted pointer)
- B, // inner body type (message buffer, key material, ...)
- D, // device
- W: Fn(&P, &mut B),
- Q: Fn(&D) -> &RunQueue<P>,
->(
- device: D,
- queue: Q,
- receiver: Receiver<Job<P, B>>,
- work: W,
-) {
- log::trace!("router worker started");
- loop {
- // handle new job
- let peer = {
- // get next job
- let job = match receiver.recv() {
- Ok(job) => job,
- _ => return,
- };
-
- // lock the job
- let mut job = job.inner.lock();
-
- // take the peer from the job
- let peer = job.peer.take().unwrap();
-
- // process job
- work(&peer, &mut job.body);
- peer
- };
-
- // process inorder jobs for peer
- queue(&device).insert(peer);
- }
-}
diff --git a/src/wireguard/router/queue.rs b/src/wireguard/router/queue.rs
new file mode 100644
index 0000000..ec4492e
--- /dev/null
+++ b/src/wireguard/router/queue.rs
@@ -0,0 +1,144 @@
+use arraydeque::ArrayDeque;
+use spin::Mutex;
+
+use std::mem;
+use std::sync::atomic::{AtomicUsize, Ordering};
+
+use super::constants::INORDER_QUEUE_SIZE;
+
+pub trait SequentialJob {
+ fn is_ready(&self) -> bool;
+
+ fn sequential_work(self);
+}
+
+pub trait ParallelJob: Sized + SequentialJob {
+ fn queue(&self) -> &Queue<Self>;
+
+ fn parallel_work(&self);
+}
+
+pub struct Queue<J: SequentialJob> {
+ contenders: AtomicUsize,
+ queue: Mutex<ArrayDeque<[J; INORDER_QUEUE_SIZE]>>,
+
+ #[cfg(debug)]
+ _flag: Mutex<()>,
+}
+
+impl<J: SequentialJob> Queue<J> {
+ pub fn new() -> Queue<J> {
+ Queue {
+ contenders: AtomicUsize::new(0),
+ queue: Mutex::new(ArrayDeque::new()),
+
+ #[cfg(debug)]
+ _flag: Mutex::new(()),
+ }
+ }
+
+ pub fn push(&self, job: J) -> bool {
+ self.queue.lock().push_back(job).is_ok()
+ }
+
+ pub fn consume(&self) {
+ // check if we are the first contender
+ let pos = self.contenders.fetch_add(1, Ordering::SeqCst);
+ if pos > 0 {
+ assert!(usize::max_value() > pos, "contenders overflow");
+ return;
+ }
+
+ // enter the critical section
+ let mut contenders = 1; // myself
+ while contenders > 0 {
+ // check soundness in debug builds
+ #[cfg(debug)]
+ let _flag = self
+ ._flag
+ .try_lock()
+ .expect("contenders should ensure mutual exclusion");
+
+ // handle every ready element
+ loop {
+ let mut queue = self.queue.lock();
+
+ // check if front job is ready
+ match queue.front() {
+ None => break,
+ Some(job) => {
+ if job.is_ready() {
+ ()
+ } else {
+ break;
+ }
+ }
+ };
+
+ // take the job out of the queue
+ let job = queue.pop_front().unwrap();
+ debug_assert!(job.is_ready());
+ mem::drop(queue);
+
+ // process element
+ job.sequential_work();
+ }
+
+ #[cfg(debug)]
+ mem::drop(_flag);
+
+ // decrease contenders
+ contenders = self.contenders.fetch_sub(contenders, Ordering::SeqCst) - contenders;
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use std::sync::Arc;
+ use std::thread;
+
+ use rand::thread_rng;
+ use rand::Rng;
+
+ struct TestJob {}
+
+ impl SequentialJob for TestJob {
+ fn is_ready(&self) -> bool {
+ true
+ }
+
+ fn sequential_work(self) {}
+ }
+
+ /* Fuzz the Queue */
+ #[test]
+ fn test_queue() {
+ fn hammer(queue: &Arc<Queue<TestJob>>) {
+ let mut rng = thread_rng();
+ for _ in 0..1_000_000 {
+ if rng.gen() {
+ queue.push(TestJob {});
+ } else {
+ queue.consume();
+ }
+ }
+ }
+
+ let queue = Arc::new(Queue::new());
+
+ // repeatedly apply operations randomly from concurrent threads
+ let other = {
+ let queue = queue.clone();
+ thread::spawn(move || hammer(&queue))
+ };
+ hammer(&queue);
+
+ // wait, consume and check empty
+ other.join().unwrap();
+ queue.consume();
+ assert_eq!(queue.queue.lock().len(), 0);
+ }
+}
diff --git a/src/wireguard/router/receive.rs b/src/wireguard/router/receive.rs
new file mode 100644
index 0000000..c5fe3da
--- /dev/null
+++ b/src/wireguard/router/receive.rs
@@ -0,0 +1,192 @@
+use super::device::DecryptionState;
+use super::messages::TransportHeader;
+use super::queue::{ParallelJob, Queue, SequentialJob};
+use super::types::Callbacks;
+use super::{REJECT_AFTER_MESSAGES, SIZE_TAG};
+
+use super::super::{tun, udp, Endpoint};
+
+use std::mem;
+use std::sync::atomic::{AtomicBool, Ordering};
+use std::sync::Arc;
+
+use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
+use spin::Mutex;
+use zerocopy::{AsBytes, LayoutVerified};
+
+struct Inner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
+ ready: AtomicBool,
+ buffer: Mutex<(Option<E>, Vec<u8>)>, // endpoint & ciphertext buffer
+ state: Arc<DecryptionState<E, C, T, B>>, // decryption state (keys and replay protector)
+}
+
+pub struct ReceiveJob<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
+ Arc<Inner<E, C, T, B>>,
+);
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone
+ for ReceiveJob<E, C, T, B>
+{
+ fn clone(&self) -> ReceiveJob<E, C, T, B> {
+ ReceiveJob(self.0.clone())
+ }
+}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ReceiveJob<E, C, T, B> {
+ pub fn new(
+ buffer: Vec<u8>,
+ state: Arc<DecryptionState<E, C, T, B>>,
+ endpoint: E,
+ ) -> ReceiveJob<E, C, T, B> {
+ ReceiveJob(Arc::new(Inner {
+ ready: AtomicBool::new(false),
+ buffer: Mutex::new((Some(endpoint), buffer)),
+ state,
+ }))
+ }
+}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ParallelJob
+ for ReceiveJob<E, C, T, B>
+{
+ fn queue(&self) -> &Queue<Self> {
+ &self.0.state.peer.inbound
+ }
+
+ fn parallel_work(&self) {
+ // TODO: refactor
+ // decrypt
+ {
+ let job = &self.0;
+ let peer = &job.state.peer;
+ let mut msg = job.buffer.lock();
+
+ // cast to header followed by payload
+ let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
+ match LayoutVerified::new_from_prefix(&mut msg.1[..]) {
+ Some(v) => v,
+ None => {
+ log::debug!("inbound worker: failed to parse message");
+ return;
+ }
+ };
+
+ // authenticate and decrypt payload
+ {
+ // create nonce object
+ let mut nonce = [0u8; 12];
+ debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
+ nonce[4..].copy_from_slice(header.f_counter.as_bytes());
+ let nonce = Nonce::assume_unique_for_key(nonce);
+
+ // do the weird ring AEAD dance
+ let key = LessSafeKey::new(
+ UnboundKey::new(&CHACHA20_POLY1305, &job.state.keypair.recv.key[..]).unwrap(),
+ );
+
+ // attempt to open (and authenticate) the body
+ match key.open_in_place(nonce, Aad::empty(), packet) {
+ Ok(_) => (),
+ Err(_) => {
+ // fault and return early
+ log::trace!("inbound worker: authentication failure");
+ msg.1.truncate(0);
+ return;
+ }
+ }
+ }
+
+ // check that counter not after reject
+ if header.f_counter.get() >= REJECT_AFTER_MESSAGES {
+ msg.1.truncate(0);
+ return;
+ }
+
+ // cryptokey route and strip padding
+ let inner_len = {
+ let length = packet.len() - SIZE_TAG;
+ if length > 0 {
+ peer.device.table.check_route(&peer, &packet[..length])
+ } else {
+ Some(0)
+ }
+ };
+
+ // truncate to remove tag
+ match inner_len {
+ None => {
+ log::trace!("inbound worker: cryptokey routing failed");
+ msg.1.truncate(0);
+ }
+ Some(len) => {
+ log::trace!(
+ "inbound worker: good route, length = {} {}",
+ len,
+ if len == 0 { "(keepalive)" } else { "" }
+ );
+ msg.1.truncate(mem::size_of::<TransportHeader>() + len);
+ }
+ }
+ }
+
+ // mark ready
+ self.0.ready.store(true, Ordering::Release);
+ }
+}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> SequentialJob
+ for ReceiveJob<E, C, T, B>
+{
+ fn is_ready(&self) -> bool {
+ self.0.ready.load(Ordering::Acquire)
+ }
+
+ fn sequential_work(self) {
+ let job = &self.0;
+ let peer = &job.state.peer;
+ let mut msg = job.buffer.lock();
+ let endpoint = msg.0.take();
+
+ // cast transport header
+ let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) =
+ match LayoutVerified::new_from_prefix(&msg.1[..]) {
+ Some(v) => v,
+ None => {
+ // also covers authentication failure
+ return;
+ }
+ };
+
+ // check for replay
+ if !job.state.protector.lock().update(header.f_counter.get()) {
+ log::debug!("inbound worker: replay detected");
+ return;
+ }
+
+ // check for confirms key
+ if !job.state.confirmed.swap(true, Ordering::SeqCst) {
+ log::debug!("inbound worker: message confirms key");
+ peer.confirm_key(&job.state.keypair);
+ }
+
+ // update endpoint
+ *peer.endpoint.lock() = endpoint;
+
+ // check if should be written to TUN
+ let mut sent = false;
+ if packet.len() > 0 {
+ sent = match peer.device.inbound.write(&packet[..]) {
+ Err(e) => {
+ log::debug!("failed to write inbound packet to TUN: {:?}", e);
+ false
+ }
+ Ok(_) => true,
+ }
+ } else {
+ log::debug!("inbound worker: received keepalive")
+ }
+
+ // trigger callback
+ C::recv(&peer.opaque, msg.1.len(), sent, &job.state.keypair);
+ }
+}
diff --git a/src/wireguard/router/runq.rs b/src/wireguard/router/runq.rs
deleted file mode 100644
index 4c848cd..0000000
--- a/src/wireguard/router/runq.rs
+++ /dev/null
@@ -1,129 +0,0 @@
-use std::hash::Hash;
-use std::mem;
-use std::sync::{Condvar, Mutex};
-
-use std::collections::hash_map::Entry;
-use std::collections::HashMap;
-use std::collections::VecDeque;
-
-pub trait ToKey {
- type Key: Hash + Eq;
- fn to_key(&self) -> Self::Key;
-}
-
-pub struct RunQueue<T: ToKey> {
- cvar: Condvar,
- inner: Mutex<Inner<T>>,
-}
-
-struct Inner<T: ToKey> {
- stop: bool,
- queue: VecDeque<T>,
- members: HashMap<T::Key, usize>,
-}
-
-impl<T: ToKey> RunQueue<T> {
- pub fn close(&self) {
- let mut inner = self.inner.lock().unwrap();
- inner.stop = true;
- self.cvar.notify_all();
- }
-
- pub fn new() -> RunQueue<T> {
- RunQueue {
- cvar: Condvar::new(),
- inner: Mutex::new(Inner {
- stop: false,
- queue: VecDeque::new(),
- members: HashMap::new(),
- }),
- }
- }
-
- pub fn insert(&self, v: T) {
- let key = v.to_key();
- let mut inner = self.inner.lock().unwrap();
- match inner.members.entry(key) {
- Entry::Occupied(mut elem) => {
- *elem.get_mut() += 1;
- }
- Entry::Vacant(spot) => {
- // add entry to back of queue
- spot.insert(0);
- inner.queue.push_back(v);
-
- // wake a thread
- self.cvar.notify_one();
- }
- }
- }
-
- /// Run (consume from) the run queue using the provided function.
- /// The function should return wheter the given element should be rescheduled.
- ///
- /// # Arguments
- ///
- /// - `f` : function to apply to every element
- ///
- /// # Note
- ///
- /// The function f may be called again even when the element was not inserted back in to the
- /// queue since the last applciation and no rescheduling was requested.
- ///
- /// This happens then the function handles all work for T,
- /// but T is added to the run queue while the function is running.
- pub fn run<F: Fn(&T) -> bool>(&self, f: F) {
- let mut inner = self.inner.lock().unwrap();
- loop {
- // fetch next element
- let elem = loop {
- // run-queue closed
- if inner.stop {
- return;
- }
-
- // try to pop from queue
- match inner.queue.pop_front() {
- Some(elem) => {
- break elem;
- }
- None => (),
- };
-
- // wait for an element to be inserted
- inner = self.cvar.wait(inner).unwrap();
- };
-
- // fetch current request number
- let key = elem.to_key();
- let old_n = *inner.members.get(&key).unwrap();
- mem::drop(inner); // drop guard
-
- // handle element
- let rerun = f(&elem);
-
- // if the function requested a re-run add the element to the back of the queue
- inner = self.inner.lock().unwrap();
- if rerun {
- inner.queue.push_back(elem);
- continue;
- }
-
- // otherwise check if new requests have come in since we ran the function
- match inner.members.entry(key) {
- Entry::Occupied(occ) => {
- if *occ.get() == old_n {
- // no new requests since last, remove entry.
- occ.remove();
- } else {
- // new requests, reschedule.
- inner.queue.push_back(elem);
- }
- }
- Entry::Vacant(_) => {
- unreachable!();
- }
- }
- }
- }
-}
diff --git a/src/wireguard/router/send.rs b/src/wireguard/router/send.rs
new file mode 100644
index 0000000..8e41796
--- /dev/null
+++ b/src/wireguard/router/send.rs
@@ -0,0 +1,143 @@
+use super::queue::{SequentialJob, ParallelJob, Queue};
+use super::KeyPair;
+use super::types::Callbacks;
+use super::peer::Peer;
+use super::{REJECT_AFTER_MESSAGES, SIZE_TAG};
+use super::messages::{TransportHeader, TYPE_TRANSPORT};
+
+use super::super::{tun, udp, Endpoint};
+
+use std::sync::atomic::{AtomicBool, Ordering};
+use std::sync::Arc;
+
+use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
+use zerocopy::{AsBytes, LayoutVerified};
+use spin::Mutex;
+
+struct Inner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
+ ready: AtomicBool,
+ buffer: Mutex<Vec<u8>>,
+ counter: u64,
+ keypair: Arc<KeyPair>,
+ peer: Peer<E, C, T, B>,
+}
+
+pub struct SendJob<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> (
+ Arc<Inner<E, C, T, B>>
+);
+
+impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone for SendJob<E, C, T, B> {
+ fn clone(&self) -> SendJob<E, C, T, B> {
+ SendJob(self.0.clone())
+ }
+}
+
+impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> SendJob<E, C, T, B> {
+ pub fn new(
+ buffer: Vec<u8>,
+ counter: u64,
+ keypair: Arc<KeyPair>,
+ peer: Peer<E, C, T, B>
+ ) -> SendJob<E, C, T, B> {
+ SendJob(Arc::new(Inner{
+ buffer: Mutex::new(buffer),
+ counter,
+ keypair,
+ peer,
+ ready: AtomicBool::new(false)
+ }))
+ }
+}
+
+impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> SequentialJob for SendJob<E, C, T, B> {
+
+ fn is_ready(&self) -> bool {
+ self.0.ready.load(Ordering::Acquire)
+ }
+
+ fn sequential_work(self) {
+ debug_assert_eq!(
+ self.is_ready(),
+ true,
+ "doing sequential work
+ on an incomplete job"
+ );
+ log::trace!("processing sequential send job");
+
+ // send to peer
+ let job = &self.0;
+ let msg = job.buffer.lock();
+ let xmit = job.peer.send_raw(&msg[..]).is_ok();
+
+ // trigger callback (for timers)
+ C::send(
+ &job.peer.opaque,
+ msg.len(),
+ xmit,
+ &job.keypair,
+ job.counter,
+ );
+ }
+}
+
+
+impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ParallelJob for SendJob<E, C, T, B> {
+
+ fn queue(&self) -> &Queue<Self> {
+ &self.0.peer.outbound
+ }
+
+ fn parallel_work(&self) {
+ debug_assert_eq!(
+ self.is_ready(),
+ false,
+ "doing parallel work on completed job"
+ );
+ log::trace!("processing parallel send job");
+
+ // encrypt body
+ {
+ // make space for the tag
+ let job = &*self.0;
+ let mut msg = job.buffer.lock();
+ msg.extend([0u8; SIZE_TAG].iter());
+
+ // cast to header (should never fail)
+ let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
+ LayoutVerified::new_from_prefix(&mut msg[..])
+ .expect("earlier code should ensure that there is ample space");
+
+ // set header fields
+ debug_assert!(
+ job.counter < REJECT_AFTER_MESSAGES,
+ "should be checked when assigning counters"
+ );
+ header.f_type.set(TYPE_TRANSPORT);
+ header.f_receiver.set(job.keypair.send.id);
+ header.f_counter.set(job.counter);
+
+ // create a nonce object
+ let mut nonce = [0u8; 12];
+ debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
+ nonce[4..].copy_from_slice(header.f_counter.as_bytes());
+ let nonce = Nonce::assume_unique_for_key(nonce);
+
+ // do the weird ring AEAD dance
+ let key = LessSafeKey::new(
+ UnboundKey::new(&CHACHA20_POLY1305, &job.keypair.send.key[..]).unwrap(),
+ );
+
+ // encrypt contents of transport message in-place
+ let end = packet.len() - SIZE_TAG;
+ let tag = key
+ .seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end])
+ .unwrap();
+
+ // append tag
+ packet[end..].copy_from_slice(tag.as_ref());
+ }
+
+ // mark ready
+ self.0.ready.store(true, Ordering::Release);
+ }
+} \ No newline at end of file
diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs
index 15db368..3d5c79b 100644
--- a/src/wireguard/router/tests.rs
+++ b/src/wireguard/router/tests.rs
@@ -50,7 +50,6 @@ mod tests {
}))
}
- #[allow(dead_code)]
fn reset(&self) {
self.0.send.lock().unwrap().clear();
self.0.recv.lock().unwrap().clear();
@@ -104,9 +103,9 @@ mod tests {
}
}
- // wait for scheduling (VERY conservative)
+ // wait for scheduling
fn wait() {
- thread::sleep(Duration::from_millis(30));
+ thread::sleep(Duration::from_millis(15));
}
fn init() {
@@ -162,7 +161,7 @@ mod tests {
};
let msg = make_packet_padded(1024, src, dst, 0);
- // every iteration sends 10 MB
+ // every iteration sends 10 GB
b.iter(|| {
opaque.store(0, Ordering::SeqCst);
while opaque.load(Ordering::Acquire) < 10 * 1024 * 1024 {
diff --git a/src/wireguard/router/worker.rs b/src/wireguard/router/worker.rs
new file mode 100644
index 0000000..bbb644c
--- /dev/null
+++ b/src/wireguard/router/worker.rs
@@ -0,0 +1,31 @@
+use super::super::{tun, udp, Endpoint};
+use super::types::Callbacks;
+
+use super::queue::ParallelJob;
+use super::receive::ReceiveJob;
+use super::send::SendJob;
+
+use crossbeam_channel::Receiver;
+
+pub enum JobUnion<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
+ Outbound(SendJob<E, C, T, B>),
+ Inbound(ReceiveJob<E, C, T, B>),
+}
+
+pub fn worker<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
+ receiver: Receiver<JobUnion<E, C, T, B>>,
+) {
+ loop {
+ match receiver.recv() {
+ Err(_) => break,
+ Ok(JobUnion::Inbound(job)) => {
+ job.parallel_work();
+ job.queue().consume();
+ }
+ Ok(JobUnion::Outbound(job)) => {
+ job.parallel_work();
+ job.queue().consume();
+ }
+ }
+ }
+}
diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs
index 6b852bb..0197a9e 100644
--- a/src/wireguard/timers.rs
+++ b/src/wireguard/timers.rs
@@ -319,8 +319,8 @@ impl Timers {
let timers = peer.timers();
if timers.enabled && timers.keepalive_interval > 0 {
timers.send_keepalive.stop();
- let queued = peer.router.send_keepalive();
- log::trace!("{} : keepalive queued {}", peer, queued);
+ peer.router.send_keepalive();
+ log::trace!("{} : keepalive queued", peer);
timers
.send_persistent_keepalive
.start(Duration::from_secs(timers.keepalive_interval));
diff --git a/src/wireguard/workers.rs b/src/wireguard/workers.rs
index 02db160..70e3b3a 100644
--- a/src/wireguard/workers.rs
+++ b/src/wireguard/workers.rs
@@ -194,7 +194,8 @@ pub fn handshake_worker<T: Tun, B: UDP>(
let mut resp_len: u64 = 0;
if let Some(msg) = resp {
resp_len = msg.len() as u64;
- let _ = wg.router.write(&msg[..], &mut src).map_err(|e| {
+ // TODO: consider a more elegant solution for accessing the bind
+ let _ = wg.router.send_raw(&msg[..], &mut src).map_err(|e| {
debug!(
"{} : handshake worker, failed to send response, error = {}",
wg, e
@@ -252,7 +253,7 @@ pub fn handshake_worker<T: Tun, B: UDP>(
);
let device = wg.peers.read();
let _ = device.begin(&mut OsRng, &peer.pk).map(|msg| {
- let _ = peer.router.send(&msg[..]).map_err(|e| {
+ let _ = peer.router.send_raw(&msg[..]).map_err(|e| {
debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e)
});
peer.state.sent_handshake_initiation();