diff options
author | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-11-04 13:19:27 +0100 |
---|---|---|
committer | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-11-04 13:19:27 +0100 |
commit | 6ba40f17cb484c0b9b76caf926ef24539892d5a6 (patch) | |
tree | 4df08b852aab26bfe37c144123a377e4fd28acb2 /src/wireguard | |
parent | Work on UAPI parser (diff) | |
download | wireguard-rs-6ba40f17cb484c0b9b76caf926ef24539892d5a6.tar.xz wireguard-rs-6ba40f17cb484c0b9b76caf926ef24539892d5a6.zip |
Work on Up/Down operation on WireGuard device
Diffstat (limited to '')
-rw-r--r-- | src/wireguard/mod.rs | 13 | ||||
-rw-r--r-- | src/wireguard/peer.rs | 111 | ||||
-rw-r--r-- | src/wireguard/router/device.rs | 13 | ||||
-rw-r--r-- | src/wireguard/router/peer.rs | 68 | ||||
-rw-r--r-- | src/wireguard/router/workers.rs | 22 | ||||
-rw-r--r-- | src/wireguard/timers.rs | 59 | ||||
-rw-r--r-- | src/wireguard/wireguard.rs | 118 |
7 files changed, 261 insertions, 143 deletions
diff --git a/src/wireguard/mod.rs b/src/wireguard/mod.rs index 83f9e8a..79feed7 100644 --- a/src/wireguard/mod.rs +++ b/src/wireguard/mod.rs @@ -4,18 +4,15 @@ mod wireguard; mod endpoint; mod handshake; +mod peer; mod router; mod types; #[cfg(test)] mod tests; -/// The WireGuard sub-module contains a pure, configurable implementation of WireGuard. -/// The implementation is generic over: -/// -/// - TUN type, specifying how packets are received on the interface side: a reader/writer and MTU reporting interface. -/// - Bind type, specifying how WireGuard messages are sent/received from the internet and what constitutes an "endpoint" -pub use wireguard::{Peer, Wireguard}; +pub use peer::Peer; +pub use wireguard::Wireguard; #[cfg(test)] pub use types::dummy_keypair; @@ -24,4 +21,6 @@ pub use types::dummy_keypair; use super::platform::dummy; use super::platform::{bind, tun, Endpoint}; -use types::{Key, KeyPair}; +use peer::PeerInner; +use types::KeyPair; +use wireguard::HandshakeJob; diff --git a/src/wireguard/peer.rs b/src/wireguard/peer.rs new file mode 100644 index 0000000..9f24dea --- /dev/null +++ b/src/wireguard/peer.rs @@ -0,0 +1,111 @@ +use super::constants::*; +use super::router; +use super::timers::{Events, Timers}; +use super::HandshakeJob; + +use super::bind::Bind; +use super::bind::Reader as BindReader; +use super::tun::{Reader, Tun}; + +use std::fmt; +use std::ops::Deref; +use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::{Instant, SystemTime}; + +use spin::{Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard}; + +use crossbeam_channel::Sender; +use x25519_dalek::PublicKey; + +pub struct Peer<T: Tun, B: Bind> { + pub router: Arc<router::Peer<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>, + pub state: Arc<PeerInner<B>>, +} + +pub struct PeerInner<B: Bind> { + // internal id (for logging) + pub id: u64, + + // handshake state + pub walltime_last_handshake: Mutex<SystemTime>, + pub last_handshake_sent: Mutex<Instant>, // instant for last handshake + pub handshake_queued: AtomicBool, // is a handshake job currently queued for the peer? + pub queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, // handshake queue + + // stats and configuration + pub pk: PublicKey, // public key, DISCUSS: avoid this. TODO: remove + pub keepalive_interval: AtomicU64, // keepalive interval + pub rx_bytes: AtomicU64, // received bytes + pub tx_bytes: AtomicU64, // transmitted bytes + + // timer model + pub timers: RwLock<Timers>, +} + +impl<T: Tun, B: Bind> Clone for Peer<T, B> { + fn clone(&self) -> Peer<T, B> { + Peer { + router: self.router.clone(), + state: self.state.clone(), + } + } +} + +impl<B: Bind> PeerInner<B> { + #[inline(always)] + pub fn timers(&self) -> RwLockReadGuard<Timers> { + self.timers.read() + } + + #[inline(always)] + pub fn timers_mut(&self) -> RwLockWriteGuard<Timers> { + self.timers.write() + } +} + +impl<T: Tun, B: Bind> fmt::Display for Peer<T, B> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "peer(id = {})", self.id) + } +} + +impl<T: Tun, B: Bind> Deref for Peer<T, B> { + type Target = PeerInner<B>; + fn deref(&self) -> &Self::Target { + &self.state + } +} + +impl<T: Tun, B: Bind> Peer<T, B> { + pub fn down(&self) { + self.stop_timers(); + } + + pub fn up(&self) {} +} + +impl<B: Bind> PeerInner<B> { + /* Queue a handshake request for the parallel workers + * (if one does not already exist) + * + * The function is ratelimited. + */ + pub fn packet_send_handshake_initiation(&self) { + // the function is rate limited + + { + let mut lhs = self.last_handshake_sent.lock(); + if lhs.elapsed() < REKEY_TIMEOUT { + return; + } + *lhs = Instant::now(); + } + + // create a new handshake job for the peer + + if !self.handshake_queued.swap(true, Ordering::SeqCst) { + self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap(); + } + } +} diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs index 7c3b0a1..a5028e1 100644 --- a/src/wireguard/router/device.rs +++ b/src/wireguard/router/device.rs @@ -27,6 +27,8 @@ use super::route::get_route; use super::super::{bind, tun, Endpoint, KeyPair}; pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { + pub enabled: AtomicBool, + // inbound writer (TUN) pub inbound: T, @@ -91,6 +93,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C, // allocate shared device state let inner = DeviceInner { inbound: tun, + enabled: AtomicBool::new(true), outbound: RwLock::new(None), queues: Mutex::new(Vec::with_capacity(num_workers)), queue_next: AtomicUsize::new(0), @@ -114,6 +117,16 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C, } } + /// Brings the router down. + /// When the router is brought down it: + /// - Prevents transmission of outbound messages. + /// - Erases all key state (key-wheels) of all peers + pub fn down(&self) {} + + /// Brints the router up + /// When the router is brought up it enables the transmission of outbound messages. + pub fn up(&self) {} + /// A new secret key has been set for the device. /// According to WireGuard semantics, this should cause all "sending" keys to be discarded. pub fn new_sk(&self) {} diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs index 5522a3e..7527a60 100644 --- a/src/wireguard/router/peer.rs +++ b/src/wireguard/router/peer.rs @@ -1,5 +1,6 @@ use std::mem; use std::net::{IpAddr, SocketAddr}; +use std::ops::Deref; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; use std::sync::mpsc::{sync_channel, SyncSender}; @@ -11,7 +12,6 @@ use log::debug; use spin::Mutex; use treebitmap::address::Address; use treebitmap::IpLookupTable; -use zerocopy::LayoutVerified; use super::super::constants::*; use super::super::{bind, tun, Endpoint, KeyPair}; @@ -55,6 +55,14 @@ pub struct Peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { thread_inbound: Option<thread::JoinHandle<()>>, } +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Deref for Peer<E, C, T, B> { + type Target = Arc<PeerInner<E, C, T, B>>; + + fn deref(&self) -> &Self::Target { + &self.state + } +} + fn treebit_list<A, R, E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( peer: &Arc<PeerInner<E, C, T, B>>, table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<E, C, T, B>>>>, @@ -199,7 +207,7 @@ pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( let thread_inbound = { let peer = peer.clone(); let device = device.clone(); - thread::spawn(move || worker_outbound(device, peer, out_rx)) + thread::spawn(move || worker_outbound(peer, out_rx)) }; // spawn inbound thread @@ -217,6 +225,36 @@ pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( } impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> PeerInner<E, C, T, B> { + /// Send a raw message to the peer (used for handshake messages) + /// + /// # Arguments + /// + /// - `msg`, message body to send to peer + /// + /// # Returns + /// + /// Unit if packet was sent, or an error indicating why sending failed + pub fn send(&self, msg: &[u8]) -> Result<(), RouterError> { + debug!("peer.send"); + + // check if device is enabled + if !self.device.enabled.load(Ordering::Acquire) { + return Ok(()); + } + + // send to endpoint (if known) + match self.endpoint.lock().as_ref() { + Some(endpoint) => self + .device + .outbound + .read() + .as_ref() + .ok_or(RouterError::SendError) + .and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError)), + None => Err(RouterError::NoEndpoint), + } + } + fn send_staged(&self) -> bool { debug!("peer.send_staged"); let mut sent = false; @@ -498,7 +536,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T pub fn send_keepalive(&self) -> bool { debug!("peer.send_keepalive"); - self.state.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX]) + self.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX]) } /// Map a subnet to the peer @@ -565,30 +603,6 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T treebit_remove(self, &self.state.device.ipv6); } - /// Send a raw message to the peer (used for handshake messages) - /// - /// # Arguments - /// - /// - `msg`, message body to send to peer - /// - /// # Returns - /// - /// Unit if packet was sent, or an error indicating why sending failed - pub fn send(&self, msg: &[u8]) -> Result<(), RouterError> { - debug!("peer.send"); - let inner = &self.state; - match inner.endpoint.lock().as_ref() { - Some(endpoint) => inner - .device - .outbound - .read() - .as_ref() - .ok_or(RouterError::SendError) - .and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError)), - None => Err(RouterError::NoEndpoint), - } - } - pub fn clear_src(&self) { (*self.state.endpoint.lock()) .as_mut() diff --git a/src/wireguard/router/workers.rs b/src/wireguard/router/workers.rs index 08c2db9..5482cee 100644 --- a/src/wireguard/router/workers.rs +++ b/src/wireguard/router/workers.rs @@ -141,8 +141,7 @@ pub fn worker_inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer /* TODO: Replace with run-queue */ pub fn worker_outbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( - device: Arc<DeviceInner<E, C, T, B>>, // related device - peer: Arc<PeerInner<E, C, T, B>>, // related peer + peer: Arc<PeerInner<E, C, T, B>>, receiver: Receiver<JobOutbound>, ) { loop { @@ -160,23 +159,8 @@ pub fn worker_outbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Write .map(|buf| { debug!("outbound worker: job complete"); - // write to UDP bind - let xmit = if let Some(dst) = peer.endpoint.lock().as_ref() { - let send: &Option<B> = &*device.outbound.read(); - if let Some(writer) = send.as_ref() { - match writer.write(&buf.msg[..], dst) { - Err(e) => { - debug!("failed to send outbound packet: {:?}", e); - false - } - Ok(_) => true, - } - } else { - false - } - } else { - false - }; + // send to peer + let xmit = peer.send(&buf.msg[..]).is_ok(); // trigger callback C::send(&peer.opaque, buf.msg.len(), xmit, &buf.keypair, buf.counter); diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs index 22a0ff1..e844f4d 100644 --- a/src/wireguard/timers.rs +++ b/src/wireguard/timers.rs @@ -3,17 +3,18 @@ use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; use std::time::{Duration, Instant, SystemTime}; -use log::{debug, info}; +use log::debug; use hjul::{Runner, Timer}; use super::constants::*; use super::router::{message_data_len, Callbacks}; -use super::wireguard::{Peer, PeerInner}; +use super::{Peer, PeerInner}; use super::{bind, tun}; - use super::types::KeyPair; pub struct Timers { + enabled: bool, + handshake_attempts: AtomicUsize, sent_lastminute_handshake: AtomicBool, need_another_keepalive: AtomicBool, @@ -33,6 +34,48 @@ impl Timers { } impl<B: bind::Bind> PeerInner<B> { + pub fn stop_timers(&self) { + // take a write lock preventing simultaneous timer events or "start_timers" call + let mut timers = self.timers_mut(); + + // set flag to prevent future timer events + if !timers.enabled { + return; + } + timers.enabled = false; + + // stop all pending timers + timers.retransmit_handshake.stop(); + timers.send_keepalive.stop(); + timers.send_persistent_keepalive.stop(); + timers.zero_key_material.stop(); + timers.new_handshake.stop(); + + // reset all timer state + timers.handshake_attempts.store(0, Ordering::SeqCst); + timers.sent_lastminute_handshake.store(false, Ordering::SeqCst); + timers.need_another_keepalive.store(false, Ordering::SeqCst); + } + + pub fn start_timers(&self) { + // take a write lock preventing simultaneous "stop_timers" call + let mut timers = self.timers_mut(); + + // set flag to renable timer events + if timers.enabled { + return; + } + timers.enabled = true; + + // start send_persistent_keepalive + let interval = self.keepalive_interval.load(Ordering::Acquire); + if interval > 0 { + timers.send_persistent_keepalive.start( + Duration::from_secs(interval) + ); + } + } + /* should be called after an authenticated data packet is sent */ pub fn timers_data_sent(&self) { self.timers() @@ -95,7 +138,7 @@ impl<B: bind::Bind> PeerInner<B> { * keepalive, data, or handshake is sent, or after one is received. */ pub fn timers_any_authenticated_packet_traversal(&self) { - let keepalive = self.keepalive.load(Ordering::Acquire); + let keepalive = self.keepalive_interval.load(Ordering::Acquire); if keepalive > 0 { // push persistent_keepalive into the future self.timers() @@ -125,9 +168,9 @@ impl<B: bind::Bind> PeerInner<B> { } - pub fn set_persistent_keepalive_interval(&self, interval: usize) { + pub fn set_persistent_keepalive_interval(&self, interval: u64) { self.timers().send_persistent_keepalive.stop(); - self.keepalive.store(interval, Ordering::SeqCst); + self.keepalive_interval.store(interval, Ordering::SeqCst); if interval > 0 { self.timers() .send_persistent_keepalive @@ -154,6 +197,7 @@ impl Timers { { // create a timer instance for the provided peer Timers { + enabled: true, need_another_keepalive: AtomicBool::new(false), sent_lastminute_handshake: AtomicBool::new(false), handshake_attempts: AtomicUsize::new(0), @@ -213,7 +257,7 @@ impl Timers { send_persistent_keepalive: { let peer = peer.clone(); runner.timer(move || { - let keepalive = peer.state.keepalive.load(Ordering::Acquire); + let keepalive = peer.state.keepalive_interval.load(Ordering::Acquire); if keepalive > 0 { peer.router.send_keepalive(); peer.timers().send_keepalive.stop(); @@ -235,6 +279,7 @@ impl Timers { pub fn dummy(runner: &Runner) -> Timers { Timers { + enabled: false, need_another_keepalive: AtomicBool::new(false), sent_lastminute_handshake: AtomicBool::new(false), handshake_attempts: AtomicUsize::new(0), diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index 6cdae6c..722f64a 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -2,6 +2,7 @@ use super::constants::*; use super::handshake; use super::router; use super::timers::{Events, Timers}; +use super::{Peer, PeerInner}; use super::bind::Reader as BindReader; use super::bind::{Bind, Writer}; @@ -22,7 +23,7 @@ use std::collections::HashMap; use log::debug; use rand::rngs::OsRng; use rand::Rng; -use spin::{Mutex, RwLock, RwLockReadGuard}; +use spin::{Mutex, RwLock}; use byteorder::{ByteOrder, LittleEndian}; use crossbeam_channel::{bounded, Sender}; @@ -32,45 +33,19 @@ const SIZE_HANDSHAKE_QUEUE: usize = 128; const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4; const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000); -pub struct Peer<T: Tun, B: Bind> { - pub router: Arc<router::Peer<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>, - pub state: Arc<PeerInner<B>>, -} - -pub struct PeerInner<B: Bind> { - // internal id (for logging) - pub id: u64, - - // handshake state - pub walltime_last_handshake: Mutex<SystemTime>, - pub last_handshake_sent: Mutex<Instant>, // instant for last handshake - pub handshake_queued: AtomicBool, // is a handshake job currently queued for the peer? - pub queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, // handshake queue - - // stats and configuration - pub pk: PublicKey, // public key, DISCUSS: avoid this. TODO: remove - pub keepalive: AtomicUsize, // keepalive interval - pub rx_bytes: AtomicU64, // received bytes - pub tx_bytes: AtomicU64, // transmitted bytes - - // timer model - pub timers: RwLock<Timers>, -} - pub struct WireguardInner<T: Tun, B: Bind> { // identifier (for logging) id: u32, start: Instant, // provides access to the MTU value of the tun device - // (otherwise owned solely by the router and a dedicated read IO thread) mtu: T::MTU, send: RwLock<Option<B::Writer>>, // identify and configuration map peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>, - // cryptkey router + // cryptokey router router: router::Device<B::Endpoint, Events<T, B>, T::Writer, B::Writer>, // handshake related state @@ -90,66 +65,12 @@ pub struct WireguardHandle<T: Tun, B: Bind> { inner: Arc<WireguardInner<T, B>>, } -impl<T: Tun, B: Bind> Clone for Peer<T, B> { - fn clone(&self) -> Peer<T, B> { - Peer { - router: self.router.clone(), - state: self.state.clone(), - } - } -} - -impl<B: Bind> PeerInner<B> { - #[inline(always)] - pub fn timers(&self) -> RwLockReadGuard<Timers> { - self.timers.read() - } -} - -impl<T: Tun, B: Bind> fmt::Display for Peer<T, B> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "peer(id = {})", self.id) - } -} - impl<T: Tun, B: Bind> fmt::Display for WireguardInner<T, B> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "wireguard({:x})", self.id) } } -impl<T: Tun, B: Bind> Deref for Peer<T, B> { - type Target = PeerInner<B>; - fn deref(&self) -> &Self::Target { - &self.state - } -} - -impl<B: Bind> PeerInner<B> { - /* Queue a handshake request for the parallel workers - * (if one does not already exist) - * - * The function is ratelimited. - */ - pub fn packet_send_handshake_initiation(&self) { - // the function is rate limited - - { - let mut lhs = self.last_handshake_sent.lock(); - if lhs.elapsed() < REKEY_TIMEOUT { - return; - } - *lhs = Instant::now(); - } - - // create a new handshake job for the peer - - if !self.handshake_queued.swap(true, Ordering::SeqCst) { - self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap(); - } - } -} - struct Handshake { device: handshake::Device, active: bool, @@ -196,6 +117,37 @@ const fn padding(size: usize, mtu: usize) -> usize { } impl<T: Tun, B: Bind> Wireguard<T, B> { + /// Brings the WireGuard device down. + /// Usually called when the associated interface is brought down. + /// + /// This stops any further action/timer on any peer + /// and prevents transmission of further messages, + /// however the device retrains its state. + /// + /// The instance will continue to consume and discard messages + /// on both ends of the device. + pub fn down(&self) { + // ensure exclusive access (to avoid race with "up" call) + let peers = self.peers.write(); + + // set all peers down (stops timers) + for peer in peers.values() { + peer.down(); + } + } + + /// Brings the WireGuard device up. + /// Usually called when the associated interface is brought up. + pub fn up(&self) { + // ensure exclusive access (to avoid race with "down" call) + let peers = self.peers.write(); + + // set all peers up (restarts timers) + for peer in peers.values() { + peer.up(); + } + } + pub fn clear_peers(&self) { self.state.peers.write().clear(); } @@ -263,7 +215,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { last_handshake_sent: Mutex::new(self.state.start - TIME_HORIZON), handshake_queued: AtomicBool::new(false), queue: Mutex::new(self.state.queue.lock().clone()), - keepalive: AtomicUsize::new(0), + keepalive_interval: AtomicU64::new(0), // disabled rx_bytes: AtomicU64::new(0), tx_bytes: AtomicU64::new(0), timers: RwLock::new(Timers::dummy(&self.runner)), |