diff options
Diffstat (limited to 'src/wireguard/router')
-rw-r--r-- | src/wireguard/router/constants.rs | 4 | ||||
-rw-r--r-- | src/wireguard/router/device.rs | 125 | ||||
-rw-r--r-- | src/wireguard/router/inbound.rs | 190 | ||||
-rw-r--r-- | src/wireguard/router/mod.rs | 10 | ||||
-rw-r--r-- | src/wireguard/router/outbound.rs | 110 | ||||
-rw-r--r-- | src/wireguard/router/peer.rs | 206 | ||||
-rw-r--r-- | src/wireguard/router/pool.rs | 164 | ||||
-rw-r--r-- | src/wireguard/router/queue.rs | 144 | ||||
-rw-r--r-- | src/wireguard/router/receive.rs | 192 | ||||
-rw-r--r-- | src/wireguard/router/runq.rs | 129 | ||||
-rw-r--r-- | src/wireguard/router/send.rs | 143 | ||||
-rw-r--r-- | src/wireguard/router/tests.rs | 7 | ||||
-rw-r--r-- | src/wireguard/router/worker.rs | 31 |
13 files changed, 633 insertions, 822 deletions
diff --git a/src/wireguard/router/constants.rs b/src/wireguard/router/constants.rs index af76299..f083811 100644 --- a/src/wireguard/router/constants.rs +++ b/src/wireguard/router/constants.rs @@ -4,6 +4,6 @@ pub const MAX_QUEUED_PACKETS: usize = 1024; // performance constants -pub const PARALLEL_QUEUE_SIZE: usize = MAX_QUEUED_PACKETS; +pub const PARALLEL_QUEUE_SIZE: usize = 4 * MAX_QUEUED_PACKETS; + pub const INORDER_QUEUE_SIZE: usize = MAX_QUEUED_PACKETS; -pub const MAX_INORDER_CONSUME: usize = INORDER_QUEUE_SIZE; diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs index 96b7d82..9d78178 100644 --- a/src/wireguard/router/device.rs +++ b/src/wireguard/router/device.rs @@ -10,19 +10,16 @@ use spin::{Mutex, RwLock}; use zerocopy::LayoutVerified; use super::anti_replay::AntiReplay; -use super::pool::Job; use super::constants::PARALLEL_QUEUE_SIZE; -use super::inbound; -use super::outbound; - use super::messages::{TransportHeader, TYPE_TRANSPORT}; use super::peer::{new_peer, Peer, PeerHandle}; use super::types::{Callbacks, RouterError}; use super::SIZE_MESSAGE_PREFIX; +use super::receive::ReceiveJob; use super::route::RoutingTable; -use super::runq::RunQueue; +use super::worker::{worker, JobUnion}; use super::super::{tun, udp, Endpoint, KeyPair}; use super::ParallelQueue; @@ -38,13 +35,8 @@ pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state pub table: RoutingTable<Peer<E, C, T, B>>, - // work queues - pub queue_outbound: ParallelQueue<Job<Peer<E, C, T, B>, outbound::Outbound>>, - pub queue_inbound: ParallelQueue<Job<Peer<E, C, T, B>, inbound::Inbound<E, C, T, B>>>, - - // run queues - pub run_inbound: RunQueue<Peer<E, C, T, B>>, - pub run_outbound: RunQueue<Peer<E, C, T, B>>, + // work queue + pub work: ParallelQueue<JobUnion<E, C, T, B>>, } pub struct EncryptionState { @@ -101,13 +93,8 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop fn drop(&mut self) { debug!("router: dropping device"); - // close worker queues - self.state.queue_outbound.close(); - self.state.queue_inbound.close(); - - // close run queues - self.state.run_outbound.close(); - self.state.run_inbound.close(); + // close worker queue + self.state.work.close(); // join all worker threads while match self.handles.pop() { @@ -118,77 +105,28 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop } _ => false, } {} - - debug!("router: device dropped"); } } impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<E, C, T, B> { pub fn new(num_workers: usize, tun: T) -> DeviceHandle<E, C, T, B> { - // allocate shared device state - let (queue_outbound, mut outrx) = ParallelQueue::new(num_workers, PARALLEL_QUEUE_SIZE); - let (queue_inbound, mut inrx) = ParallelQueue::new(num_workers, PARALLEL_QUEUE_SIZE); + let (work, mut consumers) = ParallelQueue::new(num_workers, PARALLEL_QUEUE_SIZE); let device = Device { inner: Arc::new(DeviceInner { + work, inbound: tun, - queue_inbound, outbound: RwLock::new((true, None)), - queue_outbound, - run_inbound: RunQueue::new(), - run_outbound: RunQueue::new(), recv: RwLock::new(HashMap::new()), table: RoutingTable::new(), }), }; // start worker threads - let mut threads = Vec::with_capacity(4 * num_workers); - - // inbound/decryption workers - for _ in 0..num_workers { - // parallel workers (parallel processing) - { - let device = device.clone(); - let rx = inrx.pop().unwrap(); - threads.push(thread::spawn(move || { - log::debug!("inbound parallel router worker started"); - inbound::parallel(device, rx) - })); - } - - // sequential workers (in-order processing) - { - let device = device.clone(); - threads.push(thread::spawn(move || { - log::debug!("inbound sequential router worker started"); - inbound::sequential(device) - })); - } - } - - // outbound/encryption workers - for _ in 0..num_workers { - // parallel workers (parallel processing) - { - let device = device.clone(); - let rx = outrx.pop().unwrap(); - threads.push(thread::spawn(move || { - log::debug!("outbound parallel router worker started"); - outbound::parallel(device, rx) - })); - } - - // sequential workers (in-order processing) - { - let device = device.clone(); - threads.push(thread::spawn(move || { - log::debug!("outbound sequential router worker started"); - outbound::sequential(device) - })); - } + let mut threads = Vec::with_capacity(num_workers); + while let Some(rx) = consumers.pop() { + threads.push(thread::spawn(move || worker(rx))); } - - debug_assert_eq!(threads.len(), num_workers * 4); + debug_assert_eq!(threads.len(), num_workers); // return exported device handle DeviceHandle { @@ -197,6 +135,16 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle< } } + pub fn send_raw(&self, msg : &[u8], dst: &mut E) -> Result<(), B::Error> { + let bind = self.state.outbound.read(); + if bind.0 { + if let Some(bind) = bind.1.as_ref() { + return bind.write(msg, dst); + } + } + return Ok(()) + } + /// Brings the router down. /// When the router is brought down it: /// - Prevents transmission of outbound messages. @@ -250,10 +198,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle< .ok_or(RouterError::NoCryptoKeyRoute)?; // schedule for encryption and transmission to peer - if let Some(job) = peer.send_job(msg, true) { - self.state.queue_outbound.send(job); - } - + peer.send(msg, true); Ok(()) } @@ -297,10 +242,13 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle< .get(&header.f_receiver.get()) .ok_or(RouterError::UnknownReceiverId)?; - // schedule for decryption and TUN write - if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) { - log::trace!("schedule decryption of transport message"); - self.state.queue_inbound.send(job); + // create inbound job + let job = ReceiveJob::new(msg, dec.clone(), src); + + // 1. add to sequential queue (drop if full) + // 2. then add to parallel work queue (wait if full) + if dec.peer.inbound.push(job.clone()) { + self.state.work.send(JobUnion::Inbound(job)); } Ok(()) } @@ -311,17 +259,4 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle< pub fn set_outbound_writer(&self, new: B) { self.state.outbound.write().1 = Some(new); } - - pub fn write(&self, msg: &[u8], endpoint: &mut E) -> Result<(), RouterError> { - let outbound = self.state.outbound.read(); - if outbound.0 { - outbound - .1 - .as_ref() - .ok_or(RouterError::SendError) - .and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError)) - } else { - Ok(()) - } - } } diff --git a/src/wireguard/router/inbound.rs b/src/wireguard/router/inbound.rs deleted file mode 100644 index dc2c44e..0000000 --- a/src/wireguard/router/inbound.rs +++ /dev/null @@ -1,190 +0,0 @@ -use std::mem; -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use crossbeam_channel::Receiver; -use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; -use zerocopy::{AsBytes, LayoutVerified}; - -use super::constants::MAX_INORDER_CONSUME; -use super::device::DecryptionState; -use super::device::Device; -use super::messages::TransportHeader; -use super::peer::Peer; -use super::pool::*; -use super::types::Callbacks; -use super::{tun, udp, Endpoint}; -use super::{REJECT_AFTER_MESSAGES, SIZE_TAG}; - -pub struct Inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { - msg: Vec<u8>, - failed: bool, - state: Arc<DecryptionState<E, C, T, B>>, - endpoint: Option<E>, -} - -impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Inbound<E, C, T, B> { - pub fn new( - msg: Vec<u8>, - state: Arc<DecryptionState<E, C, T, B>>, - endpoint: E, - ) -> Inbound<E, C, T, B> { - Inbound { - msg, - state, - failed: false, - endpoint: Some(endpoint), - } - } -} - -#[inline(always)] -pub fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - device: Device<E, C, T, B>, - receiver: Receiver<Job<Peer<E, C, T, B>, Inbound<E, C, T, B>>>, -) { - // parallel work to apply - #[inline(always)] - fn work<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - peer: &Peer<E, C, T, B>, - body: &mut Inbound<E, C, T, B>, - ) { - log::trace!("worker, parallel section, obtained job"); - - // cast to header followed by payload - let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = - match LayoutVerified::new_from_prefix(&mut body.msg[..]) { - Some(v) => v, - None => { - log::debug!("inbound worker: failed to parse message"); - return; - } - }; - - // authenticate and decrypt payload - { - // create nonce object - let mut nonce = [0u8; 12]; - debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); - nonce[4..].copy_from_slice(header.f_counter.as_bytes()); - let nonce = Nonce::assume_unique_for_key(nonce); - - // do the weird ring AEAD dance - let key = LessSafeKey::new( - UnboundKey::new(&CHACHA20_POLY1305, &body.state.keypair.recv.key[..]).unwrap(), - ); - - // attempt to open (and authenticate) the body - match key.open_in_place(nonce, Aad::empty(), packet) { - Ok(_) => (), - Err(_) => { - // fault and return early - log::trace!("inbound worker: authentication failure"); - body.failed = true; - return; - } - } - } - - // check that counter not after reject - if header.f_counter.get() >= REJECT_AFTER_MESSAGES { - body.failed = true; - return; - } - - // cryptokey route and strip padding - let inner_len = { - let length = packet.len() - SIZE_TAG; - if length > 0 { - peer.device.table.check_route(&peer, &packet[..length]) - } else { - Some(0) - } - }; - - // truncate to remove tag - match inner_len { - None => { - log::trace!("inbound worker: cryptokey routing failed"); - body.failed = true; - } - Some(len) => { - log::trace!( - "inbound worker: good route, length = {} {}", - len, - if len == 0 { "(keepalive)" } else { "" } - ); - body.msg.truncate(mem::size_of::<TransportHeader>() + len); - } - } - } - - worker_parallel(device, |dev| &dev.run_inbound, receiver, work) -} - -#[inline(always)] -pub fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - device: Device<E, C, T, B>, -) { - // sequential work to apply - fn work<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - peer: &Peer<E, C, T, B>, - body: &mut Inbound<E, C, T, B>, - ) { - log::trace!("worker, sequential section, obtained job"); - - // decryption failed, return early - if body.failed { - log::trace!("job faulted, remove from queue and ignore"); - return; - } - - // cast transport header - let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) = - match LayoutVerified::new_from_prefix(&body.msg[..]) { - Some(v) => v, - None => { - log::debug!("inbound worker: failed to parse message"); - return; - } - }; - - // check for replay - if !body.state.protector.lock().update(header.f_counter.get()) { - log::debug!("inbound worker: replay detected"); - return; - } - - // check for confirms key - if !body.state.confirmed.swap(true, Ordering::SeqCst) { - log::debug!("inbound worker: message confirms key"); - peer.confirm_key(&body.state.keypair); - } - - // update endpoint - *peer.endpoint.lock() = body.endpoint.take(); - - // check if should be written to TUN - let mut sent = false; - if packet.len() > 0 { - sent = match peer.device.inbound.write(&packet[..]) { - Err(e) => { - log::debug!("failed to write inbound packet to TUN: {:?}", e); - false - } - Ok(_) => true, - } - } else { - log::debug!("inbound worker: received keepalive") - } - - // trigger callback - C::recv(&peer.opaque, body.msg.len(), sent, &body.state.keypair); - } - - // handle message from the peers inbound queue - device.run_inbound.run(|peer| { - peer.inbound - .handle(|body| work(&peer, body), MAX_INORDER_CONSUME) - }); -} diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs index 8238d32..699c621 100644 --- a/src/wireguard/router/mod.rs +++ b/src/wireguard/router/mod.rs @@ -1,16 +1,17 @@ mod anti_replay; mod constants; mod device; -mod inbound; mod ip; mod messages; -mod outbound; mod peer; -mod pool; mod route; -mod runq; mod types; +mod queue; +mod receive; +mod send; +mod worker; + #[cfg(test)] mod tests; @@ -20,7 +21,6 @@ use std::mem; use super::constants::REJECT_AFTER_MESSAGES; use super::queue::ParallelQueue; use super::types::*; -use super::{tun, udp, Endpoint}; pub const SIZE_TAG: usize = 16; pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::<TransportHeader>(); diff --git a/src/wireguard/router/outbound.rs b/src/wireguard/router/outbound.rs deleted file mode 100644 index 1edb2fb..0000000 --- a/src/wireguard/router/outbound.rs +++ /dev/null @@ -1,110 +0,0 @@ -use std::sync::Arc; - -use crossbeam_channel::Receiver; -use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; -use zerocopy::{AsBytes, LayoutVerified}; - -use super::constants::MAX_INORDER_CONSUME; -use super::device::Device; -use super::messages::{TransportHeader, TYPE_TRANSPORT}; -use super::peer::Peer; -use super::pool::*; -use super::types::Callbacks; -use super::KeyPair; -use super::{tun, udp, Endpoint}; -use super::{REJECT_AFTER_MESSAGES, SIZE_TAG}; - -pub struct Outbound { - msg: Vec<u8>, - keypair: Arc<KeyPair>, - counter: u64, -} - -impl Outbound { - pub fn new(msg: Vec<u8>, keypair: Arc<KeyPair>, counter: u64) -> Outbound { - Outbound { - msg, - keypair, - counter, - } - } -} - -#[inline(always)] -pub fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - device: Device<E, C, T, B>, - receiver: Receiver<Job<Peer<E, C, T, B>, Outbound>>, -) { - #[inline(always)] - fn work<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - _peer: &Peer<E, C, T, B>, - body: &mut Outbound, - ) { - log::trace!("worker, parallel section, obtained job"); - - // make space for the tag - body.msg.extend([0u8; SIZE_TAG].iter()); - - // cast to header (should never fail) - let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = - LayoutVerified::new_from_prefix(&mut body.msg[..]) - .expect("earlier code should ensure that there is ample space"); - - // set header fields - debug_assert!( - body.counter < REJECT_AFTER_MESSAGES, - "should be checked when assigning counters" - ); - header.f_type.set(TYPE_TRANSPORT); - header.f_receiver.set(body.keypair.send.id); - header.f_counter.set(body.counter); - - // create a nonce object - let mut nonce = [0u8; 12]; - debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); - nonce[4..].copy_from_slice(header.f_counter.as_bytes()); - let nonce = Nonce::assume_unique_for_key(nonce); - - // do the weird ring AEAD dance - let key = LessSafeKey::new( - UnboundKey::new(&CHACHA20_POLY1305, &body.keypair.send.key[..]).unwrap(), - ); - - // encrypt content of transport message in-place - let end = packet.len() - SIZE_TAG; - let tag = key - .seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end]) - .unwrap(); - - // append tag - packet[end..].copy_from_slice(tag.as_ref()); - } - - worker_parallel(device, |dev| &dev.run_outbound, receiver, work); -} - -#[inline(always)] -pub fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - device: Device<E, C, T, B>, -) { - device.run_outbound.run(|peer| { - peer.outbound.handle( - |body| { - log::trace!("worker, sequential section, obtained job"); - - // send to peer - let xmit = peer.send(&body.msg[..]).is_ok(); - - // trigger callback - C::send( - &peer.opaque, - body.msg.len(), - xmit, - &body.keypair, - body.counter, - ); - }, - MAX_INORDER_CONSUME, - ) - }); -} diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs index 8fe2e1c..a20908e 100644 --- a/src/wireguard/router/peer.rs +++ b/src/wireguard/router/peer.rs @@ -1,13 +1,3 @@ -use std::mem; -use std::net::{IpAddr, SocketAddr}; -use std::ops::Deref; -use std::sync::atomic::AtomicBool; -use std::sync::Arc; - -use arraydeque::{ArrayDeque, Wrapping}; -use log::debug; -use spin::Mutex; - use super::super::constants::*; use super::super::{tun, udp, Endpoint, KeyPair}; @@ -15,17 +5,25 @@ use super::anti_replay::AntiReplay; use super::device::DecryptionState; use super::device::Device; use super::device::EncryptionState; -use super::messages::TransportHeader; use super::constants::*; -use super::runq::ToKey; use super::types::{Callbacks, RouterError}; use super::SIZE_MESSAGE_PREFIX; -// worker pool related -use super::inbound::Inbound; -use super::outbound::Outbound; -use super::pool::{InorderQueue, Job}; +use super::queue::Queue; +use super::receive::ReceiveJob; +use super::send::SendJob; +use super::worker::JobUnion; + +use std::mem; +use std::net::{IpAddr, SocketAddr}; +use std::ops::Deref; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; + +use arraydeque::{ArrayDeque, Wrapping}; +use log::debug; +use spin::Mutex; pub struct KeyWheel { next: Option<Arc<KeyPair>>, // next key state (unconfirmed) @@ -37,11 +35,11 @@ pub struct KeyWheel { pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { pub device: Device<E, C, T, B>, pub opaque: C::Opaque, - pub outbound: InorderQueue<Peer<E, C, T, B>, Outbound>, - pub inbound: InorderQueue<Peer<E, C, T, B>, Inbound<E, C, T, B>>, + pub outbound: Queue<SendJob<E, C, T, B>>, + pub inbound: Queue<ReceiveJob<E, C, T, B>>, pub staged_packets: Mutex<ArrayDeque<[Vec<u8>; MAX_QUEUED_PACKETS], Wrapping>>, pub keys: Mutex<KeyWheel>, - pub ekey: Mutex<Option<EncryptionState>>, + pub enc_key: Mutex<Option<EncryptionState>>, pub endpoint: Mutex<Option<E>>, } @@ -66,13 +64,6 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PartialEq for } } -impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ToKey for Peer<E, C, T, B> { - type Key = usize; - fn to_key(&self) -> usize { - Arc::downgrade(&self.inner).into_raw() as usize - } -} - impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Eq for Peer<E, C, T, B> {} /* A peer is transparently dereferenced to the inner type @@ -154,7 +145,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Peer keys.current = None; keys.previous = None; - *peer.ekey.lock() = None; + *peer.enc_key.lock() = None; *peer.endpoint.lock() = None; debug!("peer dropped & removed from device"); @@ -172,9 +163,9 @@ pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( inner: Arc::new(PeerInner { opaque, device, - inbound: InorderQueue::new(), - outbound: InorderQueue::new(), - ekey: spin::Mutex::new(None), + inbound: Queue::new(), + outbound: Queue::new(), + enc_key: spin::Mutex::new(None), endpoint: spin::Mutex::new(None), keys: spin::Mutex::new(KeyWheel { next: None, @@ -200,7 +191,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E, /// # Returns /// /// Unit if packet was sent, or an error indicating why sending failed - pub fn send(&self, msg: &[u8]) -> Result<(), RouterError> { + pub fn send_raw(&self, msg: &[u8]) -> Result<(), RouterError> { debug!("peer.send"); // send to endpoint (if known) @@ -223,6 +214,57 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E, } impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, B> { + /// Encrypt and send a message to the peer + /// + /// Arguments: + /// + /// - `msg` : A padded vector holding the message (allows in-place construction of the transport header) + /// - `stage`: Should the message be staged if no key is available + /// + pub(super) fn send(&self, msg: Vec<u8>, stage: bool) { + // check if key available + let (job, need_key) = { + let mut enc_key = self.enc_key.lock(); + match enc_key.as_mut() { + None => { + if stage { + self.staged_packets.lock().push_back(msg); + }; + (None, true) + } + Some(mut state) => { + // avoid integer overflow in nonce + if state.nonce >= REJECT_AFTER_MESSAGES - 1 { + *enc_key = None; + if stage { + self.staged_packets.lock().push_back(msg); + } + (None, true) + } else { + debug!("encryption state available, nonce = {}", state.nonce); + let job = + SendJob::new(msg, state.nonce, state.keypair.clone(), self.clone()); + if self.outbound.push(job.clone()) { + state.nonce += 1; + (Some(job), false) + } else { + (None, false) + } + } + } + } + }; + + if need_key { + debug_assert!(job.is_none()); + C::need_key(&self.opaque); + }; + + if let Some(job) = job { + self.device.work.send(JobUnion::Outbound(job)) + } + } + // Transmit all staged packets fn send_staged(&self) -> bool { debug!("peer.send_staged"); @@ -232,29 +274,14 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, match staged.pop_front() { Some(msg) => { sent = true; - self.send_raw(msg, false); + self.send(msg, false); } None => break sent, } } } - // Treat the msg as the payload of a transport message - // - // Returns true if the message was queued for transmission. - fn send_raw(&self, msg: Vec<u8>, stage: bool) -> bool { - log::debug!("peer.send_raw"); - match self.send_job(msg, stage) { - Some(job) => { - self.device.queue_outbound.send(job); - debug!("send_raw: got obtained send_job"); - true - } - None => false, - } - } - - pub fn confirm_key(&self, keypair: &Arc<KeyPair>) { + pub(super) fn confirm_key(&self, keypair: &Arc<KeyPair>) { debug!("peer.confirm_key"); { // take lock and check keypair = keys.next @@ -282,76 +309,12 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, C::key_confirmed(&self.opaque); // set new key for encryption - *self.ekey.lock() = ekey; + *self.enc_key.lock() = ekey; } // start transmission of staged packets self.send_staged(); } - - pub fn recv_job( - &self, - src: E, - dec: Arc<DecryptionState<E, C, T, B>>, - msg: Vec<u8>, - ) -> Option<Job<Self, Inbound<E, C, T, B>>> { - let job = Job::new(self.clone(), Inbound::new(msg, dec, src)); - self.inbound.send(job.clone()); - Some(job) - } - - pub fn send_job(&self, msg: Vec<u8>, stage: bool) -> Option<Job<Self, Outbound>> { - debug!( - "peer.send_job, msg.len() = {}, stage = {}", - msg.len(), - stage - ); - debug_assert!( - msg.len() >= mem::size_of::<TransportHeader>(), - "received message with size: {:}", - msg.len() - ); - - // check if has key - let (keypair, counter) = { - let keypair = { - // TODO: consider using atomic ptr for ekey state - let mut ekey = self.ekey.lock(); - match ekey.as_mut() { - None => None, - Some(mut state) => { - // avoid integer overflow in nonce - if state.nonce >= REJECT_AFTER_MESSAGES - 1 { - *ekey = None; - None - } else { - debug!("encryption state available, nonce = {}", state.nonce); - let counter = state.nonce; - state.nonce += 1; - Some((state.keypair.clone(), counter)) - } - } - } - }; - - // If not suitable key was found: - // 1. Stage packet for later transmission - // 2. Request new key - if keypair.is_none() && stage { - log::trace!("packet staged"); - self.staged_packets.lock().push_back(msg); - C::need_key(&self.opaque); - return None; - }; - - keypair - }?; - - // add job to in-order queue and return sender to device for inclusion in worker pool - let job = Job::new(self.clone(), Outbound::new(msg, keypair, counter)); - self.outbound.send(job.clone()); - Some(job) - } } impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, C, T, B> { @@ -403,7 +366,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, } // clear encryption state - *self.peer.ekey.lock() = None; + *self.peer.enc_key.lock() = None; } pub fn down(&self) { @@ -440,7 +403,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, // update key-wheel if new.initiator { // start using key for encryption - *self.peer.ekey.lock() = Some(EncryptionState::new(&new)); + *self.peer.enc_key.lock() = Some(EncryptionState::new(&new)); // move current into previous keys.previous = keys.current.as_ref().map(|v| v.clone()); @@ -474,16 +437,13 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, // schedule confirmation if initiator { - debug_assert!(self.peer.ekey.lock().is_some()); + debug_assert!(self.peer.enc_key.lock().is_some()); debug!("peer.add_keypair: is initiator, must confirm the key"); // attempt to confirm using staged packets if !self.peer.send_staged() { // fall back to keepalive packet - let ok = self.send_keepalive(); - debug!( - "peer.add_keypair: keepalive for confirmation, sent = {}", - ok - ); + self.send_keepalive(); + debug!("peer.add_keypair: keepalive for confirmation",); } debug!("peer.add_keypair: key attempted confirmed"); } @@ -495,9 +455,9 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, release } - pub fn send_keepalive(&self) -> bool { + pub fn send_keepalive(&self) { debug!("peer.send_keepalive"); - self.peer.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX], true) + self.peer.send(vec![0u8; SIZE_MESSAGE_PREFIX], false) } /// Map a subnet to the peer diff --git a/src/wireguard/router/pool.rs b/src/wireguard/router/pool.rs deleted file mode 100644 index 3fc0026..0000000 --- a/src/wireguard/router/pool.rs +++ /dev/null @@ -1,164 +0,0 @@ -use std::mem; -use std::sync::Arc; - -use arraydeque::ArrayDeque; -use crossbeam_channel::Receiver; -use spin::{Mutex, MutexGuard}; - -use super::constants::INORDER_QUEUE_SIZE; -use super::runq::{RunQueue, ToKey}; - -pub struct InnerJob<P, B> { - // peer (used by worker to schedule/handle inorder queue), - // when the peer is None, the job is complete - peer: Option<P>, - pub body: B, -} - -pub struct Job<P, B> { - inner: Arc<Mutex<InnerJob<P, B>>>, -} - -impl<P, B> Clone for Job<P, B> { - fn clone(&self) -> Job<P, B> { - Job { - inner: self.inner.clone(), - } - } -} - -impl<P, B> Job<P, B> { - pub fn new(peer: P, body: B) -> Job<P, B> { - Job { - inner: Arc::new(Mutex::new(InnerJob { - peer: Some(peer), - body, - })), - } - } -} - -impl<P, B> Job<P, B> { - /// Returns a mutex guard to the inner job if complete - pub fn complete(&self) -> Option<MutexGuard<InnerJob<P, B>>> { - self.inner - .try_lock() - .and_then(|m| if m.peer.is_none() { Some(m) } else { None }) - } -} - -pub struct InorderQueue<P, B> { - queue: Mutex<ArrayDeque<[Job<P, B>; INORDER_QUEUE_SIZE]>>, -} - -impl<P, B> InorderQueue<P, B> { - pub fn new() -> InorderQueue<P, B> { - InorderQueue { - queue: Mutex::new(ArrayDeque::new()), - } - } - - /// Add a new job to the in-order queue - /// - /// # Arguments - /// - /// - `job`: The job added to the back of the queue - /// - /// # Returns - /// - /// True if the element was added, - /// false to indicate that the queue is full. - pub fn send(&self, job: Job<P, B>) -> bool { - self.queue.lock().push_back(job).is_ok() - } - - /// Consume completed jobs from the in-order queue - /// - /// # Arguments - /// - /// - `f`: function to apply to the body of each jobof each job. - /// - `limit`: maximum number of jobs to handle before returning - /// - /// # Returns - /// - /// A boolean indicating if the limit was reached: - /// true indicating that the limit was reached, - /// while false implies that the queue is empty or an uncompleted job was reached. - #[inline(always)] - pub fn handle<F: Fn(&mut B)>(&self, f: F, mut limit: usize) -> bool { - // take the mutex - let mut queue = self.queue.lock(); - - while limit > 0 { - // attempt to extract front element - let front = queue.pop_front(); - let elem = match front { - Some(elem) => elem, - _ => { - return false; - } - }; - - // apply function if job complete - let ret = if let Some(mut guard) = elem.complete() { - mem::drop(queue); - f(&mut guard.body); - queue = self.queue.lock(); - false - } else { - true - }; - - // job not complete yet, return job to front - if ret { - queue.push_front(elem).unwrap(); - return false; - } - limit -= 1; - } - - // did not complete all jobs - true - } -} - -/// Allows easy construction of a parallel worker. -/// Applicable for both decryption and encryption workers. -#[inline(always)] -pub fn worker_parallel< - P: ToKey, // represents a peer (atomic reference counted pointer) - B, // inner body type (message buffer, key material, ...) - D, // device - W: Fn(&P, &mut B), - Q: Fn(&D) -> &RunQueue<P>, ->( - device: D, - queue: Q, - receiver: Receiver<Job<P, B>>, - work: W, -) { - log::trace!("router worker started"); - loop { - // handle new job - let peer = { - // get next job - let job = match receiver.recv() { - Ok(job) => job, - _ => return, - }; - - // lock the job - let mut job = job.inner.lock(); - - // take the peer from the job - let peer = job.peer.take().unwrap(); - - // process job - work(&peer, &mut job.body); - peer - }; - - // process inorder jobs for peer - queue(&device).insert(peer); - } -} diff --git a/src/wireguard/router/queue.rs b/src/wireguard/router/queue.rs new file mode 100644 index 0000000..ec4492e --- /dev/null +++ b/src/wireguard/router/queue.rs @@ -0,0 +1,144 @@ +use arraydeque::ArrayDeque; +use spin::Mutex; + +use std::mem; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use super::constants::INORDER_QUEUE_SIZE; + +pub trait SequentialJob { + fn is_ready(&self) -> bool; + + fn sequential_work(self); +} + +pub trait ParallelJob: Sized + SequentialJob { + fn queue(&self) -> &Queue<Self>; + + fn parallel_work(&self); +} + +pub struct Queue<J: SequentialJob> { + contenders: AtomicUsize, + queue: Mutex<ArrayDeque<[J; INORDER_QUEUE_SIZE]>>, + + #[cfg(debug)] + _flag: Mutex<()>, +} + +impl<J: SequentialJob> Queue<J> { + pub fn new() -> Queue<J> { + Queue { + contenders: AtomicUsize::new(0), + queue: Mutex::new(ArrayDeque::new()), + + #[cfg(debug)] + _flag: Mutex::new(()), + } + } + + pub fn push(&self, job: J) -> bool { + self.queue.lock().push_back(job).is_ok() + } + + pub fn consume(&self) { + // check if we are the first contender + let pos = self.contenders.fetch_add(1, Ordering::SeqCst); + if pos > 0 { + assert!(usize::max_value() > pos, "contenders overflow"); + return; + } + + // enter the critical section + let mut contenders = 1; // myself + while contenders > 0 { + // check soundness in debug builds + #[cfg(debug)] + let _flag = self + ._flag + .try_lock() + .expect("contenders should ensure mutual exclusion"); + + // handle every ready element + loop { + let mut queue = self.queue.lock(); + + // check if front job is ready + match queue.front() { + None => break, + Some(job) => { + if job.is_ready() { + () + } else { + break; + } + } + }; + + // take the job out of the queue + let job = queue.pop_front().unwrap(); + debug_assert!(job.is_ready()); + mem::drop(queue); + + // process element + job.sequential_work(); + } + + #[cfg(debug)] + mem::drop(_flag); + + // decrease contenders + contenders = self.contenders.fetch_sub(contenders, Ordering::SeqCst) - contenders; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::sync::Arc; + use std::thread; + + use rand::thread_rng; + use rand::Rng; + + struct TestJob {} + + impl SequentialJob for TestJob { + fn is_ready(&self) -> bool { + true + } + + fn sequential_work(self) {} + } + + /* Fuzz the Queue */ + #[test] + fn test_queue() { + fn hammer(queue: &Arc<Queue<TestJob>>) { + let mut rng = thread_rng(); + for _ in 0..1_000_000 { + if rng.gen() { + queue.push(TestJob {}); + } else { + queue.consume(); + } + } + } + + let queue = Arc::new(Queue::new()); + + // repeatedly apply operations randomly from concurrent threads + let other = { + let queue = queue.clone(); + thread::spawn(move || hammer(&queue)) + }; + hammer(&queue); + + // wait, consume and check empty + other.join().unwrap(); + queue.consume(); + assert_eq!(queue.queue.lock().len(), 0); + } +} diff --git a/src/wireguard/router/receive.rs b/src/wireguard/router/receive.rs new file mode 100644 index 0000000..c5fe3da --- /dev/null +++ b/src/wireguard/router/receive.rs @@ -0,0 +1,192 @@ +use super::device::DecryptionState; +use super::messages::TransportHeader; +use super::queue::{ParallelJob, Queue, SequentialJob}; +use super::types::Callbacks; +use super::{REJECT_AFTER_MESSAGES, SIZE_TAG}; + +use super::super::{tun, udp, Endpoint}; + +use std::mem; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + +use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; +use spin::Mutex; +use zerocopy::{AsBytes, LayoutVerified}; + +struct Inner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { + ready: AtomicBool, + buffer: Mutex<(Option<E>, Vec<u8>)>, // endpoint & ciphertext buffer + state: Arc<DecryptionState<E, C, T, B>>, // decryption state (keys and replay protector) +} + +pub struct ReceiveJob<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( + Arc<Inner<E, C, T, B>>, +); + +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone + for ReceiveJob<E, C, T, B> +{ + fn clone(&self) -> ReceiveJob<E, C, T, B> { + ReceiveJob(self.0.clone()) + } +} + +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ReceiveJob<E, C, T, B> { + pub fn new( + buffer: Vec<u8>, + state: Arc<DecryptionState<E, C, T, B>>, + endpoint: E, + ) -> ReceiveJob<E, C, T, B> { + ReceiveJob(Arc::new(Inner { + ready: AtomicBool::new(false), + buffer: Mutex::new((Some(endpoint), buffer)), + state, + })) + } +} + +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ParallelJob + for ReceiveJob<E, C, T, B> +{ + fn queue(&self) -> &Queue<Self> { + &self.0.state.peer.inbound + } + + fn parallel_work(&self) { + // TODO: refactor + // decrypt + { + let job = &self.0; + let peer = &job.state.peer; + let mut msg = job.buffer.lock(); + + // cast to header followed by payload + let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = + match LayoutVerified::new_from_prefix(&mut msg.1[..]) { + Some(v) => v, + None => { + log::debug!("inbound worker: failed to parse message"); + return; + } + }; + + // authenticate and decrypt payload + { + // create nonce object + let mut nonce = [0u8; 12]; + debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); + nonce[4..].copy_from_slice(header.f_counter.as_bytes()); + let nonce = Nonce::assume_unique_for_key(nonce); + + // do the weird ring AEAD dance + let key = LessSafeKey::new( + UnboundKey::new(&CHACHA20_POLY1305, &job.state.keypair.recv.key[..]).unwrap(), + ); + + // attempt to open (and authenticate) the body + match key.open_in_place(nonce, Aad::empty(), packet) { + Ok(_) => (), + Err(_) => { + // fault and return early + log::trace!("inbound worker: authentication failure"); + msg.1.truncate(0); + return; + } + } + } + + // check that counter not after reject + if header.f_counter.get() >= REJECT_AFTER_MESSAGES { + msg.1.truncate(0); + return; + } + + // cryptokey route and strip padding + let inner_len = { + let length = packet.len() - SIZE_TAG; + if length > 0 { + peer.device.table.check_route(&peer, &packet[..length]) + } else { + Some(0) + } + }; + + // truncate to remove tag + match inner_len { + None => { + log::trace!("inbound worker: cryptokey routing failed"); + msg.1.truncate(0); + } + Some(len) => { + log::trace!( + "inbound worker: good route, length = {} {}", + len, + if len == 0 { "(keepalive)" } else { "" } + ); + msg.1.truncate(mem::size_of::<TransportHeader>() + len); + } + } + } + + // mark ready + self.0.ready.store(true, Ordering::Release); + } +} + +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> SequentialJob + for ReceiveJob<E, C, T, B> +{ + fn is_ready(&self) -> bool { + self.0.ready.load(Ordering::Acquire) + } + + fn sequential_work(self) { + let job = &self.0; + let peer = &job.state.peer; + let mut msg = job.buffer.lock(); + let endpoint = msg.0.take(); + + // cast transport header + let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) = + match LayoutVerified::new_from_prefix(&msg.1[..]) { + Some(v) => v, + None => { + // also covers authentication failure + return; + } + }; + + // check for replay + if !job.state.protector.lock().update(header.f_counter.get()) { + log::debug!("inbound worker: replay detected"); + return; + } + + // check for confirms key + if !job.state.confirmed.swap(true, Ordering::SeqCst) { + log::debug!("inbound worker: message confirms key"); + peer.confirm_key(&job.state.keypair); + } + + // update endpoint + *peer.endpoint.lock() = endpoint; + + // check if should be written to TUN + let mut sent = false; + if packet.len() > 0 { + sent = match peer.device.inbound.write(&packet[..]) { + Err(e) => { + log::debug!("failed to write inbound packet to TUN: {:?}", e); + false + } + Ok(_) => true, + } + } else { + log::debug!("inbound worker: received keepalive") + } + + // trigger callback + C::recv(&peer.opaque, msg.1.len(), sent, &job.state.keypair); + } +} diff --git a/src/wireguard/router/runq.rs b/src/wireguard/router/runq.rs deleted file mode 100644 index 4c848cd..0000000 --- a/src/wireguard/router/runq.rs +++ /dev/null @@ -1,129 +0,0 @@ -use std::hash::Hash; -use std::mem; -use std::sync::{Condvar, Mutex}; - -use std::collections::hash_map::Entry; -use std::collections::HashMap; -use std::collections::VecDeque; - -pub trait ToKey { - type Key: Hash + Eq; - fn to_key(&self) -> Self::Key; -} - -pub struct RunQueue<T: ToKey> { - cvar: Condvar, - inner: Mutex<Inner<T>>, -} - -struct Inner<T: ToKey> { - stop: bool, - queue: VecDeque<T>, - members: HashMap<T::Key, usize>, -} - -impl<T: ToKey> RunQueue<T> { - pub fn close(&self) { - let mut inner = self.inner.lock().unwrap(); - inner.stop = true; - self.cvar.notify_all(); - } - - pub fn new() -> RunQueue<T> { - RunQueue { - cvar: Condvar::new(), - inner: Mutex::new(Inner { - stop: false, - queue: VecDeque::new(), - members: HashMap::new(), - }), - } - } - - pub fn insert(&self, v: T) { - let key = v.to_key(); - let mut inner = self.inner.lock().unwrap(); - match inner.members.entry(key) { - Entry::Occupied(mut elem) => { - *elem.get_mut() += 1; - } - Entry::Vacant(spot) => { - // add entry to back of queue - spot.insert(0); - inner.queue.push_back(v); - - // wake a thread - self.cvar.notify_one(); - } - } - } - - /// Run (consume from) the run queue using the provided function. - /// The function should return wheter the given element should be rescheduled. - /// - /// # Arguments - /// - /// - `f` : function to apply to every element - /// - /// # Note - /// - /// The function f may be called again even when the element was not inserted back in to the - /// queue since the last applciation and no rescheduling was requested. - /// - /// This happens then the function handles all work for T, - /// but T is added to the run queue while the function is running. - pub fn run<F: Fn(&T) -> bool>(&self, f: F) { - let mut inner = self.inner.lock().unwrap(); - loop { - // fetch next element - let elem = loop { - // run-queue closed - if inner.stop { - return; - } - - // try to pop from queue - match inner.queue.pop_front() { - Some(elem) => { - break elem; - } - None => (), - }; - - // wait for an element to be inserted - inner = self.cvar.wait(inner).unwrap(); - }; - - // fetch current request number - let key = elem.to_key(); - let old_n = *inner.members.get(&key).unwrap(); - mem::drop(inner); // drop guard - - // handle element - let rerun = f(&elem); - - // if the function requested a re-run add the element to the back of the queue - inner = self.inner.lock().unwrap(); - if rerun { - inner.queue.push_back(elem); - continue; - } - - // otherwise check if new requests have come in since we ran the function - match inner.members.entry(key) { - Entry::Occupied(occ) => { - if *occ.get() == old_n { - // no new requests since last, remove entry. - occ.remove(); - } else { - // new requests, reschedule. - inner.queue.push_back(elem); - } - } - Entry::Vacant(_) => { - unreachable!(); - } - } - } - } -} diff --git a/src/wireguard/router/send.rs b/src/wireguard/router/send.rs new file mode 100644 index 0000000..8e41796 --- /dev/null +++ b/src/wireguard/router/send.rs @@ -0,0 +1,143 @@ +use super::queue::{SequentialJob, ParallelJob, Queue}; +use super::KeyPair; +use super::types::Callbacks; +use super::peer::Peer; +use super::{REJECT_AFTER_MESSAGES, SIZE_TAG}; +use super::messages::{TransportHeader, TYPE_TRANSPORT}; + +use super::super::{tun, udp, Endpoint}; + +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + +use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; +use zerocopy::{AsBytes, LayoutVerified}; +use spin::Mutex; + +struct Inner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { + ready: AtomicBool, + buffer: Mutex<Vec<u8>>, + counter: u64, + keypair: Arc<KeyPair>, + peer: Peer<E, C, T, B>, +} + +pub struct SendJob<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ( + Arc<Inner<E, C, T, B>> +); + +impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone for SendJob<E, C, T, B> { + fn clone(&self) -> SendJob<E, C, T, B> { + SendJob(self.0.clone()) + } +} + +impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> SendJob<E, C, T, B> { + pub fn new( + buffer: Vec<u8>, + counter: u64, + keypair: Arc<KeyPair>, + peer: Peer<E, C, T, B> + ) -> SendJob<E, C, T, B> { + SendJob(Arc::new(Inner{ + buffer: Mutex::new(buffer), + counter, + keypair, + peer, + ready: AtomicBool::new(false) + })) + } +} + +impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> SequentialJob for SendJob<E, C, T, B> { + + fn is_ready(&self) -> bool { + self.0.ready.load(Ordering::Acquire) + } + + fn sequential_work(self) { + debug_assert_eq!( + self.is_ready(), + true, + "doing sequential work + on an incomplete job" + ); + log::trace!("processing sequential send job"); + + // send to peer + let job = &self.0; + let msg = job.buffer.lock(); + let xmit = job.peer.send_raw(&msg[..]).is_ok(); + + // trigger callback (for timers) + C::send( + &job.peer.opaque, + msg.len(), + xmit, + &job.keypair, + job.counter, + ); + } +} + + +impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ParallelJob for SendJob<E, C, T, B> { + + fn queue(&self) -> &Queue<Self> { + &self.0.peer.outbound + } + + fn parallel_work(&self) { + debug_assert_eq!( + self.is_ready(), + false, + "doing parallel work on completed job" + ); + log::trace!("processing parallel send job"); + + // encrypt body + { + // make space for the tag + let job = &*self.0; + let mut msg = job.buffer.lock(); + msg.extend([0u8; SIZE_TAG].iter()); + + // cast to header (should never fail) + let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = + LayoutVerified::new_from_prefix(&mut msg[..]) + .expect("earlier code should ensure that there is ample space"); + + // set header fields + debug_assert!( + job.counter < REJECT_AFTER_MESSAGES, + "should be checked when assigning counters" + ); + header.f_type.set(TYPE_TRANSPORT); + header.f_receiver.set(job.keypair.send.id); + header.f_counter.set(job.counter); + + // create a nonce object + let mut nonce = [0u8; 12]; + debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); + nonce[4..].copy_from_slice(header.f_counter.as_bytes()); + let nonce = Nonce::assume_unique_for_key(nonce); + + // do the weird ring AEAD dance + let key = LessSafeKey::new( + UnboundKey::new(&CHACHA20_POLY1305, &job.keypair.send.key[..]).unwrap(), + ); + + // encrypt contents of transport message in-place + let end = packet.len() - SIZE_TAG; + let tag = key + .seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end]) + .unwrap(); + + // append tag + packet[end..].copy_from_slice(tag.as_ref()); + } + + // mark ready + self.0.ready.store(true, Ordering::Release); + } +}
\ No newline at end of file diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs index 15db368..3d5c79b 100644 --- a/src/wireguard/router/tests.rs +++ b/src/wireguard/router/tests.rs @@ -50,7 +50,6 @@ mod tests { })) } - #[allow(dead_code)] fn reset(&self) { self.0.send.lock().unwrap().clear(); self.0.recv.lock().unwrap().clear(); @@ -104,9 +103,9 @@ mod tests { } } - // wait for scheduling (VERY conservative) + // wait for scheduling fn wait() { - thread::sleep(Duration::from_millis(30)); + thread::sleep(Duration::from_millis(15)); } fn init() { @@ -162,7 +161,7 @@ mod tests { }; let msg = make_packet_padded(1024, src, dst, 0); - // every iteration sends 10 MB + // every iteration sends 10 GB b.iter(|| { opaque.store(0, Ordering::SeqCst); while opaque.load(Ordering::Acquire) < 10 * 1024 * 1024 { diff --git a/src/wireguard/router/worker.rs b/src/wireguard/router/worker.rs new file mode 100644 index 0000000..bbb644c --- /dev/null +++ b/src/wireguard/router/worker.rs @@ -0,0 +1,31 @@ +use super::super::{tun, udp, Endpoint}; +use super::types::Callbacks; + +use super::queue::ParallelJob; +use super::receive::ReceiveJob; +use super::send::SendJob; + +use crossbeam_channel::Receiver; + +pub enum JobUnion<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { + Outbound(SendJob<E, C, T, B>), + Inbound(ReceiveJob<E, C, T, B>), +} + +pub fn worker<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( + receiver: Receiver<JobUnion<E, C, T, B>>, +) { + loop { + match receiver.recv() { + Err(_) => break, + Ok(JobUnion::Inbound(job)) => { + job.parallel_work(); + job.queue().consume(); + } + Ok(JobUnion::Outbound(job)) => { + job.parallel_work(); + job.queue().consume(); + } + } + } +} |