diff options
-rw-r--r-- | src/wireguard/queue.rs | 3 | ||||
-rw-r--r-- | src/wireguard/router/constants.rs | 4 | ||||
-rw-r--r-- | src/wireguard/router/device.rs | 100 | ||||
-rw-r--r-- | src/wireguard/router/inbound.rs | 190 | ||||
-rw-r--r-- | src/wireguard/router/mod.rs | 5 | ||||
-rw-r--r-- | src/wireguard/router/outbound.rs | 110 | ||||
-rw-r--r-- | src/wireguard/router/peer.rs | 192 | ||||
-rw-r--r-- | src/wireguard/router/pool.rs | 164 | ||||
-rw-r--r-- | src/wireguard/router/queue.rs | 92 | ||||
-rw-r--r-- | src/wireguard/router/receive.rs | 184 | ||||
-rw-r--r-- | src/wireguard/router/runq.rs | 164 | ||||
-rw-r--r-- | src/wireguard/router/send.rs | 95 | ||||
-rw-r--r-- | src/wireguard/router/worker.rs | 30 | ||||
-rw-r--r-- | src/wireguard/router/workers.rs | 257 | ||||
-rw-r--r-- | src/wireguard/wireguard.rs | 2 |
15 files changed, 350 insertions, 1242 deletions
diff --git a/src/wireguard/queue.rs b/src/wireguard/queue.rs index 75b9104..eea1ccf 100644 --- a/src/wireguard/queue.rs +++ b/src/wireguard/queue.rs @@ -1,6 +1,7 @@ -use crossbeam_channel::{bounded, Receiver, Sender}; use std::sync::Mutex; +use crossbeam_channel::{bounded, Receiver, Sender}; + pub struct ParallelQueue<T> { queue: Mutex<Option<Sender<T>>>, } diff --git a/src/wireguard/router/constants.rs b/src/wireguard/router/constants.rs index af76299..f083811 100644 --- a/src/wireguard/router/constants.rs +++ b/src/wireguard/router/constants.rs @@ -4,6 +4,6 @@ pub const MAX_QUEUED_PACKETS: usize = 1024; // performance constants -pub const PARALLEL_QUEUE_SIZE: usize = MAX_QUEUED_PACKETS; +pub const PARALLEL_QUEUE_SIZE: usize = 4 * MAX_QUEUED_PACKETS; + pub const INORDER_QUEUE_SIZE: usize = MAX_QUEUED_PACKETS; -pub const MAX_INORDER_CONSUME: usize = INORDER_QUEUE_SIZE; diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs index f903a8e..b8e3821 100644 --- a/src/wireguard/router/device.rs +++ b/src/wireguard/router/device.rs @@ -10,19 +10,16 @@ use spin::{Mutex, RwLock}; use zerocopy::LayoutVerified; use super::anti_replay::AntiReplay; -use super::pool::Job; use super::constants::PARALLEL_QUEUE_SIZE; -use super::inbound; -use super::outbound; - use super::messages::{TransportHeader, TYPE_TRANSPORT}; use super::peer::{new_peer, Peer, PeerHandle}; use super::types::{Callbacks, RouterError}; use super::SIZE_MESSAGE_PREFIX; +use super::receive::ReceiveJob; use super::route::RoutingTable; -use super::runq::RunQueue; +use super::worker::{worker, JobUnion}; use super::super::{tun, udp, Endpoint, KeyPair}; use super::ParallelQueue; @@ -38,13 +35,8 @@ pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state pub table: RoutingTable<Peer<E, C, T, B>>, - // work queues - pub queue_outbound: ParallelQueue<Job<Peer<E, C, T, B>, outbound::Outbound>>, - pub queue_inbound: ParallelQueue<Job<Peer<E, C, T, B>, inbound::Inbound<E, C, T, B>>>, - - // run queues - pub run_inbound: RunQueue<Peer<E, C, T, B>>, - pub run_outbound: RunQueue<Peer<E, C, T, B>>, + // work queue + pub work: ParallelQueue<JobUnion<E, C, T, B>>, } pub struct EncryptionState { @@ -101,13 +93,8 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop fn drop(&mut self) { debug!("router: dropping device"); - // close worker queues - self.state.queue_outbound.close(); - self.state.queue_inbound.close(); - - // close run queues - self.state.run_outbound.close(); - self.state.run_inbound.close(); + // close worker queue + self.state.work.close(); // join all worker threads while match self.handles.pop() { @@ -118,24 +105,17 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop } _ => false, } {} - - debug!("router: device dropped"); } } impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<E, C, T, B> { pub fn new(num_workers: usize, tun: T) -> DeviceHandle<E, C, T, B> { - // allocate shared device state - let (queue_outbound, mut outrx) = ParallelQueue::new(num_workers, PARALLEL_QUEUE_SIZE); - let (queue_inbound, mut inrx) = ParallelQueue::new(num_workers, PARALLEL_QUEUE_SIZE); + let (work, mut consumers) = ParallelQueue::new(num_workers, PARALLEL_QUEUE_SIZE); let device = Device { inner: Arc::new(DeviceInner { + work, inbound: tun, - queue_inbound, outbound: RwLock::new((true, None)), - queue_outbound, - run_inbound: RunQueue::new(), - run_outbound: RunQueue::new(), recv: RwLock::new(HashMap::new()), table: RoutingTable::new(), }), @@ -143,52 +123,10 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle< // start worker threads let mut threads = Vec::with_capacity(num_workers); - - // inbound/decryption workers - for _ in 0..num_workers { - // parallel workers (parallel processing) - { - let device = device.clone(); - let rx = inrx.pop().unwrap(); - threads.push(thread::spawn(move || { - log::debug!("inbound parallel router worker started"); - inbound::parallel(device, rx) - })); - } - - // sequential workers (in-order processing) - { - let device = device.clone(); - threads.push(thread::spawn(move || { - log::debug!("inbound sequential router worker started"); - inbound::sequential(device) - })); - } + while let Some(rx) = consumers.pop() { + threads.push(thread::spawn(move || worker(rx))); } - - // outbound/encryption workers - for _ in 0..num_workers { - // parallel workers (parallel processing) - { - let device = device.clone(); - let rx = outrx.pop().unwrap(); - threads.push(thread::spawn(move || { - log::debug!("outbound parallel router worker started"); - outbound::parallel(device, rx) - })); - } - - // sequential workers (in-order processing) - { - let device = device.clone(); - threads.push(thread::spawn(move || { - log::debug!("outbound sequential router worker started"); - outbound::sequential(device) - })); - } - } - - debug_assert_eq!(threads.len(), num_workers * 4); + debug_assert_eq!(threads.len(), num_workers); // return exported device handle DeviceHandle { @@ -250,10 +188,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle< .ok_or(RouterError::NoCryptoKeyRoute)?; // schedule for encryption and transmission to peer - if let Some(job) = peer.send_job(msg, true) { - self.state.queue_outbound.send(job); - } - + peer.send(msg, true); Ok(()) } @@ -297,10 +232,13 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle< .get(&header.f_receiver.get()) .ok_or(RouterError::UnknownReceiverId)?; - // schedule for decryption and TUN write - if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) { - log::trace!("schedule decryption of transport message"); - self.state.queue_inbound.send(job); + // create inbound job + let job = ReceiveJob::new(msg, dec.clone(), src); + + // 1. add to sequential queue (drop if full) + // 2. then add to parallel work queue (wait if full) + if dec.peer.inbound.push(job.clone()) { + self.state.work.send(JobUnion::Inbound(job)); } Ok(()) } diff --git a/src/wireguard/router/inbound.rs b/src/wireguard/router/inbound.rs deleted file mode 100644 index dc2c44e..0000000 --- a/src/wireguard/router/inbound.rs +++ /dev/null @@ -1,190 +0,0 @@ -use std::mem; -use std::sync::atomic::Ordering; -use std::sync::Arc; - -use crossbeam_channel::Receiver; -use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; -use zerocopy::{AsBytes, LayoutVerified}; - -use super::constants::MAX_INORDER_CONSUME; -use super::device::DecryptionState; -use super::device::Device; -use super::messages::TransportHeader; -use super::peer::Peer; -use super::pool::*; -use super::types::Callbacks; -use super::{tun, udp, Endpoint}; -use super::{REJECT_AFTER_MESSAGES, SIZE_TAG}; - -pub struct Inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { - msg: Vec<u8>, - failed: bool, - state: Arc<DecryptionState<E, C, T, B>>, - endpoint: Option<E>, -} - -impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Inbound<E, C, T, B> { - pub fn new( - msg: Vec<u8>, - state: Arc<DecryptionState<E, C, T, B>>, - endpoint: E, - ) -> Inbound<E, C, T, B> { - Inbound { - msg, - state, - failed: false, - endpoint: Some(endpoint), - } - } -} - -#[inline(always)] -pub fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - device: Device<E, C, T, B>, - receiver: Receiver<Job<Peer<E, C, T, B>, Inbound<E, C, T, B>>>, -) { - // parallel work to apply - #[inline(always)] - fn work<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - peer: &Peer<E, C, T, B>, - body: &mut Inbound<E, C, T, B>, - ) { - log::trace!("worker, parallel section, obtained job"); - - // cast to header followed by payload - let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = - match LayoutVerified::new_from_prefix(&mut body.msg[..]) { - Some(v) => v, - None => { - log::debug!("inbound worker: failed to parse message"); - return; - } - }; - - // authenticate and decrypt payload - { - // create nonce object - let mut nonce = [0u8; 12]; - debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); - nonce[4..].copy_from_slice(header.f_counter.as_bytes()); - let nonce = Nonce::assume_unique_for_key(nonce); - - // do the weird ring AEAD dance - let key = LessSafeKey::new( - UnboundKey::new(&CHACHA20_POLY1305, &body.state.keypair.recv.key[..]).unwrap(), - ); - - // attempt to open (and authenticate) the body - match key.open_in_place(nonce, Aad::empty(), packet) { - Ok(_) => (), - Err(_) => { - // fault and return early - log::trace!("inbound worker: authentication failure"); - body.failed = true; - return; - } - } - } - - // check that counter not after reject - if header.f_counter.get() >= REJECT_AFTER_MESSAGES { - body.failed = true; - return; - } - - // cryptokey route and strip padding - let inner_len = { - let length = packet.len() - SIZE_TAG; - if length > 0 { - peer.device.table.check_route(&peer, &packet[..length]) - } else { - Some(0) - } - }; - - // truncate to remove tag - match inner_len { - None => { - log::trace!("inbound worker: cryptokey routing failed"); - body.failed = true; - } - Some(len) => { - log::trace!( - "inbound worker: good route, length = {} {}", - len, - if len == 0 { "(keepalive)" } else { "" } - ); - body.msg.truncate(mem::size_of::<TransportHeader>() + len); - } - } - } - - worker_parallel(device, |dev| &dev.run_inbound, receiver, work) -} - -#[inline(always)] -pub fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - device: Device<E, C, T, B>, -) { - // sequential work to apply - fn work<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - peer: &Peer<E, C, T, B>, - body: &mut Inbound<E, C, T, B>, - ) { - log::trace!("worker, sequential section, obtained job"); - - // decryption failed, return early - if body.failed { - log::trace!("job faulted, remove from queue and ignore"); - return; - } - - // cast transport header - let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) = - match LayoutVerified::new_from_prefix(&body.msg[..]) { - Some(v) => v, - None => { - log::debug!("inbound worker: failed to parse message"); - return; - } - }; - - // check for replay - if !body.state.protector.lock().update(header.f_counter.get()) { - log::debug!("inbound worker: replay detected"); - return; - } - - // check for confirms key - if !body.state.confirmed.swap(true, Ordering::SeqCst) { - log::debug!("inbound worker: message confirms key"); - peer.confirm_key(&body.state.keypair); - } - - // update endpoint - *peer.endpoint.lock() = body.endpoint.take(); - - // check if should be written to TUN - let mut sent = false; - if packet.len() > 0 { - sent = match peer.device.inbound.write(&packet[..]) { - Err(e) => { - log::debug!("failed to write inbound packet to TUN: {:?}", e); - false - } - Ok(_) => true, - } - } else { - log::debug!("inbound worker: received keepalive") - } - - // trigger callback - C::recv(&peer.opaque, body.msg.len(), sent, &body.state.keypair); - } - - // handle message from the peers inbound queue - device.run_inbound.run(|peer| { - peer.inbound - .handle(|body| work(&peer, body), MAX_INORDER_CONSUME) - }); -} diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs index ec5cc63..699c621 100644 --- a/src/wireguard/router/mod.rs +++ b/src/wireguard/router/mod.rs @@ -1,14 +1,10 @@ mod anti_replay; mod constants; mod device; -mod inbound; mod ip; mod messages; -mod outbound; mod peer; -mod pool; mod route; -mod runq; mod types; mod queue; @@ -25,7 +21,6 @@ use std::mem; use super::constants::REJECT_AFTER_MESSAGES; use super::queue::ParallelQueue; use super::types::*; -use super::{tun, udp, Endpoint}; pub const SIZE_TAG: usize = 16; pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::<TransportHeader>(); diff --git a/src/wireguard/router/outbound.rs b/src/wireguard/router/outbound.rs deleted file mode 100644 index 1edb2fb..0000000 --- a/src/wireguard/router/outbound.rs +++ /dev/null @@ -1,110 +0,0 @@ -use std::sync::Arc; - -use crossbeam_channel::Receiver; -use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; -use zerocopy::{AsBytes, LayoutVerified}; - -use super::constants::MAX_INORDER_CONSUME; -use super::device::Device; -use super::messages::{TransportHeader, TYPE_TRANSPORT}; -use super::peer::Peer; -use super::pool::*; -use super::types::Callbacks; -use super::KeyPair; -use super::{tun, udp, Endpoint}; -use super::{REJECT_AFTER_MESSAGES, SIZE_TAG}; - -pub struct Outbound { - msg: Vec<u8>, - keypair: Arc<KeyPair>, - counter: u64, -} - -impl Outbound { - pub fn new(msg: Vec<u8>, keypair: Arc<KeyPair>, counter: u64) -> Outbound { - Outbound { - msg, - keypair, - counter, - } - } -} - -#[inline(always)] -pub fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - device: Device<E, C, T, B>, - receiver: Receiver<Job<Peer<E, C, T, B>, Outbound>>, -) { - #[inline(always)] - fn work<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - _peer: &Peer<E, C, T, B>, - body: &mut Outbound, - ) { - log::trace!("worker, parallel section, obtained job"); - - // make space for the tag - body.msg.extend([0u8; SIZE_TAG].iter()); - - // cast to header (should never fail) - let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = - LayoutVerified::new_from_prefix(&mut body.msg[..]) - .expect("earlier code should ensure that there is ample space"); - - // set header fields - debug_assert!( - body.counter < REJECT_AFTER_MESSAGES, - "should be checked when assigning counters" - ); - header.f_type.set(TYPE_TRANSPORT); - header.f_receiver.set(body.keypair.send.id); - header.f_counter.set(body.counter); - - // create a nonce object - let mut nonce = [0u8; 12]; - debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); - nonce[4..].copy_from_slice(header.f_counter.as_bytes()); - let nonce = Nonce::assume_unique_for_key(nonce); - - // do the weird ring AEAD dance - let key = LessSafeKey::new( - UnboundKey::new(&CHACHA20_POLY1305, &body.keypair.send.key[..]).unwrap(), - ); - - // encrypt content of transport message in-place - let end = packet.len() - SIZE_TAG; - let tag = key - .seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end]) - .unwrap(); - - // append tag - packet[end..].copy_from_slice(tag.as_ref()); - } - - worker_parallel(device, |dev| &dev.run_outbound, receiver, work); -} - -#[inline(always)] -pub fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - device: Device<E, C, T, B>, -) { - device.run_outbound.run(|peer| { - peer.outbound.handle( - |body| { - log::trace!("worker, sequential section, obtained job"); - - // send to peer - let xmit = peer.send(&body.msg[..]).is_ok(); - - // trigger callback - C::send( - &peer.opaque, - body.msg.len(), - xmit, - &body.keypair, - body.counter, - ); - }, - MAX_INORDER_CONSUME, - ) - }); -} diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs index 7312bc7..710cf32 100644 --- a/src/wireguard/router/peer.rs +++ b/src/wireguard/router/peer.rs @@ -1,13 +1,3 @@ -use std::mem; -use std::net::{IpAddr, SocketAddr}; -use std::ops::Deref; -use std::sync::atomic::AtomicBool; -use std::sync::Arc; - -use arraydeque::{ArrayDeque, Wrapping}; -use log::debug; -use spin::Mutex; - use super::super::constants::*; use super::super::{tun, udp, Endpoint, KeyPair}; @@ -15,20 +5,25 @@ use super::anti_replay::AntiReplay; use super::device::DecryptionState; use super::device::Device; use super::device::EncryptionState; -use super::messages::TransportHeader; use super::constants::*; -use super::runq::ToKey; use super::types::{Callbacks, RouterError}; use super::SIZE_MESSAGE_PREFIX; -// worker pool related -use super::inbound::Inbound; -use super::outbound::Outbound; use super::queue::Queue; - -use super::send::SendJob; use super::receive::ReceiveJob; +use super::send::SendJob; +use super::worker::JobUnion; + +use std::mem; +use std::net::{IpAddr, SocketAddr}; +use std::ops::Deref; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; + +use arraydeque::{ArrayDeque, Wrapping}; +use log::debug; +use spin::Mutex; pub struct KeyWheel { next: Option<Arc<KeyPair>>, // next key state (unconfirmed) @@ -44,7 +39,7 @@ pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E pub inbound: Queue<ReceiveJob<E, C, T, B>>, pub staged_packets: Mutex<ArrayDeque<[Vec<u8>; MAX_QUEUED_PACKETS], Wrapping>>, pub keys: Mutex<KeyWheel>, - pub ekey: Mutex<Option<EncryptionState>>, + pub enc_key: Mutex<Option<EncryptionState>>, pub endpoint: Mutex<Option<E>>, } @@ -69,13 +64,6 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PartialEq for } } -impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ToKey for Peer<E, C, T, B> { - type Key = usize; - fn to_key(&self) -> usize { - Arc::downgrade(&self.inner).into_raw() as usize - } -} - impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Eq for Peer<E, C, T, B> {} /* A peer is transparently dereferenced to the inner type @@ -157,7 +145,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Peer keys.current = None; keys.previous = None; - *peer.ekey.lock() = None; + *peer.enc_key.lock() = None; *peer.endpoint.lock() = None; debug!("peer dropped & removed from device"); @@ -175,9 +163,9 @@ pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( inner: Arc::new(PeerInner { opaque, device, - inbound: InorderQueue::new(), - outbound: InorderQueue::new(), - ekey: spin::Mutex::new(None), + inbound: Queue::new(), + outbound: Queue::new(), + enc_key: spin::Mutex::new(None), endpoint: spin::Mutex::new(None), keys: spin::Mutex::new(KeyWheel { next: None, @@ -203,7 +191,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E, /// # Returns /// /// Unit if packet was sent, or an error indicating why sending failed - pub fn send(&self, msg: &[u8]) -> Result<(), RouterError> { + pub fn send_raw(&self, msg: &[u8]) -> Result<(), RouterError> { debug!("peer.send"); // send to endpoint (if known) @@ -226,6 +214,57 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E, } impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, B> { + /// Encrypt and send a message to the peer + /// + /// Arguments: + /// + /// - `msg` : A padded vector holding the message (allows in-place construction of the transport header) + /// - `stage`: Should the message be staged if no key is available + /// + pub(super) fn send(&self, msg: Vec<u8>, stage: bool) { + // check if key available + let (job, need_key) = { + let mut enc_key = self.enc_key.lock(); + match enc_key.as_mut() { + None => { + if stage { + self.staged_packets.lock().push_back(msg); + }; + (None, true) + } + Some(mut state) => { + // avoid integer overflow in nonce + if state.nonce >= REJECT_AFTER_MESSAGES - 1 { + *enc_key = None; + if stage { + self.staged_packets.lock().push_back(msg); + } + (None, true) + } else { + debug!("encryption state available, nonce = {}", state.nonce); + let job = + SendJob::new(msg, state.nonce, state.keypair.clone(), self.clone()); + if self.outbound.push(job.clone()) { + state.nonce += 1; + (Some(job), false) + } else { + (None, false) + } + } + } + } + }; + + if need_key { + debug_assert!(job.is_none()); + C::need_key(&self.opaque); + }; + + if let Some(job) = job { + self.device.work.send(JobUnion::Outbound(job)) + } + } + // Transmit all staged packets fn send_staged(&self) -> bool { debug!("peer.send_staged"); @@ -235,28 +274,14 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, match staged.pop_front() { Some(msg) => { sent = true; - self.send_raw(msg); + self.send(msg, false); } None => break sent, } } } - // Treat the msg as the payload of a transport message - // Unlike device.send, peer.send_raw does not buffer messages when a key is not available. - fn send_raw(&self, msg: Vec<u8>) -> bool { - log::debug!("peer.send_raw"); - match self.send_job(msg, false) { - Some(job) => { - self.device.queue_outbound.send(job); - debug!("send_raw: got obtained send_job"); - true - } - None => false, - } - } - - pub fn confirm_key(&self, keypair: &Arc<KeyPair>) { + pub(super) fn confirm_key(&self, keypair: &Arc<KeyPair>) { debug!("peer.confirm_key"); { // take lock and check keypair = keys.next @@ -284,68 +309,12 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, C::key_confirmed(&self.opaque); // set new key for encryption - *self.ekey.lock() = ekey; + *self.enc_key.lock() = ekey; } // start transmission of staged packets self.send_staged(); } - - pub fn send_job(&self, msg: Vec<u8>, stage: bool) -> Option<SendJob<E, C, T, B>> { - debug!("peer.send_job"); - debug_assert!( - msg.len() >= mem::size_of::<TransportHeader>(), - "received TUN message with size: {:}", - msg.len() - ); - - // check if has key - let (keypair, counter) = { - let keypair = { - // TODO: consider using atomic ptr for ekey state - let mut ekey = self.ekey.lock(); - match ekey.as_mut() { - None => None, - Some(mut state) => { - // avoid integer overflow in nonce - if state.nonce >= REJECT_AFTER_MESSAGES - 1 { - *ekey = None; - None - } else { - debug!("encryption state available, nonce = {}", state.nonce); - let counter = state.nonce; - state.nonce += 1; - - SendJob::new( - msg, - state.nonce, - state.keypair.clone(), - self.clone() - ); - - Some((state.keypair.clone(), counter)) - } - } - } - }; - - // If not suitable key was found: - // 1. Stage packet for later transmission - // 2. Request new key - if keypair.is_none() && stage { - self.staged_packets.lock().push_back(msg); - C::need_key(&self.opaque); - return None; - }; - - keypair - }?; - - // add job to in-order queue and return sender to device for inclusion in worker pool - let job = Job::new(self.clone(), Outbound::new(msg, keypair, counter)); - self.outbound.send(job.clone()); - Some(job) - } } impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, C, T, B> { @@ -397,7 +366,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, } // clear encryption state - *self.peer.ekey.lock() = None; + *self.peer.enc_key.lock() = None; } pub fn down(&self) { @@ -434,7 +403,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, // update key-wheel if new.initiator { // start using key for encryption - *self.peer.ekey.lock() = Some(EncryptionState::new(&new)); + *self.peer.enc_key.lock() = Some(EncryptionState::new(&new)); // move current into previous keys.previous = keys.current.as_ref().map(|v| v.clone()); @@ -468,16 +437,13 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, // schedule confirmation if initiator { - debug_assert!(self.peer.ekey.lock().is_some()); + debug_assert!(self.peer.enc_key.lock().is_some()); debug!("peer.add_keypair: is initiator, must confirm the key"); // attempt to confirm using staged packets if !self.peer.send_staged() { // fall back to keepalive packet - let ok = self.send_keepalive(); - debug!( - "peer.add_keypair: keepalive for confirmation, sent = {}", - ok - ); + self.send_keepalive(); + debug!("peer.add_keypair: keepalive for confirmation",); } debug!("peer.add_keypair: key attempted confirmed"); } @@ -489,9 +455,9 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, release } - pub fn send_keepalive(&self) -> bool { + pub fn send_keepalive(&self) { debug!("peer.send_keepalive"); - self.peer.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX]) + self.peer.send(vec![0u8; SIZE_MESSAGE_PREFIX], false) } /// Map a subnet to the peer diff --git a/src/wireguard/router/pool.rs b/src/wireguard/router/pool.rs deleted file mode 100644 index 3fc0026..0000000 --- a/src/wireguard/router/pool.rs +++ /dev/null @@ -1,164 +0,0 @@ -use std::mem; -use std::sync::Arc; - -use arraydeque::ArrayDeque; -use crossbeam_channel::Receiver; -use spin::{Mutex, MutexGuard}; - -use super::constants::INORDER_QUEUE_SIZE; -use super::runq::{RunQueue, ToKey}; - -pub struct InnerJob<P, B> { - // peer (used by worker to schedule/handle inorder queue), - // when the peer is None, the job is complete - peer: Option<P>, - pub body: B, -} - -pub struct Job<P, B> { - inner: Arc<Mutex<InnerJob<P, B>>>, -} - -impl<P, B> Clone for Job<P, B> { - fn clone(&self) -> Job<P, B> { - Job { - inner: self.inner.clone(), - } - } -} - -impl<P, B> Job<P, B> { - pub fn new(peer: P, body: B) -> Job<P, B> { - Job { - inner: Arc::new(Mutex::new(InnerJob { - peer: Some(peer), - body, - })), - } - } -} - -impl<P, B> Job<P, B> { - /// Returns a mutex guard to the inner job if complete - pub fn complete(&self) -> Option<MutexGuard<InnerJob<P, B>>> { - self.inner - .try_lock() - .and_then(|m| if m.peer.is_none() { Some(m) } else { None }) - } -} - -pub struct InorderQueue<P, B> { - queue: Mutex<ArrayDeque<[Job<P, B>; INORDER_QUEUE_SIZE]>>, -} - -impl<P, B> InorderQueue<P, B> { - pub fn new() -> InorderQueue<P, B> { - InorderQueue { - queue: Mutex::new(ArrayDeque::new()), - } - } - - /// Add a new job to the in-order queue - /// - /// # Arguments - /// - /// - `job`: The job added to the back of the queue - /// - /// # Returns - /// - /// True if the element was added, - /// false to indicate that the queue is full. - pub fn send(&self, job: Job<P, B>) -> bool { - self.queue.lock().push_back(job).is_ok() - } - - /// Consume completed jobs from the in-order queue - /// - /// # Arguments - /// - /// - `f`: function to apply to the body of each jobof each job. - /// - `limit`: maximum number of jobs to handle before returning - /// - /// # Returns - /// - /// A boolean indicating if the limit was reached: - /// true indicating that the limit was reached, - /// while false implies that the queue is empty or an uncompleted job was reached. - #[inline(always)] - pub fn handle<F: Fn(&mut B)>(&self, f: F, mut limit: usize) -> bool { - // take the mutex - let mut queue = self.queue.lock(); - - while limit > 0 { - // attempt to extract front element - let front = queue.pop_front(); - let elem = match front { - Some(elem) => elem, - _ => { - return false; - } - }; - - // apply function if job complete - let ret = if let Some(mut guard) = elem.complete() { - mem::drop(queue); - f(&mut guard.body); - queue = self.queue.lock(); - false - } else { - true - }; - - // job not complete yet, return job to front - if ret { - queue.push_front(elem).unwrap(); - return false; - } - limit -= 1; - } - - // did not complete all jobs - true - } -} - -/// Allows easy construction of a parallel worker. -/// Applicable for both decryption and encryption workers. -#[inline(always)] -pub fn worker_parallel< - P: ToKey, // represents a peer (atomic reference counted pointer) - B, // inner body type (message buffer, key material, ...) - D, // device - W: Fn(&P, &mut B), - Q: Fn(&D) -> &RunQueue<P>, ->( - device: D, - queue: Q, - receiver: Receiver<Job<P, B>>, - work: W, -) { - log::trace!("router worker started"); - loop { - // handle new job - let peer = { - // get next job - let job = match receiver.recv() { - Ok(job) => job, - _ => return, - }; - - // lock the job - let mut job = job.inner.lock(); - - // take the peer from the job - let peer = job.peer.take().unwrap(); - - // process job - work(&peer, &mut job.body); - peer - }; - - // process inorder jobs for peer - queue(&device).insert(peer); - } -} diff --git a/src/wireguard/router/queue.rs b/src/wireguard/router/queue.rs index 045fd51..ec4492e 100644 --- a/src/wireguard/router/queue.rs +++ b/src/wireguard/router/queue.rs @@ -4,29 +4,36 @@ use spin::Mutex; use std::mem; use std::sync::atomic::{AtomicUsize, Ordering}; -const QUEUE_SIZE: usize = 1024; - -pub trait Job: Sized { - fn queue(&self) -> &Queue<Self>; +use super::constants::INORDER_QUEUE_SIZE; +pub trait SequentialJob { fn is_ready(&self) -> bool; - fn parallel_work(&self); - fn sequential_work(self); } +pub trait ParallelJob: Sized + SequentialJob { + fn queue(&self) -> &Queue<Self>; + + fn parallel_work(&self); +} -pub struct Queue<J: Job> { +pub struct Queue<J: SequentialJob> { contenders: AtomicUsize, - queue: Mutex<ArrayDeque<[J; QUEUE_SIZE]>>, + queue: Mutex<ArrayDeque<[J; INORDER_QUEUE_SIZE]>>, + + #[cfg(debug)] + _flag: Mutex<()>, } -impl<J: Job> Queue<J> { +impl<J: SequentialJob> Queue<J> { pub fn new() -> Queue<J> { Queue { contenders: AtomicUsize::new(0), queue: Mutex::new(ArrayDeque::new()), + + #[cfg(debug)] + _flag: Mutex::new(()), } } @@ -36,14 +43,22 @@ impl<J: Job> Queue<J> { pub fn consume(&self) { // check if we are the first contender - let pos = self.contenders.fetch_add(1, Ordering::Acquire); + let pos = self.contenders.fetch_add(1, Ordering::SeqCst); if pos > 0 { - assert!(pos < usize::max_value(), "contenders overflow"); + assert!(usize::max_value() > pos, "contenders overflow"); + return; } // enter the critical section let mut contenders = 1; // myself while contenders > 0 { + // check soundness in debug builds + #[cfg(debug)] + let _flag = self + ._flag + .try_lock() + .expect("contenders should ensure mutual exclusion"); + // handle every ready element loop { let mut queue = self.queue.lock(); @@ -69,8 +84,61 @@ impl<J: Job> Queue<J> { job.sequential_work(); } + #[cfg(debug)] + mem::drop(_flag); + // decrease contenders - contenders = self.contenders.fetch_sub(contenders, Ordering::Acquire) - contenders; + contenders = self.contenders.fetch_sub(contenders, Ordering::SeqCst) - contenders; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::sync::Arc; + use std::thread; + + use rand::thread_rng; + use rand::Rng; + + struct TestJob {} + + impl SequentialJob for TestJob { + fn is_ready(&self) -> bool { + true + } + + fn sequential_work(self) {} + } + + /* Fuzz the Queue */ + #[test] + fn test_queue() { + fn hammer(queue: &Arc<Queue<TestJob>>) { + let mut rng = thread_rng(); + for _ in 0..1_000_000 { + if rng.gen() { + queue.push(TestJob {}); + } else { + queue.consume(); + } + } } + + let queue = Arc::new(Queue::new()); + + // repeatedly apply operations randomly from concurrent threads + let other = { + let queue = queue.clone(); + thread::spawn(move || hammer(&queue)) + }; + hammer(&queue); + + // wait, consume and check empty + other.join().unwrap(); + queue.consume(); + assert_eq!(queue.queue.lock().len(), 0); } } diff --git a/src/wireguard/router/receive.rs b/src/wireguard/router/receive.rs index 53890e3..c5fe3da 100644 --- a/src/wireguard/router/receive.rs +++ b/src/wireguard/router/receive.rs @@ -1,21 +1,18 @@ -use super::queue::{Job, Queue}; -use super::KeyPair; +use super::device::DecryptionState; +use super::messages::TransportHeader; +use super::queue::{ParallelJob, Queue, SequentialJob}; use super::types::Callbacks; -use super::peer::Peer; use super::{REJECT_AFTER_MESSAGES, SIZE_TAG}; -use super::messages::{TransportHeader, TYPE_TRANSPORT}; -use super::device::DecryptionState; use super::super::{tun, udp, Endpoint}; +use std::mem; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use std::mem; use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; -use zerocopy::{AsBytes, LayoutVerified}; use spin::Mutex; - +use zerocopy::{AsBytes, LayoutVerified}; struct Inner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { ready: AtomicBool, @@ -23,49 +20,49 @@ struct Inner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { state: Arc<DecryptionState<E, C, T, B>>, // decryption state (keys and replay protector) } -pub struct ReceiveJob<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { - inner: Arc<Inner<E, C, T, B>>, +pub struct ReceiveJob<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( + Arc<Inner<E, C, T, B>>, +); + +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone + for ReceiveJob<E, C, T, B> +{ + fn clone(&self) -> ReceiveJob<E, C, T, B> { + ReceiveJob(self.0.clone()) + } } -impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ReceiveJob<E, C, T, B> { - fn new(buffer: Vec<u8>, state: Arc<DecryptionState<E, C, T, B>>, endpoint: E) -> Option<ReceiveJob<E, C, T, B>> { - // create job - let inner = Arc::new(Inner{ +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ReceiveJob<E, C, T, B> { + pub fn new( + buffer: Vec<u8>, + state: Arc<DecryptionState<E, C, T, B>>, + endpoint: E, + ) -> ReceiveJob<E, C, T, B> { + ReceiveJob(Arc::new(Inner { ready: AtomicBool::new(false), buffer: Mutex::new((Some(endpoint), buffer)), - state - }); - - // attempt to add to queue - if state.peer.inbound.push(ReceiveJob{ inner: inner.clone()}) { - Some(ReceiveJob{inner}) - } else { - None - } - + state, + })) } } -impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Job for ReceiveJob<E, C, T, B> { +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ParallelJob + for ReceiveJob<E, C, T, B> +{ fn queue(&self) -> &Queue<Self> { - &self.inner.state.peer.inbound - } - - fn is_ready(&self) -> bool { - self.inner.ready.load(Ordering::Acquire) + &self.0.state.peer.inbound } fn parallel_work(&self) { // TODO: refactor // decrypt { - let job = &self.inner; + let job = &self.0; let peer = &job.state.peer; let mut msg = job.buffer.lock(); - - let failed = || { - // cast to header followed by payload - let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = + + // cast to header followed by payload + let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = match LayoutVerified::new_from_prefix(&mut msg.1[..]) { Some(v) => v, None => { @@ -74,73 +71,81 @@ impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Job for Rece } }; - // authenticate and decrypt payload - { - // create nonce object - let mut nonce = [0u8; 12]; - debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); - nonce[4..].copy_from_slice(header.f_counter.as_bytes()); - let nonce = Nonce::assume_unique_for_key(nonce); - - // do the weird ring AEAD dance - let key = LessSafeKey::new( - UnboundKey::new(&CHACHA20_POLY1305, &job.state.keypair.recv.key[..]).unwrap(), - ); - - // attempt to open (and authenticate) the body - match key.open_in_place(nonce, Aad::empty(), packet) { - Ok(_) => (), - Err(_) => { - // fault and return early - log::trace!("inbound worker: authentication failure"); - msg.1.truncate(0); - return; - } + // authenticate and decrypt payload + { + // create nonce object + let mut nonce = [0u8; 12]; + debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); + nonce[4..].copy_from_slice(header.f_counter.as_bytes()); + let nonce = Nonce::assume_unique_for_key(nonce); + + // do the weird ring AEAD dance + let key = LessSafeKey::new( + UnboundKey::new(&CHACHA20_POLY1305, &job.state.keypair.recv.key[..]).unwrap(), + ); + + // attempt to open (and authenticate) the body + match key.open_in_place(nonce, Aad::empty(), packet) { + Ok(_) => (), + Err(_) => { + // fault and return early + log::trace!("inbound worker: authentication failure"); + msg.1.truncate(0); + return; } } + } - // check that counter not after reject - if header.f_counter.get() >= REJECT_AFTER_MESSAGES { - msg.1.truncate(0); - return; - } - - // cryptokey route and strip padding - let inner_len = { - let length = packet.len() - SIZE_TAG; - if length > 0 { - peer.device.table.check_route(&peer, &packet[..length]) - } else { - Some(0) - } - }; + // check that counter not after reject + if header.f_counter.get() >= REJECT_AFTER_MESSAGES { + msg.1.truncate(0); + return; + } - // truncate to remove tag - match inner_len { - None => { - log::trace!("inbound worker: cryptokey routing failed"); - msg.1.truncate(0); - } - Some(len) => { - log::trace!( - "inbound worker: good route, length = {} {}", - len, - if len == 0 { "(keepalive)" } else { "" } - ); - msg.1.truncate(mem::size_of::<TransportHeader>() + len); - } + // cryptokey route and strip padding + let inner_len = { + let length = packet.len() - SIZE_TAG; + if length > 0 { + peer.device.table.check_route(&peer, &packet[..length]) + } else { + Some(0) } }; + + // truncate to remove tag + match inner_len { + None => { + log::trace!("inbound worker: cryptokey routing failed"); + msg.1.truncate(0); + } + Some(len) => { + log::trace!( + "inbound worker: good route, length = {} {}", + len, + if len == 0 { "(keepalive)" } else { "" } + ); + msg.1.truncate(mem::size_of::<TransportHeader>() + len); + } + } } // mark ready - self.inner.ready.store(true, Ordering::Release); + self.0.ready.store(true, Ordering::Release); + } +} + +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> SequentialJob + for ReceiveJob<E, C, T, B> +{ + fn is_ready(&self) -> bool { + self.0.ready.load(Ordering::Acquire) } fn sequential_work(self) { - let job = &self.inner; + let job = &self.0; let peer = &job.state.peer; let mut msg = job.buffer.lock(); + let endpoint = msg.0.take(); // cast transport header let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) = @@ -165,7 +170,7 @@ impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Job for Rece } // update endpoint - *peer.endpoint.lock() = msg.0.take(); + *peer.endpoint.lock() = endpoint; // check if should be written to TUN let mut sent = false; @@ -184,5 +189,4 @@ impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Job for Rece // trigger callback C::recv(&peer.opaque, msg.1.len(), sent, &job.state.keypair); } - } diff --git a/src/wireguard/router/runq.rs b/src/wireguard/router/runq.rs deleted file mode 100644 index 44e11a1..0000000 --- a/src/wireguard/router/runq.rs +++ /dev/null @@ -1,164 +0,0 @@ -use std::hash::Hash; -use std::mem; -use std::sync::{Condvar, Mutex}; - -use std::collections::hash_map::Entry; -use std::collections::HashMap; -use std::collections::VecDeque; - -pub trait ToKey { - type Key: Hash + Eq; - fn to_key(&self) -> Self::Key; -} - -pub struct RunQueue<T: ToKey> { - cvar: Condvar, - inner: Mutex<Inner<T>>, -} - -struct Inner<T: ToKey> { - stop: bool, - queue: VecDeque<T>, - members: HashMap<T::Key, usize>, -} - -impl<T: ToKey> RunQueue<T> { - pub fn close(&self) { - let mut inner = self.inner.lock().unwrap(); - inner.stop = true; - self.cvar.notify_all(); - } - - pub fn new() -> RunQueue<T> { - RunQueue { - cvar: Condvar::new(), - inner: Mutex::new(Inner { - stop: false, - queue: VecDeque::new(), - members: HashMap::new(), - }), - } - } - - pub fn insert(&self, v: T) { - let key = v.to_key(); - let mut inner = self.inner.lock().unwrap(); - match inner.members.entry(key) { - Entry::Occupied(mut elem) => { - *elem.get_mut() += 1; - } - Entry::Vacant(spot) => { - // add entry to back of queue - spot.insert(0); - inner.queue.push_back(v); - - // wake a thread - self.cvar.notify_one(); - } - } - } - - /// Run (consume from) the run queue using the provided function. - /// The function should return wheter the given element should be rescheduled. - /// - /// # Arguments - /// - /// - `f` : function to apply to every element - /// - /// # Note - /// - /// The function f may be called again even when the element was not inserted back in to the - /// queue since the last applciation and no rescheduling was requested. - /// - /// This happens then the function handles all work for T, - /// but T is added to the run queue while the function is running. - pub fn run<F: Fn(&T) -> bool>(&self, f: F) { - let mut inner = self.inner.lock().unwrap(); - loop { - // fetch next element - let elem = loop { - // run-queue closed - if inner.stop { - return; - } - - // try to pop from queue - match inner.queue.pop_front() { - Some(elem) => { - break elem; - } - None => (), - }; - - // wait for an element to be inserted - inner = self.cvar.wait(inner).unwrap(); - }; - - // fetch current request number - let key = elem.to_key(); - let old_n = *inner.members.get(&key).unwrap(); - mem::drop(inner); // drop guard - - // handle element - let rerun = f(&elem); - - // if the function requested a re-run add the element to the back of the queue - inner = self.inner.lock().unwrap(); - if rerun { - inner.queue.push_back(elem); - continue; - } - - // otherwise check if new requests have come in since we ran the function - match inner.members.entry(key) { - Entry::Occupied(occ) => { - if *occ.get() == old_n { - // no new requests since last, remove entry. - occ.remove(); - } else { - // new requests, reschedule. - inner.queue.push_back(elem); - } - } - Entry::Vacant(_) => { - unreachable!(); - } - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::thread; - use std::time::Duration; - - /* - #[test] - fn test_wait() { - let queue: Arc<RunQueue<usize>> = Arc::new(RunQueue::new()); - - { - let queue = queue.clone(); - thread::spawn(move || { - queue.run(|e| { - println!("t0 {}", e); - thread::sleep(Duration::from_millis(100)); - }) - }); - } - - { - let queue = queue.clone(); - thread::spawn(move || { - queue.run(|e| { - println!("t1 {}", e); - thread::sleep(Duration::from_millis(100)); - }) - }); - } - - } - */ -} diff --git a/src/wireguard/router/send.rs b/src/wireguard/router/send.rs index 2bd4abd..8e41796 100644 --- a/src/wireguard/router/send.rs +++ b/src/wireguard/router/send.rs @@ -1,4 +1,4 @@ -use super::queue::{Job, Queue}; +use super::queue::{SequentialJob, ParallelJob, Queue}; use super::KeyPair; use super::types::Callbacks; use super::peer::Peer; @@ -22,8 +22,14 @@ struct Inner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { peer: Peer<E, C, T, B>, } -pub struct SendJob<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { - inner: Arc<Inner<E, C, T, B>>, +pub struct SendJob<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ( + Arc<Inner<E, C, T, B>> +); + +impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone for SendJob<E, C, T, B> { + fn clone(&self) -> SendJob<E, C, T, B> { + SendJob(self.0.clone()) + } } impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> SendJob<E, C, T, B> { @@ -32,32 +38,53 @@ impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> SendJob<E, C counter: u64, keypair: Arc<KeyPair>, peer: Peer<E, C, T, B> - ) -> Option<SendJob<E, C, T, B>> { - // create job - let inner = Arc::new(Inner{ + ) -> SendJob<E, C, T, B> { + SendJob(Arc::new(Inner{ buffer: Mutex::new(buffer), counter, keypair, peer, ready: AtomicBool::new(false) - }); - - // attempt to add to queue - if peer.outbound.push(SendJob{ inner: inner.clone()}) { - Some(SendJob{inner}) - } else { - None - } + })) } } -impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Job for SendJob<E, C, T, B> { - fn queue(&self) -> &Queue<Self> { - &self.inner.peer.outbound +impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> SequentialJob for SendJob<E, C, T, B> { + + fn is_ready(&self) -> bool { + self.0.ready.load(Ordering::Acquire) } - fn is_ready(&self) -> bool { - self.inner.ready.load(Ordering::Acquire) + fn sequential_work(self) { + debug_assert_eq!( + self.is_ready(), + true, + "doing sequential work + on an incomplete job" + ); + log::trace!("processing sequential send job"); + + // send to peer + let job = &self.0; + let msg = job.buffer.lock(); + let xmit = job.peer.send_raw(&msg[..]).is_ok(); + + // trigger callback (for timers) + C::send( + &job.peer.opaque, + msg.len(), + xmit, + &job.keypair, + job.counter, + ); + } +} + + +impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ParallelJob for SendJob<E, C, T, B> { + + fn queue(&self) -> &Queue<Self> { + &self.0.peer.outbound } fn parallel_work(&self) { @@ -71,7 +98,7 @@ impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Job for Send // encrypt body { // make space for the tag - let job = &*self.inner; + let job = &*self.0; let mut msg = job.buffer.lock(); msg.extend([0u8; SIZE_TAG].iter()); @@ -111,30 +138,6 @@ impl <E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Job for Send } // mark ready - self.inner.ready.store(true, Ordering::Release); - } - - fn sequential_work(self) { - debug_assert_eq!( - self.is_ready(), - true, - "doing sequential work - on an incomplete job" - ); - log::trace!("processing sequential send job"); - - // send to peer - let job = &self.inner; - let msg = job.buffer.lock(); - let xmit = job.peer.send(&msg[..]).is_ok(); - - // trigger callback (for timers) - C::send( - &job.peer.opaque, - msg.len(), - xmit, - &job.keypair, - job.counter, - ); + self.0.ready.store(true, Ordering::Release); } -} +}
\ No newline at end of file diff --git a/src/wireguard/router/worker.rs b/src/wireguard/router/worker.rs index d95050e..bbb644c 100644 --- a/src/wireguard/router/worker.rs +++ b/src/wireguard/router/worker.rs @@ -1,13 +1,31 @@ -use super::Device; - use super::super::{tun, udp, Endpoint}; use super::types::Callbacks; -use super::receive::ReceieveJob; +use super::queue::ParallelJob; +use super::receive::ReceiveJob; use super::send::SendJob; -fn worker<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - device: Device<E, C, T, B>, +use crossbeam_channel::Receiver; + +pub enum JobUnion<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { + Outbound(SendJob<E, C, T, B>), + Inbound(ReceiveJob<E, C, T, B>), +} + +pub fn worker<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( + receiver: Receiver<JobUnion<E, C, T, B>>, ) { - // fetch job + loop { + match receiver.recv() { + Err(_) => break, + Ok(JobUnion::Inbound(job)) => { + job.parallel_work(); + job.queue().consume(); + } + Ok(JobUnion::Outbound(job)) => { + job.parallel_work(); + job.queue().consume(); + } + } + } } diff --git a/src/wireguard/router/workers.rs b/src/wireguard/router/workers.rs deleted file mode 100644 index 8ddc136..0000000 --- a/src/wireguard/router/workers.rs +++ /dev/null @@ -1,257 +0,0 @@ -use std::sync::Arc; - -use log::{debug, trace}; - -use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; - -use crossbeam_channel::Receiver; -use std::sync::atomic::Ordering; -use zerocopy::{AsBytes, LayoutVerified}; - -use super::device::{DecryptionState, DeviceInner}; -use super::messages::{TransportHeader, TYPE_TRANSPORT}; -use super::peer::PeerInner; -use super::types::Callbacks; - -use super::REJECT_AFTER_MESSAGES; - -use super::super::types::KeyPair; -use super::super::{tun, udp, Endpoint}; - -pub const SIZE_TAG: usize = 16; - -pub struct JobEncryption { - pub msg: Vec<u8>, - pub keypair: Arc<KeyPair>, - pub counter: u64, -} - -pub struct JobDecryption { - pub msg: Vec<u8>, - pub keypair: Arc<KeyPair>, -} - -pub enum JobParallel { - Encryption(oneshot::Sender<JobEncryption>, JobEncryption), - Decryption(oneshot::Sender<Option<JobDecryption>>, JobDecryption), -} - -#[allow(type_alias_bounds)] -pub type JobInbound<E, C, T, B: udp::Writer<E>> = ( - Arc<DecryptionState<E, C, T, B>>, - E, - oneshot::Receiver<Option<JobDecryption>>, -); - -pub type JobOutbound = oneshot::Receiver<JobEncryption>; - -/* TODO: Replace with run-queue - */ -pub fn worker_inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - device: Arc<DeviceInner<E, C, T, B>>, // related device - peer: Arc<PeerInner<E, C, T, B>>, // related peer - receiver: Receiver<JobInbound<E, C, T, B>>, -) { - loop { - // fetch job - let (state, endpoint, rx) = match receiver.recv() { - Ok(v) => v, - _ => { - return; - } - }; - debug!("inbound worker: obtained job"); - - // wait for job to complete - let _ = rx - .map(|buf| { - debug!("inbound worker: job complete"); - if let Some(buf) = buf { - // cast transport header - let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) = - match LayoutVerified::new_from_prefix(&buf.msg[..]) { - Some(v) => v, - None => { - debug!("inbound worker: failed to parse message"); - return; - } - }; - - debug_assert!( - packet.len() >= CHACHA20_POLY1305.tag_len(), - "this should be checked earlier in the pipeline (decryption should fail)" - ); - - // check for replay - if !state.protector.lock().update(header.f_counter.get()) { - debug!("inbound worker: replay detected"); - return; - } - - // check for confirms key - if !state.confirmed.swap(true, Ordering::SeqCst) { - debug!("inbound worker: message confirms key"); - peer.confirm_key(&state.keypair); - } - - // update endpoint - *peer.endpoint.lock() = Some(endpoint); - - // calculate length of IP packet + padding - let length = packet.len() - SIZE_TAG; - debug!("inbound worker: plaintext length = {}", length); - - // check if should be written to TUN - let mut sent = false; - if length > 0 { - if let Some(inner_len) = device.table.check_route(&peer, &packet[..length]) - { - // TODO: Consider moving the cryptkey route check to parallel decryption worker - debug_assert!(inner_len <= length, "should be validated earlier"); - if inner_len <= length { - sent = match device.inbound.write(&packet[..inner_len]) { - Err(e) => { - debug!("failed to write inbound packet to TUN: {:?}", e); - false - } - Ok(_) => true, - } - } - } - } else { - debug!("inbound worker: received keepalive") - } - - // trigger callback - C::recv(&peer.opaque, buf.msg.len(), sent, &buf.keypair); - } else { - debug!("inbound worker: authentication failure") - } - }) - .wait(); - } -} - - -pub fn worker_outbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - peer: Arc<PeerInner<E, C, T, B>>, - receiver: Receiver<JobOutbound>, -) { - loop { - // fetch job - let rx = match receiver.recv() { - Ok(v) => v, - _ => { - return; - } - }; - debug!("outbound worker: obtained job"); - - // wait for job to complete - let _ = rx - .map(|buf| { - debug!("outbound worker: job complete"); - - // send to peer - let xmit = peer.send(&buf.msg[..]).is_ok(); - - // trigger callback - C::send(&peer.opaque, buf.msg.len(), xmit, &buf.keypair, buf.counter); - }) - .wait(); - } -} - -pub fn worker_parallel(receiver: Receiver<JobParallel>) { - loop { - // fetch next job - let job = match receiver.recv() { - Err(_) => { - return; - } - Ok(val) => val, - }; - trace!("parallel worker: obtained job"); - - // handle job - match job { - JobParallel::Encryption(tx, mut job) => { - job.msg.extend([0u8; SIZE_TAG].iter()); - - // cast to header (should never fail) - let (mut header, body): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = - LayoutVerified::new_from_prefix(&mut job.msg[..]) - .expect("earlier code should ensure that there is ample space"); - - // set header fields - debug_assert!( - job.counter < REJECT_AFTER_MESSAGES, - "should be checked when assigning counters" - ); - header.f_type.set(TYPE_TRANSPORT); - header.f_receiver.set(job.keypair.send.id); - header.f_counter.set(job.counter); - - // create a nonce object - let mut nonce = [0u8; 12]; - debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); - nonce[4..].copy_from_slice(header.f_counter.as_bytes()); - let nonce = Nonce::assume_unique_for_key(nonce); - - // do the weird ring AEAD dance - let key = LessSafeKey::new( - UnboundKey::new(&CHACHA20_POLY1305, &job.keypair.send.key[..]).unwrap(), - ); - - // encrypt content of transport message in-place - let end = body.len() - SIZE_TAG; - let tag = key - .seal_in_place_separate_tag(nonce, Aad::empty(), &mut body[..end]) - .unwrap(); - - // append tag - body[end..].copy_from_slice(tag.as_ref()); - - // pass ownership - let _ = tx.send(job); - } - JobParallel::Decryption(tx, mut job) => { - // cast to header (could fail) - let layout: Option<(LayoutVerified<&mut [u8], TransportHeader>, &mut [u8])> = - LayoutVerified::new_from_prefix(&mut job.msg[..]); - - let _ = tx.send(match layout { - Some((header, body)) => { - debug_assert_eq!( - header.f_type.get(), - TYPE_TRANSPORT, - "type and reserved bits should be checked by message de-multiplexer" - ); - if header.f_counter.get() < REJECT_AFTER_MESSAGES { - // create a nonce object - let mut nonce = [0u8; 12]; - debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); - nonce[4..].copy_from_slice(header.f_counter.as_bytes()); - let nonce = Nonce::assume_unique_for_key(nonce); - - // do the weird ring AEAD dance - let key = LessSafeKey::new( - UnboundKey::new(&CHACHA20_POLY1305, &job.keypair.recv.key[..]) - .unwrap(), - ); - - // attempt to open (and authenticate) the body - match key.open_in_place(nonce, Aad::empty(), body) { - Ok(_) => Some(job), - Err(_) => None, - } - } else { - None - } - } - None => None, - }); - } - } - } -} diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index 45b1fcb..94e240d 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -603,7 +603,7 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> { ); let device = wg.handshake.read(); let _ = device.begin(&mut rng, &peer.pk).map(|msg| { - let _ = peer.router.send(&msg[..]).map_err(|e| { + let _ = peer.router.send_raw(&msg[..]).map_err(|e| { debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e) }); peer.state.sent_handshake_initiation(); |