From ead75828cdaa5253e57b5792b51e3d99a4a78ea0 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sun, 16 Feb 2020 20:25:31 +0100 Subject: Simplified router code --- src/wireguard/router/constants.rs | 4 +- src/wireguard/router/device.rs | 125 ++++++----------------- src/wireguard/router/inbound.rs | 190 ----------------------------------- src/wireguard/router/mod.rs | 10 +- src/wireguard/router/outbound.rs | 110 -------------------- src/wireguard/router/peer.rs | 206 +++++++++++++++----------------------- src/wireguard/router/pool.rs | 164 ------------------------------ src/wireguard/router/queue.rs | 144 ++++++++++++++++++++++++++ src/wireguard/router/receive.rs | 192 +++++++++++++++++++++++++++++++++++ src/wireguard/router/runq.rs | 129 ------------------------ src/wireguard/router/send.rs | 143 ++++++++++++++++++++++++++ src/wireguard/router/tests.rs | 7 +- src/wireguard/router/worker.rs | 31 ++++++ src/wireguard/timers.rs | 4 +- src/wireguard/workers.rs | 5 +- 15 files changed, 638 insertions(+), 826 deletions(-) delete mode 100644 src/wireguard/router/inbound.rs delete mode 100644 src/wireguard/router/outbound.rs delete mode 100644 src/wireguard/router/pool.rs create mode 100644 src/wireguard/router/queue.rs create mode 100644 src/wireguard/router/receive.rs delete mode 100644 src/wireguard/router/runq.rs create mode 100644 src/wireguard/router/send.rs create mode 100644 src/wireguard/router/worker.rs diff --git a/src/wireguard/router/constants.rs b/src/wireguard/router/constants.rs index af76299..f083811 100644 --- a/src/wireguard/router/constants.rs +++ b/src/wireguard/router/constants.rs @@ -4,6 +4,6 @@ pub const MAX_QUEUED_PACKETS: usize = 1024; // performance constants -pub const PARALLEL_QUEUE_SIZE: usize = MAX_QUEUED_PACKETS; +pub const PARALLEL_QUEUE_SIZE: usize = 4 * MAX_QUEUED_PACKETS; + pub const INORDER_QUEUE_SIZE: usize = MAX_QUEUED_PACKETS; -pub const MAX_INORDER_CONSUME: usize = INORDER_QUEUE_SIZE; diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs index 96b7d82..9d78178 100644 --- a/src/wireguard/router/device.rs +++ b/src/wireguard/router/device.rs @@ -10,19 +10,16 @@ use spin::{Mutex, RwLock}; use zerocopy::LayoutVerified; use super::anti_replay::AntiReplay; -use super::pool::Job; use super::constants::PARALLEL_QUEUE_SIZE; -use super::inbound; -use super::outbound; - use super::messages::{TransportHeader, TYPE_TRANSPORT}; use super::peer::{new_peer, Peer, PeerHandle}; use super::types::{Callbacks, RouterError}; use super::SIZE_MESSAGE_PREFIX; +use super::receive::ReceiveJob; use super::route::RoutingTable; -use super::runq::RunQueue; +use super::worker::{worker, JobUnion}; use super::super::{tun, udp, Endpoint, KeyPair}; use super::ParallelQueue; @@ -38,13 +35,8 @@ pub struct DeviceInner>>>, // receiver id -> decryption state pub table: RoutingTable>, - // work queues - pub queue_outbound: ParallelQueue, outbound::Outbound>>, - pub queue_inbound: ParallelQueue, inbound::Inbound>>, - - // run queues - pub run_inbound: RunQueue>, - pub run_outbound: RunQueue>, + // work queue + pub work: ParallelQueue>, } pub struct EncryptionState { @@ -101,13 +93,8 @@ impl> Drop fn drop(&mut self) { debug!("router: dropping device"); - // close worker queues - self.state.queue_outbound.close(); - self.state.queue_inbound.close(); - - // close run queues - self.state.run_outbound.close(); - self.state.run_inbound.close(); + // close worker queue + self.state.work.close(); // join all worker threads while match self.handles.pop() { @@ -118,77 +105,28 @@ impl> Drop } _ => false, } {} - - debug!("router: device dropped"); } } impl> DeviceHandle { pub fn new(num_workers: usize, tun: T) -> DeviceHandle { - // allocate shared device state - let (queue_outbound, mut outrx) = ParallelQueue::new(num_workers, PARALLEL_QUEUE_SIZE); - let (queue_inbound, mut inrx) = ParallelQueue::new(num_workers, PARALLEL_QUEUE_SIZE); + let (work, mut consumers) = ParallelQueue::new(num_workers, PARALLEL_QUEUE_SIZE); let device = Device { inner: Arc::new(DeviceInner { + work, inbound: tun, - queue_inbound, outbound: RwLock::new((true, None)), - queue_outbound, - run_inbound: RunQueue::new(), - run_outbound: RunQueue::new(), recv: RwLock::new(HashMap::new()), table: RoutingTable::new(), }), }; // start worker threads - let mut threads = Vec::with_capacity(4 * num_workers); - - // inbound/decryption workers - for _ in 0..num_workers { - // parallel workers (parallel processing) - { - let device = device.clone(); - let rx = inrx.pop().unwrap(); - threads.push(thread::spawn(move || { - log::debug!("inbound parallel router worker started"); - inbound::parallel(device, rx) - })); - } - - // sequential workers (in-order processing) - { - let device = device.clone(); - threads.push(thread::spawn(move || { - log::debug!("inbound sequential router worker started"); - inbound::sequential(device) - })); - } - } - - // outbound/encryption workers - for _ in 0..num_workers { - // parallel workers (parallel processing) - { - let device = device.clone(); - let rx = outrx.pop().unwrap(); - threads.push(thread::spawn(move || { - log::debug!("outbound parallel router worker started"); - outbound::parallel(device, rx) - })); - } - - // sequential workers (in-order processing) - { - let device = device.clone(); - threads.push(thread::spawn(move || { - log::debug!("outbound sequential router worker started"); - outbound::sequential(device) - })); - } + let mut threads = Vec::with_capacity(num_workers); + while let Some(rx) = consumers.pop() { + threads.push(thread::spawn(move || worker(rx))); } - - debug_assert_eq!(threads.len(), num_workers * 4); + debug_assert_eq!(threads.len(), num_workers); // return exported device handle DeviceHandle { @@ -197,6 +135,16 @@ impl> DeviceHandle< } } + pub fn send_raw(&self, msg : &[u8], dst: &mut E) -> Result<(), B::Error> { + let bind = self.state.outbound.read(); + if bind.0 { + if let Some(bind) = bind.1.as_ref() { + return bind.write(msg, dst); + } + } + return Ok(()) + } + /// Brings the router down. /// When the router is brought down it: /// - Prevents transmission of outbound messages. @@ -250,10 +198,7 @@ impl> DeviceHandle< .ok_or(RouterError::NoCryptoKeyRoute)?; // schedule for encryption and transmission to peer - if let Some(job) = peer.send_job(msg, true) { - self.state.queue_outbound.send(job); - } - + peer.send(msg, true); Ok(()) } @@ -297,10 +242,13 @@ impl> DeviceHandle< .get(&header.f_receiver.get()) .ok_or(RouterError::UnknownReceiverId)?; - // schedule for decryption and TUN write - if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) { - log::trace!("schedule decryption of transport message"); - self.state.queue_inbound.send(job); + // create inbound job + let job = ReceiveJob::new(msg, dec.clone(), src); + + // 1. add to sequential queue (drop if full) + // 2. then add to parallel work queue (wait if full) + if dec.peer.inbound.push(job.clone()) { + self.state.work.send(JobUnion::Inbound(job)); } Ok(()) } @@ -311,17 +259,4 @@ impl> DeviceHandle< pub fn set_outbound_writer(&self, new: B) { self.state.outbound.write().1 = Some(new); } - - pub fn write(&self, msg: &[u8], endpoint: &mut E) -> Result<(), RouterError> { - let outbound = self.state.outbound.read(); - if outbound.0 { - outbound - .1 - .as_ref() - .ok_or(RouterError::SendError) - .and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError)) - } else { - Ok(()) - } - } } diff --git a/src/wireguard/router/inbound.rs b/src/wireguard/router/inbound.rs deleted file mode 100644 index dc2c44e..0000000 --- a/src/wireguard/router/inbound.rs +++ /dev/null @@ -1,190 +0,0 @@ -use std::mem; -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use crossbeam_channel::Receiver; -use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; -use zerocopy::{AsBytes, LayoutVerified}; - -use super::constants::MAX_INORDER_CONSUME; -use super::device::DecryptionState; -use super::device::Device; -use super::messages::TransportHeader; -use super::peer::Peer; -use super::pool::*; -use super::types::Callbacks; -use super::{tun, udp, Endpoint}; -use super::{REJECT_AFTER_MESSAGES, SIZE_TAG}; - -pub struct Inbound> { - msg: Vec, - failed: bool, - state: Arc>, - endpoint: Option, -} - -impl> Inbound { - pub fn new( - msg: Vec, - state: Arc>, - endpoint: E, - ) -> Inbound { - Inbound { - msg, - state, - failed: false, - endpoint: Some(endpoint), - } - } -} - -#[inline(always)] -pub fn parallel>( - device: Device, - receiver: Receiver, Inbound>>, -) { - // parallel work to apply - #[inline(always)] - fn work>( - peer: &Peer, - body: &mut Inbound, - ) { - log::trace!("worker, parallel section, obtained job"); - - // cast to header followed by payload - let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = - match LayoutVerified::new_from_prefix(&mut body.msg[..]) { - Some(v) => v, - None => { - log::debug!("inbound worker: failed to parse message"); - return; - } - }; - - // authenticate and decrypt payload - { - // create nonce object - let mut nonce = [0u8; 12]; - debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); - nonce[4..].copy_from_slice(header.f_counter.as_bytes()); - let nonce = Nonce::assume_unique_for_key(nonce); - - // do the weird ring AEAD dance - let key = LessSafeKey::new( - UnboundKey::new(&CHACHA20_POLY1305, &body.state.keypair.recv.key[..]).unwrap(), - ); - - // attempt to open (and authenticate) the body - match key.open_in_place(nonce, Aad::empty(), packet) { - Ok(_) => (), - Err(_) => { - // fault and return early - log::trace!("inbound worker: authentication failure"); - body.failed = true; - return; - } - } - } - - // check that counter not after reject - if header.f_counter.get() >= REJECT_AFTER_MESSAGES { - body.failed = true; - return; - } - - // cryptokey route and strip padding - let inner_len = { - let length = packet.len() - SIZE_TAG; - if length > 0 { - peer.device.table.check_route(&peer, &packet[..length]) - } else { - Some(0) - } - }; - - // truncate to remove tag - match inner_len { - None => { - log::trace!("inbound worker: cryptokey routing failed"); - body.failed = true; - } - Some(len) => { - log::trace!( - "inbound worker: good route, length = {} {}", - len, - if len == 0 { "(keepalive)" } else { "" } - ); - body.msg.truncate(mem::size_of::() + len); - } - } - } - - worker_parallel(device, |dev| &dev.run_inbound, receiver, work) -} - -#[inline(always)] -pub fn sequential>( - device: Device, -) { - // sequential work to apply - fn work>( - peer: &Peer, - body: &mut Inbound, - ) { - log::trace!("worker, sequential section, obtained job"); - - // decryption failed, return early - if body.failed { - log::trace!("job faulted, remove from queue and ignore"); - return; - } - - // cast transport header - let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) = - match LayoutVerified::new_from_prefix(&body.msg[..]) { - Some(v) => v, - None => { - log::debug!("inbound worker: failed to parse message"); - return; - } - }; - - // check for replay - if !body.state.protector.lock().update(header.f_counter.get()) { - log::debug!("inbound worker: replay detected"); - return; - } - - // check for confirms key - if !body.state.confirmed.swap(true, Ordering::SeqCst) { - log::debug!("inbound worker: message confirms key"); - peer.confirm_key(&body.state.keypair); - } - - // update endpoint - *peer.endpoint.lock() = body.endpoint.take(); - - // check if should be written to TUN - let mut sent = false; - if packet.len() > 0 { - sent = match peer.device.inbound.write(&packet[..]) { - Err(e) => { - log::debug!("failed to write inbound packet to TUN: {:?}", e); - false - } - Ok(_) => true, - } - } else { - log::debug!("inbound worker: received keepalive") - } - - // trigger callback - C::recv(&peer.opaque, body.msg.len(), sent, &body.state.keypair); - } - - // handle message from the peers inbound queue - device.run_inbound.run(|peer| { - peer.inbound - .handle(|body| work(&peer, body), MAX_INORDER_CONSUME) - }); -} diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs index 8238d32..699c621 100644 --- a/src/wireguard/router/mod.rs +++ b/src/wireguard/router/mod.rs @@ -1,16 +1,17 @@ mod anti_replay; mod constants; mod device; -mod inbound; mod ip; mod messages; -mod outbound; mod peer; -mod pool; mod route; -mod runq; mod types; +mod queue; +mod receive; +mod send; +mod worker; + #[cfg(test)] mod tests; @@ -20,7 +21,6 @@ use std::mem; use super::constants::REJECT_AFTER_MESSAGES; use super::queue::ParallelQueue; use super::types::*; -use super::{tun, udp, Endpoint}; pub const SIZE_TAG: usize = 16; pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::(); diff --git a/src/wireguard/router/outbound.rs b/src/wireguard/router/outbound.rs deleted file mode 100644 index 1edb2fb..0000000 --- a/src/wireguard/router/outbound.rs +++ /dev/null @@ -1,110 +0,0 @@ -use std::sync::Arc; - -use crossbeam_channel::Receiver; -use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; -use zerocopy::{AsBytes, LayoutVerified}; - -use super::constants::MAX_INORDER_CONSUME; -use super::device::Device; -use super::messages::{TransportHeader, TYPE_TRANSPORT}; -use super::peer::Peer; -use super::pool::*; -use super::types::Callbacks; -use super::KeyPair; -use super::{tun, udp, Endpoint}; -use super::{REJECT_AFTER_MESSAGES, SIZE_TAG}; - -pub struct Outbound { - msg: Vec, - keypair: Arc, - counter: u64, -} - -impl Outbound { - pub fn new(msg: Vec, keypair: Arc, counter: u64) -> Outbound { - Outbound { - msg, - keypair, - counter, - } - } -} - -#[inline(always)] -pub fn parallel>( - device: Device, - receiver: Receiver, Outbound>>, -) { - #[inline(always)] - fn work>( - _peer: &Peer, - body: &mut Outbound, - ) { - log::trace!("worker, parallel section, obtained job"); - - // make space for the tag - body.msg.extend([0u8; SIZE_TAG].iter()); - - // cast to header (should never fail) - let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = - LayoutVerified::new_from_prefix(&mut body.msg[..]) - .expect("earlier code should ensure that there is ample space"); - - // set header fields - debug_assert!( - body.counter < REJECT_AFTER_MESSAGES, - "should be checked when assigning counters" - ); - header.f_type.set(TYPE_TRANSPORT); - header.f_receiver.set(body.keypair.send.id); - header.f_counter.set(body.counter); - - // create a nonce object - let mut nonce = [0u8; 12]; - debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); - nonce[4..].copy_from_slice(header.f_counter.as_bytes()); - let nonce = Nonce::assume_unique_for_key(nonce); - - // do the weird ring AEAD dance - let key = LessSafeKey::new( - UnboundKey::new(&CHACHA20_POLY1305, &body.keypair.send.key[..]).unwrap(), - ); - - // encrypt content of transport message in-place - let end = packet.len() - SIZE_TAG; - let tag = key - .seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end]) - .unwrap(); - - // append tag - packet[end..].copy_from_slice(tag.as_ref()); - } - - worker_parallel(device, |dev| &dev.run_outbound, receiver, work); -} - -#[inline(always)] -pub fn sequential>( - device: Device, -) { - device.run_outbound.run(|peer| { - peer.outbound.handle( - |body| { - log::trace!("worker, sequential section, obtained job"); - - // send to peer - let xmit = peer.send(&body.msg[..]).is_ok(); - - // trigger callback - C::send( - &peer.opaque, - body.msg.len(), - xmit, - &body.keypair, - body.counter, - ); - }, - MAX_INORDER_CONSUME, - ) - }); -} diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs index 8fe2e1c..a20908e 100644 --- a/src/wireguard/router/peer.rs +++ b/src/wireguard/router/peer.rs @@ -1,13 +1,3 @@ -use std::mem; -use std::net::{IpAddr, SocketAddr}; -use std::ops::Deref; -use std::sync::atomic::AtomicBool; -use std::sync::Arc; - -use arraydeque::{ArrayDeque, Wrapping}; -use log::debug; -use spin::Mutex; - use super::super::constants::*; use super::super::{tun, udp, Endpoint, KeyPair}; @@ -15,17 +5,25 @@ use super::anti_replay::AntiReplay; use super::device::DecryptionState; use super::device::Device; use super::device::EncryptionState; -use super::messages::TransportHeader; use super::constants::*; -use super::runq::ToKey; use super::types::{Callbacks, RouterError}; use super::SIZE_MESSAGE_PREFIX; -// worker pool related -use super::inbound::Inbound; -use super::outbound::Outbound; -use super::pool::{InorderQueue, Job}; +use super::queue::Queue; +use super::receive::ReceiveJob; +use super::send::SendJob; +use super::worker::JobUnion; + +use std::mem; +use std::net::{IpAddr, SocketAddr}; +use std::ops::Deref; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; + +use arraydeque::{ArrayDeque, Wrapping}; +use log::debug; +use spin::Mutex; pub struct KeyWheel { next: Option>, // next key state (unconfirmed) @@ -37,11 +35,11 @@ pub struct KeyWheel { pub struct PeerInner> { pub device: Device, pub opaque: C::Opaque, - pub outbound: InorderQueue, Outbound>, - pub inbound: InorderQueue, Inbound>, + pub outbound: Queue>, + pub inbound: Queue>, pub staged_packets: Mutex; MAX_QUEUED_PACKETS], Wrapping>>, pub keys: Mutex, - pub ekey: Mutex>, + pub enc_key: Mutex>, pub endpoint: Mutex>, } @@ -66,13 +64,6 @@ impl> PartialEq for } } -impl> ToKey for Peer { - type Key = usize; - fn to_key(&self) -> usize { - Arc::downgrade(&self.inner).into_raw() as usize - } -} - impl> Eq for Peer {} /* A peer is transparently dereferenced to the inner type @@ -154,7 +145,7 @@ impl> Drop for Peer keys.current = None; keys.previous = None; - *peer.ekey.lock() = None; + *peer.enc_key.lock() = None; *peer.endpoint.lock() = None; debug!("peer dropped & removed from device"); @@ -172,9 +163,9 @@ pub fn new_peer>( inner: Arc::new(PeerInner { opaque, device, - inbound: InorderQueue::new(), - outbound: InorderQueue::new(), - ekey: spin::Mutex::new(None), + inbound: Queue::new(), + outbound: Queue::new(), + enc_key: spin::Mutex::new(None), endpoint: spin::Mutex::new(None), keys: spin::Mutex::new(KeyWheel { next: None, @@ -200,7 +191,7 @@ impl> PeerInner Result<(), RouterError> { + pub fn send_raw(&self, msg: &[u8]) -> Result<(), RouterError> { debug!("peer.send"); // send to endpoint (if known) @@ -223,6 +214,57 @@ impl> PeerInner> Peer { + /// Encrypt and send a message to the peer + /// + /// Arguments: + /// + /// - `msg` : A padded vector holding the message (allows in-place construction of the transport header) + /// - `stage`: Should the message be staged if no key is available + /// + pub(super) fn send(&self, msg: Vec, stage: bool) { + // check if key available + let (job, need_key) = { + let mut enc_key = self.enc_key.lock(); + match enc_key.as_mut() { + None => { + if stage { + self.staged_packets.lock().push_back(msg); + }; + (None, true) + } + Some(mut state) => { + // avoid integer overflow in nonce + if state.nonce >= REJECT_AFTER_MESSAGES - 1 { + *enc_key = None; + if stage { + self.staged_packets.lock().push_back(msg); + } + (None, true) + } else { + debug!("encryption state available, nonce = {}", state.nonce); + let job = + SendJob::new(msg, state.nonce, state.keypair.clone(), self.clone()); + if self.outbound.push(job.clone()) { + state.nonce += 1; + (Some(job), false) + } else { + (None, false) + } + } + } + } + }; + + if need_key { + debug_assert!(job.is_none()); + C::need_key(&self.opaque); + }; + + if let Some(job) = job { + self.device.work.send(JobUnion::Outbound(job)) + } + } + // Transmit all staged packets fn send_staged(&self) -> bool { debug!("peer.send_staged"); @@ -232,29 +274,14 @@ impl> Peer { sent = true; - self.send_raw(msg, false); + self.send(msg, false); } None => break sent, } } } - // Treat the msg as the payload of a transport message - // - // Returns true if the message was queued for transmission. - fn send_raw(&self, msg: Vec, stage: bool) -> bool { - log::debug!("peer.send_raw"); - match self.send_job(msg, stage) { - Some(job) => { - self.device.queue_outbound.send(job); - debug!("send_raw: got obtained send_job"); - true - } - None => false, - } - } - - pub fn confirm_key(&self, keypair: &Arc) { + pub(super) fn confirm_key(&self, keypair: &Arc) { debug!("peer.confirm_key"); { // take lock and check keypair = keys.next @@ -282,76 +309,12 @@ impl> Peer>, - msg: Vec, - ) -> Option>> { - let job = Job::new(self.clone(), Inbound::new(msg, dec, src)); - self.inbound.send(job.clone()); - Some(job) - } - - pub fn send_job(&self, msg: Vec, stage: bool) -> Option> { - debug!( - "peer.send_job, msg.len() = {}, stage = {}", - msg.len(), - stage - ); - debug_assert!( - msg.len() >= mem::size_of::(), - "received message with size: {:}", - msg.len() - ); - - // check if has key - let (keypair, counter) = { - let keypair = { - // TODO: consider using atomic ptr for ekey state - let mut ekey = self.ekey.lock(); - match ekey.as_mut() { - None => None, - Some(mut state) => { - // avoid integer overflow in nonce - if state.nonce >= REJECT_AFTER_MESSAGES - 1 { - *ekey = None; - None - } else { - debug!("encryption state available, nonce = {}", state.nonce); - let counter = state.nonce; - state.nonce += 1; - Some((state.keypair.clone(), counter)) - } - } - } - }; - - // If not suitable key was found: - // 1. Stage packet for later transmission - // 2. Request new key - if keypair.is_none() && stage { - log::trace!("packet staged"); - self.staged_packets.lock().push_back(msg); - C::need_key(&self.opaque); - return None; - }; - - keypair - }?; - - // add job to in-order queue and return sender to device for inclusion in worker pool - let job = Job::new(self.clone(), Outbound::new(msg, keypair, counter)); - self.outbound.send(job.clone()); - Some(job) - } } impl> PeerHandle { @@ -403,7 +366,7 @@ impl> PeerHandle> PeerHandle> PeerHandle> PeerHandle bool { + pub fn send_keepalive(&self) { debug!("peer.send_keepalive"); - self.peer.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX], true) + self.peer.send(vec![0u8; SIZE_MESSAGE_PREFIX], false) } /// Map a subnet to the peer diff --git a/src/wireguard/router/pool.rs b/src/wireguard/router/pool.rs deleted file mode 100644 index 3fc0026..0000000 --- a/src/wireguard/router/pool.rs +++ /dev/null @@ -1,164 +0,0 @@ -use std::mem; -use std::sync::Arc; - -use arraydeque::ArrayDeque; -use crossbeam_channel::Receiver; -use spin::{Mutex, MutexGuard}; - -use super::constants::INORDER_QUEUE_SIZE; -use super::runq::{RunQueue, ToKey}; - -pub struct InnerJob { - // peer (used by worker to schedule/handle inorder queue), - // when the peer is None, the job is complete - peer: Option

, - pub body: B, -} - -pub struct Job { - inner: Arc>>, -} - -impl Clone for Job { - fn clone(&self) -> Job { - Job { - inner: self.inner.clone(), - } - } -} - -impl Job { - pub fn new(peer: P, body: B) -> Job { - Job { - inner: Arc::new(Mutex::new(InnerJob { - peer: Some(peer), - body, - })), - } - } -} - -impl Job { - /// Returns a mutex guard to the inner job if complete - pub fn complete(&self) -> Option>> { - self.inner - .try_lock() - .and_then(|m| if m.peer.is_none() { Some(m) } else { None }) - } -} - -pub struct InorderQueue { - queue: Mutex; INORDER_QUEUE_SIZE]>>, -} - -impl InorderQueue { - pub fn new() -> InorderQueue { - InorderQueue { - queue: Mutex::new(ArrayDeque::new()), - } - } - - /// Add a new job to the in-order queue - /// - /// # Arguments - /// - /// - `job`: The job added to the back of the queue - /// - /// # Returns - /// - /// True if the element was added, - /// false to indicate that the queue is full. - pub fn send(&self, job: Job) -> bool { - self.queue.lock().push_back(job).is_ok() - } - - /// Consume completed jobs from the in-order queue - /// - /// # Arguments - /// - /// - `f`: function to apply to the body of each jobof each job. - /// - `limit`: maximum number of jobs to handle before returning - /// - /// # Returns - /// - /// A boolean indicating if the limit was reached: - /// true indicating that the limit was reached, - /// while false implies that the queue is empty or an uncompleted job was reached. - #[inline(always)] - pub fn handle(&self, f: F, mut limit: usize) -> bool { - // take the mutex - let mut queue = self.queue.lock(); - - while limit > 0 { - // attempt to extract front element - let front = queue.pop_front(); - let elem = match front { - Some(elem) => elem, - _ => { - return false; - } - }; - - // apply function if job complete - let ret = if let Some(mut guard) = elem.complete() { - mem::drop(queue); - f(&mut guard.body); - queue = self.queue.lock(); - false - } else { - true - }; - - // job not complete yet, return job to front - if ret { - queue.push_front(elem).unwrap(); - return false; - } - limit -= 1; - } - - // did not complete all jobs - true - } -} - -/// Allows easy construction of a parallel worker. -/// Applicable for both decryption and encryption workers. -#[inline(always)] -pub fn worker_parallel< - P: ToKey, // represents a peer (atomic reference counted pointer) - B, // inner body type (message buffer, key material, ...) - D, // device - W: Fn(&P, &mut B), - Q: Fn(&D) -> &RunQueue

, ->( - device: D, - queue: Q, - receiver: Receiver>, - work: W, -) { - log::trace!("router worker started"); - loop { - // handle new job - let peer = { - // get next job - let job = match receiver.recv() { - Ok(job) => job, - _ => return, - }; - - // lock the job - let mut job = job.inner.lock(); - - // take the peer from the job - let peer = job.peer.take().unwrap(); - - // process job - work(&peer, &mut job.body); - peer - }; - - // process inorder jobs for peer - queue(&device).insert(peer); - } -} diff --git a/src/wireguard/router/queue.rs b/src/wireguard/router/queue.rs new file mode 100644 index 0000000..ec4492e --- /dev/null +++ b/src/wireguard/router/queue.rs @@ -0,0 +1,144 @@ +use arraydeque::ArrayDeque; +use spin::Mutex; + +use std::mem; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use super::constants::INORDER_QUEUE_SIZE; + +pub trait SequentialJob { + fn is_ready(&self) -> bool; + + fn sequential_work(self); +} + +pub trait ParallelJob: Sized + SequentialJob { + fn queue(&self) -> &Queue; + + fn parallel_work(&self); +} + +pub struct Queue { + contenders: AtomicUsize, + queue: Mutex>, + + #[cfg(debug)] + _flag: Mutex<()>, +} + +impl Queue { + pub fn new() -> Queue { + Queue { + contenders: AtomicUsize::new(0), + queue: Mutex::new(ArrayDeque::new()), + + #[cfg(debug)] + _flag: Mutex::new(()), + } + } + + pub fn push(&self, job: J) -> bool { + self.queue.lock().push_back(job).is_ok() + } + + pub fn consume(&self) { + // check if we are the first contender + let pos = self.contenders.fetch_add(1, Ordering::SeqCst); + if pos > 0 { + assert!(usize::max_value() > pos, "contenders overflow"); + return; + } + + // enter the critical section + let mut contenders = 1; // myself + while contenders > 0 { + // check soundness in debug builds + #[cfg(debug)] + let _flag = self + ._flag + .try_lock() + .expect("contenders should ensure mutual exclusion"); + + // handle every ready element + loop { + let mut queue = self.queue.lock(); + + // check if front job is ready + match queue.front() { + None => break, + Some(job) => { + if job.is_ready() { + () + } else { + break; + } + } + }; + + // take the job out of the queue + let job = queue.pop_front().unwrap(); + debug_assert!(job.is_ready()); + mem::drop(queue); + + // process element + job.sequential_work(); + } + + #[cfg(debug)] + mem::drop(_flag); + + // decrease contenders + contenders = self.contenders.fetch_sub(contenders, Ordering::SeqCst) - contenders; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::sync::Arc; + use std::thread; + + use rand::thread_rng; + use rand::Rng; + + struct TestJob {} + + impl SequentialJob for TestJob { + fn is_ready(&self) -> bool { + true + } + + fn sequential_work(self) {} + } + + /* Fuzz the Queue */ + #[test] + fn test_queue() { + fn hammer(queue: &Arc>) { + let mut rng = thread_rng(); + for _ in 0..1_000_000 { + if rng.gen() { + queue.push(TestJob {}); + } else { + queue.consume(); + } + } + } + + let queue = Arc::new(Queue::new()); + + // repeatedly apply operations randomly from concurrent threads + let other = { + let queue = queue.clone(); + thread::spawn(move || hammer(&queue)) + }; + hammer(&queue); + + // wait, consume and check empty + other.join().unwrap(); + queue.consume(); + assert_eq!(queue.queue.lock().len(), 0); + } +} diff --git a/src/wireguard/router/receive.rs b/src/wireguard/router/receive.rs new file mode 100644 index 0000000..c5fe3da --- /dev/null +++ b/src/wireguard/router/receive.rs @@ -0,0 +1,192 @@ +use super::device::DecryptionState; +use super::messages::TransportHeader; +use super::queue::{ParallelJob, Queue, SequentialJob}; +use super::types::Callbacks; +use super::{REJECT_AFTER_MESSAGES, SIZE_TAG}; + +use super::super::{tun, udp, Endpoint}; + +use std::mem; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + +use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; +use spin::Mutex; +use zerocopy::{AsBytes, LayoutVerified}; + +struct Inner> { + ready: AtomicBool, + buffer: Mutex<(Option, Vec)>, // endpoint & ciphertext buffer + state: Arc>, // decryption state (keys and replay protector) +} + +pub struct ReceiveJob>( + Arc>, +); + +impl> Clone + for ReceiveJob +{ + fn clone(&self) -> ReceiveJob { + ReceiveJob(self.0.clone()) + } +} + +impl> ReceiveJob { + pub fn new( + buffer: Vec, + state: Arc>, + endpoint: E, + ) -> ReceiveJob { + ReceiveJob(Arc::new(Inner { + ready: AtomicBool::new(false), + buffer: Mutex::new((Some(endpoint), buffer)), + state, + })) + } +} + +impl> ParallelJob + for ReceiveJob +{ + fn queue(&self) -> &Queue { + &self.0.state.peer.inbound + } + + fn parallel_work(&self) { + // TODO: refactor + // decrypt + { + let job = &self.0; + let peer = &job.state.peer; + let mut msg = job.buffer.lock(); + + // cast to header followed by payload + let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = + match LayoutVerified::new_from_prefix(&mut msg.1[..]) { + Some(v) => v, + None => { + log::debug!("inbound worker: failed to parse message"); + return; + } + }; + + // authenticate and decrypt payload + { + // create nonce object + let mut nonce = [0u8; 12]; + debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); + nonce[4..].copy_from_slice(header.f_counter.as_bytes()); + let nonce = Nonce::assume_unique_for_key(nonce); + + // do the weird ring AEAD dance + let key = LessSafeKey::new( + UnboundKey::new(&CHACHA20_POLY1305, &job.state.keypair.recv.key[..]).unwrap(), + ); + + // attempt to open (and authenticate) the body + match key.open_in_place(nonce, Aad::empty(), packet) { + Ok(_) => (), + Err(_) => { + // fault and return early + log::trace!("inbound worker: authentication failure"); + msg.1.truncate(0); + return; + } + } + } + + // check that counter not after reject + if header.f_counter.get() >= REJECT_AFTER_MESSAGES { + msg.1.truncate(0); + return; + } + + // cryptokey route and strip padding + let inner_len = { + let length = packet.len() - SIZE_TAG; + if length > 0 { + peer.device.table.check_route(&peer, &packet[..length]) + } else { + Some(0) + } + }; + + // truncate to remove tag + match inner_len { + None => { + log::trace!("inbound worker: cryptokey routing failed"); + msg.1.truncate(0); + } + Some(len) => { + log::trace!( + "inbound worker: good route, length = {} {}", + len, + if len == 0 { "(keepalive)" } else { "" } + ); + msg.1.truncate(mem::size_of::() + len); + } + } + } + + // mark ready + self.0.ready.store(true, Ordering::Release); + } +} + +impl> SequentialJob + for ReceiveJob +{ + fn is_ready(&self) -> bool { + self.0.ready.load(Ordering::Acquire) + } + + fn sequential_work(self) { + let job = &self.0; + let peer = &job.state.peer; + let mut msg = job.buffer.lock(); + let endpoint = msg.0.take(); + + // cast transport header + let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) = + match LayoutVerified::new_from_prefix(&msg.1[..]) { + Some(v) => v, + None => { + // also covers authentication failure + return; + } + }; + + // check for replay + if !job.state.protector.lock().update(header.f_counter.get()) { + log::debug!("inbound worker: replay detected"); + return; + } + + // check for confirms key + if !job.state.confirmed.swap(true, Ordering::SeqCst) { + log::debug!("inbound worker: message confirms key"); + peer.confirm_key(&job.state.keypair); + } + + // update endpoint + *peer.endpoint.lock() = endpoint; + + // check if should be written to TUN + let mut sent = false; + if packet.len() > 0 { + sent = match peer.device.inbound.write(&packet[..]) { + Err(e) => { + log::debug!("failed to write inbound packet to TUN: {:?}", e); + false + } + Ok(_) => true, + } + } else { + log::debug!("inbound worker: received keepalive") + } + + // trigger callback + C::recv(&peer.opaque, msg.1.len(), sent, &job.state.keypair); + } +} diff --git a/src/wireguard/router/runq.rs b/src/wireguard/router/runq.rs deleted file mode 100644 index 4c848cd..0000000 --- a/src/wireguard/router/runq.rs +++ /dev/null @@ -1,129 +0,0 @@ -use std::hash::Hash; -use std::mem; -use std::sync::{Condvar, Mutex}; - -use std::collections::hash_map::Entry; -use std::collections::HashMap; -use std::collections::VecDeque; - -pub trait ToKey { - type Key: Hash + Eq; - fn to_key(&self) -> Self::Key; -} - -pub struct RunQueue { - cvar: Condvar, - inner: Mutex>, -} - -struct Inner { - stop: bool, - queue: VecDeque, - members: HashMap, -} - -impl RunQueue { - pub fn close(&self) { - let mut inner = self.inner.lock().unwrap(); - inner.stop = true; - self.cvar.notify_all(); - } - - pub fn new() -> RunQueue { - RunQueue { - cvar: Condvar::new(), - inner: Mutex::new(Inner { - stop: false, - queue: VecDeque::new(), - members: HashMap::new(), - }), - } - } - - pub fn insert(&self, v: T) { - let key = v.to_key(); - let mut inner = self.inner.lock().unwrap(); - match inner.members.entry(key) { - Entry::Occupied(mut elem) => { - *elem.get_mut() += 1; - } - Entry::Vacant(spot) => { - // add entry to back of queue - spot.insert(0); - inner.queue.push_back(v); - - // wake a thread - self.cvar.notify_one(); - } - } - } - - /// Run (consume from) the run queue using the provided function. - /// The function should return wheter the given element should be rescheduled. - /// - /// # Arguments - /// - /// - `f` : function to apply to every element - /// - /// # Note - /// - /// The function f may be called again even when the element was not inserted back in to the - /// queue since the last applciation and no rescheduling was requested. - /// - /// This happens then the function handles all work for T, - /// but T is added to the run queue while the function is running. - pub fn run bool>(&self, f: F) { - let mut inner = self.inner.lock().unwrap(); - loop { - // fetch next element - let elem = loop { - // run-queue closed - if inner.stop { - return; - } - - // try to pop from queue - match inner.queue.pop_front() { - Some(elem) => { - break elem; - } - None => (), - }; - - // wait for an element to be inserted - inner = self.cvar.wait(inner).unwrap(); - }; - - // fetch current request number - let key = elem.to_key(); - let old_n = *inner.members.get(&key).unwrap(); - mem::drop(inner); // drop guard - - // handle element - let rerun = f(&elem); - - // if the function requested a re-run add the element to the back of the queue - inner = self.inner.lock().unwrap(); - if rerun { - inner.queue.push_back(elem); - continue; - } - - // otherwise check if new requests have come in since we ran the function - match inner.members.entry(key) { - Entry::Occupied(occ) => { - if *occ.get() == old_n { - // no new requests since last, remove entry. - occ.remove(); - } else { - // new requests, reschedule. - inner.queue.push_back(elem); - } - } - Entry::Vacant(_) => { - unreachable!(); - } - } - } - } -} diff --git a/src/wireguard/router/send.rs b/src/wireguard/router/send.rs new file mode 100644 index 0000000..8e41796 --- /dev/null +++ b/src/wireguard/router/send.rs @@ -0,0 +1,143 @@ +use super::queue::{SequentialJob, ParallelJob, Queue}; +use super::KeyPair; +use super::types::Callbacks; +use super::peer::Peer; +use super::{REJECT_AFTER_MESSAGES, SIZE_TAG}; +use super::messages::{TransportHeader, TYPE_TRANSPORT}; + +use super::super::{tun, udp, Endpoint}; + +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + +use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; +use zerocopy::{AsBytes, LayoutVerified}; +use spin::Mutex; + +struct Inner> { + ready: AtomicBool, + buffer: Mutex>, + counter: u64, + keypair: Arc, + peer: Peer, +} + +pub struct SendJob> ( + Arc> +); + +impl > Clone for SendJob { + fn clone(&self) -> SendJob { + SendJob(self.0.clone()) + } +} + +impl > SendJob { + pub fn new( + buffer: Vec, + counter: u64, + keypair: Arc, + peer: Peer + ) -> SendJob { + SendJob(Arc::new(Inner{ + buffer: Mutex::new(buffer), + counter, + keypair, + peer, + ready: AtomicBool::new(false) + })) + } +} + +impl > SequentialJob for SendJob { + + fn is_ready(&self) -> bool { + self.0.ready.load(Ordering::Acquire) + } + + fn sequential_work(self) { + debug_assert_eq!( + self.is_ready(), + true, + "doing sequential work + on an incomplete job" + ); + log::trace!("processing sequential send job"); + + // send to peer + let job = &self.0; + let msg = job.buffer.lock(); + let xmit = job.peer.send_raw(&msg[..]).is_ok(); + + // trigger callback (for timers) + C::send( + &job.peer.opaque, + msg.len(), + xmit, + &job.keypair, + job.counter, + ); + } +} + + +impl > ParallelJob for SendJob { + + fn queue(&self) -> &Queue { + &self.0.peer.outbound + } + + fn parallel_work(&self) { + debug_assert_eq!( + self.is_ready(), + false, + "doing parallel work on completed job" + ); + log::trace!("processing parallel send job"); + + // encrypt body + { + // make space for the tag + let job = &*self.0; + let mut msg = job.buffer.lock(); + msg.extend([0u8; SIZE_TAG].iter()); + + // cast to header (should never fail) + let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = + LayoutVerified::new_from_prefix(&mut msg[..]) + .expect("earlier code should ensure that there is ample space"); + + // set header fields + debug_assert!( + job.counter < REJECT_AFTER_MESSAGES, + "should be checked when assigning counters" + ); + header.f_type.set(TYPE_TRANSPORT); + header.f_receiver.set(job.keypair.send.id); + header.f_counter.set(job.counter); + + // create a nonce object + let mut nonce = [0u8; 12]; + debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); + nonce[4..].copy_from_slice(header.f_counter.as_bytes()); + let nonce = Nonce::assume_unique_for_key(nonce); + + // do the weird ring AEAD dance + let key = LessSafeKey::new( + UnboundKey::new(&CHACHA20_POLY1305, &job.keypair.send.key[..]).unwrap(), + ); + + // encrypt contents of transport message in-place + let end = packet.len() - SIZE_TAG; + let tag = key + .seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end]) + .unwrap(); + + // append tag + packet[end..].copy_from_slice(tag.as_ref()); + } + + // mark ready + self.0.ready.store(true, Ordering::Release); + } +} \ No newline at end of file diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs index 15db368..3d5c79b 100644 --- a/src/wireguard/router/tests.rs +++ b/src/wireguard/router/tests.rs @@ -50,7 +50,6 @@ mod tests { })) } - #[allow(dead_code)] fn reset(&self) { self.0.send.lock().unwrap().clear(); self.0.recv.lock().unwrap().clear(); @@ -104,9 +103,9 @@ mod tests { } } - // wait for scheduling (VERY conservative) + // wait for scheduling fn wait() { - thread::sleep(Duration::from_millis(30)); + thread::sleep(Duration::from_millis(15)); } fn init() { @@ -162,7 +161,7 @@ mod tests { }; let msg = make_packet_padded(1024, src, dst, 0); - // every iteration sends 10 MB + // every iteration sends 10 GB b.iter(|| { opaque.store(0, Ordering::SeqCst); while opaque.load(Ordering::Acquire) < 10 * 1024 * 1024 { diff --git a/src/wireguard/router/worker.rs b/src/wireguard/router/worker.rs new file mode 100644 index 0000000..bbb644c --- /dev/null +++ b/src/wireguard/router/worker.rs @@ -0,0 +1,31 @@ +use super::super::{tun, udp, Endpoint}; +use super::types::Callbacks; + +use super::queue::ParallelJob; +use super::receive::ReceiveJob; +use super::send::SendJob; + +use crossbeam_channel::Receiver; + +pub enum JobUnion> { + Outbound(SendJob), + Inbound(ReceiveJob), +} + +pub fn worker>( + receiver: Receiver>, +) { + loop { + match receiver.recv() { + Err(_) => break, + Ok(JobUnion::Inbound(job)) => { + job.parallel_work(); + job.queue().consume(); + } + Ok(JobUnion::Outbound(job)) => { + job.parallel_work(); + job.queue().consume(); + } + } + } +} diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs index 6b852bb..0197a9e 100644 --- a/src/wireguard/timers.rs +++ b/src/wireguard/timers.rs @@ -319,8 +319,8 @@ impl Timers { let timers = peer.timers(); if timers.enabled && timers.keepalive_interval > 0 { timers.send_keepalive.stop(); - let queued = peer.router.send_keepalive(); - log::trace!("{} : keepalive queued {}", peer, queued); + peer.router.send_keepalive(); + log::trace!("{} : keepalive queued", peer); timers .send_persistent_keepalive .start(Duration::from_secs(timers.keepalive_interval)); diff --git a/src/wireguard/workers.rs b/src/wireguard/workers.rs index 02db160..70e3b3a 100644 --- a/src/wireguard/workers.rs +++ b/src/wireguard/workers.rs @@ -194,7 +194,8 @@ pub fn handshake_worker( let mut resp_len: u64 = 0; if let Some(msg) = resp { resp_len = msg.len() as u64; - let _ = wg.router.write(&msg[..], &mut src).map_err(|e| { + // TODO: consider a more elegant solution for accessing the bind + let _ = wg.router.send_raw(&msg[..], &mut src).map_err(|e| { debug!( "{} : handshake worker, failed to send response, error = {}", wg, e @@ -252,7 +253,7 @@ pub fn handshake_worker( ); let device = wg.peers.read(); let _ = device.begin(&mut OsRng, &peer.pk).map(|msg| { - let _ = peer.router.send(&msg[..]).map_err(|e| { + let _ = peer.router.send_raw(&msg[..]).map_err(|e| { debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e) }); peer.state.sent_handshake_initiation(); -- cgit v1.2.3-59-g8ed1b