diff options
Diffstat (limited to 'src/wireguard/router/peer.rs')
-rw-r--r-- | src/wireguard/router/peer.rs | 192 |
1 files changed, 79 insertions, 113 deletions
diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs index 7312bc7..710cf32 100644 --- a/src/wireguard/router/peer.rs +++ b/src/wireguard/router/peer.rs @@ -1,13 +1,3 @@ -use std::mem; -use std::net::{IpAddr, SocketAddr}; -use std::ops::Deref; -use std::sync::atomic::AtomicBool; -use std::sync::Arc; - -use arraydeque::{ArrayDeque, Wrapping}; -use log::debug; -use spin::Mutex; - use super::super::constants::*; use super::super::{tun, udp, Endpoint, KeyPair}; @@ -15,20 +5,25 @@ use super::anti_replay::AntiReplay; use super::device::DecryptionState; use super::device::Device; use super::device::EncryptionState; -use super::messages::TransportHeader; use super::constants::*; -use super::runq::ToKey; use super::types::{Callbacks, RouterError}; use super::SIZE_MESSAGE_PREFIX; -// worker pool related -use super::inbound::Inbound; -use super::outbound::Outbound; use super::queue::Queue; - -use super::send::SendJob; use super::receive::ReceiveJob; +use super::send::SendJob; +use super::worker::JobUnion; + +use std::mem; +use std::net::{IpAddr, SocketAddr}; +use std::ops::Deref; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; + +use arraydeque::{ArrayDeque, Wrapping}; +use log::debug; +use spin::Mutex; pub struct KeyWheel { next: Option<Arc<KeyPair>>, // next key state (unconfirmed) @@ -44,7 +39,7 @@ pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E pub inbound: Queue<ReceiveJob<E, C, T, B>>, pub staged_packets: Mutex<ArrayDeque<[Vec<u8>; MAX_QUEUED_PACKETS], Wrapping>>, pub keys: Mutex<KeyWheel>, - pub ekey: Mutex<Option<EncryptionState>>, + pub enc_key: Mutex<Option<EncryptionState>>, pub endpoint: Mutex<Option<E>>, } @@ -69,13 +64,6 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PartialEq for } } -impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ToKey for Peer<E, C, T, B> { - type Key = usize; - fn to_key(&self) -> usize { - Arc::downgrade(&self.inner).into_raw() as usize - } -} - impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Eq for Peer<E, C, T, B> {} /* A peer is transparently dereferenced to the inner type @@ -157,7 +145,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Drop for Peer keys.current = None; keys.previous = None; - *peer.ekey.lock() = None; + *peer.enc_key.lock() = None; *peer.endpoint.lock() = None; debug!("peer dropped & removed from device"); @@ -175,9 +163,9 @@ pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>( inner: Arc::new(PeerInner { opaque, device, - inbound: InorderQueue::new(), - outbound: InorderQueue::new(), - ekey: spin::Mutex::new(None), + inbound: Queue::new(), + outbound: Queue::new(), + enc_key: spin::Mutex::new(None), endpoint: spin::Mutex::new(None), keys: spin::Mutex::new(KeyWheel { next: None, @@ -203,7 +191,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E, /// # Returns /// /// Unit if packet was sent, or an error indicating why sending failed - pub fn send(&self, msg: &[u8]) -> Result<(), RouterError> { + pub fn send_raw(&self, msg: &[u8]) -> Result<(), RouterError> { debug!("peer.send"); // send to endpoint (if known) @@ -226,6 +214,57 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E, } impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, B> { + /// Encrypt and send a message to the peer + /// + /// Arguments: + /// + /// - `msg` : A padded vector holding the message (allows in-place construction of the transport header) + /// - `stage`: Should the message be staged if no key is available + /// + pub(super) fn send(&self, msg: Vec<u8>, stage: bool) { + // check if key available + let (job, need_key) = { + let mut enc_key = self.enc_key.lock(); + match enc_key.as_mut() { + None => { + if stage { + self.staged_packets.lock().push_back(msg); + }; + (None, true) + } + Some(mut state) => { + // avoid integer overflow in nonce + if state.nonce >= REJECT_AFTER_MESSAGES - 1 { + *enc_key = None; + if stage { + self.staged_packets.lock().push_back(msg); + } + (None, true) + } else { + debug!("encryption state available, nonce = {}", state.nonce); + let job = + SendJob::new(msg, state.nonce, state.keypair.clone(), self.clone()); + if self.outbound.push(job.clone()) { + state.nonce += 1; + (Some(job), false) + } else { + (None, false) + } + } + } + } + }; + + if need_key { + debug_assert!(job.is_none()); + C::need_key(&self.opaque); + }; + + if let Some(job) = job { + self.device.work.send(JobUnion::Outbound(job)) + } + } + // Transmit all staged packets fn send_staged(&self) -> bool { debug!("peer.send_staged"); @@ -235,28 +274,14 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, match staged.pop_front() { Some(msg) => { sent = true; - self.send_raw(msg); + self.send(msg, false); } None => break sent, } } } - // Treat the msg as the payload of a transport message - // Unlike device.send, peer.send_raw does not buffer messages when a key is not available. - fn send_raw(&self, msg: Vec<u8>) -> bool { - log::debug!("peer.send_raw"); - match self.send_job(msg, false) { - Some(job) => { - self.device.queue_outbound.send(job); - debug!("send_raw: got obtained send_job"); - true - } - None => false, - } - } - - pub fn confirm_key(&self, keypair: &Arc<KeyPair>) { + pub(super) fn confirm_key(&self, keypair: &Arc<KeyPair>) { debug!("peer.confirm_key"); { // take lock and check keypair = keys.next @@ -284,68 +309,12 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T, C::key_confirmed(&self.opaque); // set new key for encryption - *self.ekey.lock() = ekey; + *self.enc_key.lock() = ekey; } // start transmission of staged packets self.send_staged(); } - - pub fn send_job(&self, msg: Vec<u8>, stage: bool) -> Option<SendJob<E, C, T, B>> { - debug!("peer.send_job"); - debug_assert!( - msg.len() >= mem::size_of::<TransportHeader>(), - "received TUN message with size: {:}", - msg.len() - ); - - // check if has key - let (keypair, counter) = { - let keypair = { - // TODO: consider using atomic ptr for ekey state - let mut ekey = self.ekey.lock(); - match ekey.as_mut() { - None => None, - Some(mut state) => { - // avoid integer overflow in nonce - if state.nonce >= REJECT_AFTER_MESSAGES - 1 { - *ekey = None; - None - } else { - debug!("encryption state available, nonce = {}", state.nonce); - let counter = state.nonce; - state.nonce += 1; - - SendJob::new( - msg, - state.nonce, - state.keypair.clone(), - self.clone() - ); - - Some((state.keypair.clone(), counter)) - } - } - } - }; - - // If not suitable key was found: - // 1. Stage packet for later transmission - // 2. Request new key - if keypair.is_none() && stage { - self.staged_packets.lock().push_back(msg); - C::need_key(&self.opaque); - return None; - }; - - keypair - }?; - - // add job to in-order queue and return sender to device for inclusion in worker pool - let job = Job::new(self.clone(), Outbound::new(msg, keypair, counter)); - self.outbound.send(job.clone()); - Some(job) - } } impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, C, T, B> { @@ -397,7 +366,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, } // clear encryption state - *self.peer.ekey.lock() = None; + *self.peer.enc_key.lock() = None; } pub fn down(&self) { @@ -434,7 +403,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, // update key-wheel if new.initiator { // start using key for encryption - *self.peer.ekey.lock() = Some(EncryptionState::new(&new)); + *self.peer.enc_key.lock() = Some(EncryptionState::new(&new)); // move current into previous keys.previous = keys.current.as_ref().map(|v| v.clone()); @@ -468,16 +437,13 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, // schedule confirmation if initiator { - debug_assert!(self.peer.ekey.lock().is_some()); + debug_assert!(self.peer.enc_key.lock().is_some()); debug!("peer.add_keypair: is initiator, must confirm the key"); // attempt to confirm using staged packets if !self.peer.send_staged() { // fall back to keepalive packet - let ok = self.send_keepalive(); - debug!( - "peer.add_keypair: keepalive for confirmation, sent = {}", - ok - ); + self.send_keepalive(); + debug!("peer.add_keypair: keepalive for confirmation",); } debug!("peer.add_keypair: key attempted confirmed"); } @@ -489,9 +455,9 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerHandle<E, release } - pub fn send_keepalive(&self) -> bool { + pub fn send_keepalive(&self) { debug!("peer.send_keepalive"); - self.peer.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX]) + self.peer.send(vec![0u8; SIZE_MESSAGE_PREFIX], false) } /// Map a subnet to the peer |