aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-12-03 21:49:08 +0100
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-12-03 21:49:08 +0100
commit5a7f762d6ce6b5bbdbd10f5966adc909597f37d6 (patch)
treeb53fa0c1ee02c1e211d6cf94c6ba0334135ec42e /src
parentClose socket fd after getmtu ioctl (diff)
downloadwireguard-rs-5a7f762d6ce6b5bbdbd10f5966adc909597f37d6.tar.xz
wireguard-rs-5a7f762d6ce6b5bbdbd10f5966adc909597f37d6.zip
Moving away from peer threads
Diffstat (limited to '')
-rw-r--r--src/main.rs1
-rw-r--r--src/platform/dummy/tun.rs2
-rw-r--r--src/platform/linux/tun.rs24
-rw-r--r--src/wireguard/peer.rs2
-rw-r--r--src/wireguard/router/device copy.rs228
-rw-r--r--src/wireguard/router/device.rs141
-rw-r--r--src/wireguard/router/inbound.rs172
-rw-r--r--src/wireguard/router/mod.rs16
-rw-r--r--src/wireguard/router/outbound.rs104
-rw-r--r--src/wireguard/router/peer.rs220
-rw-r--r--src/wireguard/router/pool.rs132
-rw-r--r--src/wireguard/router/route.rs27
-rw-r--r--src/wireguard/tests.rs2
-rw-r--r--src/wireguard/timers.rs1
14 files changed, 640 insertions, 432 deletions
diff --git a/src/main.rs b/src/main.rs
index 5ea830f..e68c771 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,4 +1,5 @@
#![feature(test)]
+#![feature(weak_into_raw)]
#![allow(dead_code)]
use log;
diff --git a/src/platform/dummy/tun.rs b/src/platform/dummy/tun.rs
index 5d13628..50c6654 100644
--- a/src/platform/dummy/tun.rs
+++ b/src/platform/dummy/tun.rs
@@ -6,9 +6,7 @@ use rand::Rng;
use std::cmp::min;
use std::error::Error;
use std::fmt;
-use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
-use std::sync::Arc;
use std::sync::Mutex;
use std::thread;
use std::time::Duration;
diff --git a/src/platform/linux/tun.rs b/src/platform/linux/tun.rs
index 2bac49f..39b9320 100644
--- a/src/platform/linux/tun.rs
+++ b/src/platform/linux/tun.rs
@@ -359,31 +359,9 @@ impl PlatformTun for LinuxTun {
// create PlatformTunMTU instance
Ok((
- vec![LinuxTunReader { fd }], // TODO: enable multi-queue for Linux
+ vec![LinuxTunReader { fd }], // TODO: use multi-queue for Linux
LinuxTunWriter { fd },
LinuxTunStatus::new(req.name)?,
))
}
}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use std::env;
-
- fn is_root() -> bool {
- match env::var("USER") {
- Ok(val) => val == "root",
- Err(_) => false,
- }
- }
-
- #[test]
- fn test_tun_create() {
- if !is_root() {
- return;
- }
- let (readers, writers, mtu) = LinuxTun::create("test").unwrap();
- // TODO: test (any good idea how?)
- }
-}
diff --git a/src/wireguard/peer.rs b/src/wireguard/peer.rs
index 5bcd070..04622fd 100644
--- a/src/wireguard/peer.rs
+++ b/src/wireguard/peer.rs
@@ -18,7 +18,7 @@ use crossbeam_channel::Sender;
use x25519_dalek::PublicKey;
pub struct Peer<T: Tun, B: UDP> {
- pub router: Arc<router::Peer<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>,
+ pub router: Arc<router::PeerHandle<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>,
pub state: Arc<PeerInner<T, B>>,
}
diff --git a/src/wireguard/router/device copy.rs b/src/wireguard/router/device copy.rs
deleted file mode 100644
index 04b2045..0000000
--- a/src/wireguard/router/device copy.rs
+++ /dev/null
@@ -1,228 +0,0 @@
-use std::collections::HashMap;
-
-use std::net::{Ipv4Addr, Ipv6Addr};
-use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
-use std::sync::mpsc::sync_channel;
-use std::sync::mpsc::SyncSender;
-use std::sync::Arc;
-use std::thread;
-use std::time::Instant;
-
-use log::debug;
-use spin::{Mutex, RwLock};
-use treebitmap::IpLookupTable;
-use zerocopy::LayoutVerified;
-
-use super::anti_replay::AntiReplay;
-use super::constants::*;
-
-use super::messages::{TransportHeader, TYPE_TRANSPORT};
-use super::peer::{new_peer, Peer, PeerInner};
-use super::types::{Callbacks, RouterError};
-use super::workers::{worker_parallel, JobParallel};
-use super::SIZE_MESSAGE_PREFIX;
-
-use super::route::get_route;
-
-use super::super::{bind, tun, Endpoint, KeyPair};
-
-pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
- // inbound writer (TUN)
- pub inbound: T,
-
- // outbound writer (Bind)
- pub outbound: RwLock<(bool, Option<B>)>,
-
- // routing
- pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state
- pub ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv4 cryptkey routing
- pub ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv6 cryptkey routing
-
- // work queues
- pub queue_next: AtomicUsize, // next round-robin index
- pub queues: Mutex<Vec<SyncSender<JobParallel>>>, // work queues (1 per thread)
-}
-
-pub struct EncryptionState {
- pub keypair: Arc<KeyPair>, // keypair
- pub nonce: u64, // next available nonce
- pub death: Instant, // (birth + reject-after-time - keepalive-timeout - rekey-timeout)
-}
-
-pub struct DecryptionState<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
- pub keypair: Arc<KeyPair>,
- pub confirmed: AtomicBool,
- pub protector: Mutex<AntiReplay>,
- pub peer: Arc<PeerInner<E, C, T, B>>,
- pub death: Instant, // time when the key can no longer be used for decryption
-}
-
-pub struct Device<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
- state: Arc<DeviceInner<E, C, T, B>>, // reference to device state
- handles: Vec<thread::JoinHandle<()>>, // join handles for workers
-}
-
-impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Device<E, C, T, B> {
- fn drop(&mut self) {
- debug!("router: dropping device");
-
- // drop all queues
- {
- let mut queues = self.state.queues.lock();
- while queues.pop().is_some() {}
- }
-
- // join all worker threads
- while match self.handles.pop() {
- Some(handle) => {
- handle.thread().unpark();
- handle.join().unwrap();
- true
- }
- _ => false,
- } {}
-
- debug!("router: device dropped");
- }
-}
-
-impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C, T, B> {
- pub fn new(num_workers: usize, tun: T) -> Device<E, C, T, B> {
- // allocate shared device state
- let inner = DeviceInner {
- inbound: tun,
- outbound: RwLock::new((true, None)),
- queues: Mutex::new(Vec::with_capacity(num_workers)),
- queue_next: AtomicUsize::new(0),
- recv: RwLock::new(HashMap::new()),
- ipv4: RwLock::new(IpLookupTable::new()),
- ipv6: RwLock::new(IpLookupTable::new()),
- };
-
- // start worker threads
- let mut threads = Vec::with_capacity(num_workers);
- for _ in 0..num_workers {
- let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE);
- inner.queues.lock().push(tx);
- threads.push(thread::spawn(move || worker_parallel(rx)));
- }
-
- // return exported device handle
- Device {
- state: Arc::new(inner),
- handles: threads,
- }
- }
-
- /// Brings the router down.
- /// When the router is brought down it:
- /// - Prevents transmission of outbound messages.
- pub fn down(&self) {
- self.state.outbound.write().0 = false;
- }
-
- /// Brints the router up
- /// When the router is brought up it enables the transmission of outbound messages.
- pub fn up(&self) {
- self.state.outbound.write().0 = true;
- }
-
- /// A new secret key has been set for the device.
- /// According to WireGuard semantics, this should cause all "sending" keys to be discarded.
- pub fn new_sk(&self) {}
-
- /// Adds a new peer to the device
- ///
- /// # Returns
- ///
- /// A atomic ref. counted peer (with liftime matching the device)
- pub fn new_peer(&self, opaque: C::Opaque) -> Peer<E, C, T, B> {
- new_peer(self.state.clone(), opaque)
- }
-
- /// Cryptkey routes and sends a plaintext message (IP packet)
- ///
- /// # Arguments
- ///
- /// - msg: IP packet to crypt-key route
- ///
- pub fn send(&self, msg: Vec<u8>) -> Result<(), RouterError> {
- debug_assert!(msg.len() > SIZE_MESSAGE_PREFIX);
- log::trace!(
- "Router, outbound packet = {}",
- hex::encode(&msg[SIZE_MESSAGE_PREFIX..])
- );
-
- // ignore header prefix (for in-place transport message construction)
- let packet = &msg[SIZE_MESSAGE_PREFIX..];
-
- // lookup peer based on IP packet destination address
- let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptoKeyRoute)?;
-
- // schedule for encryption and transmission to peer
- if let Some(job) = peer.send_job(msg, true) {
- // add job to worker queue
- let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst);
- let queues = self.state.queues.lock();
- queues[idx % queues.len()].send(job).unwrap();
- }
-
- Ok(())
- }
-
- /// Receive an encrypted transport message
- ///
- /// # Arguments
- ///
- /// - src: Source address of the packet
- /// - msg: Encrypted transport message
- ///
- /// # Returns
- ///
- ///
- pub fn recv(&self, src: E, msg: Vec<u8>) -> Result<(), RouterError> {
- // parse / cast
- let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) {
- Some(v) => v,
- None => {
- return Err(RouterError::MalformedTransportMessage);
- }
- };
-
- let header: LayoutVerified<&[u8], TransportHeader> = header;
-
- debug_assert!(
- header.f_type.get() == TYPE_TRANSPORT as u32,
- "this should be checked by the message type multiplexer"
- );
-
- log::trace!(
- "Router, handle transport message: (receiver = {}, counter = {})",
- header.f_receiver,
- header.f_counter
- );
-
- // lookup peer based on receiver id
- let dec = self.state.recv.read();
- let dec = dec
- .get(&header.f_receiver.get())
- .ok_or(RouterError::UnknownReceiverId)?;
-
- // schedule for decryption and TUN write
- if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) {
- // add job to worker queue
- let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst);
- let queues = self.state.queues.lock();
- queues[idx % queues.len()].send(job).unwrap();
- }
-
- Ok(())
- }
-
- /// Set outbound writer
- ///
- ///
- pub fn set_outbound_writer(&self, new: B) {
- self.state.outbound.write().1 = Some(new);
- }
-}
diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs
index 621010b..88eeae1 100644
--- a/src/wireguard/router/device.rs
+++ b/src/wireguard/router/device.rs
@@ -1,7 +1,8 @@
use std::collections::HashMap;
+use std::ops::Deref;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::mpsc::sync_channel;
-use std::sync::mpsc::SyncSender;
+use std::sync::mpsc::{Receiver, SyncSender};
use std::sync::Arc;
use std::thread;
use std::time::Instant;
@@ -11,18 +12,61 @@ use spin::{Mutex, RwLock};
use zerocopy::LayoutVerified;
use super::anti_replay::AntiReplay;
-use super::constants::*;
+use super::pool::Job;
+
+use super::inbound;
+use super::outbound;
use super::messages::{TransportHeader, TYPE_TRANSPORT};
-use super::peer::{new_peer, Peer, PeerInner};
+use super::peer::{new_peer, Peer, PeerHandle};
use super::types::{Callbacks, RouterError};
-use super::workers::{worker_parallel, JobParallel};
use super::SIZE_MESSAGE_PREFIX;
use super::route::RoutingTable;
use super::super::{tun, udp, Endpoint, KeyPair};
+pub struct ParallelQueue<T> {
+ next: AtomicUsize, // next round-robin index
+ queues: Vec<Mutex<SyncSender<T>>>, // work queues (1 per thread)
+}
+
+impl<T> ParallelQueue<T> {
+ fn new(queues: usize) -> (Vec<Receiver<T>>, Self) {
+ let mut rxs = vec![];
+ let mut txs = vec![];
+
+ for _ in 0..queues {
+ let (tx, rx) = sync_channel(128);
+ txs.push(Mutex::new(tx));
+ rxs.push(rx);
+ }
+
+ (
+ rxs,
+ ParallelQueue {
+ next: AtomicUsize::new(0),
+ queues: txs,
+ },
+ )
+ }
+
+ pub fn send(&self, v: T) {
+ let len = self.queues.len();
+ let idx = self.next.fetch_add(1, Ordering::SeqCst);
+ let que = self.queues[idx % len].lock();
+ que.send(v).unwrap();
+ }
+
+ pub fn close(&self) {
+ for i in 0..self.queues.len() {
+ let (tx, _) = sync_channel(0);
+ let queue = &self.queues[i];
+ *queue.lock() = tx;
+ }
+ }
+}
+
pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
// inbound writer (TUN)
pub inbound: T,
@@ -32,11 +76,11 @@ pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer
// routing
pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state
- pub table: RoutingTable<PeerInner<E, C, T, B>>,
+ pub table: RoutingTable<Peer<E, C, T, B>>,
// work queues
- pub queue_next: AtomicUsize, // next round-robin index
- pub queues: Mutex<Vec<SyncSender<JobParallel>>>, // work queues (1 per thread)
+ pub outbound_queue: ParallelQueue<Job<Peer<E, C, T, B>, outbound::Outbound>>,
+ pub inbound_queue: ParallelQueue<Job<Peer<E, C, T, B>, inbound::Inbound<E, C, T, B>>>,
}
pub struct EncryptionState {
@@ -49,24 +93,53 @@ pub struct DecryptionState<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Wr
pub keypair: Arc<KeyPair>,
pub confirmed: AtomicBool,
pub protector: Mutex<AntiReplay>,
- pub peer: Arc<PeerInner<E, C, T, B>>,
+ pub peer: Peer<E, C, T, B>,
pub death: Instant, // time when the key can no longer be used for decryption
}
pub struct Device<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
- state: Arc<DeviceInner<E, C, T, B>>, // reference to device state
+ inner: Arc<DeviceInner<E, C, T, B>>,
+}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone for Device<E, C, T, B> {
+ fn clone(&self) -> Self {
+ Device {
+ inner: self.inner.clone(),
+ }
+ }
+}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PartialEq
+ for Device<E, C, T, B>
+{
+ fn eq(&self, other: &Self) -> bool {
+ Arc::ptr_eq(&self.inner, &other.inner)
+ }
+}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Eq for Device<E, C, T, B> {}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref for Device<E, C, T, B> {
+ type Target = DeviceInner<E, C, T, B>;
+ fn deref(&self) -> &Self::Target {
+ &self.inner
+ }
+}
+
+pub struct DeviceHandle<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
+ state: Device<E, C, T, B>, // reference to device state
handles: Vec<thread::JoinHandle<()>>, // join handles for workers
}
-impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Device<E, C, T, B> {
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop
+ for DeviceHandle<E, C, T, B>
+{
fn drop(&mut self) {
debug!("router: dropping device");
- // drop all queues
- {
- let mut queues = self.state.queues.lock();
- while queues.pop().is_some() {}
- }
+ // close worker queues
+ self.state.outbound_queue.close();
+ self.state.inbound_queue.close();
// join all worker threads
while match self.handles.pop() {
@@ -82,14 +155,16 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Devi
}
}
-impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C, T, B> {
- pub fn new(num_workers: usize, tun: T) -> Device<E, C, T, B> {
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<E, C, T, B> {
+ pub fn new(num_workers: usize, tun: T) -> DeviceHandle<E, C, T, B> {
// allocate shared device state
+ let (mut outrx, outbound_queue) = ParallelQueue::new(num_workers);
+ let (mut inrx, inbound_queue) = ParallelQueue::new(num_workers);
let inner = DeviceInner {
inbound: tun,
+ inbound_queue,
outbound: RwLock::new((true, None)),
- queues: Mutex::new(Vec::with_capacity(num_workers)),
- queue_next: AtomicUsize::new(0),
+ outbound_queue,
recv: RwLock::new(HashMap::new()),
table: RoutingTable::new(),
};
@@ -97,14 +172,20 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C,
// start worker threads
let mut threads = Vec::with_capacity(num_workers);
for _ in 0..num_workers {
- let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE);
- inner.queues.lock().push(tx);
- threads.push(thread::spawn(move || worker_parallel(rx)));
+ let rx = inrx.pop().unwrap();
+ threads.push(thread::spawn(move || inbound::worker(rx)));
+ }
+
+ for _ in 0..num_workers {
+ let rx = outrx.pop().unwrap();
+ threads.push(thread::spawn(move || outbound::worker(rx)));
}
// return exported device handle
- Device {
- state: Arc::new(inner),
+ DeviceHandle {
+ state: Device {
+ inner: Arc::new(inner),
+ },
handles: threads,
}
}
@@ -131,7 +212,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C,
/// # Returns
///
/// A atomic ref. counted peer (with liftime matching the device)
- pub fn new_peer(&self, opaque: C::Opaque) -> Peer<E, C, T, B> {
+ pub fn new_peer(&self, opaque: C::Opaque) -> PeerHandle<E, C, T, B> {
new_peer(self.state.clone(), opaque)
}
@@ -160,10 +241,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C,
// schedule for encryption and transmission to peer
if let Some(job) = peer.send_job(msg, true) {
- // add job to worker queue
- let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst);
- let queues = self.state.queues.lock();
- queues[idx % queues.len()].send(job).unwrap();
+ self.state.outbound_queue.send(job);
}
Ok(())
@@ -209,10 +287,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C,
// schedule for decryption and TUN write
if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) {
- // add job to worker queue
- let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst);
- let queues = self.state.queues.lock();
- queues[idx % queues.len()].send(job).unwrap();
+ self.state.inbound_queue.send(job);
}
Ok(())
diff --git a/src/wireguard/router/inbound.rs b/src/wireguard/router/inbound.rs
new file mode 100644
index 0000000..d4ad307
--- /dev/null
+++ b/src/wireguard/router/inbound.rs
@@ -0,0 +1,172 @@
+use super::device::DecryptionState;
+use super::messages::TransportHeader;
+use super::peer::Peer;
+use super::pool::*;
+use super::types::Callbacks;
+use super::{tun, udp, Endpoint};
+
+use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
+use zerocopy::{AsBytes, LayoutVerified};
+
+use std::mem;
+use std::sync::atomic::Ordering;
+use std::sync::mpsc::Receiver;
+use std::sync::Arc;
+
+pub const SIZE_TAG: usize = 16;
+
+pub struct Inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
+ msg: Vec<u8>,
+ failed: bool,
+ state: Arc<DecryptionState<E, C, T, B>>,
+ endpoint: Option<E>,
+}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Inbound<E, C, T, B> {
+ pub fn new(
+ msg: Vec<u8>,
+ state: Arc<DecryptionState<E, C, T, B>>,
+ endpoint: E,
+ ) -> Inbound<E, C, T, B> {
+ Inbound {
+ msg,
+ state,
+ failed: false,
+ endpoint: Some(endpoint),
+ }
+ }
+}
+
+#[inline(always)]
+fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
+ peer: &Peer<E, C, T, B>,
+ body: &mut Inbound<E, C, T, B>,
+) {
+ // cast to header followed by payload
+ let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
+ match LayoutVerified::new_from_prefix(&mut body.msg[..]) {
+ Some(v) => v,
+ None => {
+ log::debug!("inbound worker: failed to parse message");
+ return;
+ }
+ };
+
+ // authenticate and decrypt payload
+ {
+ // create nonce object
+ let mut nonce = [0u8; 12];
+ debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
+ nonce[4..].copy_from_slice(header.f_counter.as_bytes());
+ let nonce = Nonce::assume_unique_for_key(nonce);
+
+ // do the weird ring AEAD dance
+ let key = LessSafeKey::new(
+ UnboundKey::new(&CHACHA20_POLY1305, &body.state.keypair.recv.key[..]).unwrap(),
+ );
+
+ // attempt to open (and authenticate) the body
+ match key.open_in_place(nonce, Aad::empty(), packet) {
+ Ok(_) => (),
+ Err(_) => {
+ // fault and return early
+ body.failed = true;
+ return;
+ }
+ }
+ }
+
+ // cryptokey route and strip padding
+ let inner_len = {
+ let length = packet.len() - SIZE_TAG;
+ if length > 0 {
+ peer.device.table.check_route(&peer, &packet[..length])
+ } else {
+ Some(0)
+ }
+ };
+
+ // truncate to remove tag
+ match inner_len {
+ None => {
+ body.failed = true;
+ }
+ Some(len) => {
+ body.msg.truncate(mem::size_of::<TransportHeader>() + len);
+ }
+ }
+}
+
+#[inline(always)]
+fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
+ peer: &Peer<E, C, T, B>,
+ body: &mut Inbound<E, C, T, B>,
+) {
+ // decryption failed, return early
+ if body.failed {
+ return;
+ }
+
+ // cast transport header
+ let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) =
+ match LayoutVerified::new_from_prefix(&body.msg[..]) {
+ Some(v) => v,
+ None => {
+ log::debug!("inbound worker: failed to parse message");
+ return;
+ }
+ };
+ debug_assert!(
+ packet.len() >= CHACHA20_POLY1305.tag_len(),
+ "this should be checked earlier in the pipeline (decryption should fail)"
+ );
+
+ // check for replay
+ if !body.state.protector.lock().update(header.f_counter.get()) {
+ log::debug!("inbound worker: replay detected");
+ return;
+ }
+
+ // check for confirms key
+ if !body.state.confirmed.swap(true, Ordering::SeqCst) {
+ log::debug!("inbound worker: message confirms key");
+ peer.confirm_key(&body.state.keypair);
+ }
+
+ // update endpoint
+ *peer.endpoint.lock() = body.endpoint.take();
+
+ // calculate length of IP packet + padding
+ let length = packet.len() - SIZE_TAG;
+ log::debug!("inbound worker: plaintext length = {}", length);
+
+ // check if should be written to TUN
+ let mut sent = false;
+ if length > 0 {
+ sent = match peer.device.inbound.write(&packet[..]) {
+ Err(e) => {
+ log::debug!("failed to write inbound packet to TUN: {:?}", e);
+ false
+ }
+ Ok(_) => true,
+ }
+ } else {
+ log::debug!("inbound worker: received keepalive")
+ }
+
+ // trigger callback
+ C::recv(&peer.opaque, body.msg.len(), sent, &body.state.keypair);
+}
+
+#[inline(always)]
+fn queue<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
+ peer: &Peer<E, C, T, B>,
+) -> &InorderQueue<Peer<E, C, T, B>, Inbound<E, C, T, B>> {
+ &peer.inbound
+}
+
+pub fn worker<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
+ receiver: Receiver<Job<Peer<E, C, T, B>, Inbound<E, C, T, B>>>,
+) {
+ worker_template(receiver, parallel, sequential, queue)
+}
diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs
index 6aa894d..3243b88 100644
--- a/src/wireguard/router/mod.rs
+++ b/src/wireguard/router/mod.rs
@@ -1,12 +1,16 @@
mod anti_replay;
mod constants;
mod device;
+mod inbound;
mod ip;
mod messages;
+mod outbound;
mod peer;
+mod pool;
mod route;
mod types;
-mod workers;
+
+// mod workers;
#[cfg(test)]
mod tests;
@@ -16,15 +20,17 @@ use std::mem;
use super::constants::REJECT_AFTER_MESSAGES;
use super::types::*;
+use super::{tun, udp, Endpoint};
+pub const SIZE_TAG: usize = 16;
pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::<TransportHeader>();
-pub const CAPACITY_MESSAGE_POSTFIX: usize = workers::SIZE_TAG;
+pub const CAPACITY_MESSAGE_POSTFIX: usize = SIZE_TAG;
pub const fn message_data_len(payload: usize) -> usize {
- payload + mem::size_of::<TransportHeader>() + workers::SIZE_TAG
+ payload + mem::size_of::<TransportHeader>() + SIZE_TAG
}
-pub use device::Device;
+pub use device::DeviceHandle as Device;
pub use messages::TYPE_TRANSPORT;
-pub use peer::Peer;
+pub use peer::PeerHandle;
pub use types::Callbacks;
diff --git a/src/wireguard/router/outbound.rs b/src/wireguard/router/outbound.rs
new file mode 100644
index 0000000..30b7c2c
--- /dev/null
+++ b/src/wireguard/router/outbound.rs
@@ -0,0 +1,104 @@
+use super::messages::{TransportHeader, TYPE_TRANSPORT};
+use super::peer::Peer;
+use super::pool::*;
+use super::types::Callbacks;
+use super::KeyPair;
+use super::REJECT_AFTER_MESSAGES;
+use super::{tun, udp, Endpoint};
+
+use std::sync::mpsc::Receiver;
+use std::sync::Arc;
+
+use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
+use zerocopy::{AsBytes, LayoutVerified};
+
+pub const SIZE_TAG: usize = 16;
+
+pub struct Outbound {
+ msg: Vec<u8>,
+ keypair: Arc<KeyPair>,
+ counter: u64,
+}
+
+impl Outbound {
+ pub fn new(msg: Vec<u8>, keypair: Arc<KeyPair>, counter: u64) -> Outbound {
+ Outbound {
+ msg,
+ keypair,
+ counter,
+ }
+ }
+}
+
+#[inline(always)]
+fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
+ _peer: &Peer<E, C, T, B>,
+ body: &mut Outbound,
+) {
+ // make space for the tag
+ body.msg.extend([0u8; SIZE_TAG].iter());
+
+ // cast to header (should never fail)
+ let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
+ LayoutVerified::new_from_prefix(&mut body.msg[..])
+ .expect("earlier code should ensure that there is ample space");
+
+ // set header fields
+ debug_assert!(
+ body.counter < REJECT_AFTER_MESSAGES,
+ "should be checked when assigning counters"
+ );
+ header.f_type.set(TYPE_TRANSPORT);
+ header.f_receiver.set(body.keypair.send.id);
+ header.f_counter.set(body.counter);
+
+ // create a nonce object
+ let mut nonce = [0u8; 12];
+ debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
+ nonce[4..].copy_from_slice(header.f_counter.as_bytes());
+ let nonce = Nonce::assume_unique_for_key(nonce);
+
+ // do the weird ring AEAD dance
+ let key =
+ LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &body.keypair.send.key[..]).unwrap());
+
+ // encrypt content of transport message in-place
+ let end = packet.len() - SIZE_TAG;
+ let tag = key
+ .seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end])
+ .unwrap();
+
+ // append tag
+ packet[end..].copy_from_slice(tag.as_ref());
+}
+
+#[inline(always)]
+fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
+ peer: &Peer<E, C, T, B>,
+ body: &mut Outbound,
+) {
+ // send to peer
+ let xmit = peer.send(&body.msg[..]).is_ok();
+
+ // trigger callback
+ C::send(
+ &peer.opaque,
+ body.msg.len(),
+ xmit,
+ &body.keypair,
+ body.counter,
+ );
+}
+
+#[inline(always)]
+pub fn queue<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
+ peer: &Peer<E, C, T, B>,
+) -> &InorderQueue<Peer<E, C, T, B>, Outbound> {
+ &peer.outbound
+}
+
+pub fn worker<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
+ receiver: Receiver<Job<Peer<E, C, T, B>, Outbound>>,
+) {
+ worker_template(receiver, parallel, sequential, queue)
+}
diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs
index fff4dfc..192d4e2 100644
--- a/src/wireguard/router/peer.rs
+++ b/src/wireguard/router/peer.rs
@@ -2,10 +2,7 @@ use std::mem;
use std::net::{IpAddr, SocketAddr};
use std::ops::Deref;
use std::sync::atomic::AtomicBool;
-use std::sync::atomic::Ordering;
-use std::sync::mpsc::{sync_channel, SyncSender};
use std::sync::Arc;
-use std::thread;
use arraydeque::{ArrayDeque, Wrapping};
use log::debug;
@@ -16,18 +13,18 @@ use super::super::{tun, udp, Endpoint, KeyPair};
use super::anti_replay::AntiReplay;
use super::device::DecryptionState;
-use super::device::DeviceInner;
+use super::device::Device;
use super::device::EncryptionState;
use super::messages::TransportHeader;
-use futures::*;
-
-use super::workers::{worker_inbound, worker_outbound};
-use super::workers::{JobDecryption, JobEncryption, JobInbound, JobOutbound, JobParallel};
-use super::SIZE_MESSAGE_PREFIX;
-
use super::constants::*;
use super::types::{Callbacks, RouterError};
+use super::SIZE_MESSAGE_PREFIX;
+
+// worker pool related
+use super::inbound::Inbound;
+use super::outbound::Outbound;
+use super::pool::{InorderQueue, Job};
pub struct KeyWheel {
next: Option<Arc<KeyPair>>, // next key state (unconfirmed)
@@ -37,10 +34,10 @@ pub struct KeyWheel {
}
pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
- pub device: Arc<DeviceInner<E, C, T, B>>,
+ pub device: Device<E, C, T, B>,
pub opaque: C::Opaque,
- pub outbound: Mutex<SyncSender<JobOutbound>>,
- pub inbound: Mutex<SyncSender<JobInbound<E, C, T, B>>>,
+ pub outbound: InorderQueue<Peer<E, C, T, B>, Outbound>,
+ pub inbound: InorderQueue<Peer<E, C, T, B>, Inbound<E, C, T, B>>,
pub staged_packets: Mutex<ArrayDeque<[Vec<u8>; MAX_STAGED_PACKETS], Wrapping>>,
pub keys: Mutex<KeyWheel>,
pub ekey: Mutex<Option<EncryptionState>>,
@@ -48,16 +45,42 @@ pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E
}
pub struct Peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
- state: Arc<PeerInner<E, C, T, B>>,
- thread_outbound: Option<thread::JoinHandle<()>>,
- thread_inbound: Option<thread::JoinHandle<()>>,
+ inner: Arc<PeerInner<E, C, T, B>>,
+}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone for Peer<E, C, T, B> {
+ fn clone(&self) -> Self {
+ Peer {
+ inner: self.inner.clone(),
+ }
+ }
}
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PartialEq for Peer<E, C, T, B> {
+ fn eq(&self, other: &Self) -> bool {
+ Arc::ptr_eq(&self.inner, &other.inner)
+ }
+}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Eq for Peer<E, C, T, B> {}
+
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref for Peer<E, C, T, B> {
- type Target = Arc<PeerInner<E, C, T, B>>;
+ type Target = PeerInner<E, C, T, B>;
+ fn deref(&self) -> &Self::Target {
+ &self.inner
+ }
+}
+
+pub struct PeerHandle<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
+ peer: Peer<E, C, T, B>,
+}
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref
+ for PeerHandle<E, C, T, B>
+{
+ type Target = PeerInner<E, C, T, B>;
fn deref(&self) -> &Self::Target {
- &self.state
+ &self.peer
}
}
@@ -72,37 +95,24 @@ impl EncryptionState {
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DecryptionState<E, C, T, B> {
- fn new(
- peer: &Arc<PeerInner<E, C, T, B>>,
- keypair: &Arc<KeyPair>,
- ) -> DecryptionState<E, C, T, B> {
+ fn new(peer: Peer<E, C, T, B>, keypair: &Arc<KeyPair>) -> DecryptionState<E, C, T, B> {
DecryptionState {
confirmed: AtomicBool::new(keypair.initiator),
keypair: keypair.clone(),
protector: spin::Mutex::new(AntiReplay::new()),
- peer: peer.clone(),
death: keypair.birth + REJECT_AFTER_TIME,
+ peer,
}
}
}
-impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Peer<E, C, T, B> {
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for PeerHandle<E, C, T, B> {
fn drop(&mut self) {
- let peer = &self.state;
+ let peer = &self.peer;
// remove from cryptkey router
- self.state.device.table.remove(peer);
-
- // drop channels
-
- mem::replace(&mut *peer.inbound.lock(), sync_channel(0).0);
- mem::replace(&mut *peer.outbound.lock(), sync_channel(0).0);
-
- // join with workers
-
- mem::replace(&mut self.thread_inbound, None).map(|v| v.join());
- mem::replace(&mut self.thread_outbound, None).map(|v| v.join());
+ self.peer.device.table.remove(peer);
// release ids from the receiver map
@@ -134,50 +144,32 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Peer
}
pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
- device: Arc<DeviceInner<E, C, T, B>>,
+ device: Device<E, C, T, B>,
opaque: C::Opaque,
-) -> Peer<E, C, T, B> {
- let (out_tx, out_rx) = sync_channel(128);
- let (in_tx, in_rx) = sync_channel(128);
-
+) -> PeerHandle<E, C, T, B> {
// allocate peer object
let peer = {
let device = device.clone();
- Arc::new(PeerInner {
- opaque,
- device,
- inbound: Mutex::new(in_tx),
- outbound: Mutex::new(out_tx),
- ekey: spin::Mutex::new(None),
- endpoint: spin::Mutex::new(None),
- keys: spin::Mutex::new(KeyWheel {
- next: None,
- current: None,
- previous: None,
- retired: vec![],
+ Peer {
+ inner: Arc::new(PeerInner {
+ opaque,
+ device,
+ inbound: InorderQueue::new(),
+ outbound: InorderQueue::new(),
+ ekey: spin::Mutex::new(None),
+ endpoint: spin::Mutex::new(None),
+ keys: spin::Mutex::new(KeyWheel {
+ next: None,
+ current: None,
+ previous: None,
+ retired: vec![],
+ }),
+ staged_packets: spin::Mutex::new(ArrayDeque::new()),
}),
- staged_packets: spin::Mutex::new(ArrayDeque::new()),
- })
- };
-
- // spawn outbound thread
- let thread_inbound = {
- let peer = peer.clone();
- thread::spawn(move || worker_outbound(peer, out_rx))
- };
-
- // spawn inbound thread
- let thread_outbound = {
- let peer = peer.clone();
- let device = device.clone();
- thread::spawn(move || worker_inbound(device, peer, in_rx))
+ }
};
- Peer {
- state: peer,
- thread_inbound: Some(thread_inbound),
- thread_outbound: Some(thread_outbound),
- }
+ PeerHandle { peer }
}
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E, C, T, B> {
@@ -210,7 +202,9 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E,
None => Err(RouterError::NoEndpoint),
}
}
+}
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, B> {
// Transmit all staged packets
fn send_staged(&self) -> bool {
debug!("peer.send_staged");
@@ -230,16 +224,12 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E,
// Treat the msg as the payload of a transport message
// Unlike device.send, peer.send_raw does not buffer messages when a key is not available.
fn send_raw(&self, msg: Vec<u8>) -> bool {
- debug!("peer.send_raw");
+ log::debug!("peer.send_raw");
match self.send_job(msg, false) {
Some(job) => {
+ self.device.outbound_queue.send(job);
debug!("send_raw: got obtained send_job");
- let index = self.device.queue_next.fetch_add(1, Ordering::SeqCst);
- let queues = self.device.queues.lock();
- match queues[index % queues.len()].send(job) {
- Ok(_) => true,
- Err(_) => false,
- }
+ true
}
None => false,
}
@@ -285,16 +275,11 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E,
src: E,
dec: Arc<DecryptionState<E, C, T, B>>,
msg: Vec<u8>,
- ) -> Option<JobParallel> {
- let (tx, rx) = oneshot();
- let keypair = dec.keypair.clone();
- match self.inbound.lock().try_send((dec, src, rx)) {
- Ok(_) => Some(JobParallel::Decryption(tx, JobDecryption { msg, keypair })),
- Err(_) => None,
- }
+ ) -> Option<Job<Self, Inbound<E, C, T, B>>> {
+ Some(Job::new(self.clone(), Inbound::new(msg, dec, src)))
}
- pub fn send_job(&self, msg: Vec<u8>, stage: bool) -> Option<JobParallel> {
+ pub fn send_job(&self, msg: Vec<u8>, stage: bool) -> Option<Job<Self, Outbound>> {
debug!("peer.send_job");
debug_assert!(
msg.len() >= mem::size_of::<TransportHeader>(),
@@ -337,22 +322,13 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E,
}?;
// add job to in-order queue and return sender to device for inclusion in worker pool
- let (tx, rx) = oneshot();
- match self.outbound.lock().try_send(rx) {
- Ok(_) => Some(JobParallel::Encryption(
- tx,
- JobEncryption {
- msg,
- counter,
- keypair,
- },
- )),
- Err(_) => None,
- }
+ let job = Job::new(self.clone(), Outbound::new(msg, keypair, counter));
+ self.outbound.send(job.clone());
+ Some(job)
}
}
-impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, B> {
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, C, T, B> {
/// Set the endpoint of the peer
///
/// # Arguments
@@ -365,7 +341,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
/// as sockets should be "unsticked" when manually updating the endpoint
pub fn set_endpoint(&self, endpoint: E) {
debug!("peer.set_endpoint");
- *self.state.endpoint.lock() = Some(endpoint);
+ *self.peer.endpoint.lock() = Some(endpoint);
}
/// Returns the current endpoint of the peer (for configuration)
@@ -375,11 +351,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
/// Does not convey potential "sticky socket" information
pub fn get_endpoint(&self) -> Option<SocketAddr> {
debug!("peer.get_endpoint");
- self.state
- .endpoint
- .lock()
- .as_ref()
- .map(|e| e.into_address())
+ self.peer.endpoint.lock().as_ref().map(|e| e.into_address())
}
/// Zero all key-material related to the peer
@@ -387,7 +359,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
debug!("peer.zero_keys");
let mut release: Vec<u32> = Vec::with_capacity(3);
- let mut keys = self.state.keys.lock();
+ let mut keys = self.peer.keys.lock();
// update key-wheel
@@ -398,14 +370,14 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
// update inbound "recv" map
{
- let mut recv = self.state.device.recv.write();
+ let mut recv = self.peer.device.recv.write();
for id in release {
recv.remove(&id);
}
}
// clear encryption state
- *self.state.ekey.lock() = None;
+ *self.peer.ekey.lock() = None;
}
pub fn down(&self) {
@@ -436,13 +408,13 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
let initiator = new.initiator;
let release = {
let new = Arc::new(new);
- let mut keys = self.state.keys.lock();
+ let mut keys = self.peer.keys.lock();
let mut release = mem::replace(&mut keys.retired, vec![]);
// update key-wheel
if new.initiator {
// start using key for encryption
- *self.state.ekey.lock() = Some(EncryptionState::new(&new));
+ *self.peer.ekey.lock() = Some(EncryptionState::new(&new));
// move current into previous
keys.previous = keys.current.as_ref().map(|v| v.clone());
@@ -456,7 +428,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
// update incoming packet id map
{
debug!("peer.add_keypair: updating inbound id map");
- let mut recv = self.state.device.recv.write();
+ let mut recv = self.peer.device.recv.write();
// purge recv map of previous id
keys.previous.as_ref().map(|k| {
@@ -468,7 +440,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
debug_assert!(!recv.contains_key(&new.recv.id));
recv.insert(
new.recv.id,
- Arc::new(DecryptionState::new(&self.state, &new)),
+ Arc::new(DecryptionState::new(self.peer.clone(), &new)),
);
}
release
@@ -476,10 +448,10 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
// schedule confirmation
if initiator {
- debug_assert!(self.state.ekey.lock().is_some());
+ debug_assert!(self.peer.ekey.lock().is_some());
debug!("peer.add_keypair: is initiator, must confirm the key");
// attempt to confirm using staged packets
- if !self.state.send_staged() {
+ if !self.peer.send_staged() {
// fall back to keepalive packet
let ok = self.send_keepalive();
debug!(
@@ -499,7 +471,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
pub fn send_keepalive(&self) -> bool {
debug!("peer.send_keepalive");
- self.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX])
+ self.peer.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX])
}
/// Map a subnet to the peer
@@ -517,10 +489,10 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
/// If an identical value already exists as part of a prior peer,
/// the allowed IP entry will be removed from that peer and added to this peer.
pub fn add_allowed_ip(&self, ip: IpAddr, masklen: u32) {
- self.state
+ self.peer
.device
.table
- .insert(ip, masklen, self.state.clone())
+ .insert(ip, masklen, self.peer.clone())
}
/// List subnets mapped to the peer
@@ -529,23 +501,21 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
///
/// A vector of subnets, represented by as mask/size
pub fn list_allowed_ips(&self) -> Vec<(IpAddr, u32)> {
- self.state.device.table.list(&self.state)
+ self.peer.device.table.list(&self.peer)
}
/// Clear subnets mapped to the peer.
/// After the call, no subnets will be cryptkey routed to the peer.
/// Used for the UAPI command "replace_allowed_ips=true"
pub fn remove_allowed_ips(&self) {
- self.state.device.table.remove(&self.state)
+ self.peer.device.table.remove(&self.peer)
}
pub fn clear_src(&self) {
- (*self.state.endpoint.lock())
- .as_mut()
- .map(|e| e.clear_src());
+ (*self.peer.endpoint.lock()).as_mut().map(|e| e.clear_src());
}
pub fn purge_staged_packets(&self) {
- self.state.staged_packets.lock().clear();
+ self.peer.staged_packets.lock().clear();
}
}
diff --git a/src/wireguard/router/pool.rs b/src/wireguard/router/pool.rs
new file mode 100644
index 0000000..12956c1
--- /dev/null
+++ b/src/wireguard/router/pool.rs
@@ -0,0 +1,132 @@
+use arraydeque::ArrayDeque;
+use spin::{Mutex, MutexGuard};
+use std::sync::mpsc::Receiver;
+use std::sync::Arc;
+
+const INORDER_QUEUE_SIZE: usize = 64;
+
+pub struct InnerJob<P, B> {
+ // peer (used by worker to schedule/handle inorder queue),
+ // when the peer is None, the job is complete
+ peer: Option<P>,
+ pub body: B,
+}
+
+pub struct Job<P, B> {
+ inner: Arc<Mutex<InnerJob<P, B>>>,
+}
+
+impl<P, B> Clone for Job<P, B> {
+ fn clone(&self) -> Job<P, B> {
+ Job {
+ inner: self.inner.clone(),
+ }
+ }
+}
+
+impl<P, B> Job<P, B> {
+ pub fn new(peer: P, body: B) -> Job<P, B> {
+ Job {
+ inner: Arc::new(Mutex::new(InnerJob {
+ peer: Some(peer),
+ body,
+ })),
+ }
+ }
+}
+
+impl<P, B> Job<P, B> {
+ /// Returns a mutex guard to the inner job if complete
+ pub fn complete(&self) -> Option<MutexGuard<InnerJob<P, B>>> {
+ self.inner
+ .try_lock()
+ .and_then(|m| if m.peer.is_none() { Some(m) } else { None })
+ }
+}
+
+pub struct InorderQueue<P, B> {
+ queue: Mutex<ArrayDeque<[Job<P, B>; INORDER_QUEUE_SIZE]>>,
+}
+
+impl<P, B> InorderQueue<P, B> {
+ pub fn send(&self, job: Job<P, B>) -> bool {
+ self.queue.lock().push_back(job).is_ok()
+ }
+
+ pub fn new() -> InorderQueue<P, B> {
+ InorderQueue {
+ queue: Mutex::new(ArrayDeque::new()),
+ }
+ }
+
+ #[inline(always)]
+ pub fn handle<F: Fn(&mut InnerJob<P, B>)>(&self, f: F) {
+ // take the mutex
+ let mut queue = self.queue.lock();
+
+ // handle all complete messages
+ while queue
+ .pop_front()
+ .and_then(|j| {
+ // check if job is complete
+ let ret = if let Some(mut guard) = j.complete() {
+ f(&mut *guard);
+ false
+ } else {
+ true
+ };
+
+ // return job to cyclic buffer if not complete
+ if ret {
+ let _res = queue.push_front(j);
+ debug_assert!(_res.is_ok());
+ None
+ } else {
+ // add job back to pool
+ Some(())
+ }
+ })
+ .is_some()
+ {}
+ }
+}
+
+/// Allows easy construction of a semi-parallel worker.
+/// Applicable for both decryption and encryption workers.
+#[inline(always)]
+pub fn worker_template<
+ P, // represents a peer (atomic reference counted pointer)
+ B, // inner body type (message buffer, key material, ...)
+ W: Fn(&P, &mut B),
+ S: Fn(&P, &mut B),
+ Q: Fn(&P) -> &InorderQueue<P, B>,
+>(
+ receiver: Receiver<Job<P, B>>, // receiever for new jobs
+ work_parallel: W, // perform parallel / out-of-order work on peer
+ work_sequential: S, // perform sequential work on peer
+ queue: Q, // resolve a peer to an inorder queue
+) {
+ loop {
+ // handle new job
+ let peer = {
+ // get next job
+ let job = match receiver.recv() {
+ Ok(job) => job,
+ _ => return,
+ };
+
+ // lock the job
+ let mut job = job.inner.lock();
+
+ // take the peer from the job
+ let peer = job.peer.take().unwrap();
+
+ // process job
+ work_parallel(&peer, &mut job.body);
+ peer
+ };
+
+ // process inorder jobs for peer
+ queue(&peer).handle(|j| work_sequential(&peer, &mut j.body));
+ }
+}
diff --git a/src/wireguard/router/route.rs b/src/wireguard/router/route.rs
index 1c93009..40dc36b 100644
--- a/src/wireguard/router/route.rs
+++ b/src/wireguard/router/route.rs
@@ -4,7 +4,6 @@ use zerocopy::LayoutVerified;
use std::mem;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
-use std::sync::Arc;
use spin::RwLock;
use treebitmap::address::Address;
@@ -12,12 +11,12 @@ use treebitmap::IpLookupTable;
/* Functions for obtaining and validating "cryptokey" routes */
-pub struct RoutingTable<T> {
- ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<T>>>,
- ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<T>>>,
+pub struct RoutingTable<T: Eq + Clone> {
+ ipv4: RwLock<IpLookupTable<Ipv4Addr, T>>,
+ ipv6: RwLock<IpLookupTable<Ipv6Addr, T>>,
}
-impl<T> RoutingTable<T> {
+impl<T: Eq + Clone> RoutingTable<T> {
pub fn new() -> Self {
RoutingTable {
ipv4: RwLock::new(IpLookupTable::new()),
@@ -26,27 +25,27 @@ impl<T> RoutingTable<T> {
}
// collect keys mapping to the given value
- fn collect<A>(table: &IpLookupTable<A, Arc<T>>, value: &Arc<T>) -> Vec<(A, u32)>
+ fn collect<A>(table: &IpLookupTable<A, T>, value: &T) -> Vec<(A, u32)>
where
A: Address,
{
let mut res = Vec::new();
for (ip, cidr, v) in table.iter() {
- if Arc::ptr_eq(v, value) {
+ if v == value {
res.push((ip, cidr))
}
}
res
}
- pub fn insert(&self, ip: IpAddr, cidr: u32, value: Arc<T>) {
+ pub fn insert(&self, ip: IpAddr, cidr: u32, value: T) {
match ip {
IpAddr::V4(v4) => self.ipv4.write().insert(v4.mask(cidr), cidr, value),
IpAddr::V6(v6) => self.ipv6.write().insert(v6.mask(cidr), cidr, value),
};
}
- pub fn list(&self, value: &Arc<T>) -> Vec<(IpAddr, u32)> {
+ pub fn list(&self, value: &T) -> Vec<(IpAddr, u32)> {
let mut res = vec![];
res.extend(
Self::collect(&*self.ipv4.read(), value)
@@ -61,7 +60,7 @@ impl<T> RoutingTable<T> {
res
}
- pub fn remove(&self, value: &Arc<T>) {
+ pub fn remove(&self, value: &T) {
let mut v4 = self.ipv4.write();
for (ip, cidr) in Self::collect(&*v4, value) {
v4.remove(ip, cidr);
@@ -74,7 +73,7 @@ impl<T> RoutingTable<T> {
}
#[inline(always)]
- pub fn get_route(&self, packet: &[u8]) -> Option<Arc<T>> {
+ pub fn get_route(&self, packet: &[u8]) -> Option<T> {
match packet.get(0)? >> 4 {
VERSION_IP4 => {
// check length and cast to IPv4 header
@@ -113,7 +112,7 @@ impl<T> RoutingTable<T> {
}
#[inline(always)]
- pub fn check_route(&self, peer: &Arc<T>, packet: &[u8]) -> Option<usize> {
+ pub fn check_route(&self, peer: &T, packet: &[u8]) -> Option<usize> {
match packet.get(0)? >> 4 {
VERSION_IP4 => {
// check length and cast to IPv4 header
@@ -130,7 +129,7 @@ impl<T> RoutingTable<T> {
.read()
.longest_match(Ipv4Addr::from(header.f_source))
.and_then(|(_, _, p)| {
- if Arc::ptr_eq(p, peer) {
+ if p == peer {
Some(header.f_total_len.get() as usize)
} else {
None
@@ -152,7 +151,7 @@ impl<T> RoutingTable<T> {
.read()
.longest_match(Ipv6Addr::from(header.f_source))
.and_then(|(_, _, p)| {
- if Arc::ptr_eq(p, peer) {
+ if p == peer {
Some(header.f_len.get() as usize + mem::size_of::<IPv6Header>())
} else {
None
diff --git a/src/wireguard/tests.rs b/src/wireguard/tests.rs
index bf1bd5f..3cccb42 100644
--- a/src/wireguard/tests.rs
+++ b/src/wireguard/tests.rs
@@ -1,5 +1,5 @@
+use super::dummy;
use super::wireguard::Wireguard;
-use super::{dummy, tun, udp};
use std::net::IpAddr;
use std::thread;
diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs
index 0ce4210..e1aabad 100644
--- a/src/wireguard/timers.rs
+++ b/src/wireguard/timers.rs
@@ -137,6 +137,7 @@ impl<T: tun::Tun, B: udp::UDP> PeerInner<T, B> {
pub fn timers_handshake_complete(&self) {
let timers = self.timers();
if timers.enabled {
+ timers.retransmit_handshake.stop();
timers.handshake_attempts.store(0, Ordering::SeqCst);
timers.sent_lastminute_handshake.store(false, Ordering::SeqCst);
*self.walltime_last_handshake.lock() = Some(SystemTime::now());