diff options
Diffstat (limited to '')
-rw-r--r-- | src/wireguard/peer.rs | 2 | ||||
-rw-r--r-- | src/wireguard/router/device copy.rs | 228 | ||||
-rw-r--r-- | src/wireguard/router/device.rs | 141 | ||||
-rw-r--r-- | src/wireguard/router/inbound.rs | 172 | ||||
-rw-r--r-- | src/wireguard/router/mod.rs | 16 | ||||
-rw-r--r-- | src/wireguard/router/outbound.rs | 104 | ||||
-rw-r--r-- | src/wireguard/router/peer.rs | 220 | ||||
-rw-r--r-- | src/wireguard/router/pool.rs | 132 | ||||
-rw-r--r-- | src/wireguard/router/route.rs | 27 | ||||
-rw-r--r-- | src/wireguard/tests.rs | 2 | ||||
-rw-r--r-- | src/wireguard/timers.rs | 1 |
11 files changed, 638 insertions, 407 deletions
diff --git a/src/wireguard/peer.rs b/src/wireguard/peer.rs index 5bcd070..04622fd 100644 --- a/src/wireguard/peer.rs +++ b/src/wireguard/peer.rs @@ -18,7 +18,7 @@ use crossbeam_channel::Sender; use x25519_dalek::PublicKey; pub struct Peer<T: Tun, B: UDP> { - pub router: Arc<router::Peer<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>, + pub router: Arc<router::PeerHandle<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>, pub state: Arc<PeerInner<T, B>>, } diff --git a/src/wireguard/router/device copy.rs b/src/wireguard/router/device copy.rs deleted file mode 100644 index 04b2045..0000000 --- a/src/wireguard/router/device copy.rs +++ /dev/null @@ -1,228 +0,0 @@ -use std::collections::HashMap; - -use std::net::{Ipv4Addr, Ipv6Addr}; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::sync::mpsc::sync_channel; -use std::sync::mpsc::SyncSender; -use std::sync::Arc; -use std::thread; -use std::time::Instant; - -use log::debug; -use spin::{Mutex, RwLock}; -use treebitmap::IpLookupTable; -use zerocopy::LayoutVerified; - -use super::anti_replay::AntiReplay; -use super::constants::*; - -use super::messages::{TransportHeader, TYPE_TRANSPORT}; -use super::peer::{new_peer, Peer, PeerInner}; -use super::types::{Callbacks, RouterError}; -use super::workers::{worker_parallel, JobParallel}; -use super::SIZE_MESSAGE_PREFIX; - -use super::route::get_route; - -use super::super::{bind, tun, Endpoint, KeyPair}; - -pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { - // inbound writer (TUN) - pub inbound: T, - - // outbound writer (Bind) - pub outbound: RwLock<(bool, Option<B>)>, - - // routing - pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state - pub ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv4 cryptkey routing - pub ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv6 cryptkey routing - - // work queues - pub queue_next: AtomicUsize, // next round-robin index - pub queues: Mutex<Vec<SyncSender<JobParallel>>>, // work queues (1 per thread) -} - -pub struct EncryptionState { - pub keypair: Arc<KeyPair>, // keypair - pub nonce: u64, // next available nonce - pub death: Instant, // (birth + reject-after-time - keepalive-timeout - rekey-timeout) -} - -pub struct DecryptionState<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { - pub keypair: Arc<KeyPair>, - pub confirmed: AtomicBool, - pub protector: Mutex<AntiReplay>, - pub peer: Arc<PeerInner<E, C, T, B>>, - pub death: Instant, // time when the key can no longer be used for decryption -} - -pub struct Device<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { - state: Arc<DeviceInner<E, C, T, B>>, // reference to device state - handles: Vec<thread::JoinHandle<()>>, // join handles for workers -} - -impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Device<E, C, T, B> { - fn drop(&mut self) { - debug!("router: dropping device"); - - // drop all queues - { - let mut queues = self.state.queues.lock(); - while queues.pop().is_some() {} - } - - // join all worker threads - while match self.handles.pop() { - Some(handle) => { - handle.thread().unpark(); - handle.join().unwrap(); - true - } - _ => false, - } {} - - debug!("router: device dropped"); - } -} - -impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C, T, B> { - pub fn new(num_workers: usize, tun: T) -> Device<E, C, T, B> { - // allocate shared device state - let inner = DeviceInner { - inbound: tun, - outbound: RwLock::new((true, None)), - queues: Mutex::new(Vec::with_capacity(num_workers)), - queue_next: AtomicUsize::new(0), - recv: RwLock::new(HashMap::new()), - ipv4: RwLock::new(IpLookupTable::new()), - ipv6: RwLock::new(IpLookupTable::new()), - }; - - // start worker threads - let mut threads = Vec::with_capacity(num_workers); - for _ in 0..num_workers { - let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE); - inner.queues.lock().push(tx); - threads.push(thread::spawn(move || worker_parallel(rx))); - } - - // return exported device handle - Device { - state: Arc::new(inner), - handles: threads, - } - } - - /// Brings the router down. - /// When the router is brought down it: - /// - Prevents transmission of outbound messages. - pub fn down(&self) { - self.state.outbound.write().0 = false; - } - - /// Brints the router up - /// When the router is brought up it enables the transmission of outbound messages. - pub fn up(&self) { - self.state.outbound.write().0 = true; - } - - /// A new secret key has been set for the device. - /// According to WireGuard semantics, this should cause all "sending" keys to be discarded. - pub fn new_sk(&self) {} - - /// Adds a new peer to the device - /// - /// # Returns - /// - /// A atomic ref. counted peer (with liftime matching the device) - pub fn new_peer(&self, opaque: C::Opaque) -> Peer<E, C, T, B> { - new_peer(self.state.clone(), opaque) - } - - /// Cryptkey routes and sends a plaintext message (IP packet) - /// - /// # Arguments - /// - /// - msg: IP packet to crypt-key route - /// - pub fn send(&self, msg: Vec<u8>) -> Result<(), RouterError> { - debug_assert!(msg.len() > SIZE_MESSAGE_PREFIX); - log::trace!( - "Router, outbound packet = {}", - hex::encode(&msg[SIZE_MESSAGE_PREFIX..]) - ); - - // ignore header prefix (for in-place transport message construction) - let packet = &msg[SIZE_MESSAGE_PREFIX..]; - - // lookup peer based on IP packet destination address - let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptoKeyRoute)?; - - // schedule for encryption and transmission to peer - if let Some(job) = peer.send_job(msg, true) { - // add job to worker queue - let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst); - let queues = self.state.queues.lock(); - queues[idx % queues.len()].send(job).unwrap(); - } - - Ok(()) - } - - /// Receive an encrypted transport message - /// - /// # Arguments - /// - /// - src: Source address of the packet - /// - msg: Encrypted transport message - /// - /// # Returns - /// - /// - pub fn recv(&self, src: E, msg: Vec<u8>) -> Result<(), RouterError> { - // parse / cast - let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) { - Some(v) => v, - None => { - return Err(RouterError::MalformedTransportMessage); - } - }; - - let header: LayoutVerified<&[u8], TransportHeader> = header; - - debug_assert!( - header.f_type.get() == TYPE_TRANSPORT as u32, - "this should be checked by the message type multiplexer" - ); - - log::trace!( - "Router, handle transport message: (receiver = {}, counter = {})", - header.f_receiver, - header.f_counter - ); - - // lookup peer based on receiver id - let dec = self.state.recv.read(); - let dec = dec - .get(&header.f_receiver.get()) - .ok_or(RouterError::UnknownReceiverId)?; - - // schedule for decryption and TUN write - if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) { - // add job to worker queue - let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst); - let queues = self.state.queues.lock(); - queues[idx % queues.len()].send(job).unwrap(); - } - - Ok(()) - } - - /// Set outbound writer - /// - /// - pub fn set_outbound_writer(&self, new: B) { - self.state.outbound.write().1 = Some(new); - } -} diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs index 621010b..88eeae1 100644 --- a/src/wireguard/router/device.rs +++ b/src/wireguard/router/device.rs @@ -1,7 +1,8 @@ use std::collections::HashMap; +use std::ops::Deref; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::mpsc::sync_channel; -use std::sync::mpsc::SyncSender; +use std::sync::mpsc::{Receiver, SyncSender}; use std::sync::Arc; use std::thread; use std::time::Instant; @@ -11,18 +12,61 @@ use spin::{Mutex, RwLock}; use zerocopy::LayoutVerified; use super::anti_replay::AntiReplay; -use super::constants::*; +use super::pool::Job; + +use super::inbound; +use super::outbound; use super::messages::{TransportHeader, TYPE_TRANSPORT}; -use super::peer::{new_peer, Peer, PeerInner}; +use super::peer::{new_peer, Peer, PeerHandle}; use super::types::{Callbacks, RouterError}; -use super::workers::{worker_parallel, JobParallel}; use super::SIZE_MESSAGE_PREFIX; use super::route::RoutingTable; use super::super::{tun, udp, Endpoint, KeyPair}; +pub struct ParallelQueue<T> { + next: AtomicUsize, // next round-robin index + queues: Vec<Mutex<SyncSender<T>>>, // work queues (1 per thread) +} + +impl<T> ParallelQueue<T> { + fn new(queues: usize) -> (Vec<Receiver<T>>, Self) { + let mut rxs = vec![]; + let mut txs = vec![]; + + for _ in 0..queues { + let (tx, rx) = sync_channel(128); + txs.push(Mutex::new(tx)); + rxs.push(rx); + } + + ( + rxs, + ParallelQueue { + next: AtomicUsize::new(0), + queues: txs, + }, + ) + } + + pub fn send(&self, v: T) { + let len = self.queues.len(); + let idx = self.next.fetch_add(1, Ordering::SeqCst); + let que = self.queues[idx % len].lock(); + que.send(v).unwrap(); + } + + pub fn close(&self) { + for i in 0..self.queues.len() { + let (tx, _) = sync_channel(0); + let queue = &self.queues[i]; + *queue.lock() = tx; + } + } +} + pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { // inbound writer (TUN) pub inbound: T, @@ -32,11 +76,11 @@ pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer // routing pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state - pub table: RoutingTable<PeerInner<E, C, T, B>>, + pub table: RoutingTable<Peer<E, C, T, B>>, // work queues - pub queue_next: AtomicUsize, // next round-robin index - pub queues: Mutex<Vec<SyncSender<JobParallel>>>, // work queues (1 per thread) + pub outbound_queue: ParallelQueue<Job<Peer<E, C, T, B>, outbound::Outbound>>, + pub inbound_queue: ParallelQueue<Job<Peer<E, C, T, B>, inbound::Inbound<E, C, T, B>>>, } pub struct EncryptionState { @@ -49,24 +93,53 @@ pub struct DecryptionState<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Wr pub keypair: Arc<KeyPair>, pub confirmed: AtomicBool, pub protector: Mutex<AntiReplay>, - pub peer: Arc<PeerInner<E, C, T, B>>, + pub peer: Peer<E, C, T, B>, pub death: Instant, // time when the key can no longer be used for decryption } pub struct Device<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { - state: Arc<DeviceInner<E, C, T, B>>, // reference to device state + inner: Arc<DeviceInner<E, C, T, B>>, +} + +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone for Device<E, C, T, B> { + fn clone(&self) -> Self { + Device { + inner: self.inner.clone(), + } + } +} + +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PartialEq + for Device<E, C, T, B> +{ + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.inner, &other.inner) + } +} + +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Eq for Device<E, C, T, B> {} + +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref for Device<E, C, T, B> { + type Target = DeviceInner<E, C, T, B>; + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +pub struct DeviceHandle<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { + state: Device<E, C, T, B>, // reference to device state handles: Vec<thread::JoinHandle<()>>, // join handles for workers } -impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Device<E, C, T, B> { +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop + for DeviceHandle<E, C, T, B> +{ fn drop(&mut self) { debug!("router: dropping device"); - // drop all queues - { - let mut queues = self.state.queues.lock(); - while queues.pop().is_some() {} - } + // close worker queues + self.state.outbound_queue.close(); + self.state.inbound_queue.close(); // join all worker threads while match self.handles.pop() { @@ -82,14 +155,16 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Devi } } -impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C, T, B> { - pub fn new(num_workers: usize, tun: T) -> Device<E, C, T, B> { +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<E, C, T, B> { + pub fn new(num_workers: usize, tun: T) -> DeviceHandle<E, C, T, B> { // allocate shared device state + let (mut outrx, outbound_queue) = ParallelQueue::new(num_workers); + let (mut inrx, inbound_queue) = ParallelQueue::new(num_workers); let inner = DeviceInner { inbound: tun, + inbound_queue, outbound: RwLock::new((true, None)), - queues: Mutex::new(Vec::with_capacity(num_workers)), - queue_next: AtomicUsize::new(0), + outbound_queue, recv: RwLock::new(HashMap::new()), table: RoutingTable::new(), }; @@ -97,14 +172,20 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C, // start worker threads let mut threads = Vec::with_capacity(num_workers); for _ in 0..num_workers { - let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE); - inner.queues.lock().push(tx); - threads.push(thread::spawn(move || worker_parallel(rx))); + let rx = inrx.pop().unwrap(); + threads.push(thread::spawn(move || inbound::worker(rx))); + } + + for _ in 0..num_workers { + let rx = outrx.pop().unwrap(); + threads.push(thread::spawn(move || outbound::worker(rx))); } // return exported device handle - Device { - state: Arc::new(inner), + DeviceHandle { + state: Device { + inner: Arc::new(inner), + }, handles: threads, } } @@ -131,7 +212,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C, /// # Returns /// /// A atomic ref. counted peer (with liftime matching the device) - pub fn new_peer(&self, opaque: C::Opaque) -> Peer<E, C, T, B> { + pub fn new_peer(&self, opaque: C::Opaque) -> PeerHandle<E, C, T, B> { new_peer(self.state.clone(), opaque) } @@ -160,10 +241,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C, // schedule for encryption and transmission to peer if let Some(job) = peer.send_job(msg, true) { - // add job to worker queue - let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst); - let queues = self.state.queues.lock(); - queues[idx % queues.len()].send(job).unwrap(); + self.state.outbound_queue.send(job); } Ok(()) @@ -209,10 +287,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Device<E, C, // schedule for decryption and TUN write if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) { - // add job to worker queue - let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst); - let queues = self.state.queues.lock(); - queues[idx % queues.len()].send(job).unwrap(); + self.state.inbound_queue.send(job); } Ok(()) diff --git a/src/wireguard/router/inbound.rs b/src/wireguard/router/inbound.rs new file mode 100644 index 0000000..d4ad307 --- /dev/null +++ b/src/wireguard/router/inbound.rs @@ -0,0 +1,172 @@ +use super::device::DecryptionState; +use super::messages::TransportHeader; +use super::peer::Peer; +use super::pool::*; +use super::types::Callbacks; +use super::{tun, udp, Endpoint}; + +use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; +use zerocopy::{AsBytes, LayoutVerified}; + +use std::mem; +use std::sync::atomic::Ordering; +use std::sync::mpsc::Receiver; +use std::sync::Arc; + +pub const SIZE_TAG: usize = 16; + +pub struct Inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { + msg: Vec<u8>, + failed: bool, + state: Arc<DecryptionState<E, C, T, B>>, + endpoint: Option<E>, +} + +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Inbound<E, C, T, B> { + pub fn new( + msg: Vec<u8>, + state: Arc<DecryptionState<E, C, T, B>>, + endpoint: E, + ) -> Inbound<E, C, T, B> { + Inbound { + msg, + state, + failed: false, + endpoint: Some(endpoint), + } + } +} + +#[inline(always)] +fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( + peer: &Peer<E, C, T, B>, + body: &mut Inbound<E, C, T, B>, +) { + // cast to header followed by payload + let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = + match LayoutVerified::new_from_prefix(&mut body.msg[..]) { + Some(v) => v, + None => { + log::debug!("inbound worker: failed to parse message"); + return; + } + }; + + // authenticate and decrypt payload + { + // create nonce object + let mut nonce = [0u8; 12]; + debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); + nonce[4..].copy_from_slice(header.f_counter.as_bytes()); + let nonce = Nonce::assume_unique_for_key(nonce); + + // do the weird ring AEAD dance + let key = LessSafeKey::new( + UnboundKey::new(&CHACHA20_POLY1305, &body.state.keypair.recv.key[..]).unwrap(), + ); + + // attempt to open (and authenticate) the body + match key.open_in_place(nonce, Aad::empty(), packet) { + Ok(_) => (), + Err(_) => { + // fault and return early + body.failed = true; + return; + } + } + } + + // cryptokey route and strip padding + let inner_len = { + let length = packet.len() - SIZE_TAG; + if length > 0 { + peer.device.table.check_route(&peer, &packet[..length]) + } else { + Some(0) + } + }; + + // truncate to remove tag + match inner_len { + None => { + body.failed = true; + } + Some(len) => { + body.msg.truncate(mem::size_of::<TransportHeader>() + len); + } + } +} + +#[inline(always)] +fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( + peer: &Peer<E, C, T, B>, + body: &mut Inbound<E, C, T, B>, +) { + // decryption failed, return early + if body.failed { + return; + } + + // cast transport header + let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) = + match LayoutVerified::new_from_prefix(&body.msg[..]) { + Some(v) => v, + None => { + log::debug!("inbound worker: failed to parse message"); + return; + } + }; + debug_assert!( + packet.len() >= CHACHA20_POLY1305.tag_len(), + "this should be checked earlier in the pipeline (decryption should fail)" + ); + + // check for replay + if !body.state.protector.lock().update(header.f_counter.get()) { + log::debug!("inbound worker: replay detected"); + return; + } + + // check for confirms key + if !body.state.confirmed.swap(true, Ordering::SeqCst) { + log::debug!("inbound worker: message confirms key"); + peer.confirm_key(&body.state.keypair); + } + + // update endpoint + *peer.endpoint.lock() = body.endpoint.take(); + + // calculate length of IP packet + padding + let length = packet.len() - SIZE_TAG; + log::debug!("inbound worker: plaintext length = {}", length); + + // check if should be written to TUN + let mut sent = false; + if length > 0 { + sent = match peer.device.inbound.write(&packet[..]) { + Err(e) => { + log::debug!("failed to write inbound packet to TUN: {:?}", e); + false + } + Ok(_) => true, + } + } else { + log::debug!("inbound worker: received keepalive") + } + + // trigger callback + C::recv(&peer.opaque, body.msg.len(), sent, &body.state.keypair); +} + +#[inline(always)] +fn queue<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( + peer: &Peer<E, C, T, B>, +) -> &InorderQueue<Peer<E, C, T, B>, Inbound<E, C, T, B>> { + &peer.inbound +} + +pub fn worker<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( + receiver: Receiver<Job<Peer<E, C, T, B>, Inbound<E, C, T, B>>>, +) { + worker_template(receiver, parallel, sequential, queue) +} diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs index 6aa894d..3243b88 100644 --- a/src/wireguard/router/mod.rs +++ b/src/wireguard/router/mod.rs @@ -1,12 +1,16 @@ mod anti_replay; mod constants; mod device; +mod inbound; mod ip; mod messages; +mod outbound; mod peer; +mod pool; mod route; mod types; -mod workers; + +// mod workers; #[cfg(test)] mod tests; @@ -16,15 +20,17 @@ use std::mem; use super::constants::REJECT_AFTER_MESSAGES; use super::types::*; +use super::{tun, udp, Endpoint}; +pub const SIZE_TAG: usize = 16; pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::<TransportHeader>(); -pub const CAPACITY_MESSAGE_POSTFIX: usize = workers::SIZE_TAG; +pub const CAPACITY_MESSAGE_POSTFIX: usize = SIZE_TAG; pub const fn message_data_len(payload: usize) -> usize { - payload + mem::size_of::<TransportHeader>() + workers::SIZE_TAG + payload + mem::size_of::<TransportHeader>() + SIZE_TAG } -pub use device::Device; +pub use device::DeviceHandle as Device; pub use messages::TYPE_TRANSPORT; -pub use peer::Peer; +pub use peer::PeerHandle; pub use types::Callbacks; diff --git a/src/wireguard/router/outbound.rs b/src/wireguard/router/outbound.rs new file mode 100644 index 0000000..30b7c2c --- /dev/null +++ b/src/wireguard/router/outbound.rs @@ -0,0 +1,104 @@ +use super::messages::{TransportHeader, TYPE_TRANSPORT}; +use super::peer::Peer; +use super::pool::*; +use super::types::Callbacks; +use super::KeyPair; +use super::REJECT_AFTER_MESSAGES; +use super::{tun, udp, Endpoint}; + +use std::sync::mpsc::Receiver; +use std::sync::Arc; + +use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; +use zerocopy::{AsBytes, LayoutVerified}; + +pub const SIZE_TAG: usize = 16; + +pub struct Outbound { + msg: Vec<u8>, + keypair: Arc<KeyPair>, + counter: u64, +} + +impl Outbound { + pub fn new(msg: Vec<u8>, keypair: Arc<KeyPair>, counter: u64) -> Outbound { + Outbound { + msg, + keypair, + counter, + } + } +} + +#[inline(always)] +fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( + _peer: &Peer<E, C, T, B>, + body: &mut Outbound, +) { + // make space for the tag + body.msg.extend([0u8; SIZE_TAG].iter()); + + // cast to header (should never fail) + let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = + LayoutVerified::new_from_prefix(&mut body.msg[..]) + .expect("earlier code should ensure that there is ample space"); + + // set header fields + debug_assert!( + body.counter < REJECT_AFTER_MESSAGES, + "should be checked when assigning counters" + ); + header.f_type.set(TYPE_TRANSPORT); + header.f_receiver.set(body.keypair.send.id); + header.f_counter.set(body.counter); + + // create a nonce object + let mut nonce = [0u8; 12]; + debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); + nonce[4..].copy_from_slice(header.f_counter.as_bytes()); + let nonce = Nonce::assume_unique_for_key(nonce); + + // do the weird ring AEAD dance + let key = + LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &body.keypair.send.key[..]).unwrap()); + + // encrypt content of transport message in-place + let end = packet.len() - SIZE_TAG; + let tag = key + .seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end]) + .unwrap(); + + // append tag + packet[end..].copy_from_slice(tag.as_ref()); +} + +#[inline(always)] +fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( + peer: &Peer<E, C, T, B>, + body: &mut Outbound, +) { + // send to peer + let xmit = peer.send(&body.msg[..]).is_ok(); + + // trigger callback + C::send( + &peer.opaque, + body.msg.len(), + xmit, + &body.keypair, + body.counter, + ); +} + +#[inline(always)] +pub fn queue<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( + peer: &Peer<E, C, T, B>, +) -> &InorderQueue<Peer<E, C, T, B>, Outbound> { + &peer.outbound +} + +pub fn worker<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( + receiver: Receiver<Job<Peer<E, C, T, B>, Outbound>>, +) { + worker_template(receiver, parallel, sequential, queue) +} diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs index fff4dfc..192d4e2 100644 --- a/src/wireguard/router/peer.rs +++ b/src/wireguard/router/peer.rs @@ -2,10 +2,7 @@ use std::mem; use std::net::{IpAddr, SocketAddr}; use std::ops::Deref; use std::sync::atomic::AtomicBool; -use std::sync::atomic::Ordering; -use std::sync::mpsc::{sync_channel, SyncSender}; use std::sync::Arc; -use std::thread; use arraydeque::{ArrayDeque, Wrapping}; use log::debug; @@ -16,18 +13,18 @@ use super::super::{tun, udp, Endpoint, KeyPair}; use super::anti_replay::AntiReplay; use super::device::DecryptionState; -use super::device::DeviceInner; +use super::device::Device; use super::device::EncryptionState; use super::messages::TransportHeader; -use futures::*; - -use super::workers::{worker_inbound, worker_outbound}; -use super::workers::{JobDecryption, JobEncryption, JobInbound, JobOutbound, JobParallel}; -use super::SIZE_MESSAGE_PREFIX; - use super::constants::*; use super::types::{Callbacks, RouterError}; +use super::SIZE_MESSAGE_PREFIX; + +// worker pool related +use super::inbound::Inbound; +use super::outbound::Outbound; +use super::pool::{InorderQueue, Job}; pub struct KeyWheel { next: Option<Arc<KeyPair>>, // next key state (unconfirmed) @@ -37,10 +34,10 @@ pub struct KeyWheel { } pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { - pub device: Arc<DeviceInner<E, C, T, B>>, + pub device: Device<E, C, T, B>, pub opaque: C::Opaque, - pub outbound: Mutex<SyncSender<JobOutbound>>, - pub inbound: Mutex<SyncSender<JobInbound<E, C, T, B>>>, + pub outbound: InorderQueue<Peer<E, C, T, B>, Outbound>, + pub inbound: InorderQueue<Peer<E, C, T, B>, Inbound<E, C, T, B>>, pub staged_packets: Mutex<ArrayDeque<[Vec<u8>; MAX_STAGED_PACKETS], Wrapping>>, pub keys: Mutex<KeyWheel>, pub ekey: Mutex<Option<EncryptionState>>, @@ -48,16 +45,42 @@ pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E } pub struct Peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { - state: Arc<PeerInner<E, C, T, B>>, - thread_outbound: Option<thread::JoinHandle<()>>, - thread_inbound: Option<thread::JoinHandle<()>>, + inner: Arc<PeerInner<E, C, T, B>>, +} + +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone for Peer<E, C, T, B> { + fn clone(&self) -> Self { + Peer { + inner: self.inner.clone(), + } + } } +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PartialEq for Peer<E, C, T, B> { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.inner, &other.inner) + } +} + +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Eq for Peer<E, C, T, B> {} + impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref for Peer<E, C, T, B> { - type Target = Arc<PeerInner<E, C, T, B>>; + type Target = PeerInner<E, C, T, B>; + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +pub struct PeerHandle<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { + peer: Peer<E, C, T, B>, +} +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref + for PeerHandle<E, C, T, B> +{ + type Target = PeerInner<E, C, T, B>; fn deref(&self) -> &Self::Target { - &self.state + &self.peer } } @@ -72,37 +95,24 @@ impl EncryptionState { } impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DecryptionState<E, C, T, B> { - fn new( - peer: &Arc<PeerInner<E, C, T, B>>, - keypair: &Arc<KeyPair>, - ) -> DecryptionState<E, C, T, B> { + fn new(peer: Peer<E, C, T, B>, keypair: &Arc<KeyPair>) -> DecryptionState<E, C, T, B> { DecryptionState { confirmed: AtomicBool::new(keypair.initiator), keypair: keypair.clone(), protector: spin::Mutex::new(AntiReplay::new()), - peer: peer.clone(), death: keypair.birth + REJECT_AFTER_TIME, + peer, } } } -impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Peer<E, C, T, B> { +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for PeerHandle<E, C, T, B> { fn drop(&mut self) { - let peer = &self.state; + let peer = &self.peer; // remove from cryptkey router - self.state.device.table.remove(peer); - - // drop channels - - mem::replace(&mut *peer.inbound.lock(), sync_channel(0).0); - mem::replace(&mut *peer.outbound.lock(), sync_channel(0).0); - - // join with workers - - mem::replace(&mut self.thread_inbound, None).map(|v| v.join()); - mem::replace(&mut self.thread_outbound, None).map(|v| v.join()); + self.peer.device.table.remove(peer); // release ids from the receiver map @@ -134,50 +144,32 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Peer } pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - device: Arc<DeviceInner<E, C, T, B>>, + device: Device<E, C, T, B>, opaque: C::Opaque, -) -> Peer<E, C, T, B> { - let (out_tx, out_rx) = sync_channel(128); - let (in_tx, in_rx) = sync_channel(128); - +) -> PeerHandle<E, C, T, B> { // allocate peer object let peer = { let device = device.clone(); - Arc::new(PeerInner { - opaque, - device, - inbound: Mutex::new(in_tx), - outbound: Mutex::new(out_tx), - ekey: spin::Mutex::new(None), - endpoint: spin::Mutex::new(None), - keys: spin::Mutex::new(KeyWheel { - next: None, - current: None, - previous: None, - retired: vec![], + Peer { + inner: Arc::new(PeerInner { + opaque, + device, + inbound: InorderQueue::new(), + outbound: InorderQueue::new(), + ekey: spin::Mutex::new(None), + endpoint: spin::Mutex::new(None), + keys: spin::Mutex::new(KeyWheel { + next: None, + current: None, + previous: None, + retired: vec![], + }), + staged_packets: spin::Mutex::new(ArrayDeque::new()), }), - staged_packets: spin::Mutex::new(ArrayDeque::new()), - }) - }; - - // spawn outbound thread - let thread_inbound = { - let peer = peer.clone(); - thread::spawn(move || worker_outbound(peer, out_rx)) - }; - - // spawn inbound thread - let thread_outbound = { - let peer = peer.clone(); - let device = device.clone(); - thread::spawn(move || worker_inbound(device, peer, in_rx)) + } }; - Peer { - state: peer, - thread_inbound: Some(thread_inbound), - thread_outbound: Some(thread_outbound), - } + PeerHandle { peer } } impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E, C, T, B> { @@ -210,7 +202,9 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E, None => Err(RouterError::NoEndpoint), } } +} +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, B> { // Transmit all staged packets fn send_staged(&self) -> bool { debug!("peer.send_staged"); @@ -230,16 +224,12 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E, // Treat the msg as the payload of a transport message // Unlike device.send, peer.send_raw does not buffer messages when a key is not available. fn send_raw(&self, msg: Vec<u8>) -> bool { - debug!("peer.send_raw"); + log::debug!("peer.send_raw"); match self.send_job(msg, false) { Some(job) => { + self.device.outbound_queue.send(job); debug!("send_raw: got obtained send_job"); - let index = self.device.queue_next.fetch_add(1, Ordering::SeqCst); - let queues = self.device.queues.lock(); - match queues[index % queues.len()].send(job) { - Ok(_) => true, - Err(_) => false, - } + true } None => false, } @@ -285,16 +275,11 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E, src: E, dec: Arc<DecryptionState<E, C, T, B>>, msg: Vec<u8>, - ) -> Option<JobParallel> { - let (tx, rx) = oneshot(); - let keypair = dec.keypair.clone(); - match self.inbound.lock().try_send((dec, src, rx)) { - Ok(_) => Some(JobParallel::Decryption(tx, JobDecryption { msg, keypair })), - Err(_) => None, - } + ) -> Option<Job<Self, Inbound<E, C, T, B>>> { + Some(Job::new(self.clone(), Inbound::new(msg, dec, src))) } - pub fn send_job(&self, msg: Vec<u8>, stage: bool) -> Option<JobParallel> { + pub fn send_job(&self, msg: Vec<u8>, stage: bool) -> Option<Job<Self, Outbound>> { debug!("peer.send_job"); debug_assert!( msg.len() >= mem::size_of::<TransportHeader>(), @@ -337,22 +322,13 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E, }?; // add job to in-order queue and return sender to device for inclusion in worker pool - let (tx, rx) = oneshot(); - match self.outbound.lock().try_send(rx) { - Ok(_) => Some(JobParallel::Encryption( - tx, - JobEncryption { - msg, - counter, - keypair, - }, - )), - Err(_) => None, - } + let job = Job::new(self.clone(), Outbound::new(msg, keypair, counter)); + self.outbound.send(job.clone()); + Some(job) } } -impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, B> { +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, C, T, B> { /// Set the endpoint of the peer /// /// # Arguments @@ -365,7 +341,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, /// as sockets should be "unsticked" when manually updating the endpoint pub fn set_endpoint(&self, endpoint: E) { debug!("peer.set_endpoint"); - *self.state.endpoint.lock() = Some(endpoint); + *self.peer.endpoint.lock() = Some(endpoint); } /// Returns the current endpoint of the peer (for configuration) @@ -375,11 +351,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, /// Does not convey potential "sticky socket" information pub fn get_endpoint(&self) -> Option<SocketAddr> { debug!("peer.get_endpoint"); - self.state - .endpoint - .lock() - .as_ref() - .map(|e| e.into_address()) + self.peer.endpoint.lock().as_ref().map(|e| e.into_address()) } /// Zero all key-material related to the peer @@ -387,7 +359,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, debug!("peer.zero_keys"); let mut release: Vec<u32> = Vec::with_capacity(3); - let mut keys = self.state.keys.lock(); + let mut keys = self.peer.keys.lock(); // update key-wheel @@ -398,14 +370,14 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, // update inbound "recv" map { - let mut recv = self.state.device.recv.write(); + let mut recv = self.peer.device.recv.write(); for id in release { recv.remove(&id); } } // clear encryption state - *self.state.ekey.lock() = None; + *self.peer.ekey.lock() = None; } pub fn down(&self) { @@ -436,13 +408,13 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, let initiator = new.initiator; let release = { let new = Arc::new(new); - let mut keys = self.state.keys.lock(); + let mut keys = self.peer.keys.lock(); let mut release = mem::replace(&mut keys.retired, vec![]); // update key-wheel if new.initiator { // start using key for encryption - *self.state.ekey.lock() = Some(EncryptionState::new(&new)); + *self.peer.ekey.lock() = Some(EncryptionState::new(&new)); // move current into previous keys.previous = keys.current.as_ref().map(|v| v.clone()); @@ -456,7 +428,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, // update incoming packet id map { debug!("peer.add_keypair: updating inbound id map"); - let mut recv = self.state.device.recv.write(); + let mut recv = self.peer.device.recv.write(); // purge recv map of previous id keys.previous.as_ref().map(|k| { @@ -468,7 +440,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, debug_assert!(!recv.contains_key(&new.recv.id)); recv.insert( new.recv.id, - Arc::new(DecryptionState::new(&self.state, &new)), + Arc::new(DecryptionState::new(self.peer.clone(), &new)), ); } release @@ -476,10 +448,10 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, // schedule confirmation if initiator { - debug_assert!(self.state.ekey.lock().is_some()); + debug_assert!(self.peer.ekey.lock().is_some()); debug!("peer.add_keypair: is initiator, must confirm the key"); // attempt to confirm using staged packets - if !self.state.send_staged() { + if !self.peer.send_staged() { // fall back to keepalive packet let ok = self.send_keepalive(); debug!( @@ -499,7 +471,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, pub fn send_keepalive(&self) -> bool { debug!("peer.send_keepalive"); - self.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX]) + self.peer.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX]) } /// Map a subnet to the peer @@ -517,10 +489,10 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, /// If an identical value already exists as part of a prior peer, /// the allowed IP entry will be removed from that peer and added to this peer. pub fn add_allowed_ip(&self, ip: IpAddr, masklen: u32) { - self.state + self.peer .device .table - .insert(ip, masklen, self.state.clone()) + .insert(ip, masklen, self.peer.clone()) } /// List subnets mapped to the peer @@ -529,23 +501,21 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, /// /// A vector of subnets, represented by as mask/size pub fn list_allowed_ips(&self) -> Vec<(IpAddr, u32)> { - self.state.device.table.list(&self.state) + self.peer.device.table.list(&self.peer) } /// Clear subnets mapped to the peer. /// After the call, no subnets will be cryptkey routed to the peer. /// Used for the UAPI command "replace_allowed_ips=true" pub fn remove_allowed_ips(&self) { - self.state.device.table.remove(&self.state) + self.peer.device.table.remove(&self.peer) } pub fn clear_src(&self) { - (*self.state.endpoint.lock()) - .as_mut() - .map(|e| e.clear_src()); + (*self.peer.endpoint.lock()).as_mut().map(|e| e.clear_src()); } pub fn purge_staged_packets(&self) { - self.state.staged_packets.lock().clear(); + self.peer.staged_packets.lock().clear(); } } diff --git a/src/wireguard/router/pool.rs b/src/wireguard/router/pool.rs new file mode 100644 index 0000000..12956c1 --- /dev/null +++ b/src/wireguard/router/pool.rs @@ -0,0 +1,132 @@ +use arraydeque::ArrayDeque; +use spin::{Mutex, MutexGuard}; +use std::sync::mpsc::Receiver; +use std::sync::Arc; + +const INORDER_QUEUE_SIZE: usize = 64; + +pub struct InnerJob<P, B> { + // peer (used by worker to schedule/handle inorder queue), + // when the peer is None, the job is complete + peer: Option<P>, + pub body: B, +} + +pub struct Job<P, B> { + inner: Arc<Mutex<InnerJob<P, B>>>, +} + +impl<P, B> Clone for Job<P, B> { + fn clone(&self) -> Job<P, B> { + Job { + inner: self.inner.clone(), + } + } +} + +impl<P, B> Job<P, B> { + pub fn new(peer: P, body: B) -> Job<P, B> { + Job { + inner: Arc::new(Mutex::new(InnerJob { + peer: Some(peer), + body, + })), + } + } +} + +impl<P, B> Job<P, B> { + /// Returns a mutex guard to the inner job if complete + pub fn complete(&self) -> Option<MutexGuard<InnerJob<P, B>>> { + self.inner + .try_lock() + .and_then(|m| if m.peer.is_none() { Some(m) } else { None }) + } +} + +pub struct InorderQueue<P, B> { + queue: Mutex<ArrayDeque<[Job<P, B>; INORDER_QUEUE_SIZE]>>, +} + +impl<P, B> InorderQueue<P, B> { + pub fn send(&self, job: Job<P, B>) -> bool { + self.queue.lock().push_back(job).is_ok() + } + + pub fn new() -> InorderQueue<P, B> { + InorderQueue { + queue: Mutex::new(ArrayDeque::new()), + } + } + + #[inline(always)] + pub fn handle<F: Fn(&mut InnerJob<P, B>)>(&self, f: F) { + // take the mutex + let mut queue = self.queue.lock(); + + // handle all complete messages + while queue + .pop_front() + .and_then(|j| { + // check if job is complete + let ret = if let Some(mut guard) = j.complete() { + f(&mut *guard); + false + } else { + true + }; + + // return job to cyclic buffer if not complete + if ret { + let _res = queue.push_front(j); + debug_assert!(_res.is_ok()); + None + } else { + // add job back to pool + Some(()) + } + }) + .is_some() + {} + } +} + +/// Allows easy construction of a semi-parallel worker. +/// Applicable for both decryption and encryption workers. +#[inline(always)] +pub fn worker_template< + P, // represents a peer (atomic reference counted pointer) + B, // inner body type (message buffer, key material, ...) + W: Fn(&P, &mut B), + S: Fn(&P, &mut B), + Q: Fn(&P) -> &InorderQueue<P, B>, +>( + receiver: Receiver<Job<P, B>>, // receiever for new jobs + work_parallel: W, // perform parallel / out-of-order work on peer + work_sequential: S, // perform sequential work on peer + queue: Q, // resolve a peer to an inorder queue +) { + loop { + // handle new job + let peer = { + // get next job + let job = match receiver.recv() { + Ok(job) => job, + _ => return, + }; + + // lock the job + let mut job = job.inner.lock(); + + // take the peer from the job + let peer = job.peer.take().unwrap(); + + // process job + work_parallel(&peer, &mut job.body); + peer + }; + + // process inorder jobs for peer + queue(&peer).handle(|j| work_sequential(&peer, &mut j.body)); + } +} diff --git a/src/wireguard/router/route.rs b/src/wireguard/router/route.rs index 1c93009..40dc36b 100644 --- a/src/wireguard/router/route.rs +++ b/src/wireguard/router/route.rs @@ -4,7 +4,6 @@ use zerocopy::LayoutVerified; use std::mem; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; -use std::sync::Arc; use spin::RwLock; use treebitmap::address::Address; @@ -12,12 +11,12 @@ use treebitmap::IpLookupTable; /* Functions for obtaining and validating "cryptokey" routes */ -pub struct RoutingTable<T> { - ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<T>>>, - ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<T>>>, +pub struct RoutingTable<T: Eq + Clone> { + ipv4: RwLock<IpLookupTable<Ipv4Addr, T>>, + ipv6: RwLock<IpLookupTable<Ipv6Addr, T>>, } -impl<T> RoutingTable<T> { +impl<T: Eq + Clone> RoutingTable<T> { pub fn new() -> Self { RoutingTable { ipv4: RwLock::new(IpLookupTable::new()), @@ -26,27 +25,27 @@ impl<T> RoutingTable<T> { } // collect keys mapping to the given value - fn collect<A>(table: &IpLookupTable<A, Arc<T>>, value: &Arc<T>) -> Vec<(A, u32)> + fn collect<A>(table: &IpLookupTable<A, T>, value: &T) -> Vec<(A, u32)> where A: Address, { let mut res = Vec::new(); for (ip, cidr, v) in table.iter() { - if Arc::ptr_eq(v, value) { + if v == value { res.push((ip, cidr)) } } res } - pub fn insert(&self, ip: IpAddr, cidr: u32, value: Arc<T>) { + pub fn insert(&self, ip: IpAddr, cidr: u32, value: T) { match ip { IpAddr::V4(v4) => self.ipv4.write().insert(v4.mask(cidr), cidr, value), IpAddr::V6(v6) => self.ipv6.write().insert(v6.mask(cidr), cidr, value), }; } - pub fn list(&self, value: &Arc<T>) -> Vec<(IpAddr, u32)> { + pub fn list(&self, value: &T) -> Vec<(IpAddr, u32)> { let mut res = vec![]; res.extend( Self::collect(&*self.ipv4.read(), value) @@ -61,7 +60,7 @@ impl<T> RoutingTable<T> { res } - pub fn remove(&self, value: &Arc<T>) { + pub fn remove(&self, value: &T) { let mut v4 = self.ipv4.write(); for (ip, cidr) in Self::collect(&*v4, value) { v4.remove(ip, cidr); @@ -74,7 +73,7 @@ impl<T> RoutingTable<T> { } #[inline(always)] - pub fn get_route(&self, packet: &[u8]) -> Option<Arc<T>> { + pub fn get_route(&self, packet: &[u8]) -> Option<T> { match packet.get(0)? >> 4 { VERSION_IP4 => { // check length and cast to IPv4 header @@ -113,7 +112,7 @@ impl<T> RoutingTable<T> { } #[inline(always)] - pub fn check_route(&self, peer: &Arc<T>, packet: &[u8]) -> Option<usize> { + pub fn check_route(&self, peer: &T, packet: &[u8]) -> Option<usize> { match packet.get(0)? >> 4 { VERSION_IP4 => { // check length and cast to IPv4 header @@ -130,7 +129,7 @@ impl<T> RoutingTable<T> { .read() .longest_match(Ipv4Addr::from(header.f_source)) .and_then(|(_, _, p)| { - if Arc::ptr_eq(p, peer) { + if p == peer { Some(header.f_total_len.get() as usize) } else { None @@ -152,7 +151,7 @@ impl<T> RoutingTable<T> { .read() .longest_match(Ipv6Addr::from(header.f_source)) .and_then(|(_, _, p)| { - if Arc::ptr_eq(p, peer) { + if p == peer { Some(header.f_len.get() as usize + mem::size_of::<IPv6Header>()) } else { None diff --git a/src/wireguard/tests.rs b/src/wireguard/tests.rs index bf1bd5f..3cccb42 100644 --- a/src/wireguard/tests.rs +++ b/src/wireguard/tests.rs @@ -1,5 +1,5 @@ +use super::dummy; use super::wireguard::Wireguard; -use super::{dummy, tun, udp}; use std::net::IpAddr; use std::thread; diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs index 0ce4210..e1aabad 100644 --- a/src/wireguard/timers.rs +++ b/src/wireguard/timers.rs @@ -137,6 +137,7 @@ impl<T: tun::Tun, B: udp::UDP> PeerInner<T, B> { pub fn timers_handshake_complete(&self) { let timers = self.timers(); if timers.enabled { + timers.retransmit_handshake.stop(); timers.handshake_attempts.store(0, Ordering::SeqCst); timers.sent_lastminute_handshake.store(false, Ordering::SeqCst); *self.walltime_last_handshake.lock() = Some(SystemTime::now()); |