diff options
Diffstat (limited to '')
-rw-r--r-- | src/wireguard/router/peer.rs | 220 |
1 files changed, 95 insertions, 125 deletions
diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs index fff4dfc..192d4e2 100644 --- a/src/wireguard/router/peer.rs +++ b/src/wireguard/router/peer.rs @@ -2,10 +2,7 @@ use std::mem; use std::net::{IpAddr, SocketAddr}; use std::ops::Deref; use std::sync::atomic::AtomicBool; -use std::sync::atomic::Ordering; -use std::sync::mpsc::{sync_channel, SyncSender}; use std::sync::Arc; -use std::thread; use arraydeque::{ArrayDeque, Wrapping}; use log::debug; @@ -16,18 +13,18 @@ use super::super::{tun, udp, Endpoint, KeyPair}; use super::anti_replay::AntiReplay; use super::device::DecryptionState; -use super::device::DeviceInner; +use super::device::Device; use super::device::EncryptionState; use super::messages::TransportHeader; -use futures::*; - -use super::workers::{worker_inbound, worker_outbound}; -use super::workers::{JobDecryption, JobEncryption, JobInbound, JobOutbound, JobParallel}; -use super::SIZE_MESSAGE_PREFIX; - use super::constants::*; use super::types::{Callbacks, RouterError}; +use super::SIZE_MESSAGE_PREFIX; + +// worker pool related +use super::inbound::Inbound; +use super::outbound::Outbound; +use super::pool::{InorderQueue, Job}; pub struct KeyWheel { next: Option<Arc<KeyPair>>, // next key state (unconfirmed) @@ -37,10 +34,10 @@ pub struct KeyWheel { } pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { - pub device: Arc<DeviceInner<E, C, T, B>>, + pub device: Device<E, C, T, B>, pub opaque: C::Opaque, - pub outbound: Mutex<SyncSender<JobOutbound>>, - pub inbound: Mutex<SyncSender<JobInbound<E, C, T, B>>>, + pub outbound: InorderQueue<Peer<E, C, T, B>, Outbound>, + pub inbound: InorderQueue<Peer<E, C, T, B>, Inbound<E, C, T, B>>, pub staged_packets: Mutex<ArrayDeque<[Vec<u8>; MAX_STAGED_PACKETS], Wrapping>>, pub keys: Mutex<KeyWheel>, pub ekey: Mutex<Option<EncryptionState>>, @@ -48,16 +45,42 @@ pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E } pub struct Peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { - state: Arc<PeerInner<E, C, T, B>>, - thread_outbound: Option<thread::JoinHandle<()>>, - thread_inbound: Option<thread::JoinHandle<()>>, + inner: Arc<PeerInner<E, C, T, B>>, +} + +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Clone for Peer<E, C, T, B> { + fn clone(&self) -> Self { + Peer { + inner: self.inner.clone(), + } + } } +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PartialEq for Peer<E, C, T, B> { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.inner, &other.inner) + } +} + +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Eq for Peer<E, C, T, B> {} + impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref for Peer<E, C, T, B> { - type Target = Arc<PeerInner<E, C, T, B>>; + type Target = PeerInner<E, C, T, B>; + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +pub struct PeerHandle<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> { + peer: Peer<E, C, T, B>, +} +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Deref + for PeerHandle<E, C, T, B> +{ + type Target = PeerInner<E, C, T, B>; fn deref(&self) -> &Self::Target { - &self.state + &self.peer } } @@ -72,37 +95,24 @@ impl EncryptionState { } impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DecryptionState<E, C, T, B> { - fn new( - peer: &Arc<PeerInner<E, C, T, B>>, - keypair: &Arc<KeyPair>, - ) -> DecryptionState<E, C, T, B> { + fn new(peer: Peer<E, C, T, B>, keypair: &Arc<KeyPair>) -> DecryptionState<E, C, T, B> { DecryptionState { confirmed: AtomicBool::new(keypair.initiator), keypair: keypair.clone(), protector: spin::Mutex::new(AntiReplay::new()), - peer: peer.clone(), death: keypair.birth + REJECT_AFTER_TIME, + peer, } } } -impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Peer<E, C, T, B> { +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for PeerHandle<E, C, T, B> { fn drop(&mut self) { - let peer = &self.state; + let peer = &self.peer; // remove from cryptkey router - self.state.device.table.remove(peer); - - // drop channels - - mem::replace(&mut *peer.inbound.lock(), sync_channel(0).0); - mem::replace(&mut *peer.outbound.lock(), sync_channel(0).0); - - // join with workers - - mem::replace(&mut self.thread_inbound, None).map(|v| v.join()); - mem::replace(&mut self.thread_outbound, None).map(|v| v.join()); + self.peer.device.table.remove(peer); // release ids from the receiver map @@ -134,50 +144,32 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Peer } pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( - device: Arc<DeviceInner<E, C, T, B>>, + device: Device<E, C, T, B>, opaque: C::Opaque, -) -> Peer<E, C, T, B> { - let (out_tx, out_rx) = sync_channel(128); - let (in_tx, in_rx) = sync_channel(128); - +) -> PeerHandle<E, C, T, B> { // allocate peer object let peer = { let device = device.clone(); - Arc::new(PeerInner { - opaque, - device, - inbound: Mutex::new(in_tx), - outbound: Mutex::new(out_tx), - ekey: spin::Mutex::new(None), - endpoint: spin::Mutex::new(None), - keys: spin::Mutex::new(KeyWheel { - next: None, - current: None, - previous: None, - retired: vec![], + Peer { + inner: Arc::new(PeerInner { + opaque, + device, + inbound: InorderQueue::new(), + outbound: InorderQueue::new(), + ekey: spin::Mutex::new(None), + endpoint: spin::Mutex::new(None), + keys: spin::Mutex::new(KeyWheel { + next: None, + current: None, + previous: None, + retired: vec![], + }), + staged_packets: spin::Mutex::new(ArrayDeque::new()), }), - staged_packets: spin::Mutex::new(ArrayDeque::new()), - }) - }; - - // spawn outbound thread - let thread_inbound = { - let peer = peer.clone(); - thread::spawn(move || worker_outbound(peer, out_rx)) - }; - - // spawn inbound thread - let thread_outbound = { - let peer = peer.clone(); - let device = device.clone(); - thread::spawn(move || worker_inbound(device, peer, in_rx)) + } }; - Peer { - state: peer, - thread_inbound: Some(thread_inbound), - thread_outbound: Some(thread_outbound), - } + PeerHandle { peer } } impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E, C, T, B> { @@ -210,7 +202,9 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E, None => Err(RouterError::NoEndpoint), } } +} +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, B> { // Transmit all staged packets fn send_staged(&self) -> bool { debug!("peer.send_staged"); @@ -230,16 +224,12 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E, // Treat the msg as the payload of a transport message // Unlike device.send, peer.send_raw does not buffer messages when a key is not available. fn send_raw(&self, msg: Vec<u8>) -> bool { - debug!("peer.send_raw"); + log::debug!("peer.send_raw"); match self.send_job(msg, false) { Some(job) => { + self.device.outbound_queue.send(job); debug!("send_raw: got obtained send_job"); - let index = self.device.queue_next.fetch_add(1, Ordering::SeqCst); - let queues = self.device.queues.lock(); - match queues[index % queues.len()].send(job) { - Ok(_) => true, - Err(_) => false, - } + true } None => false, } @@ -285,16 +275,11 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E, src: E, dec: Arc<DecryptionState<E, C, T, B>>, msg: Vec<u8>, - ) -> Option<JobParallel> { - let (tx, rx) = oneshot(); - let keypair = dec.keypair.clone(); - match self.inbound.lock().try_send((dec, src, rx)) { - Ok(_) => Some(JobParallel::Decryption(tx, JobDecryption { msg, keypair })), - Err(_) => None, - } + ) -> Option<Job<Self, Inbound<E, C, T, B>>> { + Some(Job::new(self.clone(), Inbound::new(msg, dec, src))) } - pub fn send_job(&self, msg: Vec<u8>, stage: bool) -> Option<JobParallel> { + pub fn send_job(&self, msg: Vec<u8>, stage: bool) -> Option<Job<Self, Outbound>> { debug!("peer.send_job"); debug_assert!( msg.len() >= mem::size_of::<TransportHeader>(), @@ -337,22 +322,13 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E, }?; // add job to in-order queue and return sender to device for inclusion in worker pool - let (tx, rx) = oneshot(); - match self.outbound.lock().try_send(rx) { - Ok(_) => Some(JobParallel::Encryption( - tx, - JobEncryption { - msg, - counter, - keypair, - }, - )), - Err(_) => None, - } + let job = Job::new(self.clone(), Outbound::new(msg, keypair, counter)); + self.outbound.send(job.clone()); + Some(job) } } -impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, B> { +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, C, T, B> { /// Set the endpoint of the peer /// /// # Arguments @@ -365,7 +341,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, /// as sockets should be "unsticked" when manually updating the endpoint pub fn set_endpoint(&self, endpoint: E) { debug!("peer.set_endpoint"); - *self.state.endpoint.lock() = Some(endpoint); + *self.peer.endpoint.lock() = Some(endpoint); } /// Returns the current endpoint of the peer (for configuration) @@ -375,11 +351,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, /// Does not convey potential "sticky socket" information pub fn get_endpoint(&self) -> Option<SocketAddr> { debug!("peer.get_endpoint"); - self.state - .endpoint - .lock() - .as_ref() - .map(|e| e.into_address()) + self.peer.endpoint.lock().as_ref().map(|e| e.into_address()) } /// Zero all key-material related to the peer @@ -387,7 +359,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, debug!("peer.zero_keys"); let mut release: Vec<u32> = Vec::with_capacity(3); - let mut keys = self.state.keys.lock(); + let mut keys = self.peer.keys.lock(); // update key-wheel @@ -398,14 +370,14 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, // update inbound "recv" map { - let mut recv = self.state.device.recv.write(); + let mut recv = self.peer.device.recv.write(); for id in release { recv.remove(&id); } } // clear encryption state - *self.state.ekey.lock() = None; + *self.peer.ekey.lock() = None; } pub fn down(&self) { @@ -436,13 +408,13 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, let initiator = new.initiator; let release = { let new = Arc::new(new); - let mut keys = self.state.keys.lock(); + let mut keys = self.peer.keys.lock(); let mut release = mem::replace(&mut keys.retired, vec![]); // update key-wheel if new.initiator { // start using key for encryption - *self.state.ekey.lock() = Some(EncryptionState::new(&new)); + *self.peer.ekey.lock() = Some(EncryptionState::new(&new)); // move current into previous keys.previous = keys.current.as_ref().map(|v| v.clone()); @@ -456,7 +428,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, // update incoming packet id map { debug!("peer.add_keypair: updating inbound id map"); - let mut recv = self.state.device.recv.write(); + let mut recv = self.peer.device.recv.write(); // purge recv map of previous id keys.previous.as_ref().map(|k| { @@ -468,7 +440,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, debug_assert!(!recv.contains_key(&new.recv.id)); recv.insert( new.recv.id, - Arc::new(DecryptionState::new(&self.state, &new)), + Arc::new(DecryptionState::new(self.peer.clone(), &new)), ); } release @@ -476,10 +448,10 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, // schedule confirmation if initiator { - debug_assert!(self.state.ekey.lock().is_some()); + debug_assert!(self.peer.ekey.lock().is_some()); debug!("peer.add_keypair: is initiator, must confirm the key"); // attempt to confirm using staged packets - if !self.state.send_staged() { + if !self.peer.send_staged() { // fall back to keepalive packet let ok = self.send_keepalive(); debug!( @@ -499,7 +471,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, pub fn send_keepalive(&self) -> bool { debug!("peer.send_keepalive"); - self.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX]) + self.peer.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX]) } /// Map a subnet to the peer @@ -517,10 +489,10 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, /// If an identical value already exists as part of a prior peer, /// the allowed IP entry will be removed from that peer and added to this peer. pub fn add_allowed_ip(&self, ip: IpAddr, masklen: u32) { - self.state + self.peer .device .table - .insert(ip, masklen, self.state.clone()) + .insert(ip, masklen, self.peer.clone()) } /// List subnets mapped to the peer @@ -529,23 +501,21 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, /// /// A vector of subnets, represented by as mask/size pub fn list_allowed_ips(&self) -> Vec<(IpAddr, u32)> { - self.state.device.table.list(&self.state) + self.peer.device.table.list(&self.peer) } /// Clear subnets mapped to the peer. /// After the call, no subnets will be cryptkey routed to the peer. /// Used for the UAPI command "replace_allowed_ips=true" pub fn remove_allowed_ips(&self) { - self.state.device.table.remove(&self.state) + self.peer.device.table.remove(&self.peer) } pub fn clear_src(&self) { - (*self.state.endpoint.lock()) - .as_mut() - .map(|e| e.clear_src()); + (*self.peer.endpoint.lock()).as_mut().map(|e| e.clear_src()); } pub fn purge_staged_packets(&self) { - self.state.staged_packets.lock().clear(); + self.peer.staged_packets.lock().clear(); } } |