From 5a7f762d6ce6b5bbdbd10f5966adc909597f37d6 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Tue, 3 Dec 2019 21:49:08 +0100 Subject: Moving away from peer threads --- src/main.rs | 1 + src/platform/dummy/tun.rs | 2 - src/platform/linux/tun.rs | 24 +--- src/wireguard/peer.rs | 2 +- src/wireguard/router/device copy.rs | 228 ------------------------------------ src/wireguard/router/device.rs | 141 ++++++++++++++++------ src/wireguard/router/inbound.rs | 172 +++++++++++++++++++++++++++ src/wireguard/router/mod.rs | 16 ++- src/wireguard/router/outbound.rs | 104 ++++++++++++++++ src/wireguard/router/peer.rs | 220 +++++++++++++++------------------- src/wireguard/router/pool.rs | 132 +++++++++++++++++++++ src/wireguard/router/route.rs | 27 ++--- src/wireguard/tests.rs | 2 +- src/wireguard/timers.rs | 1 + 14 files changed, 640 insertions(+), 432 deletions(-) delete mode 100644 src/wireguard/router/device copy.rs create mode 100644 src/wireguard/router/inbound.rs create mode 100644 src/wireguard/router/outbound.rs create mode 100644 src/wireguard/router/pool.rs (limited to 'src') diff --git a/src/main.rs b/src/main.rs index 5ea830f..e68c771 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,5 @@ #![feature(test)] +#![feature(weak_into_raw)] #![allow(dead_code)] use log; diff --git a/src/platform/dummy/tun.rs b/src/platform/dummy/tun.rs index 5d13628..50c6654 100644 --- a/src/platform/dummy/tun.rs +++ b/src/platform/dummy/tun.rs @@ -6,9 +6,7 @@ use rand::Rng; use std::cmp::min; use std::error::Error; use std::fmt; -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; -use std::sync::Arc; use std::sync::Mutex; use std::thread; use std::time::Duration; diff --git a/src/platform/linux/tun.rs b/src/platform/linux/tun.rs index 2bac49f..39b9320 100644 --- a/src/platform/linux/tun.rs +++ b/src/platform/linux/tun.rs @@ -359,31 +359,9 @@ impl PlatformTun for LinuxTun { // create PlatformTunMTU instance Ok(( - vec![LinuxTunReader { fd }], // TODO: enable multi-queue for Linux + vec![LinuxTunReader { fd }], // TODO: use multi-queue for Linux LinuxTunWriter { fd }, LinuxTunStatus::new(req.name)?, )) } } - -#[cfg(test)] -mod tests { - use super::*; - use std::env; - - fn is_root() -> bool { - match env::var("USER") { - Ok(val) => val == "root", - Err(_) => false, - } - } - - #[test] - fn test_tun_create() { - if !is_root() { - return; - } - let (readers, writers, mtu) = LinuxTun::create("test").unwrap(); - // TODO: test (any good idea how?) - } -} diff --git a/src/wireguard/peer.rs b/src/wireguard/peer.rs index 5bcd070..04622fd 100644 --- a/src/wireguard/peer.rs +++ b/src/wireguard/peer.rs @@ -18,7 +18,7 @@ use crossbeam_channel::Sender; use x25519_dalek::PublicKey; pub struct Peer { - pub router: Arc, T::Writer, B::Writer>>, + pub router: Arc, T::Writer, B::Writer>>, pub state: Arc>, } diff --git a/src/wireguard/router/device copy.rs b/src/wireguard/router/device copy.rs deleted file mode 100644 index 04b2045..0000000 --- a/src/wireguard/router/device copy.rs +++ /dev/null @@ -1,228 +0,0 @@ -use std::collections::HashMap; - -use std::net::{Ipv4Addr, Ipv6Addr}; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::sync::mpsc::sync_channel; -use std::sync::mpsc::SyncSender; -use std::sync::Arc; -use std::thread; -use std::time::Instant; - -use log::debug; -use spin::{Mutex, RwLock}; -use treebitmap::IpLookupTable; -use zerocopy::LayoutVerified; - -use super::anti_replay::AntiReplay; -use super::constants::*; - -use super::messages::{TransportHeader, TYPE_TRANSPORT}; -use super::peer::{new_peer, Peer, PeerInner}; -use super::types::{Callbacks, RouterError}; -use super::workers::{worker_parallel, JobParallel}; -use super::SIZE_MESSAGE_PREFIX; - -use super::route::get_route; - -use super::super::{bind, tun, Endpoint, KeyPair}; - -pub struct DeviceInner> { - // inbound writer (TUN) - pub inbound: T, - - // outbound writer (Bind) - pub outbound: RwLock<(bool, Option)>, - - // routing - pub recv: RwLock>>>, // receiver id -> decryption state - pub ipv4: RwLock>>>, // ipv4 cryptkey routing - pub ipv6: RwLock>>>, // ipv6 cryptkey routing - - // work queues - pub queue_next: AtomicUsize, // next round-robin index - pub queues: Mutex>>, // work queues (1 per thread) -} - -pub struct EncryptionState { - pub keypair: Arc, // keypair - pub nonce: u64, // next available nonce - pub death: Instant, // (birth + reject-after-time - keepalive-timeout - rekey-timeout) -} - -pub struct DecryptionState> { - pub keypair: Arc, - pub confirmed: AtomicBool, - pub protector: Mutex, - pub peer: Arc>, - pub death: Instant, // time when the key can no longer be used for decryption -} - -pub struct Device> { - state: Arc>, // reference to device state - handles: Vec>, // join handles for workers -} - -impl> Drop for Device { - fn drop(&mut self) { - debug!("router: dropping device"); - - // drop all queues - { - let mut queues = self.state.queues.lock(); - while queues.pop().is_some() {} - } - - // join all worker threads - while match self.handles.pop() { - Some(handle) => { - handle.thread().unpark(); - handle.join().unwrap(); - true - } - _ => false, - } {} - - debug!("router: device dropped"); - } -} - -impl> Device { - pub fn new(num_workers: usize, tun: T) -> Device { - // allocate shared device state - let inner = DeviceInner { - inbound: tun, - outbound: RwLock::new((true, None)), - queues: Mutex::new(Vec::with_capacity(num_workers)), - queue_next: AtomicUsize::new(0), - recv: RwLock::new(HashMap::new()), - ipv4: RwLock::new(IpLookupTable::new()), - ipv6: RwLock::new(IpLookupTable::new()), - }; - - // start worker threads - let mut threads = Vec::with_capacity(num_workers); - for _ in 0..num_workers { - let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE); - inner.queues.lock().push(tx); - threads.push(thread::spawn(move || worker_parallel(rx))); - } - - // return exported device handle - Device { - state: Arc::new(inner), - handles: threads, - } - } - - /// Brings the router down. - /// When the router is brought down it: - /// - Prevents transmission of outbound messages. - pub fn down(&self) { - self.state.outbound.write().0 = false; - } - - /// Brints the router up - /// When the router is brought up it enables the transmission of outbound messages. - pub fn up(&self) { - self.state.outbound.write().0 = true; - } - - /// A new secret key has been set for the device. - /// According to WireGuard semantics, this should cause all "sending" keys to be discarded. - pub fn new_sk(&self) {} - - /// Adds a new peer to the device - /// - /// # Returns - /// - /// A atomic ref. counted peer (with liftime matching the device) - pub fn new_peer(&self, opaque: C::Opaque) -> Peer { - new_peer(self.state.clone(), opaque) - } - - /// Cryptkey routes and sends a plaintext message (IP packet) - /// - /// # Arguments - /// - /// - msg: IP packet to crypt-key route - /// - pub fn send(&self, msg: Vec) -> Result<(), RouterError> { - debug_assert!(msg.len() > SIZE_MESSAGE_PREFIX); - log::trace!( - "Router, outbound packet = {}", - hex::encode(&msg[SIZE_MESSAGE_PREFIX..]) - ); - - // ignore header prefix (for in-place transport message construction) - let packet = &msg[SIZE_MESSAGE_PREFIX..]; - - // lookup peer based on IP packet destination address - let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptoKeyRoute)?; - - // schedule for encryption and transmission to peer - if let Some(job) = peer.send_job(msg, true) { - // add job to worker queue - let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst); - let queues = self.state.queues.lock(); - queues[idx % queues.len()].send(job).unwrap(); - } - - Ok(()) - } - - /// Receive an encrypted transport message - /// - /// # Arguments - /// - /// - src: Source address of the packet - /// - msg: Encrypted transport message - /// - /// # Returns - /// - /// - pub fn recv(&self, src: E, msg: Vec) -> Result<(), RouterError> { - // parse / cast - let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) { - Some(v) => v, - None => { - return Err(RouterError::MalformedTransportMessage); - } - }; - - let header: LayoutVerified<&[u8], TransportHeader> = header; - - debug_assert!( - header.f_type.get() == TYPE_TRANSPORT as u32, - "this should be checked by the message type multiplexer" - ); - - log::trace!( - "Router, handle transport message: (receiver = {}, counter = {})", - header.f_receiver, - header.f_counter - ); - - // lookup peer based on receiver id - let dec = self.state.recv.read(); - let dec = dec - .get(&header.f_receiver.get()) - .ok_or(RouterError::UnknownReceiverId)?; - - // schedule for decryption and TUN write - if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) { - // add job to worker queue - let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst); - let queues = self.state.queues.lock(); - queues[idx % queues.len()].send(job).unwrap(); - } - - Ok(()) - } - - /// Set outbound writer - /// - /// - pub fn set_outbound_writer(&self, new: B) { - self.state.outbound.write().1 = Some(new); - } -} diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs index 621010b..88eeae1 100644 --- a/src/wireguard/router/device.rs +++ b/src/wireguard/router/device.rs @@ -1,7 +1,8 @@ use std::collections::HashMap; +use std::ops::Deref; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::mpsc::sync_channel; -use std::sync::mpsc::SyncSender; +use std::sync::mpsc::{Receiver, SyncSender}; use std::sync::Arc; use std::thread; use std::time::Instant; @@ -11,18 +12,61 @@ use spin::{Mutex, RwLock}; use zerocopy::LayoutVerified; use super::anti_replay::AntiReplay; -use super::constants::*; +use super::pool::Job; + +use super::inbound; +use super::outbound; use super::messages::{TransportHeader, TYPE_TRANSPORT}; -use super::peer::{new_peer, Peer, PeerInner}; +use super::peer::{new_peer, Peer, PeerHandle}; use super::types::{Callbacks, RouterError}; -use super::workers::{worker_parallel, JobParallel}; use super::SIZE_MESSAGE_PREFIX; use super::route::RoutingTable; use super::super::{tun, udp, Endpoint, KeyPair}; +pub struct ParallelQueue { + next: AtomicUsize, // next round-robin index + queues: Vec>>, // work queues (1 per thread) +} + +impl ParallelQueue { + fn new(queues: usize) -> (Vec>, Self) { + let mut rxs = vec![]; + let mut txs = vec![]; + + for _ in 0..queues { + let (tx, rx) = sync_channel(128); + txs.push(Mutex::new(tx)); + rxs.push(rx); + } + + ( + rxs, + ParallelQueue { + next: AtomicUsize::new(0), + queues: txs, + }, + ) + } + + pub fn send(&self, v: T) { + let len = self.queues.len(); + let idx = self.next.fetch_add(1, Ordering::SeqCst); + let que = self.queues[idx % len].lock(); + que.send(v).unwrap(); + } + + pub fn close(&self) { + for i in 0..self.queues.len() { + let (tx, _) = sync_channel(0); + let queue = &self.queues[i]; + *queue.lock() = tx; + } + } +} + pub struct DeviceInner> { // inbound writer (TUN) pub inbound: T, @@ -32,11 +76,11 @@ pub struct DeviceInner>>>, // receiver id -> decryption state - pub table: RoutingTable>, + pub table: RoutingTable>, // work queues - pub queue_next: AtomicUsize, // next round-robin index - pub queues: Mutex>>, // work queues (1 per thread) + pub outbound_queue: ParallelQueue, outbound::Outbound>>, + pub inbound_queue: ParallelQueue, inbound::Inbound>>, } pub struct EncryptionState { @@ -49,24 +93,53 @@ pub struct DecryptionState, pub confirmed: AtomicBool, pub protector: Mutex, - pub peer: Arc>, + pub peer: Peer, pub death: Instant, // time when the key can no longer be used for decryption } pub struct Device> { - state: Arc>, // reference to device state + inner: Arc>, +} + +impl> Clone for Device { + fn clone(&self) -> Self { + Device { + inner: self.inner.clone(), + } + } +} + +impl> PartialEq + for Device +{ + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.inner, &other.inner) + } +} + +impl> Eq for Device {} + +impl> Deref for Device { + type Target = DeviceInner; + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +pub struct DeviceHandle> { + state: Device, // reference to device state handles: Vec>, // join handles for workers } -impl> Drop for Device { +impl> Drop + for DeviceHandle +{ fn drop(&mut self) { debug!("router: dropping device"); - // drop all queues - { - let mut queues = self.state.queues.lock(); - while queues.pop().is_some() {} - } + // close worker queues + self.state.outbound_queue.close(); + self.state.inbound_queue.close(); // join all worker threads while match self.handles.pop() { @@ -82,14 +155,16 @@ impl> Drop for Devi } } -impl> Device { - pub fn new(num_workers: usize, tun: T) -> Device { +impl> DeviceHandle { + pub fn new(num_workers: usize, tun: T) -> DeviceHandle { // allocate shared device state + let (mut outrx, outbound_queue) = ParallelQueue::new(num_workers); + let (mut inrx, inbound_queue) = ParallelQueue::new(num_workers); let inner = DeviceInner { inbound: tun, + inbound_queue, outbound: RwLock::new((true, None)), - queues: Mutex::new(Vec::with_capacity(num_workers)), - queue_next: AtomicUsize::new(0), + outbound_queue, recv: RwLock::new(HashMap::new()), table: RoutingTable::new(), }; @@ -97,14 +172,20 @@ impl> Device> Device Peer { + pub fn new_peer(&self, opaque: C::Opaque) -> PeerHandle { new_peer(self.state.clone(), opaque) } @@ -160,10 +241,7 @@ impl> Device> Device> { + msg: Vec, + failed: bool, + state: Arc>, + endpoint: Option, +} + +impl> Inbound { + pub fn new( + msg: Vec, + state: Arc>, + endpoint: E, + ) -> Inbound { + Inbound { + msg, + state, + failed: false, + endpoint: Some(endpoint), + } + } +} + +#[inline(always)] +fn parallel>( + peer: &Peer, + body: &mut Inbound, +) { + // cast to header followed by payload + let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = + match LayoutVerified::new_from_prefix(&mut body.msg[..]) { + Some(v) => v, + None => { + log::debug!("inbound worker: failed to parse message"); + return; + } + }; + + // authenticate and decrypt payload + { + // create nonce object + let mut nonce = [0u8; 12]; + debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); + nonce[4..].copy_from_slice(header.f_counter.as_bytes()); + let nonce = Nonce::assume_unique_for_key(nonce); + + // do the weird ring AEAD dance + let key = LessSafeKey::new( + UnboundKey::new(&CHACHA20_POLY1305, &body.state.keypair.recv.key[..]).unwrap(), + ); + + // attempt to open (and authenticate) the body + match key.open_in_place(nonce, Aad::empty(), packet) { + Ok(_) => (), + Err(_) => { + // fault and return early + body.failed = true; + return; + } + } + } + + // cryptokey route and strip padding + let inner_len = { + let length = packet.len() - SIZE_TAG; + if length > 0 { + peer.device.table.check_route(&peer, &packet[..length]) + } else { + Some(0) + } + }; + + // truncate to remove tag + match inner_len { + None => { + body.failed = true; + } + Some(len) => { + body.msg.truncate(mem::size_of::() + len); + } + } +} + +#[inline(always)] +fn sequential>( + peer: &Peer, + body: &mut Inbound, +) { + // decryption failed, return early + if body.failed { + return; + } + + // cast transport header + let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) = + match LayoutVerified::new_from_prefix(&body.msg[..]) { + Some(v) => v, + None => { + log::debug!("inbound worker: failed to parse message"); + return; + } + }; + debug_assert!( + packet.len() >= CHACHA20_POLY1305.tag_len(), + "this should be checked earlier in the pipeline (decryption should fail)" + ); + + // check for replay + if !body.state.protector.lock().update(header.f_counter.get()) { + log::debug!("inbound worker: replay detected"); + return; + } + + // check for confirms key + if !body.state.confirmed.swap(true, Ordering::SeqCst) { + log::debug!("inbound worker: message confirms key"); + peer.confirm_key(&body.state.keypair); + } + + // update endpoint + *peer.endpoint.lock() = body.endpoint.take(); + + // calculate length of IP packet + padding + let length = packet.len() - SIZE_TAG; + log::debug!("inbound worker: plaintext length = {}", length); + + // check if should be written to TUN + let mut sent = false; + if length > 0 { + sent = match peer.device.inbound.write(&packet[..]) { + Err(e) => { + log::debug!("failed to write inbound packet to TUN: {:?}", e); + false + } + Ok(_) => true, + } + } else { + log::debug!("inbound worker: received keepalive") + } + + // trigger callback + C::recv(&peer.opaque, body.msg.len(), sent, &body.state.keypair); +} + +#[inline(always)] +fn queue>( + peer: &Peer, +) -> &InorderQueue, Inbound> { + &peer.inbound +} + +pub fn worker>( + receiver: Receiver, Inbound>>, +) { + worker_template(receiver, parallel, sequential, queue) +} diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs index 6aa894d..3243b88 100644 --- a/src/wireguard/router/mod.rs +++ b/src/wireguard/router/mod.rs @@ -1,12 +1,16 @@ mod anti_replay; mod constants; mod device; +mod inbound; mod ip; mod messages; +mod outbound; mod peer; +mod pool; mod route; mod types; -mod workers; + +// mod workers; #[cfg(test)] mod tests; @@ -16,15 +20,17 @@ use std::mem; use super::constants::REJECT_AFTER_MESSAGES; use super::types::*; +use super::{tun, udp, Endpoint}; +pub const SIZE_TAG: usize = 16; pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::(); -pub const CAPACITY_MESSAGE_POSTFIX: usize = workers::SIZE_TAG; +pub const CAPACITY_MESSAGE_POSTFIX: usize = SIZE_TAG; pub const fn message_data_len(payload: usize) -> usize { - payload + mem::size_of::() + workers::SIZE_TAG + payload + mem::size_of::() + SIZE_TAG } -pub use device::Device; +pub use device::DeviceHandle as Device; pub use messages::TYPE_TRANSPORT; -pub use peer::Peer; +pub use peer::PeerHandle; pub use types::Callbacks; diff --git a/src/wireguard/router/outbound.rs b/src/wireguard/router/outbound.rs new file mode 100644 index 0000000..30b7c2c --- /dev/null +++ b/src/wireguard/router/outbound.rs @@ -0,0 +1,104 @@ +use super::messages::{TransportHeader, TYPE_TRANSPORT}; +use super::peer::Peer; +use super::pool::*; +use super::types::Callbacks; +use super::KeyPair; +use super::REJECT_AFTER_MESSAGES; +use super::{tun, udp, Endpoint}; + +use std::sync::mpsc::Receiver; +use std::sync::Arc; + +use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; +use zerocopy::{AsBytes, LayoutVerified}; + +pub const SIZE_TAG: usize = 16; + +pub struct Outbound { + msg: Vec, + keypair: Arc, + counter: u64, +} + +impl Outbound { + pub fn new(msg: Vec, keypair: Arc, counter: u64) -> Outbound { + Outbound { + msg, + keypair, + counter, + } + } +} + +#[inline(always)] +fn parallel>( + _peer: &Peer, + body: &mut Outbound, +) { + // make space for the tag + body.msg.extend([0u8; SIZE_TAG].iter()); + + // cast to header (should never fail) + let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = + LayoutVerified::new_from_prefix(&mut body.msg[..]) + .expect("earlier code should ensure that there is ample space"); + + // set header fields + debug_assert!( + body.counter < REJECT_AFTER_MESSAGES, + "should be checked when assigning counters" + ); + header.f_type.set(TYPE_TRANSPORT); + header.f_receiver.set(body.keypair.send.id); + header.f_counter.set(body.counter); + + // create a nonce object + let mut nonce = [0u8; 12]; + debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); + nonce[4..].copy_from_slice(header.f_counter.as_bytes()); + let nonce = Nonce::assume_unique_for_key(nonce); + + // do the weird ring AEAD dance + let key = + LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &body.keypair.send.key[..]).unwrap()); + + // encrypt content of transport message in-place + let end = packet.len() - SIZE_TAG; + let tag = key + .seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end]) + .unwrap(); + + // append tag + packet[end..].copy_from_slice(tag.as_ref()); +} + +#[inline(always)] +fn sequential>( + peer: &Peer, + body: &mut Outbound, +) { + // send to peer + let xmit = peer.send(&body.msg[..]).is_ok(); + + // trigger callback + C::send( + &peer.opaque, + body.msg.len(), + xmit, + &body.keypair, + body.counter, + ); +} + +#[inline(always)] +pub fn queue>( + peer: &Peer, +) -> &InorderQueue, Outbound> { + &peer.outbound +} + +pub fn worker>( + receiver: Receiver, Outbound>>, +) { + worker_template(receiver, parallel, sequential, queue) +} diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs index fff4dfc..192d4e2 100644 --- a/src/wireguard/router/peer.rs +++ b/src/wireguard/router/peer.rs @@ -2,10 +2,7 @@ use std::mem; use std::net::{IpAddr, SocketAddr}; use std::ops::Deref; use std::sync::atomic::AtomicBool; -use std::sync::atomic::Ordering; -use std::sync::mpsc::{sync_channel, SyncSender}; use std::sync::Arc; -use std::thread; use arraydeque::{ArrayDeque, Wrapping}; use log::debug; @@ -16,18 +13,18 @@ use super::super::{tun, udp, Endpoint, KeyPair}; use super::anti_replay::AntiReplay; use super::device::DecryptionState; -use super::device::DeviceInner; +use super::device::Device; use super::device::EncryptionState; use super::messages::TransportHeader; -use futures::*; - -use super::workers::{worker_inbound, worker_outbound}; -use super::workers::{JobDecryption, JobEncryption, JobInbound, JobOutbound, JobParallel}; -use super::SIZE_MESSAGE_PREFIX; - use super::constants::*; use super::types::{Callbacks, RouterError}; +use super::SIZE_MESSAGE_PREFIX; + +// worker pool related +use super::inbound::Inbound; +use super::outbound::Outbound; +use super::pool::{InorderQueue, Job}; pub struct KeyWheel { next: Option>, // next key state (unconfirmed) @@ -37,10 +34,10 @@ pub struct KeyWheel { } pub struct PeerInner> { - pub device: Arc>, + pub device: Device, pub opaque: C::Opaque, - pub outbound: Mutex>, - pub inbound: Mutex>>, + pub outbound: InorderQueue, Outbound>, + pub inbound: InorderQueue, Inbound>, pub staged_packets: Mutex; MAX_STAGED_PACKETS], Wrapping>>, pub keys: Mutex, pub ekey: Mutex>, @@ -48,16 +45,42 @@ pub struct PeerInner> { - state: Arc>, - thread_outbound: Option>, - thread_inbound: Option>, + inner: Arc>, +} + +impl> Clone for Peer { + fn clone(&self) -> Self { + Peer { + inner: self.inner.clone(), + } + } } +impl> PartialEq for Peer { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.inner, &other.inner) + } +} + +impl> Eq for Peer {} + impl> Deref for Peer { - type Target = Arc>; + type Target = PeerInner; + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +pub struct PeerHandle> { + peer: Peer, +} +impl> Deref + for PeerHandle +{ + type Target = PeerInner; fn deref(&self) -> &Self::Target { - &self.state + &self.peer } } @@ -72,37 +95,24 @@ impl EncryptionState { } impl> DecryptionState { - fn new( - peer: &Arc>, - keypair: &Arc, - ) -> DecryptionState { + fn new(peer: Peer, keypair: &Arc) -> DecryptionState { DecryptionState { confirmed: AtomicBool::new(keypair.initiator), keypair: keypair.clone(), protector: spin::Mutex::new(AntiReplay::new()), - peer: peer.clone(), death: keypair.birth + REJECT_AFTER_TIME, + peer, } } } -impl> Drop for Peer { +impl> Drop for PeerHandle { fn drop(&mut self) { - let peer = &self.state; + let peer = &self.peer; // remove from cryptkey router - self.state.device.table.remove(peer); - - // drop channels - - mem::replace(&mut *peer.inbound.lock(), sync_channel(0).0); - mem::replace(&mut *peer.outbound.lock(), sync_channel(0).0); - - // join with workers - - mem::replace(&mut self.thread_inbound, None).map(|v| v.join()); - mem::replace(&mut self.thread_outbound, None).map(|v| v.join()); + self.peer.device.table.remove(peer); // release ids from the receiver map @@ -134,50 +144,32 @@ impl> Drop for Peer } pub fn new_peer>( - device: Arc>, + device: Device, opaque: C::Opaque, -) -> Peer { - let (out_tx, out_rx) = sync_channel(128); - let (in_tx, in_rx) = sync_channel(128); - +) -> PeerHandle { // allocate peer object let peer = { let device = device.clone(); - Arc::new(PeerInner { - opaque, - device, - inbound: Mutex::new(in_tx), - outbound: Mutex::new(out_tx), - ekey: spin::Mutex::new(None), - endpoint: spin::Mutex::new(None), - keys: spin::Mutex::new(KeyWheel { - next: None, - current: None, - previous: None, - retired: vec![], + Peer { + inner: Arc::new(PeerInner { + opaque, + device, + inbound: InorderQueue::new(), + outbound: InorderQueue::new(), + ekey: spin::Mutex::new(None), + endpoint: spin::Mutex::new(None), + keys: spin::Mutex::new(KeyWheel { + next: None, + current: None, + previous: None, + retired: vec![], + }), + staged_packets: spin::Mutex::new(ArrayDeque::new()), }), - staged_packets: spin::Mutex::new(ArrayDeque::new()), - }) - }; - - // spawn outbound thread - let thread_inbound = { - let peer = peer.clone(); - thread::spawn(move || worker_outbound(peer, out_rx)) - }; - - // spawn inbound thread - let thread_outbound = { - let peer = peer.clone(); - let device = device.clone(); - thread::spawn(move || worker_inbound(device, peer, in_rx)) + } }; - Peer { - state: peer, - thread_inbound: Some(thread_inbound), - thread_outbound: Some(thread_outbound), - } + PeerHandle { peer } } impl> PeerInner { @@ -210,7 +202,9 @@ impl> PeerInner Err(RouterError::NoEndpoint), } } +} +impl> Peer { // Transmit all staged packets fn send_staged(&self) -> bool { debug!("peer.send_staged"); @@ -230,16 +224,12 @@ impl> PeerInner) -> bool { - debug!("peer.send_raw"); + log::debug!("peer.send_raw"); match self.send_job(msg, false) { Some(job) => { + self.device.outbound_queue.send(job); debug!("send_raw: got obtained send_job"); - let index = self.device.queue_next.fetch_add(1, Ordering::SeqCst); - let queues = self.device.queues.lock(); - match queues[index % queues.len()].send(job) { - Ok(_) => true, - Err(_) => false, - } + true } None => false, } @@ -285,16 +275,11 @@ impl> PeerInner>, msg: Vec, - ) -> Option { - let (tx, rx) = oneshot(); - let keypair = dec.keypair.clone(); - match self.inbound.lock().try_send((dec, src, rx)) { - Ok(_) => Some(JobParallel::Decryption(tx, JobDecryption { msg, keypair })), - Err(_) => None, - } + ) -> Option>> { + Some(Job::new(self.clone(), Inbound::new(msg, dec, src))) } - pub fn send_job(&self, msg: Vec, stage: bool) -> Option { + pub fn send_job(&self, msg: Vec, stage: bool) -> Option> { debug!("peer.send_job"); debug_assert!( msg.len() >= mem::size_of::(), @@ -337,22 +322,13 @@ impl> PeerInner Some(JobParallel::Encryption( - tx, - JobEncryption { - msg, - counter, - keypair, - }, - )), - Err(_) => None, - } + let job = Job::new(self.clone(), Outbound::new(msg, keypair, counter)); + self.outbound.send(job.clone()); + Some(job) } } -impl> Peer { +impl> PeerHandle { /// Set the endpoint of the peer /// /// # Arguments @@ -365,7 +341,7 @@ impl> Peer> Peer Option { debug!("peer.get_endpoint"); - self.state - .endpoint - .lock() - .as_ref() - .map(|e| e.into_address()) + self.peer.endpoint.lock().as_ref().map(|e| e.into_address()) } /// Zero all key-material related to the peer @@ -387,7 +359,7 @@ impl> Peer = Vec::with_capacity(3); - let mut keys = self.state.keys.lock(); + let mut keys = self.peer.keys.lock(); // update key-wheel @@ -398,14 +370,14 @@ impl> Peer> Peer> Peer> Peer> Peer> Peer bool { debug!("peer.send_keepalive"); - self.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX]) + self.peer.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX]) } /// Map a subnet to the peer @@ -517,10 +489,10 @@ impl> Peer> Peer Vec<(IpAddr, u32)> { - self.state.device.table.list(&self.state) + self.peer.device.table.list(&self.peer) } /// Clear subnets mapped to the peer. /// After the call, no subnets will be cryptkey routed to the peer. /// Used for the UAPI command "replace_allowed_ips=true" pub fn remove_allowed_ips(&self) { - self.state.device.table.remove(&self.state) + self.peer.device.table.remove(&self.peer) } pub fn clear_src(&self) { - (*self.state.endpoint.lock()) - .as_mut() - .map(|e| e.clear_src()); + (*self.peer.endpoint.lock()).as_mut().map(|e| e.clear_src()); } pub fn purge_staged_packets(&self) { - self.state.staged_packets.lock().clear(); + self.peer.staged_packets.lock().clear(); } } diff --git a/src/wireguard/router/pool.rs b/src/wireguard/router/pool.rs new file mode 100644 index 0000000..12956c1 --- /dev/null +++ b/src/wireguard/router/pool.rs @@ -0,0 +1,132 @@ +use arraydeque::ArrayDeque; +use spin::{Mutex, MutexGuard}; +use std::sync::mpsc::Receiver; +use std::sync::Arc; + +const INORDER_QUEUE_SIZE: usize = 64; + +pub struct InnerJob { + // peer (used by worker to schedule/handle inorder queue), + // when the peer is None, the job is complete + peer: Option

, + pub body: B, +} + +pub struct Job { + inner: Arc>>, +} + +impl Clone for Job { + fn clone(&self) -> Job { + Job { + inner: self.inner.clone(), + } + } +} + +impl Job { + pub fn new(peer: P, body: B) -> Job { + Job { + inner: Arc::new(Mutex::new(InnerJob { + peer: Some(peer), + body, + })), + } + } +} + +impl Job { + /// Returns a mutex guard to the inner job if complete + pub fn complete(&self) -> Option>> { + self.inner + .try_lock() + .and_then(|m| if m.peer.is_none() { Some(m) } else { None }) + } +} + +pub struct InorderQueue { + queue: Mutex; INORDER_QUEUE_SIZE]>>, +} + +impl InorderQueue { + pub fn send(&self, job: Job) -> bool { + self.queue.lock().push_back(job).is_ok() + } + + pub fn new() -> InorderQueue { + InorderQueue { + queue: Mutex::new(ArrayDeque::new()), + } + } + + #[inline(always)] + pub fn handle)>(&self, f: F) { + // take the mutex + let mut queue = self.queue.lock(); + + // handle all complete messages + while queue + .pop_front() + .and_then(|j| { + // check if job is complete + let ret = if let Some(mut guard) = j.complete() { + f(&mut *guard); + false + } else { + true + }; + + // return job to cyclic buffer if not complete + if ret { + let _res = queue.push_front(j); + debug_assert!(_res.is_ok()); + None + } else { + // add job back to pool + Some(()) + } + }) + .is_some() + {} + } +} + +/// Allows easy construction of a semi-parallel worker. +/// Applicable for both decryption and encryption workers. +#[inline(always)] +pub fn worker_template< + P, // represents a peer (atomic reference counted pointer) + B, // inner body type (message buffer, key material, ...) + W: Fn(&P, &mut B), + S: Fn(&P, &mut B), + Q: Fn(&P) -> &InorderQueue, +>( + receiver: Receiver>, // receiever for new jobs + work_parallel: W, // perform parallel / out-of-order work on peer + work_sequential: S, // perform sequential work on peer + queue: Q, // resolve a peer to an inorder queue +) { + loop { + // handle new job + let peer = { + // get next job + let job = match receiver.recv() { + Ok(job) => job, + _ => return, + }; + + // lock the job + let mut job = job.inner.lock(); + + // take the peer from the job + let peer = job.peer.take().unwrap(); + + // process job + work_parallel(&peer, &mut job.body); + peer + }; + + // process inorder jobs for peer + queue(&peer).handle(|j| work_sequential(&peer, &mut j.body)); + } +} diff --git a/src/wireguard/router/route.rs b/src/wireguard/router/route.rs index 1c93009..40dc36b 100644 --- a/src/wireguard/router/route.rs +++ b/src/wireguard/router/route.rs @@ -4,7 +4,6 @@ use zerocopy::LayoutVerified; use std::mem; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; -use std::sync::Arc; use spin::RwLock; use treebitmap::address::Address; @@ -12,12 +11,12 @@ use treebitmap::IpLookupTable; /* Functions for obtaining and validating "cryptokey" routes */ -pub struct RoutingTable { - ipv4: RwLock>>, - ipv6: RwLock>>, +pub struct RoutingTable { + ipv4: RwLock>, + ipv6: RwLock>, } -impl RoutingTable { +impl RoutingTable { pub fn new() -> Self { RoutingTable { ipv4: RwLock::new(IpLookupTable::new()), @@ -26,27 +25,27 @@ impl RoutingTable { } // collect keys mapping to the given value - fn collect(table: &IpLookupTable>, value: &Arc) -> Vec<(A, u32)> + fn collect(table: &IpLookupTable, value: &T) -> Vec<(A, u32)> where A: Address, { let mut res = Vec::new(); for (ip, cidr, v) in table.iter() { - if Arc::ptr_eq(v, value) { + if v == value { res.push((ip, cidr)) } } res } - pub fn insert(&self, ip: IpAddr, cidr: u32, value: Arc) { + pub fn insert(&self, ip: IpAddr, cidr: u32, value: T) { match ip { IpAddr::V4(v4) => self.ipv4.write().insert(v4.mask(cidr), cidr, value), IpAddr::V6(v6) => self.ipv6.write().insert(v6.mask(cidr), cidr, value), }; } - pub fn list(&self, value: &Arc) -> Vec<(IpAddr, u32)> { + pub fn list(&self, value: &T) -> Vec<(IpAddr, u32)> { let mut res = vec![]; res.extend( Self::collect(&*self.ipv4.read(), value) @@ -61,7 +60,7 @@ impl RoutingTable { res } - pub fn remove(&self, value: &Arc) { + pub fn remove(&self, value: &T) { let mut v4 = self.ipv4.write(); for (ip, cidr) in Self::collect(&*v4, value) { v4.remove(ip, cidr); @@ -74,7 +73,7 @@ impl RoutingTable { } #[inline(always)] - pub fn get_route(&self, packet: &[u8]) -> Option> { + pub fn get_route(&self, packet: &[u8]) -> Option { match packet.get(0)? >> 4 { VERSION_IP4 => { // check length and cast to IPv4 header @@ -113,7 +112,7 @@ impl RoutingTable { } #[inline(always)] - pub fn check_route(&self, peer: &Arc, packet: &[u8]) -> Option { + pub fn check_route(&self, peer: &T, packet: &[u8]) -> Option { match packet.get(0)? >> 4 { VERSION_IP4 => { // check length and cast to IPv4 header @@ -130,7 +129,7 @@ impl RoutingTable { .read() .longest_match(Ipv4Addr::from(header.f_source)) .and_then(|(_, _, p)| { - if Arc::ptr_eq(p, peer) { + if p == peer { Some(header.f_total_len.get() as usize) } else { None @@ -152,7 +151,7 @@ impl RoutingTable { .read() .longest_match(Ipv6Addr::from(header.f_source)) .and_then(|(_, _, p)| { - if Arc::ptr_eq(p, peer) { + if p == peer { Some(header.f_len.get() as usize + mem::size_of::()) } else { None diff --git a/src/wireguard/tests.rs b/src/wireguard/tests.rs index bf1bd5f..3cccb42 100644 --- a/src/wireguard/tests.rs +++ b/src/wireguard/tests.rs @@ -1,5 +1,5 @@ +use super::dummy; use super::wireguard::Wireguard; -use super::{dummy, tun, udp}; use std::net::IpAddr; use std::thread; diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs index 0ce4210..e1aabad 100644 --- a/src/wireguard/timers.rs +++ b/src/wireguard/timers.rs @@ -137,6 +137,7 @@ impl PeerInner { pub fn timers_handshake_complete(&self) { let timers = self.timers(); if timers.enabled { + timers.retransmit_handshake.stop(); timers.handshake_attempts.store(0, Ordering::SeqCst); timers.sent_lastminute_handshake.store(false, Ordering::SeqCst); *self.walltime_last_handshake.lock() = Some(SystemTime::now()); -- cgit v1.2.3-59-g8ed1b