From 106c5e8b5c865c8396f824f4f5aa14d1bf0952b1 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sun, 16 Feb 2020 18:12:43 +0100 Subject: Work on router optimizations --- src/wireguard/queue.rs | 3 +- src/wireguard/router/constants.rs | 4 +- src/wireguard/router/device.rs | 100 +++------------ src/wireguard/router/inbound.rs | 190 ---------------------------- src/wireguard/router/mod.rs | 5 - src/wireguard/router/outbound.rs | 110 ---------------- src/wireguard/router/peer.rs | 192 ++++++++++++---------------- src/wireguard/router/pool.rs | 164 ------------------------ src/wireguard/router/queue.rs | 92 ++++++++++++-- src/wireguard/router/receive.rs | 184 ++++++++++++++------------- src/wireguard/router/runq.rs | 164 ------------------------ src/wireguard/router/send.rs | 95 +++++++------- src/wireguard/router/worker.rs | 30 ++++- src/wireguard/router/workers.rs | 257 -------------------------------------- src/wireguard/wireguard.rs | 2 +- 15 files changed, 350 insertions(+), 1242 deletions(-) delete mode 100644 src/wireguard/router/inbound.rs delete mode 100644 src/wireguard/router/outbound.rs delete mode 100644 src/wireguard/router/pool.rs delete mode 100644 src/wireguard/router/runq.rs delete mode 100644 src/wireguard/router/workers.rs (limited to 'src') diff --git a/src/wireguard/queue.rs b/src/wireguard/queue.rs index 75b9104..eea1ccf 100644 --- a/src/wireguard/queue.rs +++ b/src/wireguard/queue.rs @@ -1,6 +1,7 @@ -use crossbeam_channel::{bounded, Receiver, Sender}; use std::sync::Mutex; +use crossbeam_channel::{bounded, Receiver, Sender}; + pub struct ParallelQueue { queue: Mutex>>, } diff --git a/src/wireguard/router/constants.rs b/src/wireguard/router/constants.rs index af76299..f083811 100644 --- a/src/wireguard/router/constants.rs +++ b/src/wireguard/router/constants.rs @@ -4,6 +4,6 @@ pub const MAX_QUEUED_PACKETS: usize = 1024; // performance constants -pub const PARALLEL_QUEUE_SIZE: usize = MAX_QUEUED_PACKETS; +pub const PARALLEL_QUEUE_SIZE: usize = 4 * MAX_QUEUED_PACKETS; + pub const INORDER_QUEUE_SIZE: usize = MAX_QUEUED_PACKETS; -pub const MAX_INORDER_CONSUME: usize = INORDER_QUEUE_SIZE; diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs index f903a8e..b8e3821 100644 --- a/src/wireguard/router/device.rs +++ b/src/wireguard/router/device.rs @@ -10,19 +10,16 @@ use spin::{Mutex, RwLock}; use zerocopy::LayoutVerified; use super::anti_replay::AntiReplay; -use super::pool::Job; use super::constants::PARALLEL_QUEUE_SIZE; -use super::inbound; -use super::outbound; - use super::messages::{TransportHeader, TYPE_TRANSPORT}; use super::peer::{new_peer, Peer, PeerHandle}; use super::types::{Callbacks, RouterError}; use super::SIZE_MESSAGE_PREFIX; +use super::receive::ReceiveJob; use super::route::RoutingTable; -use super::runq::RunQueue; +use super::worker::{worker, JobUnion}; use super::super::{tun, udp, Endpoint, KeyPair}; use super::ParallelQueue; @@ -38,13 +35,8 @@ pub struct DeviceInner>>>, // receiver id -> decryption state pub table: RoutingTable>, - // work queues - pub queue_outbound: ParallelQueue, outbound::Outbound>>, - pub queue_inbound: ParallelQueue, inbound::Inbound>>, - - // run queues - pub run_inbound: RunQueue>, - pub run_outbound: RunQueue>, + // work queue + pub work: ParallelQueue>, } pub struct EncryptionState { @@ -101,13 +93,8 @@ impl> Drop fn drop(&mut self) { debug!("router: dropping device"); - // close worker queues - self.state.queue_outbound.close(); - self.state.queue_inbound.close(); - - // close run queues - self.state.run_outbound.close(); - self.state.run_inbound.close(); + // close worker queue + self.state.work.close(); // join all worker threads while match self.handles.pop() { @@ -118,24 +105,17 @@ impl> Drop } _ => false, } {} - - debug!("router: device dropped"); } } impl> DeviceHandle { pub fn new(num_workers: usize, tun: T) -> DeviceHandle { - // allocate shared device state - let (queue_outbound, mut outrx) = ParallelQueue::new(num_workers, PARALLEL_QUEUE_SIZE); - let (queue_inbound, mut inrx) = ParallelQueue::new(num_workers, PARALLEL_QUEUE_SIZE); + let (work, mut consumers) = ParallelQueue::new(num_workers, PARALLEL_QUEUE_SIZE); let device = Device { inner: Arc::new(DeviceInner { + work, inbound: tun, - queue_inbound, outbound: RwLock::new((true, None)), - queue_outbound, - run_inbound: RunQueue::new(), - run_outbound: RunQueue::new(), recv: RwLock::new(HashMap::new()), table: RoutingTable::new(), }), @@ -143,52 +123,10 @@ impl> DeviceHandle< // start worker threads let mut threads = Vec::with_capacity(num_workers); - - // inbound/decryption workers - for _ in 0..num_workers { - // parallel workers (parallel processing) - { - let device = device.clone(); - let rx = inrx.pop().unwrap(); - threads.push(thread::spawn(move || { - log::debug!("inbound parallel router worker started"); - inbound::parallel(device, rx) - })); - } - - // sequential workers (in-order processing) - { - let device = device.clone(); - threads.push(thread::spawn(move || { - log::debug!("inbound sequential router worker started"); - inbound::sequential(device) - })); - } + while let Some(rx) = consumers.pop() { + threads.push(thread::spawn(move || worker(rx))); } - - // outbound/encryption workers - for _ in 0..num_workers { - // parallel workers (parallel processing) - { - let device = device.clone(); - let rx = outrx.pop().unwrap(); - threads.push(thread::spawn(move || { - log::debug!("outbound parallel router worker started"); - outbound::parallel(device, rx) - })); - } - - // sequential workers (in-order processing) - { - let device = device.clone(); - threads.push(thread::spawn(move || { - log::debug!("outbound sequential router worker started"); - outbound::sequential(device) - })); - } - } - - debug_assert_eq!(threads.len(), num_workers * 4); + debug_assert_eq!(threads.len(), num_workers); // return exported device handle DeviceHandle { @@ -250,10 +188,7 @@ impl> DeviceHandle< .ok_or(RouterError::NoCryptoKeyRoute)?; // schedule for encryption and transmission to peer - if let Some(job) = peer.send_job(msg, true) { - self.state.queue_outbound.send(job); - } - + peer.send(msg, true); Ok(()) } @@ -297,10 +232,13 @@ impl> DeviceHandle< .get(&header.f_receiver.get()) .ok_or(RouterError::UnknownReceiverId)?; - // schedule for decryption and TUN write - if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) { - log::trace!("schedule decryption of transport message"); - self.state.queue_inbound.send(job); + // create inbound job + let job = ReceiveJob::new(msg, dec.clone(), src); + + // 1. add to sequential queue (drop if full) + // 2. then add to parallel work queue (wait if full) + if dec.peer.inbound.push(job.clone()) { + self.state.work.send(JobUnion::Inbound(job)); } Ok(()) } diff --git a/src/wireguard/router/inbound.rs b/src/wireguard/router/inbound.rs deleted file mode 100644 index dc2c44e..0000000 --- a/src/wireguard/router/inbound.rs +++ /dev/null @@ -1,190 +0,0 @@ -use std::mem; -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use crossbeam_channel::Receiver; -use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; -use zerocopy::{AsBytes, LayoutVerified}; - -use super::constants::MAX_INORDER_CONSUME; -use super::device::DecryptionState; -use super::device::Device; -use super::messages::TransportHeader; -use super::peer::Peer; -use super::pool::*; -use super::types::Callbacks; -use super::{tun, udp, Endpoint}; -use super::{REJECT_AFTER_MESSAGES, SIZE_TAG}; - -pub struct Inbound> { - msg: Vec, - failed: bool, - state: Arc>, - endpoint: Option, -} - -impl> Inbound { - pub fn new( - msg: Vec, - state: Arc>, - endpoint: E, - ) -> Inbound { - Inbound { - msg, - state, - failed: false, - endpoint: Some(endpoint), - } - } -} - -#[inline(always)] -pub fn parallel>( - device: Device, - receiver: Receiver, Inbound>>, -) { - // parallel work to apply - #[inline(always)] - fn work>( - peer: &Peer, - body: &mut Inbound, - ) { - log::trace!("worker, parallel section, obtained job"); - - // cast to header followed by payload - let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = - match LayoutVerified::new_from_prefix(&mut body.msg[..]) { - Some(v) => v, - None => { - log::debug!("inbound worker: failed to parse message"); - return; - } - }; - - // authenticate and decrypt payload - { - // create nonce object - let mut nonce = [0u8; 12]; - debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); - nonce[4..].copy_from_slice(header.f_counter.as_bytes()); - let nonce = Nonce::assume_unique_for_key(nonce); - - // do the weird ring AEAD dance - let key = LessSafeKey::new( - UnboundKey::new(&CHACHA20_POLY1305, &body.state.keypair.recv.key[..]).unwrap(), - ); - - // attempt to open (and authenticate) the body - match key.open_in_place(nonce, Aad::empty(), packet) { - Ok(_) => (), - Err(_) => { - // fault and return early - log::trace!("inbound worker: authentication failure"); - body.failed = true; - return; - } - } - } - - // check that counter not after reject - if header.f_counter.get() >= REJECT_AFTER_MESSAGES { - body.failed = true; - return; - } - - // cryptokey route and strip padding - let inner_len = { - let length = packet.len() - SIZE_TAG; - if length > 0 { - peer.device.table.check_route(&peer, &packet[..length]) - } else { - Some(0) - } - }; - - // truncate to remove tag - match inner_len { - None => { - log::trace!("inbound worker: cryptokey routing failed"); - body.failed = true; - } - Some(len) => { - log::trace!( - "inbound worker: good route, length = {} {}", - len, - if len == 0 { "(keepalive)" } else { "" } - ); - body.msg.truncate(mem::size_of::() + len); - } - } - } - - worker_parallel(device, |dev| &dev.run_inbound, receiver, work) -} - -#[inline(always)] -pub fn sequential>( - device: Device, -) { - // sequential work to apply - fn work>( - peer: &Peer, - body: &mut Inbound, - ) { - log::trace!("worker, sequential section, obtained job"); - - // decryption failed, return early - if body.failed { - log::trace!("job faulted, remove from queue and ignore"); - return; - } - - // cast transport header - let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) = - match LayoutVerified::new_from_prefix(&body.msg[..]) { - Some(v) => v, - None => { - log::debug!("inbound worker: failed to parse message"); - return; - } - }; - - // check for replay - if !body.state.protector.lock().update(header.f_counter.get()) { - log::debug!("inbound worker: replay detected"); - return; - } - - // check for confirms key - if !body.state.confirmed.swap(true, Ordering::SeqCst) { - log::debug!("inbound worker: message confirms key"); - peer.confirm_key(&body.state.keypair); - } - - // update endpoint - *peer.endpoint.lock() = body.endpoint.take(); - - // check if should be written to TUN - let mut sent = false; - if packet.len() > 0 { - sent = match peer.device.inbound.write(&packet[..]) { - Err(e) => { - log::debug!("failed to write inbound packet to TUN: {:?}", e); - false - } - Ok(_) => true, - } - } else { - log::debug!("inbound worker: received keepalive") - } - - // trigger callback - C::recv(&peer.opaque, body.msg.len(), sent, &body.state.keypair); - } - - // handle message from the peers inbound queue - device.run_inbound.run(|peer| { - peer.inbound - .handle(|body| work(&peer, body), MAX_INORDER_CONSUME) - }); -} diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs index ec5cc63..699c621 100644 --- a/src/wireguard/router/mod.rs +++ b/src/wireguard/router/mod.rs @@ -1,14 +1,10 @@ mod anti_replay; mod constants; mod device; -mod inbound; mod ip; mod messages; -mod outbound; mod peer; -mod pool; mod route; -mod runq; mod types; mod queue; @@ -25,7 +21,6 @@ use std::mem; use super::constants::REJECT_AFTER_MESSAGES; use super::queue::ParallelQueue; use super::types::*; -use super::{tun, udp, Endpoint}; pub const SIZE_TAG: usize = 16; pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::(); diff --git a/src/wireguard/router/outbound.rs b/src/wireguard/router/outbound.rs deleted file mode 100644 index 1edb2fb..0000000 --- a/src/wireguard/router/outbound.rs +++ /dev/null @@ -1,110 +0,0 @@ -use std::sync::Arc; - -use crossbeam_channel::Receiver; -use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; -use zerocopy::{AsBytes, LayoutVerified}; - -use super::constants::MAX_INORDER_CONSUME; -use super::device::Device; -use super::messages::{TransportHeader, TYPE_TRANSPORT}; -use super::peer::Peer; -use super::pool::*; -use super::types::Callbacks; -use super::KeyPair; -use super::{tun, udp, Endpoint}; -use super::{REJECT_AFTER_MESSAGES, SIZE_TAG}; - -pub struct Outbound { - msg: Vec, - keypair: Arc, - counter: u64, -} - -impl Outbound { - pub fn new(msg: Vec, keypair: Arc, counter: u64) -> Outbound { - Outbound { - msg, - keypair, - counter, - } - } -} - -#[inline(always)] -pub fn parallel>( - device: Device, - receiver: Receiver, Outbound>>, -) { - #[inline(always)] - fn work>( - _peer: &Peer, - body: &mut Outbound, - ) { - log::trace!("worker, parallel section, obtained job"); - - // make space for the tag - body.msg.extend([0u8; SIZE_TAG].iter()); - - // cast to header (should never fail) - let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = - LayoutVerified::new_from_prefix(&mut body.msg[..]) - .expect("earlier code should ensure that there is ample space"); - - // set header fields - debug_assert!( - body.counter < REJECT_AFTER_MESSAGES, - "should be checked when assigning counters" - ); - header.f_type.set(TYPE_TRANSPORT); - header.f_receiver.set(body.keypair.send.id); - header.f_counter.set(body.counter); - - // create a nonce object - let mut nonce = [0u8; 12]; - debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); - nonce[4..].copy_from_slice(header.f_counter.as_bytes()); - let nonce = Nonce::assume_unique_for_key(nonce); - - // do the weird ring AEAD dance - let key = LessSafeKey::new( - UnboundKey::new(&CHACHA20_POLY1305, &body.keypair.send.key[..]).unwrap(), - ); - - // encrypt content of transport message in-place - let end = packet.len() - SIZE_TAG; - let tag = key - .seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end]) - .unwrap(); - - // append tag - packet[end..].copy_from_slice(tag.as_ref()); - } - - worker_parallel(device, |dev| &dev.run_outbound, receiver, work); -} - -#[inline(always)] -pub fn sequential>( - device: Device, -) { - device.run_outbound.run(|peer| { - peer.outbound.handle( - |body| { - log::trace!("worker, sequential section, obtained job"); - - // send to peer - let xmit = peer.send(&body.msg[..]).is_ok(); - - // trigger callback - C::send( - &peer.opaque, - body.msg.len(), - xmit, - &body.keypair, - body.counter, - ); - }, - MAX_INORDER_CONSUME, - ) - }); -} diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs index 7312bc7..710cf32 100644 --- a/src/wireguard/router/peer.rs +++ b/src/wireguard/router/peer.rs @@ -1,13 +1,3 @@ -use std::mem; -use std::net::{IpAddr, SocketAddr}; -use std::ops::Deref; -use std::sync::atomic::AtomicBool; -use std::sync::Arc; - -use arraydeque::{ArrayDeque, Wrapping}; -use log::debug; -use spin::Mutex; - use super::super::constants::*; use super::super::{tun, udp, Endpoint, KeyPair}; @@ -15,20 +5,25 @@ use super::anti_replay::AntiReplay; use super::device::DecryptionState; use super::device::Device; use super::device::EncryptionState; -use super::messages::TransportHeader; use super::constants::*; -use super::runq::ToKey; use super::types::{Callbacks, RouterError}; use super::SIZE_MESSAGE_PREFIX; -// worker pool related -use super::inbound::Inbound; -use super::outbound::Outbound; use super::queue::Queue; - -use super::send::SendJob; use super::receive::ReceiveJob; +use super::send::SendJob; +use super::worker::JobUnion; + +use std::mem; +use std::net::{IpAddr, SocketAddr}; +use std::ops::Deref; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; + +use arraydeque::{ArrayDeque, Wrapping}; +use log::debug; +use spin::Mutex; pub struct KeyWheel { next: Option>, // next key state (unconfirmed) @@ -44,7 +39,7 @@ pub struct PeerInner>, pub staged_packets: Mutex; MAX_QUEUED_PACKETS], Wrapping>>, pub keys: Mutex, - pub ekey: Mutex>, + pub enc_key: Mutex>, pub endpoint: Mutex>, } @@ -69,13 +64,6 @@ impl> PartialEq for } } -impl> ToKey for Peer { - type Key = usize; - fn to_key(&self) -> usize { - Arc::downgrade(&self.inner).into_raw() as usize - } -} - impl> Eq for Peer {} /* A peer is transparently dereferenced to the inner type @@ -157,7 +145,7 @@ impl> Drop for Peer keys.current = None; keys.previous = None; - *peer.ekey.lock() = None; + *peer.enc_key.lock() = None; *peer.endpoint.lock() = None; debug!("peer dropped & removed from device"); @@ -175,9 +163,9 @@ pub fn new_peer>( inner: Arc::new(PeerInner { opaque, device, - inbound: InorderQueue::new(), - outbound: InorderQueue::new(), - ekey: spin::Mutex::new(None), + inbound: Queue::new(), + outbound: Queue::new(), + enc_key: spin::Mutex::new(None), endpoint: spin::Mutex::new(None), keys: spin::Mutex::new(KeyWheel { next: None, @@ -203,7 +191,7 @@ impl> PeerInner Result<(), RouterError> { + pub fn send_raw(&self, msg: &[u8]) -> Result<(), RouterError> { debug!("peer.send"); // send to endpoint (if known) @@ -226,6 +214,57 @@ impl> PeerInner> Peer { + /// Encrypt and send a message to the peer + /// + /// Arguments: + /// + /// - `msg` : A padded vector holding the message (allows in-place construction of the transport header) + /// - `stage`: Should the message be staged if no key is available + /// + pub(super) fn send(&self, msg: Vec, stage: bool) { + // check if key available + let (job, need_key) = { + let mut enc_key = self.enc_key.lock(); + match enc_key.as_mut() { + None => { + if stage { + self.staged_packets.lock().push_back(msg); + }; + (None, true) + } + Some(mut state) => { + // avoid integer overflow in nonce + if state.nonce >= REJECT_AFTER_MESSAGES - 1 { + *enc_key = None; + if stage { + self.staged_packets.lock().push_back(msg); + } + (None, true) + } else { + debug!("encryption state available, nonce = {}", state.nonce); + let job = + SendJob::new(msg, state.nonce, state.keypair.clone(), self.clone()); + if self.outbound.push(job.clone()) { + state.nonce += 1; + (Some(job), false) + } else { + (None, false) + } + } + } + } + }; + + if need_key { + debug_assert!(job.is_none()); + C::need_key(&self.opaque); + }; + + if let Some(job) = job { + self.device.work.send(JobUnion::Outbound(job)) + } + } + // Transmit all staged packets fn send_staged(&self) -> bool { debug!("peer.send_staged"); @@ -235,28 +274,14 @@ impl> Peer { sent = true; - self.send_raw(msg); + self.send(msg, false); } None => break sent, } } } - // Treat the msg as the payload of a transport message - // Unlike device.send, peer.send_raw does not buffer messages when a key is not available. - fn send_raw(&self, msg: Vec) -> bool { - log::debug!("peer.send_raw"); - match self.send_job(msg, false) { - Some(job) => { - self.device.queue_outbound.send(job); - debug!("send_raw: got obtained send_job"); - true - } - None => false, - } - } - - pub fn confirm_key(&self, keypair: &Arc) { + pub(super) fn confirm_key(&self, keypair: &Arc) { debug!("peer.confirm_key"); { // take lock and check keypair = keys.next @@ -284,68 +309,12 @@ impl> Peer, stage: bool) -> Option> { - debug!("peer.send_job"); - debug_assert!( - msg.len() >= mem::size_of::(), - "received TUN message with size: {:}", - msg.len() - ); - - // check if has key - let (keypair, counter) = { - let keypair = { - // TODO: consider using atomic ptr for ekey state - let mut ekey = self.ekey.lock(); - match ekey.as_mut() { - None => None, - Some(mut state) => { - // avoid integer overflow in nonce - if state.nonce >= REJECT_AFTER_MESSAGES - 1 { - *ekey = None; - None - } else { - debug!("encryption state available, nonce = {}", state.nonce); - let counter = state.nonce; - state.nonce += 1; - - SendJob::new( - msg, - state.nonce, - state.keypair.clone(), - self.clone() - ); - - Some((state.keypair.clone(), counter)) - } - } - } - }; - - // If not suitable key was found: - // 1. Stage packet for later transmission - // 2. Request new key - if keypair.is_none() && stage { - self.staged_packets.lock().push_back(msg); - C::need_key(&self.opaque); - return None; - }; - - keypair - }?; - - // add job to in-order queue and return sender to device for inclusion in worker pool - let job = Job::new(self.clone(), Outbound::new(msg, keypair, counter)); - self.outbound.send(job.clone()); - Some(job) - } } impl> PeerHandle { @@ -397,7 +366,7 @@ impl> PeerHandle> PeerHandle> PeerHandle> PeerHandle bool { + pub fn send_keepalive(&self) { debug!("peer.send_keepalive"); - self.peer.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX]) + self.peer.send(vec![0u8; SIZE_MESSAGE_PREFIX], false) } /// Map a subnet to the peer diff --git a/src/wireguard/router/pool.rs b/src/wireguard/router/pool.rs deleted file mode 100644 index 3fc0026..0000000 --- a/src/wireguard/router/pool.rs +++ /dev/null @@ -1,164 +0,0 @@ -use std::mem; -use std::sync::Arc; - -use arraydeque::ArrayDeque; -use crossbeam_channel::Receiver; -use spin::{Mutex, MutexGuard}; - -use super::constants::INORDER_QUEUE_SIZE; -use super::runq::{RunQueue, ToKey}; - -pub struct InnerJob { - // peer (used by worker to schedule/handle inorder queue), - // when the peer is None, the job is complete - peer: Option

, - pub body: B, -} - -pub struct Job { - inner: Arc>>, -} - -impl Clone for Job { - fn clone(&self) -> Job { - Job { - inner: self.inner.clone(), - } - } -} - -impl Job { - pub fn new(peer: P, body: B) -> Job { - Job { - inner: Arc::new(Mutex::new(InnerJob { - peer: Some(peer), - body, - })), - } - } -} - -impl Job { - /// Returns a mutex guard to the inner job if complete - pub fn complete(&self) -> Option>> { - self.inner - .try_lock() - .and_then(|m| if m.peer.is_none() { Some(m) } else { None }) - } -} - -pub struct InorderQueue { - queue: Mutex; INORDER_QUEUE_SIZE]>>, -} - -impl InorderQueue { - pub fn new() -> InorderQueue { - InorderQueue { - queue: Mutex::new(ArrayDeque::new()), - } - } - - /// Add a new job to the in-order queue - /// - /// # Arguments - /// - /// - `job`: The job added to the back of the queue - /// - /// # Returns - /// - /// True if the element was added, - /// false to indicate that the queue is full. - pub fn send(&self, job: Job) -> bool { - self.queue.lock().push_back(job).is_ok() - } - - /// Consume completed jobs from the in-order queue - /// - /// # Arguments - /// - /// - `f`: function to apply to the body of each jobof each job. - /// - `limit`: maximum number of jobs to handle before returning - /// - /// # Returns - /// - /// A boolean indicating if the limit was reached: - /// true indicating that the limit was reached, - /// while false implies that the queue is empty or an uncompleted job was reached. - #[inline(always)] - pub fn handle(&self, f: F, mut limit: usize) -> bool { - // take the mutex - let mut queue = self.queue.lock(); - - while limit > 0 { - // attempt to extract front element - let front = queue.pop_front(); - let elem = match front { - Some(elem) => elem, - _ => { - return false; - } - }; - - // apply function if job complete - let ret = if let Some(mut guard) = elem.complete() { - mem::drop(queue); - f(&mut guard.body); - queue = self.queue.lock(); - false - } else { - true - }; - - // job not complete yet, return job to front - if ret { - queue.push_front(elem).unwrap(); - return false; - } - limit -= 1; - } - - // did not complete all jobs - true - } -} - -/// Allows easy construction of a parallel worker. -/// Applicable for both decryption and encryption workers. -#[inline(always)] -pub fn worker_parallel< - P: ToKey, // represents a peer (atomic reference counted pointer) - B, // inner body type (message buffer, key material, ...) - D, // device - W: Fn(&P, &mut B), - Q: Fn(&D) -> &RunQueue

, ->( - device: D, - queue: Q, - receiver: Receiver>, - work: W, -) { - log::trace!("router worker started"); - loop { - // handle new job - let peer = { - // get next job - let job = match receiver.recv() { - Ok(job) => job, - _ => return, - }; - - // lock the job - let mut job = job.inner.lock(); - - // take the peer from the job - let peer = job.peer.take().unwrap(); - - // process job - work(&peer, &mut job.body); - peer - }; - - // process inorder jobs for peer - queue(&device).insert(peer); - } -} diff --git a/src/wireguard/router/queue.rs b/src/wireguard/router/queue.rs index 045fd51..ec4492e 100644 --- a/src/wireguard/router/queue.rs +++ b/src/wireguard/router/queue.rs @@ -4,29 +4,36 @@ use spin::Mutex; use std::mem; use std::sync::atomic::{AtomicUsize, Ordering}; -const QUEUE_SIZE: usize = 1024; - -pub trait Job: Sized { - fn queue(&self) -> &Queue; +use super::constants::INORDER_QUEUE_SIZE; +pub trait SequentialJob { fn is_ready(&self) -> bool; - fn parallel_work(&self); - fn sequential_work(self); } +pub trait ParallelJob: Sized + SequentialJob { + fn queue(&self) -> &Queue; + + fn parallel_work(&self); +} -pub struct Queue { +pub struct Queue { contenders: AtomicUsize, - queue: Mutex>, + queue: Mutex>, + + #[cfg(debug)] + _flag: Mutex<()>, } -impl Queue { +impl Queue { pub fn new() -> Queue { Queue { contenders: AtomicUsize::new(0), queue: Mutex::new(ArrayDeque::new()), + + #[cfg(debug)] + _flag: Mutex::new(()), } } @@ -36,14 +43,22 @@ impl Queue { pub fn consume(&self) { // check if we are the first contender - let pos = self.contenders.fetch_add(1, Ordering::Acquire); + let pos = self.contenders.fetch_add(1, Ordering::SeqCst); if pos > 0 { - assert!(pos < usize::max_value(), "contenders overflow"); + assert!(usize::max_value() > pos, "contenders overflow"); + return; } // enter the critical section let mut contenders = 1; // myself while contenders > 0 { + // check soundness in debug builds + #[cfg(debug)] + let _flag = self + ._flag + .try_lock() + .expect("contenders should ensure mutual exclusion"); + // handle every ready element loop { let mut queue = self.queue.lock(); @@ -69,8 +84,61 @@ impl Queue { job.sequential_work(); } + #[cfg(debug)] + mem::drop(_flag); + // decrease contenders - contenders = self.contenders.fetch_sub(contenders, Ordering::Acquire) - contenders; + contenders = self.contenders.fetch_sub(contenders, Ordering::SeqCst) - contenders; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::sync::Arc; + use std::thread; + + use rand::thread_rng; + use rand::Rng; + + struct TestJob {} + + impl SequentialJob for TestJob { + fn is_ready(&self) -> bool { + true + } + + fn sequential_work(self) {} + } + + /* Fuzz the Queue */ + #[test] + fn test_queue() { + fn hammer(queue: &Arc>) { + let mut rng = thread_rng(); + for _ in 0..1_000_000 { + if rng.gen() { + queue.push(TestJob {}); + } else { + queue.consume(); + } + } } + + let queue = Arc::new(Queue::new()); + + // repeatedly apply operations randomly from concurrent threads + let other = { + let queue = queue.clone(); + thread::spawn(move || hammer(&queue)) + }; + hammer(&queue); + + // wait, consume and check empty + other.join().unwrap(); + queue.consume(); + assert_eq!(queue.queue.lock().len(), 0); } } diff --git a/src/wireguard/router/receive.rs b/src/wireguard/router/receive.rs index 53890e3..c5fe3da 100644 --- a/src/wireguard/router/receive.rs +++ b/src/wireguard/router/receive.rs @@ -1,21 +1,18 @@ -use super::queue::{Job, Queue}; -use super::KeyPair; +use super::device::DecryptionState; +use super::messages::TransportHeader; +use super::queue::{ParallelJob, Queue, SequentialJob}; use super::types::Callbacks; -use super::peer::Peer; use super::{REJECT_AFTER_MESSAGES, SIZE_TAG}; -use super::messages::{TransportHeader, TYPE_TRANSPORT}; -use super::device::DecryptionState; use super::super::{tun, udp, Endpoint}; +use std::mem; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use std::mem; use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; -use zerocopy::{AsBytes, LayoutVerified}; use spin::Mutex; - +use zerocopy::{AsBytes, LayoutVerified}; struct Inner> { ready: AtomicBool, @@ -23,49 +20,49 @@ struct Inner> { state: Arc>, // decryption state (keys and replay protector) } -pub struct ReceiveJob> { - inner: Arc>, +pub struct ReceiveJob>( + Arc>, +); + +impl> Clone + for ReceiveJob +{ + fn clone(&self) -> ReceiveJob { + ReceiveJob(self.0.clone()) + } } -impl > ReceiveJob { - fn new(buffer: Vec, state: Arc>, endpoint: E) -> Option> { - // create job - let inner = Arc::new(Inner{ +impl> ReceiveJob { + pub fn new( + buffer: Vec, + state: Arc>, + endpoint: E, + ) -> ReceiveJob { + ReceiveJob(Arc::new(Inner { ready: AtomicBool::new(false), buffer: Mutex::new((Some(endpoint), buffer)), - state - }); - - // attempt to add to queue - if state.peer.inbound.push(ReceiveJob{ inner: inner.clone()}) { - Some(ReceiveJob{inner}) - } else { - None - } - + state, + })) } } -impl > Job for ReceiveJob { +impl> ParallelJob + for ReceiveJob +{ fn queue(&self) -> &Queue { - &self.inner.state.peer.inbound - } - - fn is_ready(&self) -> bool { - self.inner.ready.load(Ordering::Acquire) + &self.0.state.peer.inbound } fn parallel_work(&self) { // TODO: refactor // decrypt { - let job = &self.inner; + let job = &self.0; let peer = &job.state.peer; let mut msg = job.buffer.lock(); - - let failed = || { - // cast to header followed by payload - let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = + + // cast to header followed by payload + let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = match LayoutVerified::new_from_prefix(&mut msg.1[..]) { Some(v) => v, None => { @@ -74,73 +71,81 @@ impl > Job for Rece } }; - // authenticate and decrypt payload - { - // create nonce object - let mut nonce = [0u8; 12]; - debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); - nonce[4..].copy_from_slice(header.f_counter.as_bytes()); - let nonce = Nonce::assume_unique_for_key(nonce); - - // do the weird ring AEAD dance - let key = LessSafeKey::new( - UnboundKey::new(&CHACHA20_POLY1305, &job.state.keypair.recv.key[..]).unwrap(), - ); - - // attempt to open (and authenticate) the body - match key.open_in_place(nonce, Aad::empty(), packet) { - Ok(_) => (), - Err(_) => { - // fault and return early - log::trace!("inbound worker: authentication failure"); - msg.1.truncate(0); - return; - } + // authenticate and decrypt payload + { + // create nonce object + let mut nonce = [0u8; 12]; + debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); + nonce[4..].copy_from_slice(header.f_counter.as_bytes()); + let nonce = Nonce::assume_unique_for_key(nonce); + + // do the weird ring AEAD dance + let key = LessSafeKey::new( + UnboundKey::new(&CHACHA20_POLY1305, &job.state.keypair.recv.key[..]).unwrap(), + ); + + // attempt to open (and authenticate) the body + match key.open_in_place(nonce, Aad::empty(), packet) { + Ok(_) => (), + Err(_) => { + // fault and return early + log::trace!("inbound worker: authentication failure"); + msg.1.truncate(0); + return; } } + } - // check that counter not after reject - if header.f_counter.get() >= REJECT_AFTER_MESSAGES { - msg.1.truncate(0); - return; - } - - // cryptokey route and strip padding - let inner_len = { - let length = packet.len() - SIZE_TAG; - if length > 0 { - peer.device.table.check_route(&peer, &packet[..length]) - } else { - Some(0) - } - }; + // check that counter not after reject + if header.f_counter.get() >= REJECT_AFTER_MESSAGES { + msg.1.truncate(0); + return; + } - // truncate to remove tag - match inner_len { - None => { - log::trace!("inbound worker: cryptokey routing failed"); - msg.1.truncate(0); - } - Some(len) => { - log::trace!( - "inbound worker: good route, length = {} {}", - len, - if len == 0 { "(keepalive)" } else { "" } - ); - msg.1.truncate(mem::size_of::() + len); - } + // cryptokey route and strip padding + let inner_len = { + let length = packet.len() - SIZE_TAG; + if length > 0 { + peer.device.table.check_route(&peer, &packet[..length]) + } else { + Some(0) } }; + + // truncate to remove tag + match inner_len { + None => { + log::trace!("inbound worker: cryptokey routing failed"); + msg.1.truncate(0); + } + Some(len) => { + log::trace!( + "inbound worker: good route, length = {} {}", + len, + if len == 0 { "(keepalive)" } else { "" } + ); + msg.1.truncate(mem::size_of::() + len); + } + } } // mark ready - self.inner.ready.store(true, Ordering::Release); + self.0.ready.store(true, Ordering::Release); + } +} + +impl> SequentialJob + for ReceiveJob +{ + fn is_ready(&self) -> bool { + self.0.ready.load(Ordering::Acquire) } fn sequential_work(self) { - let job = &self.inner; + let job = &self.0; let peer = &job.state.peer; let mut msg = job.buffer.lock(); + let endpoint = msg.0.take(); // cast transport header let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) = @@ -165,7 +170,7 @@ impl > Job for Rece } // update endpoint - *peer.endpoint.lock() = msg.0.take(); + *peer.endpoint.lock() = endpoint; // check if should be written to TUN let mut sent = false; @@ -184,5 +189,4 @@ impl > Job for Rece // trigger callback C::recv(&peer.opaque, msg.1.len(), sent, &job.state.keypair); } - } diff --git a/src/wireguard/router/runq.rs b/src/wireguard/router/runq.rs deleted file mode 100644 index 44e11a1..0000000 --- a/src/wireguard/router/runq.rs +++ /dev/null @@ -1,164 +0,0 @@ -use std::hash::Hash; -use std::mem; -use std::sync::{Condvar, Mutex}; - -use std::collections::hash_map::Entry; -use std::collections::HashMap; -use std::collections::VecDeque; - -pub trait ToKey { - type Key: Hash + Eq; - fn to_key(&self) -> Self::Key; -} - -pub struct RunQueue { - cvar: Condvar, - inner: Mutex>, -} - -struct Inner { - stop: bool, - queue: VecDeque, - members: HashMap, -} - -impl RunQueue { - pub fn close(&self) { - let mut inner = self.inner.lock().unwrap(); - inner.stop = true; - self.cvar.notify_all(); - } - - pub fn new() -> RunQueue { - RunQueue { - cvar: Condvar::new(), - inner: Mutex::new(Inner { - stop: false, - queue: VecDeque::new(), - members: HashMap::new(), - }), - } - } - - pub fn insert(&self, v: T) { - let key = v.to_key(); - let mut inner = self.inner.lock().unwrap(); - match inner.members.entry(key) { - Entry::Occupied(mut elem) => { - *elem.get_mut() += 1; - } - Entry::Vacant(spot) => { - // add entry to back of queue - spot.insert(0); - inner.queue.push_back(v); - - // wake a thread - self.cvar.notify_one(); - } - } - } - - /// Run (consume from) the run queue using the provided function. - /// The function should return wheter the given element should be rescheduled. - /// - /// # Arguments - /// - /// - `f` : function to apply to every element - /// - /// # Note - /// - /// The function f may be called again even when the element was not inserted back in to the - /// queue since the last applciation and no rescheduling was requested. - /// - /// This happens then the function handles all work for T, - /// but T is added to the run queue while the function is running. - pub fn run bool>(&self, f: F) { - let mut inner = self.inner.lock().unwrap(); - loop { - // fetch next element - let elem = loop { - // run-queue closed - if inner.stop { - return; - } - - // try to pop from queue - match inner.queue.pop_front() { - Some(elem) => { - break elem; - } - None => (), - }; - - // wait for an element to be inserted - inner = self.cvar.wait(inner).unwrap(); - }; - - // fetch current request number - let key = elem.to_key(); - let old_n = *inner.members.get(&key).unwrap(); - mem::drop(inner); // drop guard - - // handle element - let rerun = f(&elem); - - // if the function requested a re-run add the element to the back of the queue - inner = self.inner.lock().unwrap(); - if rerun { - inner.queue.push_back(elem); - continue; - } - - // otherwise check if new requests have come in since we ran the function - match inner.members.entry(key) { - Entry::Occupied(occ) => { - if *occ.get() == old_n { - // no new requests since last, remove entry. - occ.remove(); - } else { - // new requests, reschedule. - inner.queue.push_back(elem); - } - } - Entry::Vacant(_) => { - unreachable!(); - } - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::thread; - use std::time::Duration; - - /* - #[test] - fn test_wait() { - let queue: Arc> = Arc::new(RunQueue::new()); - - { - let queue = queue.clone(); - thread::spawn(move || { - queue.run(|e| { - println!("t0 {}", e); - thread::sleep(Duration::from_millis(100)); - }) - }); - } - - { - let queue = queue.clone(); - thread::spawn(move || { - queue.run(|e| { - println!("t1 {}", e); - thread::sleep(Duration::from_millis(100)); - }) - }); - } - - } - */ -} diff --git a/src/wireguard/router/send.rs b/src/wireguard/router/send.rs index 2bd4abd..8e41796 100644 --- a/src/wireguard/router/send.rs +++ b/src/wireguard/router/send.rs @@ -1,4 +1,4 @@ -use super::queue::{Job, Queue}; +use super::queue::{SequentialJob, ParallelJob, Queue}; use super::KeyPair; use super::types::Callbacks; use super::peer::Peer; @@ -22,8 +22,14 @@ struct Inner> { peer: Peer, } -pub struct SendJob> { - inner: Arc>, +pub struct SendJob> ( + Arc> +); + +impl > Clone for SendJob { + fn clone(&self) -> SendJob { + SendJob(self.0.clone()) + } } impl > SendJob { @@ -32,32 +38,53 @@ impl > SendJob, peer: Peer - ) -> Option> { - // create job - let inner = Arc::new(Inner{ + ) -> SendJob { + SendJob(Arc::new(Inner{ buffer: Mutex::new(buffer), counter, keypair, peer, ready: AtomicBool::new(false) - }); - - // attempt to add to queue - if peer.outbound.push(SendJob{ inner: inner.clone()}) { - Some(SendJob{inner}) - } else { - None - } + })) } } -impl > Job for SendJob { - fn queue(&self) -> &Queue { - &self.inner.peer.outbound +impl > SequentialJob for SendJob { + + fn is_ready(&self) -> bool { + self.0.ready.load(Ordering::Acquire) } - fn is_ready(&self) -> bool { - self.inner.ready.load(Ordering::Acquire) + fn sequential_work(self) { + debug_assert_eq!( + self.is_ready(), + true, + "doing sequential work + on an incomplete job" + ); + log::trace!("processing sequential send job"); + + // send to peer + let job = &self.0; + let msg = job.buffer.lock(); + let xmit = job.peer.send_raw(&msg[..]).is_ok(); + + // trigger callback (for timers) + C::send( + &job.peer.opaque, + msg.len(), + xmit, + &job.keypair, + job.counter, + ); + } +} + + +impl > ParallelJob for SendJob { + + fn queue(&self) -> &Queue { + &self.0.peer.outbound } fn parallel_work(&self) { @@ -71,7 +98,7 @@ impl > Job for Send // encrypt body { // make space for the tag - let job = &*self.inner; + let job = &*self.0; let mut msg = job.buffer.lock(); msg.extend([0u8; SIZE_TAG].iter()); @@ -111,30 +138,6 @@ impl > Job for Send } // mark ready - self.inner.ready.store(true, Ordering::Release); - } - - fn sequential_work(self) { - debug_assert_eq!( - self.is_ready(), - true, - "doing sequential work - on an incomplete job" - ); - log::trace!("processing sequential send job"); - - // send to peer - let job = &self.inner; - let msg = job.buffer.lock(); - let xmit = job.peer.send(&msg[..]).is_ok(); - - // trigger callback (for timers) - C::send( - &job.peer.opaque, - msg.len(), - xmit, - &job.keypair, - job.counter, - ); + self.0.ready.store(true, Ordering::Release); } -} +} \ No newline at end of file diff --git a/src/wireguard/router/worker.rs b/src/wireguard/router/worker.rs index d95050e..bbb644c 100644 --- a/src/wireguard/router/worker.rs +++ b/src/wireguard/router/worker.rs @@ -1,13 +1,31 @@ -use super::Device; - use super::super::{tun, udp, Endpoint}; use super::types::Callbacks; -use super::receive::ReceieveJob; +use super::queue::ParallelJob; +use super::receive::ReceiveJob; use super::send::SendJob; -fn worker>( - device: Device, +use crossbeam_channel::Receiver; + +pub enum JobUnion> { + Outbound(SendJob), + Inbound(ReceiveJob), +} + +pub fn worker>( + receiver: Receiver>, ) { - // fetch job + loop { + match receiver.recv() { + Err(_) => break, + Ok(JobUnion::Inbound(job)) => { + job.parallel_work(); + job.queue().consume(); + } + Ok(JobUnion::Outbound(job)) => { + job.parallel_work(); + job.queue().consume(); + } + } + } } diff --git a/src/wireguard/router/workers.rs b/src/wireguard/router/workers.rs deleted file mode 100644 index 8ddc136..0000000 --- a/src/wireguard/router/workers.rs +++ /dev/null @@ -1,257 +0,0 @@ -use std::sync::Arc; - -use log::{debug, trace}; - -use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; - -use crossbeam_channel::Receiver; -use std::sync::atomic::Ordering; -use zerocopy::{AsBytes, LayoutVerified}; - -use super::device::{DecryptionState, DeviceInner}; -use super::messages::{TransportHeader, TYPE_TRANSPORT}; -use super::peer::PeerInner; -use super::types::Callbacks; - -use super::REJECT_AFTER_MESSAGES; - -use super::super::types::KeyPair; -use super::super::{tun, udp, Endpoint}; - -pub const SIZE_TAG: usize = 16; - -pub struct JobEncryption { - pub msg: Vec, - pub keypair: Arc, - pub counter: u64, -} - -pub struct JobDecryption { - pub msg: Vec, - pub keypair: Arc, -} - -pub enum JobParallel { - Encryption(oneshot::Sender, JobEncryption), - Decryption(oneshot::Sender>, JobDecryption), -} - -#[allow(type_alias_bounds)] -pub type JobInbound> = ( - Arc>, - E, - oneshot::Receiver>, -); - -pub type JobOutbound = oneshot::Receiver; - -/* TODO: Replace with run-queue - */ -pub fn worker_inbound>( - device: Arc>, // related device - peer: Arc>, // related peer - receiver: Receiver>, -) { - loop { - // fetch job - let (state, endpoint, rx) = match receiver.recv() { - Ok(v) => v, - _ => { - return; - } - }; - debug!("inbound worker: obtained job"); - - // wait for job to complete - let _ = rx - .map(|buf| { - debug!("inbound worker: job complete"); - if let Some(buf) = buf { - // cast transport header - let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) = - match LayoutVerified::new_from_prefix(&buf.msg[..]) { - Some(v) => v, - None => { - debug!("inbound worker: failed to parse message"); - return; - } - }; - - debug_assert!( - packet.len() >= CHACHA20_POLY1305.tag_len(), - "this should be checked earlier in the pipeline (decryption should fail)" - ); - - // check for replay - if !state.protector.lock().update(header.f_counter.get()) { - debug!("inbound worker: replay detected"); - return; - } - - // check for confirms key - if !state.confirmed.swap(true, Ordering::SeqCst) { - debug!("inbound worker: message confirms key"); - peer.confirm_key(&state.keypair); - } - - // update endpoint - *peer.endpoint.lock() = Some(endpoint); - - // calculate length of IP packet + padding - let length = packet.len() - SIZE_TAG; - debug!("inbound worker: plaintext length = {}", length); - - // check if should be written to TUN - let mut sent = false; - if length > 0 { - if let Some(inner_len) = device.table.check_route(&peer, &packet[..length]) - { - // TODO: Consider moving the cryptkey route check to parallel decryption worker - debug_assert!(inner_len <= length, "should be validated earlier"); - if inner_len <= length { - sent = match device.inbound.write(&packet[..inner_len]) { - Err(e) => { - debug!("failed to write inbound packet to TUN: {:?}", e); - false - } - Ok(_) => true, - } - } - } - } else { - debug!("inbound worker: received keepalive") - } - - // trigger callback - C::recv(&peer.opaque, buf.msg.len(), sent, &buf.keypair); - } else { - debug!("inbound worker: authentication failure") - } - }) - .wait(); - } -} - - -pub fn worker_outbound>( - peer: Arc>, - receiver: Receiver, -) { - loop { - // fetch job - let rx = match receiver.recv() { - Ok(v) => v, - _ => { - return; - } - }; - debug!("outbound worker: obtained job"); - - // wait for job to complete - let _ = rx - .map(|buf| { - debug!("outbound worker: job complete"); - - // send to peer - let xmit = peer.send(&buf.msg[..]).is_ok(); - - // trigger callback - C::send(&peer.opaque, buf.msg.len(), xmit, &buf.keypair, buf.counter); - }) - .wait(); - } -} - -pub fn worker_parallel(receiver: Receiver) { - loop { - // fetch next job - let job = match receiver.recv() { - Err(_) => { - return; - } - Ok(val) => val, - }; - trace!("parallel worker: obtained job"); - - // handle job - match job { - JobParallel::Encryption(tx, mut job) => { - job.msg.extend([0u8; SIZE_TAG].iter()); - - // cast to header (should never fail) - let (mut header, body): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = - LayoutVerified::new_from_prefix(&mut job.msg[..]) - .expect("earlier code should ensure that there is ample space"); - - // set header fields - debug_assert!( - job.counter < REJECT_AFTER_MESSAGES, - "should be checked when assigning counters" - ); - header.f_type.set(TYPE_TRANSPORT); - header.f_receiver.set(job.keypair.send.id); - header.f_counter.set(job.counter); - - // create a nonce object - let mut nonce = [0u8; 12]; - debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); - nonce[4..].copy_from_slice(header.f_counter.as_bytes()); - let nonce = Nonce::assume_unique_for_key(nonce); - - // do the weird ring AEAD dance - let key = LessSafeKey::new( - UnboundKey::new(&CHACHA20_POLY1305, &job.keypair.send.key[..]).unwrap(), - ); - - // encrypt content of transport message in-place - let end = body.len() - SIZE_TAG; - let tag = key - .seal_in_place_separate_tag(nonce, Aad::empty(), &mut body[..end]) - .unwrap(); - - // append tag - body[end..].copy_from_slice(tag.as_ref()); - - // pass ownership - let _ = tx.send(job); - } - JobParallel::Decryption(tx, mut job) => { - // cast to header (could fail) - let layout: Option<(LayoutVerified<&mut [u8], TransportHeader>, &mut [u8])> = - LayoutVerified::new_from_prefix(&mut job.msg[..]); - - let _ = tx.send(match layout { - Some((header, body)) => { - debug_assert_eq!( - header.f_type.get(), - TYPE_TRANSPORT, - "type and reserved bits should be checked by message de-multiplexer" - ); - if header.f_counter.get() < REJECT_AFTER_MESSAGES { - // create a nonce object - let mut nonce = [0u8; 12]; - debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); - nonce[4..].copy_from_slice(header.f_counter.as_bytes()); - let nonce = Nonce::assume_unique_for_key(nonce); - - // do the weird ring AEAD dance - let key = LessSafeKey::new( - UnboundKey::new(&CHACHA20_POLY1305, &job.keypair.recv.key[..]) - .unwrap(), - ); - - // attempt to open (and authenticate) the body - match key.open_in_place(nonce, Aad::empty(), body) { - Ok(_) => Some(job), - Err(_) => None, - } - } else { - None - } - } - None => None, - }); - } - } - } -} diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index 45b1fcb..94e240d 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -603,7 +603,7 @@ impl Wireguard { ); let device = wg.handshake.read(); let _ = device.begin(&mut rng, &peer.pk).map(|msg| { - let _ = peer.router.send(&msg[..]).map_err(|e| { + let _ = peer.router.send_raw(&msg[..]).map_err(|e| { debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e) }); peer.state.sent_handshake_initiation(); -- cgit v1.2.3-59-g8ed1b