diff options
Diffstat (limited to 'src/wireguard')
-rw-r--r-- | src/wireguard/router/device.rs | 101 | ||||
-rw-r--r-- | src/wireguard/router/inbound.rs | 242 | ||||
-rw-r--r-- | src/wireguard/router/mod.rs | 1 | ||||
-rw-r--r-- | src/wireguard/router/outbound.rs | 138 | ||||
-rw-r--r-- | src/wireguard/router/peer.rs | 21 | ||||
-rw-r--r-- | src/wireguard/router/pool.rs | 77 | ||||
-rw-r--r-- | src/wireguard/router/runq.rs | 145 | ||||
-rw-r--r-- | src/wireguard/router/tests.rs | 2 |
8 files changed, 477 insertions, 250 deletions
diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs index e405446..9bba199 100644 --- a/src/wireguard/router/device.rs +++ b/src/wireguard/router/device.rs @@ -20,6 +20,7 @@ use super::peer::{new_peer, Peer, PeerHandle}; use super::types::{Callbacks, RouterError}; use super::SIZE_MESSAGE_PREFIX; +use super::runq::RunQueue; use super::route::RoutingTable; use super::super::{tun, udp, Endpoint, KeyPair}; @@ -37,8 +38,12 @@ pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer pub table: RoutingTable<Peer<E, C, T, B>>, // work queues - pub outbound_queue: ParallelQueue<Job<Peer<E, C, T, B>, outbound::Outbound>>, - pub inbound_queue: ParallelQueue<Job<Peer<E, C, T, B>, inbound::Inbound<E, C, T, B>>>, + pub queue_outbound: ParallelQueue<Job<Peer<E, C, T, B>, outbound::Outbound>>, + pub queue_inbound: ParallelQueue<Job<Peer<E, C, T, B>, inbound::Inbound<E, C, T, B>>>, + + // run queues + pub run_inbound: RunQueue<Peer<E, C, T, B>>, + pub run_outbound: RunQueue<Peer<E, C, T, B>>, } pub struct EncryptionState { @@ -96,8 +101,12 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop debug!("router: dropping device"); // close worker queues - self.state.outbound_queue.close(); - self.state.inbound_queue.close(); + self.state.queue_outbound.close(); + self.state.queue_inbound.close(); + + // close run queues + self.state.run_outbound.close(); + self.state.run_inbound.close(); // join all worker threads while match self.handles.pop() { @@ -116,43 +125,73 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<E, C, T, B> { pub fn new(num_workers: usize, tun: T) -> DeviceHandle<E, C, T, B> { // allocate shared device state - let (mut outrx, outbound_queue) = ParallelQueue::new(num_workers); - let (mut inrx, inbound_queue) = ParallelQueue::new(num_workers); - let inner = DeviceInner { - inbound: tun, - inbound_queue, - outbound: RwLock::new((true, None)), - outbound_queue, - recv: RwLock::new(HashMap::new()), - table: RoutingTable::new(), + let (mut outrx, queue_outbound) = ParallelQueue::new(num_workers); + let (mut inrx, queue_inbound) = ParallelQueue::new(num_workers); + let device = Device { + inner: Arc::new(DeviceInner { + inbound: tun, + queue_inbound, + outbound: RwLock::new((true, None)), + queue_outbound, + run_inbound: RunQueue::new(), + run_outbound: RunQueue::new(), + recv: RwLock::new(HashMap::new()), + table: RoutingTable::new(), + }) }; // start worker threads let mut threads = Vec::with_capacity(num_workers); + // inbound/decryption workers for _ in 0..num_workers { - let rx = inrx.pop().unwrap(); - threads.push(thread::spawn(move || { - log::debug!("inbound router worker started"); - inbound::worker(rx) - })); + // parallel workers (parallel processing) + { + let device = device.clone(); + let rx = inrx.pop().unwrap(); + threads.push(thread::spawn(move || { + log::debug!("inbound parallel router worker started"); + inbound::parallel(device, rx) + })); + } + + // sequential workers (in-order processing) + { + let device = device.clone(); + threads.push(thread::spawn(move || { + log::debug!("inbound sequential router worker started"); + inbound::sequential(device) + })); + } } + // outbound/encryption workers for _ in 0..num_workers { - let rx = outrx.pop().unwrap(); - threads.push(thread::spawn(move || { - log::debug!("outbound router worker started"); - outbound::worker(rx) - })); + // parallel workers (parallel processing) + { + let device = device.clone(); + let rx = outrx.pop().unwrap(); + threads.push(thread::spawn(move || { + log::debug!("outbound parallel router worker started"); + outbound::parallel(device, rx) + })); + } + + // sequential workers (in-order processing) + { + let device = device.clone(); + threads.push(thread::spawn(move || { + log::debug!("outbound sequential router worker started"); + outbound::sequential(device) + })); + } } - debug_assert_eq!(threads.len(), num_workers * 2); + debug_assert_eq!(threads.len(), num_workers * 4); // return exported device handle DeviceHandle { - state: Device { - inner: Arc::new(inner), - }, + state: device, handles: threads, } } @@ -192,7 +231,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle< pub fn send(&self, msg: Vec<u8>) -> Result<(), RouterError> { debug_assert!(msg.len() > SIZE_MESSAGE_PREFIX); log::trace!( - "Router, outbound packet = {}", + "send, packet = {}", hex::encode(&msg[SIZE_MESSAGE_PREFIX..]) ); @@ -208,7 +247,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle< // schedule for encryption and transmission to peer if let Some(job) = peer.send_job(msg, true) { - self.state.outbound_queue.send(job); + self.state.queue_outbound.send(job); } Ok(()) @@ -225,6 +264,8 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle< /// /// pub fn recv(&self, src: E, msg: Vec<u8>) -> Result<(), RouterError> { + log::trace!("receive, src: {}", src.into_address()); + // parse / cast let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) { Some(v) => v, @@ -255,7 +296,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle< // schedule for decryption and TUN write if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) { log::trace!("schedule decryption of transport message"); - self.state.inbound_queue.send(job); + self.state.queue_inbound.send(job); } Ok(()) } diff --git a/src/wireguard/router/inbound.rs b/src/wireguard/router/inbound.rs index 3d47bb7..9b15750 100644 --- a/src/wireguard/router/inbound.rs +++ b/src/wireguard/router/inbound.rs @@ -4,6 +4,8 @@ use super::peer::Peer; use super::pool::*; use super::types::Callbacks; use super::{tun, udp, Endpoint}; +use super::device::Device; +use super::runq::RunQueue; use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; use zerocopy::{AsBytes, LayoutVerified}; @@ -38,139 +40,151 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Inbound<E, C, } #[inline(always)] -fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - peer: &Peer<E, C, T, B>, - body: &mut Inbound<E, C, T, B>, +pub fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( + device: Device<E, C, T, B>, + receiver: Receiver<Job<Peer<E, C, T, B>, Inbound<E, C, T, B>>>, ) { - log::trace!("worker, parallel section, obtained job"); + // run queue to schedule + fn queue<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( + device: &Device<E, C, T, B>, + ) -> &RunQueue<Peer<E, C, T, B>> { + &device.run_inbound + } + + // parallel work to apply + fn work<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( + peer: &Peer<E, C, T, B>, + body: &mut Inbound<E, C, T, B>, + ) { + log::trace!("worker, parallel section, obtained job"); + + // cast to header followed by payload + let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = + match LayoutVerified::new_from_prefix(&mut body.msg[..]) { + Some(v) => v, + None => { + log::debug!("inbound worker: failed to parse message"); + return; + } + }; + + // authenticate and decrypt payload + { + // create nonce object + let mut nonce = [0u8; 12]; + debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); + nonce[4..].copy_from_slice(header.f_counter.as_bytes()); + let nonce = Nonce::assume_unique_for_key(nonce); + + // do the weird ring AEAD dance + let key = LessSafeKey::new( + UnboundKey::new(&CHACHA20_POLY1305, &body.state.keypair.recv.key[..]).unwrap(), + ); - // cast to header followed by payload - let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = - match LayoutVerified::new_from_prefix(&mut body.msg[..]) { - Some(v) => v, - None => { - log::debug!("inbound worker: failed to parse message"); - return; + // attempt to open (and authenticate) the body + match key.open_in_place(nonce, Aad::empty(), packet) { + Ok(_) => (), + Err(_) => { + // fault and return early + log::trace!("inbound worker: authentication failure"); + body.failed = true; + return; + } + } + } + + // cryptokey route and strip padding + let inner_len = { + let length = packet.len() - SIZE_TAG; + if length > 0 { + peer.device.table.check_route(&peer, &packet[..length]) + } else { + Some(0) } }; - // authenticate and decrypt payload - { - // create nonce object - let mut nonce = [0u8; 12]; - debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); - nonce[4..].copy_from_slice(header.f_counter.as_bytes()); - let nonce = Nonce::assume_unique_for_key(nonce); - - // do the weird ring AEAD dance - let key = LessSafeKey::new( - UnboundKey::new(&CHACHA20_POLY1305, &body.state.keypair.recv.key[..]).unwrap(), - ); - - // attempt to open (and authenticate) the body - match key.open_in_place(nonce, Aad::empty(), packet) { - Ok(_) => (), - Err(_) => { - // fault and return early - log::trace!("inbound worker: authentication failure"); + // truncate to remove tag + match inner_len { + None => { + log::trace!("inbound worker: cryptokey routing failed"); body.failed = true; - return; + } + Some(len) => { + log::trace!( + "inbound worker: good route, length = {} {}", + len, + if len == 0 { "(keepalive)" } else { "" } + ); + body.msg.truncate(mem::size_of::<TransportHeader>() + len); } } } - // cryptokey route and strip padding - let inner_len = { - let length = packet.len() - SIZE_TAG; - if length > 0 { - peer.device.table.check_route(&peer, &packet[..length]) - } else { - Some(0) - } - }; - - // truncate to remove tag - match inner_len { - None => { - log::trace!("inbound worker: cryptokey routing failed"); - body.failed = true; - } - Some(len) => { - log::trace!( - "inbound worker: good route, length = {} {}", - len, - if len == 0 { "(keepalive)" } else { "" } - ); - body.msg.truncate(mem::size_of::<TransportHeader>() + len); - } - } + worker_parallel(device, |dev| &dev.run_inbound, receiver, work) } #[inline(always)] -fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - peer: &Peer<E, C, T, B>, - body: &mut Inbound<E, C, T, B>, +pub fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( + device: Device<E, C, T, B>, ) { - log::trace!("worker, sequential section, obtained job"); - - // decryption failed, return early - if body.failed { - log::trace!("job faulted, remove from queue and ignore"); - return; - } - - // cast transport header - let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) = - match LayoutVerified::new_from_prefix(&body.msg[..]) { - Some(v) => v, - None => { - log::debug!("inbound worker: failed to parse message"); - return; - } - }; - - // check for replay - if !body.state.protector.lock().update(header.f_counter.get()) { - log::debug!("inbound worker: replay detected"); - return; - } + // sequential work to apply + fn work<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( + peer: &Peer<E, C, T, B>, + body: &mut Inbound<E, C, T, B> + ) { + log::trace!("worker, sequential section, obtained job"); + + // decryption failed, return early + if body.failed { + log::trace!("job faulted, remove from queue and ignore"); + return; + } - // check for confirms key - if !body.state.confirmed.swap(true, Ordering::SeqCst) { - log::debug!("inbound worker: message confirms key"); - peer.confirm_key(&body.state.keypair); - } + // cast transport header + let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) = + match LayoutVerified::new_from_prefix(&body.msg[..]) { + Some(v) => v, + None => { + log::debug!("inbound worker: failed to parse message"); + return; + } + }; + + // check for replay + if !body.state.protector.lock().update(header.f_counter.get()) { + log::debug!("inbound worker: replay detected"); + return; + } - // update endpoint - *peer.endpoint.lock() = body.endpoint.take(); + // check for confirms key + if !body.state.confirmed.swap(true, Ordering::SeqCst) { + log::debug!("inbound worker: message confirms key"); + peer.confirm_key(&body.state.keypair); + } - // check if should be written to TUN - let mut sent = false; - if packet.len() > 0 { - sent = match peer.device.inbound.write(&packet[..]) { - Err(e) => { - log::debug!("failed to write inbound packet to TUN: {:?}", e); - false + // update endpoint + *peer.endpoint.lock() = body.endpoint.take(); + + // check if should be written to TUN + let mut sent = false; + if packet.len() > 0 { + sent = match peer.device.inbound.write(&packet[..]) { + Err(e) => { + log::debug!("failed to write inbound packet to TUN: {:?}", e); + false + } + Ok(_) => true, } - Ok(_) => true, + } else { + log::debug!("inbound worker: received keepalive") } - } else { - log::debug!("inbound worker: received keepalive") - } - // trigger callback - C::recv(&peer.opaque, body.msg.len(), sent, &body.state.keypair); -} - -#[inline(always)] -fn queue<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - peer: &Peer<E, C, T, B>, -) -> &InorderQueue<Peer<E, C, T, B>, Inbound<E, C, T, B>> { - &peer.inbound -} + // trigger callback + C::recv(&peer.opaque, body.msg.len(), sent, &body.state.keypair); + } -pub fn worker<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - receiver: Receiver<Job<Peer<E, C, T, B>, Inbound<E, C, T, B>>>, -) { - worker_template(receiver, parallel, sequential, queue) + // handle message from the peers inbound queue + device.run_inbound.run(|peer| { + peer.inbound.handle(|body| work(&peer, body)); + }); } diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs index bccb0a9..35efe4c 100644 --- a/src/wireguard/router/mod.rs +++ b/src/wireguard/router/mod.rs @@ -10,6 +10,7 @@ mod pool; mod queue; mod route; mod types; +mod runq; // mod workers; diff --git a/src/wireguard/router/outbound.rs b/src/wireguard/router/outbound.rs index d08637b..6c42d8f 100644 --- a/src/wireguard/router/outbound.rs +++ b/src/wireguard/router/outbound.rs @@ -5,6 +5,7 @@ use super::types::Callbacks; use super::KeyPair; use super::REJECT_AFTER_MESSAGES; use super::{tun, udp, Endpoint}; +use super::device::Device; use std::sync::mpsc::Receiver; use std::sync::Arc; @@ -31,78 +32,77 @@ impl Outbound { } #[inline(always)] -fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - _peer: &Peer<E, C, T, B>, - body: &mut Outbound, +pub fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( + device: Device<E, C, T, B>, + receiver: Receiver<Job<Peer<E, C, T, B>, Outbound>>, + ) { - log::trace!("worker, parallel section, obtained job"); - - // make space for the tag - body.msg.extend([0u8; SIZE_TAG].iter()); - - // cast to header (should never fail) - let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = - LayoutVerified::new_from_prefix(&mut body.msg[..]) - .expect("earlier code should ensure that there is ample space"); - - // set header fields - debug_assert!( - body.counter < REJECT_AFTER_MESSAGES, - "should be checked when assigning counters" - ); - header.f_type.set(TYPE_TRANSPORT); - header.f_receiver.set(body.keypair.send.id); - header.f_counter.set(body.counter); - - // create a nonce object - let mut nonce = [0u8; 12]; - debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); - nonce[4..].copy_from_slice(header.f_counter.as_bytes()); - let nonce = Nonce::assume_unique_for_key(nonce); - - // do the weird ring AEAD dance - let key = - LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &body.keypair.send.key[..]).unwrap()); - - // encrypt content of transport message in-place - let end = packet.len() - SIZE_TAG; - let tag = key - .seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end]) - .unwrap(); - - // append tag - packet[end..].copy_from_slice(tag.as_ref()); -} + fn work<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( + _peer: &Peer<E, C, T, B>, + body: &mut Outbound, + ) { + log::trace!("worker, parallel section, obtained job"); + + // make space for the tag + body.msg.extend([0u8; SIZE_TAG].iter()); + + // cast to header (should never fail) + let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = + LayoutVerified::new_from_prefix(&mut body.msg[..]) + .expect("earlier code should ensure that there is ample space"); + + // set header fields + debug_assert!( + body.counter < REJECT_AFTER_MESSAGES, + "should be checked when assigning counters" + ); + header.f_type.set(TYPE_TRANSPORT); + header.f_receiver.set(body.keypair.send.id); + header.f_counter.set(body.counter); + + // create a nonce object + let mut nonce = [0u8; 12]; + debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); + nonce[4..].copy_from_slice(header.f_counter.as_bytes()); + let nonce = Nonce::assume_unique_for_key(nonce); + + // do the weird ring AEAD dance + let key = + LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &body.keypair.send.key[..]).unwrap()); + + // encrypt content of transport message in-place + let end = packet.len() - SIZE_TAG; + let tag = key + .seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end]) + .unwrap(); + + // append tag + packet[end..].copy_from_slice(tag.as_ref()); + } -#[inline(always)] -fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - peer: &Peer<E, C, T, B>, - body: &mut Outbound, -) { - log::trace!("worker, sequential section, obtained job"); - - // send to peer - let xmit = peer.send(&body.msg[..]).is_ok(); - - // trigger callback - C::send( - &peer.opaque, - body.msg.len(), - xmit, - &body.keypair, - body.counter, - ); + worker_parallel(device, |dev| &dev.run_outbound, receiver, work); } -#[inline(always)] -pub fn queue<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - peer: &Peer<E, C, T, B>, -) -> &InorderQueue<Peer<E, C, T, B>, Outbound> { - &peer.outbound -} -pub fn worker<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - receiver: Receiver<Job<Peer<E, C, T, B>, Outbound>>, +#[inline(always)] +pub fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( + device: Device<E, C, T, B>, ) { - worker_template(receiver, parallel, sequential, queue) -} + device.run_outbound.run(|peer| { + peer.outbound.handle(|body| { + log::trace!("worker, sequential section, obtained job"); + + // send to peer + let xmit = peer.send(&body.msg[..]).is_ok(); + + // trigger callback + C::send( + &peer.opaque, + body.msg.len(), + xmit, + &body.keypair, + body.counter, + ); + }); + }); +}
\ No newline at end of file diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs index 40442a8..a00ce1a 100644 --- a/src/wireguard/router/peer.rs +++ b/src/wireguard/router/peer.rs @@ -20,6 +20,7 @@ use super::messages::TransportHeader; use super::constants::*; use super::types::{Callbacks, RouterError}; use super::SIZE_MESSAGE_PREFIX; +use super::runq::ToKey; // worker pool related use super::inbound::Inbound; @@ -56,14 +57,28 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone for Pee } } +/* Equality of peers is defined as pointer equality + * the atomic reference counted pointer. + */ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PartialEq for Peer<E, C, T, B> { fn eq(&self, other: &Self) -> bool { Arc::ptr_eq(&self.inner, &other.inner) } } +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ToKey for Peer<E, C, T, B> { + type Key = usize; + fn to_key(&self) -> usize { + Arc::downgrade(&self.inner).into_raw() as usize + } +} + impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Eq for Peer<E, C, T, B> {} +/* A peer is transparently dereferenced to the inner type + * + */ + impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref for Peer<E, C, T, B> { type Target = PeerInner<E, C, T, B>; fn deref(&self) -> &Self::Target { @@ -71,6 +86,10 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref for Pee } } + +/* A peer handle is a specially designated peer pointer + * which removes the peer from the device when dropped. + */ pub struct PeerHandle<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { peer: Peer<E, C, T, B>, } @@ -227,7 +246,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, log::debug!("peer.send_raw"); match self.send_job(msg, false) { Some(job) => { - self.device.outbound_queue.send(job); + self.device.queue_outbound.send(job); debug!("send_raw: got obtained send_job"); true } diff --git a/src/wireguard/router/pool.rs b/src/wireguard/router/pool.rs index 9c72372..98b1144 100644 --- a/src/wireguard/router/pool.rs +++ b/src/wireguard/router/pool.rs @@ -2,6 +2,9 @@ use arraydeque::ArrayDeque; use spin::{Mutex, MutexGuard}; use std::sync::mpsc::Receiver; use std::sync::Arc; +use std::mem; + +use super::runq::{RunQueue, ToKey}; const INORDER_QUEUE_SIZE: usize = 64; @@ -60,51 +63,53 @@ impl<P, B> InorderQueue<P, B> { } #[inline(always)] - pub fn handle<F: Fn(&mut InnerJob<P, B>)>(&self, f: F) { + pub fn handle<F: Fn(&mut B)>(&self, f: F) { // take the mutex let mut queue = self.queue.lock(); - // handle all complete messages - while queue - .pop_front() - .and_then(|j| { - // check if job is complete - let ret = if let Some(mut guard) = j.complete() { - f(&mut *guard); - false - } else { - true - }; - - // return job to cyclic buffer if not complete - if ret { - let _res = queue.push_front(j); - debug_assert!(_res.is_ok()); - None - } else { - // add job back to pool - Some(()) + loop { + // attempt to extract front element + let front = queue.pop_front(); + let elem = match front { + Some(elem) => elem, + _ => { + return; } - }) - .is_some() - {} + }; + + // apply function if job complete + let ret = if let Some(mut guard) = elem.complete() { + mem::drop(queue); + f(&mut guard.body); + queue = self.queue.lock(); + false + } else { + true + }; + + // job not complete yet, return job to front + if ret { + queue.push_front(elem).unwrap(); + return; + } + } } } /// Allows easy construction of a semi-parallel worker. /// Applicable for both decryption and encryption workers. #[inline(always)] -pub fn worker_template< - P, // represents a peer (atomic reference counted pointer) +pub fn worker_parallel< + P : ToKey, // represents a peer (atomic reference counted pointer) B, // inner body type (message buffer, key material, ...) + D, // device W: Fn(&P, &mut B), - S: Fn(&P, &mut B), - Q: Fn(&P) -> &InorderQueue<P, B>, + Q: Fn(&D) -> &RunQueue<P>, >( - receiver: Receiver<Job<P, B>>, // receiever for new jobs - work_parallel: W, // perform parallel / out-of-order work on peer - work_sequential: S, // perform sequential work on peer - queue: Q, // resolve a peer to an inorder queue + device: D, + queue: Q, + receiver: Receiver<Job<P, B>>, + work: W, ) { log::trace!("router worker started"); loop { @@ -123,11 +128,11 @@ pub fn worker_template< let peer = job.peer.take().unwrap(); // process job - work_parallel(&peer, &mut job.body); + work(&peer, &mut job.body); peer }; - + // process inorder jobs for peer - queue(&peer).handle(|j| work_sequential(&peer, &mut j.body)); + queue(&device).insert(peer); } -} +}
\ No newline at end of file diff --git a/src/wireguard/router/runq.rs b/src/wireguard/router/runq.rs new file mode 100644 index 0000000..6d96490 --- /dev/null +++ b/src/wireguard/router/runq.rs @@ -0,0 +1,145 @@ +use std::mem; +use std::sync::{Condvar, Mutex}; +use std::hash::Hash; + +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::collections::VecDeque; + +pub trait ToKey { + type Key: Hash + Eq; + fn to_key(&self) -> Self::Key; +} + +pub struct RunQueue<T : ToKey> { + cvar: Condvar, + inner: Mutex<Inner<T>>, +} + +struct Inner<T : ToKey> { + stop: bool, + queue: VecDeque<T>, + members: HashMap<T::Key, usize>, +} + +impl<T : ToKey> RunQueue<T> { + pub fn close(&self) { + let mut inner = self.inner.lock().unwrap(); + inner.stop = true; + self.cvar.notify_all(); + } + + pub fn new() -> RunQueue<T> { + RunQueue { + cvar: Condvar::new(), + inner: Mutex::new(Inner { + stop:false, + queue: VecDeque::new(), + members: HashMap::new(), + }), + } + } + + pub fn insert(&self, v: T) { + let key = v.to_key(); + let mut inner = self.inner.lock().unwrap(); + match inner.members.entry(key) { + Entry::Occupied(mut elem) => { + *elem.get_mut() += 1; + } + Entry::Vacant(spot) => { + // add entry to back of queue + spot.insert(0); + inner.queue.push_back(v); + + // wake a thread + self.cvar.notify_one(); + } + } + } + + pub fn run<F: Fn(&T) -> ()>(&self, f: F) { + let mut inner = self.inner.lock().unwrap(); + loop { + // fetch next element + let elem = loop { + // run-queue closed + if inner.stop { + return; + } + + // try to pop from queue + match inner.queue.pop_front() { + Some(elem) => { + break elem; + } + None => (), + }; + + // wait for an element to be inserted + inner = self.cvar.wait(inner).unwrap(); + }; + + // fetch current request number + let key = elem.to_key(); + let old_n = *inner.members.get(&key).unwrap(); + mem::drop(inner); // drop guard + + // handle element + f(&elem); + + // retake lock and check if should be added back to queue + inner = self.inner.lock().unwrap(); + match inner.members.entry(key) { + Entry::Occupied(occ) => { + if *occ.get() == old_n { + // no new requests since last, remove entry. + occ.remove(); + } else { + // new requests, reschedule. + inner.queue.push_back(elem); + } + } + Entry::Vacant(_) => { + unreachable!(); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::thread; + use std::sync::Arc; + use std::time::Duration; + + /* + #[test] + fn test_wait() { + let queue: Arc<RunQueue<usize>> = Arc::new(RunQueue::new()); + + { + let queue = queue.clone(); + thread::spawn(move || { + queue.run(|e| { + println!("t0 {}", e); + thread::sleep(Duration::from_millis(100)); + }) + }); + } + + { + let queue = queue.clone(); + thread::spawn(move || { + queue.run(|e| { + println!("t1 {}", e); + thread::sleep(Duration::from_millis(100)); + }) + }); + } + + } + */ +}
\ No newline at end of file diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs index 1f500c0..fe1fbbe 100644 --- a/src/wireguard/router/tests.rs +++ b/src/wireguard/router/tests.rs @@ -273,6 +273,8 @@ mod tests { } } } + + println!("Test complete, drop device"); } #[test] |