diff options
author | Guanhao Yin <sopium@mysterious.site> | 2017-03-27 22:27:01 +0800 |
---|---|---|
committer | Guanhao Yin <sopium@mysterious.site> | 2017-03-27 22:27:01 +0800 |
commit | b15c04948d3a32892f0ec5307186fc572dd06e10 (patch) | |
tree | 1d682dedc71f962d3c359fa5aa9d24168145742f | |
parent | Work with (super) jumbo frames (diff) | |
download | wireguard-rs-b15c04948d3a32892f0ec5307186fc572dd06e10.tar.xz wireguard-rs-b15c04948d3a32892f0ec5307186fc572dd06e10.zip |
Implement handshake load monitoring
-rw-r--r-- | src/protocol/controller.rs | 31 | ||||
-rw-r--r-- | src/protocol/load_monitor.rs | 108 | ||||
-rw-r--r-- | src/protocol/mod.rs | 3 |
3 files changed, 133 insertions, 9 deletions
diff --git a/src/protocol/controller.rs b/src/protocol/controller.rs index c177c08..bb9f3e9 100644 --- a/src/protocol/controller.rs +++ b/src/protocol/controller.rs @@ -54,8 +54,12 @@ const REJECT_AFTER_TIME: u64 = 180; const REKEY_TIMEOUT: u64 = 5; const KEEPALIVE_TIMEOUT: u64 = 10; + const BUFSIZE: usize = 65536; +// How many handshake messages per second is considered normal load. +const HANDSHAKES_PER_SEC: u32 = 250; + // How many packets to queue. const QUEUE_SIZE: usize = 16; @@ -76,6 +80,7 @@ pub struct WgState { rt4: RwLock<IpLookupTable<Ipv4Addr, SharedPeerState>>, rt6: RwLock<IpLookupTable<Ipv6Addr, SharedPeerState>>, + load_monitor: Mutex<LoadMonitor>, // The secret used to calc cookie. cookie_secret: Mutex<([u8; 32], Instant)>, } @@ -167,11 +172,6 @@ struct Transport { reject_after_time: Mutex<TimerHandle>, } -// TODO determine / detect load. -fn is_under_load() -> bool { - false -} - fn udp_process_handshake_init(wg: Arc<WgState>, sock: &UdpSocket, p: &[u8], addr: SocketAddr) { if p.len() != 148 { return; @@ -180,8 +180,8 @@ fn udp_process_handshake_init(wg: Arc<WgState>, sock: &UdpSocket, p: &[u8], addr // Lock info. let info = wg.info.read().unwrap(); - if is_under_load() { - let cookie = calc_cookie(&wg.get_cookie_secret(), addr.to_string().as_bytes()); + if wg.check_handshake_load() { + let cookie = calc_cookie(&wg.get_cookie_secret(), &socket_addr_to_bytes(&addr)); if !cookie_verify(p, &cookie) { debug!("Mac2 verify failed, send cookie reply."); let peer_id = Id::from_slice(&p[4..8]); @@ -247,8 +247,8 @@ fn udp_process_handshake_resp(wg: &WgState, sock: &UdpSocket, p: &[u8], addr: So // Lock info. let info = wg.info.read().unwrap(); - if is_under_load() { - let cookie = calc_cookie(&wg.get_cookie_secret(), addr.to_string().as_bytes()); + if wg.check_handshake_load() { + let cookie = calc_cookie(&wg.get_cookie_secret(), &socket_addr_to_bytes(&addr)); if !cookie_verify(p, &cookie) { debug!("Mac2 verify failed, send cookie reply."); let peer_id = Id::from_slice(&p[4..8]); @@ -318,6 +318,14 @@ fn udp_process_handshake_resp(wg: &WgState, sock: &UdpSocket, p: &[u8], addr: So } } +/// Maps a `SocketAddr` to bytes. +fn socket_addr_to_bytes(a: &SocketAddr) -> [u8; 16] { + match a.ip() { + IpAddr::V4(a) => a.to_ipv6_mapped().octets(), + IpAddr::V6(a) => a.octets(), + } +} + fn udp_process_cookie_reply(wg: &WgState, p: &[u8]) { let self_id = Id::from_slice(&p[4..8]); @@ -864,6 +872,7 @@ impl WgState { id_map: RwLock::new(HashMap::with_capacity(4)), rt4: RwLock::new(IpLookupTable::new()), rt6: RwLock::new(IpLookupTable::new()), + load_monitor: Mutex::new(LoadMonitor::new(HANDSHAKES_PER_SEC)), cookie_secret: Mutex::new((cookie, Instant::now())), } } @@ -901,6 +910,10 @@ impl WgState { } } + fn check_handshake_load(&self) -> bool { + self.load_monitor.lock().unwrap().check() + } + fn get_cookie_secret(&self) -> [u8; 32] { let mut cs = self.cookie_secret.lock().unwrap(); let now = Instant::now(); diff --git a/src/protocol/load_monitor.rs b/src/protocol/load_monitor.rs new file mode 100644 index 0000000..5cb5475 --- /dev/null +++ b/src/protocol/load_monitor.rs @@ -0,0 +1,108 @@ +// Copyright 2017 Guanhao Yin <sopium@mysterious.site> + +// This file is part of WireGuard.rs. + +// WireGuard.rs is free software: you can redistribute it and/or +// modify it under the terms of the GNU General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. + +// WireGuard.rs is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. + +// You should have received a copy of the GNU General Public License +// along with WireGuard.rs. If not, see <https://www.gnu.org/licenses/>. + +use std::time::Instant; + +/// Monitors the frequency of handshake messages and determine whether +/// they are arriving too quickly. +/// +/// Implemented with a (deep) token bucket. +pub struct LoadMonitor { + // Constant. How many messages are allowed every seconds. + freq: u32, + // Scaled to 10^9, to match timer precision. + bucket: u64, + last_check: Instant, + under_load: bool, +} + +const CAP_RATIO: u64 = 4; + +impl LoadMonitor { + /// Create a new load monitor that allows `freq` messages per second. + /// + /// If there are more than `freq` messages per second, it will soon indicate + /// that we are under load. + /// + /// If there are less than `freq` messages per second, it will slowly + /// but eventually determine that we are no longer under load. + pub fn new(freq: u32) -> Self { + LoadMonitor { + freq: freq, + bucket: CAP_RATIO * freq as u64 * NANOS_PER_SEC, + last_check: Instant::now(), + under_load: false, + } + } + + /// Call this when receiving a message. + /// + /// Returns whether we are under load. + pub fn check(&mut self) -> bool { + let freq = self.freq as u64; + let cap = CAP_RATIO * freq * NANOS_PER_SEC; + + let now = Instant::now(); + let passed = now.duration_since(self.last_check); + let bucket_add = (passed.as_secs() * freq * NANOS_PER_SEC) + + passed.subsec_nanos() as u64 * freq; + self.last_check = now; + self.bucket = ::std::cmp::min(cap, self.bucket + bucket_add); + + self.bucket = self.bucket.saturating_sub(NANOS_PER_SEC); + + // println!("bucket: {}", self.bucket as f64 / NANOS_PER_SEC as f64); + + if self.under_load { + if self.bucket >= 7 * cap / 8 { + self.under_load = false; + debug!("No longer under load."); + } + } else { + if self.bucket <= 3 * cap / 4 { + self.under_load = true; + debug!("Under load!"); + } + } + + self.under_load + } +} + +const NANOS_PER_SEC: u64 = 1_000_000_000; + +#[cfg(test)] +mod tests { + use super::*; + use std::thread::sleep; + use std::time::Duration; + + #[test] + fn load_monitor() { + let mut u = LoadMonitor::new(100); + + for _ in 0..110 { + u.check(); + } + + assert!(u.check()); + + sleep(Duration::from_secs(1)); + + assert!(!u.check()); + } +} diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index ae1d851..ee7e511 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -31,6 +31,8 @@ mod ip; mod controller; /// A generic timer, but optimised for operations mostly used in WG. mod timer; +/// Determine load. +mod load_monitor; /// Re-export some types and functions from other crates, so users /// of this module won't have to manually pull in all these crates. @@ -44,3 +46,4 @@ use self::ip::*; use self::timer::*; pub use self::types::{WgInfo, PeerInfo, WgStateOut, PeerStateOut}; use self::types::*; +use self::load_monitor::*; |