From 6c386146a77ecb8ff317d76823c0f788bd70d8c3 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sun, 10 May 2020 21:23:34 +0200 Subject: Refactoring timer code: - Remove the Events struct - Implement Callbacks on the PeerInner, elimiting an Arc. --- src/configuration/config.rs | 26 ++++----- src/main.rs | 4 +- src/wireguard/mod.rs | 3 - src/wireguard/peer.rs | 51 +---------------- src/wireguard/router/peer.rs | 31 +++++++++++ src/wireguard/tests.rs | 10 +--- src/wireguard/timers.rs | 129 +++++++++++++++++++++++++------------------ src/wireguard/wireguard.rs | 81 ++++++++++++++------------- src/wireguard/workers.rs | 26 +++++---- 9 files changed, 186 insertions(+), 175 deletions(-) diff --git a/src/configuration/config.rs b/src/configuration/config.rs index 3f3c2c5..77f8d9a 100644 --- a/src/configuration/config.rs +++ b/src/configuration/config.rs @@ -310,25 +310,25 @@ impl Configuration for WireGuardConfig { fn set_endpoint(&self, peer: &PublicKey, addr: SocketAddr) { if let Some(peer) = self.lock().wireguard.lookup_peer(peer) { - peer.router.set_endpoint(B::Endpoint::from_address(addr)); + peer.set_endpoint(B::Endpoint::from_address(addr)); } } fn set_persistent_keepalive_interval(&self, peer: &PublicKey, secs: u64) { if let Some(peer) = self.lock().wireguard.lookup_peer(peer) { - peer.set_persistent_keepalive_interval(secs); + peer.opaque().set_persistent_keepalive_interval(secs); } } fn replace_allowed_ips(&self, peer: &PublicKey) { if let Some(peer) = self.lock().wireguard.lookup_peer(peer) { - peer.router.remove_allowed_ips(); + peer.remove_allowed_ips(); } } fn add_allowed_ip(&self, peer: &PublicKey, ip: IpAddr, masklen: u32) { if let Some(peer) = self.lock().wireguard.lookup_peer(peer) { - peer.router.add_allowed_ip(ip, masklen); + peer.add_allowed_ip(ip, masklen); } } @@ -337,26 +337,26 @@ impl Configuration for WireGuardConfig { let peers = cfg.wireguard.list_peers(); let mut state = Vec::with_capacity(peers.len()); - for p in peers { + for (pk, p) in peers { // convert the system time to (secs, nano) since epoch - let last_handshake_time = (*p.walltime_last_handshake.lock()).and_then(|t| { + let last_handshake_time = (*p.opaque().walltime_last_handshake.lock()).and_then(|t| { let duration = t .duration_since(SystemTime::UNIX_EPOCH) .unwrap_or(Duration::from_secs(0)); Some((duration.as_secs(), duration.subsec_nanos() as u64)) }); - if let Some(psk) = cfg.wireguard.get_psk(&p.pk) { + if let Some(psk) = cfg.wireguard.get_psk(&pk) { // extract state into PeerState state.push(PeerState { preshared_key: psk, - endpoint: p.router.get_endpoint(), - rx_bytes: p.rx_bytes.load(Ordering::Relaxed), - tx_bytes: p.tx_bytes.load(Ordering::Relaxed), - persistent_keepalive_interval: p.get_keepalive_interval(), - allowed_ips: p.router.list_allowed_ips(), + endpoint: p.get_endpoint(), + rx_bytes: p.opaque().rx_bytes.load(Ordering::Relaxed), + tx_bytes: p.opaque().tx_bytes.load(Ordering::Relaxed), + persistent_keepalive_interval: p.opaque().get_keepalive_interval(), + allowed_ips: p.list_allowed_ips(), last_handshake_time, - public_key: p.pk, + public_key: pk, }) } } diff --git a/src/main.rs b/src/main.rs index 8877422..bf42706 100644 --- a/src/main.rs +++ b/src/main.rs @@ -100,7 +100,7 @@ fn main() { // daemonize if !foreground { let daemonize = Daemonize::new() - .pid_file(format!("/tmp/wgrs-{}.pid", name)) + .pid_file(format!("/tmp/wireguard-rs-{}.pid", name)) .chown_pid_file(true) .working_directory("/tmp") .user("nobody") @@ -170,7 +170,7 @@ fn main() { Err(err) => { log::info!("UAPI connection error: {}", err); profiler_stop(); - exit(0); + exit(-1); } } }); diff --git a/src/wireguard/mod.rs b/src/wireguard/mod.rs index ee1fd78..ca17737 100644 --- a/src/wireguard/mod.rs +++ b/src/wireguard/mod.rs @@ -20,9 +20,6 @@ mod workers; #[cfg(test)] mod tests; -// represents a peer -pub use peer::Peer; - // represents a WireGuard interface pub use wireguard::WireGuard; diff --git a/src/wireguard/peer.rs b/src/wireguard/peer.rs index b3656fe..27d39bd 100644 --- a/src/wireguard/peer.rs +++ b/src/wireguard/peer.rs @@ -1,5 +1,4 @@ -use super::router; -use super::timers::{Events, Timers}; +use super::timers::Timers; use super::tun::Tun; use super::udp::UDP; @@ -9,9 +8,7 @@ use super::wireguard::WireGuard; use super::workers::HandshakeJob; use std::fmt; -use std::ops::Deref; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use std::sync::Arc; use std::time::{Instant, SystemTime}; use spin::{Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard}; @@ -31,7 +28,7 @@ pub struct PeerInner { pub handshake_queued: AtomicBool, // is a handshake job currently queued for the peer? // stats and configuration - pub pk: PublicKey, // public key + pub pk: PublicKey, // public key (TODO: there has to be a way to remove this) pub rx_bytes: AtomicU64, // received bytes pub tx_bytes: AtomicU64, // transmitted bytes @@ -39,20 +36,6 @@ pub struct PeerInner { pub timers: RwLock, } -pub struct Peer { - pub router: Arc, T::Writer, B::Writer>>, - pub state: Arc>, -} - -impl Clone for Peer { - fn clone(&self) -> Peer { - Peer { - router: self.router.clone(), - state: self.state.clone(), - } - } -} - impl PeerInner { /* Queue a handshake request for the parallel workers * (if one does not already exist) @@ -104,33 +87,3 @@ impl fmt::Display for PeerInner { write!(f, "peer(id = {})", self.id) } } - -impl fmt::Display for Peer { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "peer(id = {})", self.id) - } -} - -impl Deref for Peer { - type Target = PeerInner; - fn deref(&self) -> &Self::Target { - &self.state - } -} - -impl Peer { - /// Bring the peer down. Causing: - /// - /// - Timers to be stopped and disabled. - /// - All keystate to be zeroed - pub fn down(&self) { - self.stop_timers(); - self.router.down(); - } - - /// Bring the peer up. - pub fn up(&self) { - self.router.up(); - self.start_timers(); - } -} diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs index 67d90d8..3eed7c7 100644 --- a/src/wireguard/router/peer.rs +++ b/src/wireguard/router/peer.rs @@ -22,6 +22,7 @@ use core::sync::atomic::AtomicBool; use alloc::sync::Arc; // TODO: consider no_std alternatives +use std::fmt; use std::net::{IpAddr, SocketAddr}; use arraydeque::{ArrayDeque, Wrapping}; @@ -46,6 +47,14 @@ pub struct PeerInner>, } +impl> Deref for PeerInner { + type Target = C::Opaque; + + fn deref(&self) -> &Self::Target { + &self.opaque + } +} + pub struct Peer> { inner: Arc>, } @@ -87,6 +96,16 @@ pub struct PeerHandle, } +impl> Clone + for PeerHandle +{ + fn clone(&self) -> Self { + PeerHandle { + peer: self.peer.clone(), + } + } +} + impl> Deref for PeerHandle { @@ -96,6 +115,14 @@ impl> Deref } } +impl> fmt::Display + for PeerHandle +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "PeerHandle(format: TODO)") + } +} + impl EncryptionState { fn new(keypair: &Arc) -> EncryptionState { EncryptionState { @@ -338,6 +365,10 @@ impl> PeerHandle &C::Opaque { + &self.opaque + } + /// Returns the current endpoint of the peer (for configuration) /// /// # Note diff --git a/src/wireguard/tests.rs b/src/wireguard/tests.rs index bb1f87a..6bc4be3 100644 --- a/src/wireguard/tests.rs +++ b/src/wireguard/tests.rs @@ -123,17 +123,13 @@ fn test_pure_wireguard() { let peer2 = wg1.lookup_peer(&pk2).unwrap(); let peer1 = wg2.lookup_peer(&pk1).unwrap(); - peer1 - .router - .add_allowed_ip("192.168.1.0".parse().unwrap(), 24); + peer1.add_allowed_ip("192.168.1.0".parse().unwrap(), 24); - peer2 - .router - .add_allowed_ip("192.168.2.0".parse().unwrap(), 24); + peer2.add_allowed_ip("192.168.2.0".parse().unwrap(), 24); // set endpoint (the other should be learned dynamically) - peer2.router.set_endpoint(dummy::UnitEndpoint::new()); + peer2.set_endpoint(dummy::UnitEndpoint::new()); let num_packets = 20; diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs index 0197a9e..a435e5c 100644 --- a/src/wireguard/timers.rs +++ b/src/wireguard/timers.rs @@ -1,17 +1,19 @@ -use std::marker::PhantomData; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; use std::time::{Duration, Instant, SystemTime}; -use hjul::{Runner, Timer}; use log::debug; +use hjul::Timer; +use x25519_dalek::PublicKey; + use super::constants::*; -use super::peer::{Peer, PeerInner}; +use super::peer::PeerInner; use super::router::{message_data_len, Callbacks}; use super::tun::Tun; use super::types::KeyPair; use super::udp::UDP; +use super::WireGuard; pub struct Timers { // only updated during configuration @@ -229,7 +231,35 @@ impl PeerInner { } impl Timers { - pub fn new(runner: &Runner, running: bool, peer: Peer) -> Timers { + pub fn new( + wg: WireGuard, // WireGuard device + pk: PublicKey, // public key of peer + running: bool, // timers started + ) -> Timers { + macro_rules! fetch_peer { + ( $wg:expr, $pk:expr ) => { + match $wg.lookup_peer(&$pk) { + None => { + return; + } + Some(peer) => peer, + } + }; + } + + macro_rules! fetch_timer { + ( $peer:expr ) => {{ + let timers = $peer.timers(); + if timers.enabled { + timers + } else { + return; + } + }}; + } + + let runner = wg.runner.lock(); + // create a timer instance for the provided peer Timers { enabled: running, @@ -238,21 +268,16 @@ impl Timers { sent_lastminute_handshake: AtomicBool::new(false), handshake_attempts: AtomicUsize::new(0), retransmit_handshake: { - let peer = peer.clone(); + let wg = wg.clone(); + let pk = pk.clone(); runner.timer(move || { + // fetch peer by public key + let peer = fetch_peer!(wg, pk); + let timers = fetch_timer!(peer); log::trace!("{} : timer fired (retransmit_handshake)", peer); - // ignore if timers are disabled - let timers = peer.timers(); - if !timers.enabled { - return; - } - // check if handshake attempts remaining - let attempts = peer - .timers() - .handshake_attempts - .fetch_add(1, Ordering::SeqCst); + let attempts = timers.handshake_attempts.fetch_add(1, Ordering::SeqCst); if attempts > MAX_TIMER_HANDSHAKES { debug!( "Handshake for peer {} did not complete after {} attempts, giving up", @@ -261,7 +286,7 @@ impl Timers { ); timers.send_keepalive.stop(); timers.zero_key_material.start(REJECT_AFTER_TIME * 3); - peer.router.purge_staged_packets(); + peer.purge_staged_packets(); } else { debug!( "Handshake for {} did not complete after {} seconds, retrying (try {})", @@ -270,56 +295,72 @@ impl Timers { attempts ); timers.retransmit_handshake.reset(REKEY_TIMEOUT); - peer.router.clear_src(); + peer.clear_src(); peer.packet_send_queued_handshake_initiation(true); } }) }, send_keepalive: { - let peer = peer.clone(); + let wg = wg.clone(); + let pk = pk.clone(); runner.timer(move || { + // fetch peer by public key + let peer = fetch_peer!(wg, pk); + let timers = fetch_timer!(peer); log::trace!("{} : timer fired (send_keepalive)", peer); - // ignore if timers are disabled - let timers = peer.timers(); - if !timers.enabled { - return; - } - - peer.router.send_keepalive(); + // send keepalive and schedule next keepalive + peer.send_keepalive(); if timers.need_another_keepalive() { timers.send_keepalive.start(KEEPALIVE_TIMEOUT); } }) }, new_handshake: { - let peer = peer.clone(); + let wg = wg.clone(); + let pk = pk.clone(); runner.timer(move || { + // fetch peer by public key + let peer = fetch_peer!(wg, pk); + let _timers = fetch_timer!(peer); log::trace!("{} : timer fired (new_handshake)", peer); + + // clear source and retry log::debug!( "Retrying handshake with {} because we stopped hearing back after {} seconds", peer, (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT).as_secs() ); - peer.router.clear_src(); + peer.clear_src(); peer.packet_send_queued_handshake_initiation(false); }) }, zero_key_material: { - let peer = peer.clone(); + let wg = wg.clone(); + let pk = pk.clone(); runner.timer(move || { + // fetch peer by public key + let peer = fetch_peer!(wg, pk); + let _timers = fetch_timer!(peer); log::trace!("{} : timer fired (zero_key_material)", peer); - peer.router.zero_keys(); + + // null all key-material + peer.zero_keys(); }) }, send_persistent_keepalive: { - let peer = peer.clone(); + let wg = wg.clone(); + let pk = pk.clone(); runner.timer(move || { + // fetch peer by public key + let peer = fetch_peer!(wg, pk); + let timers = fetch_timer!(peer); log::trace!("{} : timer fired (send_persistent_keepalive)", peer); - let timers = peer.timers(); - if timers.enabled && timers.keepalive_interval > 0 { + + // send and schedule persistent keepalive + if timers.keepalive_interval > 0 { timers.send_keepalive.stop(); - peer.router.send_keepalive(); + peer.send_keepalive(); log::trace!("{} : keepalive queued", peer); timers .send_persistent_keepalive @@ -329,28 +370,10 @@ impl Timers { }, } } - - pub fn dummy(runner: &Runner) -> Timers { - Timers { - enabled: false, - keepalive_interval: 0, - need_another_keepalive: AtomicBool::new(false), - sent_lastminute_handshake: AtomicBool::new(false), - handshake_attempts: AtomicUsize::new(0), - retransmit_handshake: runner.timer(|| {}), - new_handshake: runner.timer(|| {}), - send_keepalive: runner.timer(|| {}), - send_persistent_keepalive: runner.timer(|| {}), - zero_key_material: runner.timer(|| {}), - } - } } -/* instance of the router callbacks */ -pub struct Events(PhantomData<(T, B)>); - -impl Callbacks for Events { - type Opaque = Arc>; +impl Callbacks for PeerInner { + type Opaque = Self; /* Called after the router encrypts a transport message destined for the peer. * This method is called, even if the encrypted payload is empty (keepalive) diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index b878adc..7490703 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -1,8 +1,8 @@ use super::constants::*; use super::handshake; -use super::peer::{Peer, PeerInner}; +use super::peer::PeerInner; use super::router; -use super::timers::{Events, Timers}; +use super::timers::Timers; use super::queue::ParallelQueue; use super::workers::HandshakeJob; @@ -45,10 +45,12 @@ pub struct WireguardInner { pub mtu: AtomicUsize, // peer map - pub peers: RwLock>>, + pub peers: RwLock< + handshake::Device, T::Writer, B::Writer>>, + >, // cryptokey router - pub router: router::Device, T::Writer, B::Writer>, + pub router: router::Device, T::Writer, B::Writer>, // handshake related state pub last_under_load: Mutex, @@ -136,6 +138,7 @@ impl WireGuard { // set all peers down (stops timers) for (_, peer) in self.peers.write().iter() { + peer.stop_timers(); peer.down(); } @@ -162,6 +165,7 @@ impl WireGuard { // set all peers up (restarts timers) for (_, peer) in self.peers.write().iter() { peer.up(); + peer.start_timers(); } *enabled = true; @@ -175,16 +179,24 @@ impl WireGuard { let _ = self.peers.write().remove(pk); } - pub fn lookup_peer(&self, pk: &PublicKey) -> Option> { - self.peers.read().get(pk).map(|p| p.clone()) + pub fn lookup_peer( + &self, + pk: &PublicKey, + ) -> Option, T::Writer, B::Writer>> { + self.peers.read().get(pk).map(|handle| handle.clone()) } - pub fn list_peers(&self) -> Vec> { + pub fn list_peers( + &self, + ) -> Vec<( + PublicKey, + router::PeerHandle, T::Writer, B::Writer>, + )> { let peers = self.peers.read(); let mut list = Vec::with_capacity(peers.len()); for (k, v) in peers.iter() { - debug_assert!(k.as_bytes() == v.pk.as_bytes()); - list.push(v.clone()); + debug_assert!(k.as_bytes() == v.opaque().pk.as_bytes()); + list.push((k.clone(), v.clone())); } list } @@ -215,36 +227,27 @@ impl WireGuard { return false; } - let state = Arc::new(PeerInner { - id: OsRng.gen(), - pk, - wg: self.clone(), - walltime_last_handshake: Mutex::new(None), - last_handshake_sent: Mutex::new(Instant::now() - TIME_HORIZON), - handshake_queued: AtomicBool::new(false), - rx_bytes: AtomicU64::new(0), - tx_bytes: AtomicU64::new(0), - timers: RwLock::new(Timers::dummy(&*self.runner.lock())), - }); - - // create a router peer - let router = Arc::new(self.router.new_peer(state.clone())); - - // form WireGuard peer - let peer = Peer { router, state }; - // prevent up/down while inserting - let enabled = self.enabled.read(); + let enabled = *self.enabled.read(); - /* The need for dummy timers arises from the chicken-egg - * problem of the timer callbacks being able to set timers themselves. - * - * This is in fact the only place where the write lock is ever taken. - * TODO: Consider the ease of using atomic pointers instead. - */ - *peer.timers.write() = Timers::new(&*self.runner.lock(), *enabled, peer.clone()); + // create timers (lookup by public key) + let timers = Timers::new::(self.clone(), pk.clone(), enabled); - // finally, add the peer to the wireguard device + // create new router peer + let peer: router::PeerHandle, T::Writer, B::Writer> = + self.router.new_peer(PeerInner { + id: OsRng.gen(), + pk, + wg: self.clone(), + walltime_last_handshake: Mutex::new(None), + last_handshake_sent: Mutex::new(Instant::now() - TIME_HORIZON), + handshake_queued: AtomicBool::new(false), + rx_bytes: AtomicU64::new(0), + tx_bytes: AtomicU64::new(0), + timers: RwLock::new(timers), + }); + + // finally, add the peer to the handshake device peers.add(pk, peer).is_ok() } @@ -288,6 +291,10 @@ impl WireGuard { // create handshake queue let (tx, mut rxs) = ParallelQueue::new(cpus, 128); + // create router + let router: router::Device, T::Writer, B::Writer> = + router::Device::new(num_cpus::get(), writer); + // create arc to state let wg = WireGuard { inner: Arc::new(WireguardInner { @@ -296,7 +303,7 @@ impl WireGuard { id: OsRng.gen(), mtu: AtomicUsize::new(0), last_under_load: Mutex::new(Instant::now() - TIME_HORIZON), - router: router::Device::new(num_cpus::get(), writer), + router, pending: AtomicUsize::new(0), peers: RwLock::new(handshake::Device::new()), runner: Mutex::new(Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY)), diff --git a/src/wireguard/workers.rs b/src/wireguard/workers.rs index 70e3b3a..b4673cd 100644 --- a/src/wireguard/workers.rs +++ b/src/wireguard/workers.rs @@ -209,23 +209,25 @@ pub fn handshake_worker( // add to rx_bytes and tx_bytes let req_len = msg.len() as u64; - peer.rx_bytes.fetch_add(req_len, Ordering::Relaxed); - peer.tx_bytes.fetch_add(resp_len, Ordering::Relaxed); + peer.opaque().rx_bytes.fetch_add(req_len, Ordering::Relaxed); + peer.opaque() + .tx_bytes + .fetch_add(resp_len, Ordering::Relaxed); // update endpoint - peer.router.set_endpoint(src); + peer.set_endpoint(src); if resp_len > 0 { // update timers after sending handshake response debug!("{} : handshake worker, handshake response sent", wg); - peer.state.sent_handshake_response(); + peer.opaque().sent_handshake_response(); } else { // update timers after receiving handshake response debug!( "{} : handshake worker, handshake response was received", wg ); - peer.state.timers_handshake_complete(); + peer.opaque().timers_handshake_complete(); } // add any new keypair to peer @@ -233,10 +235,10 @@ pub fn handshake_worker( debug!("{} : handshake worker, new keypair for {}", wg, peer); // this means that a handshake response was processed or sent - peer.timers_session_derived(); + peer.opaque().timers_session_derived(); // free any unused ids - for id in peer.router.add_keypair(kp) { + for id in peer.add_keypair(kp) { device.release(id); } }); @@ -252,13 +254,15 @@ pub fn handshake_worker( wg, peer ); let device = wg.peers.read(); - let _ = device.begin(&mut OsRng, &peer.pk).map(|msg| { - let _ = peer.router.send_raw(&msg[..]).map_err(|e| { + let _ = device.begin(&mut OsRng, &pk).map(|msg| { + let _ = peer.send_raw(&msg[..]).map_err(|e| { debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e) }); - peer.state.sent_handshake_initiation(); + peer.opaque().sent_handshake_initiation(); }); - peer.handshake_queued.store(false, Ordering::SeqCst); + peer.opaque() + .handshake_queued + .store(false, Ordering::SeqCst); } } } -- cgit v1.2.3-59-g8ed1b