aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGuanhao Yin <sopium@mysterious.site>2017-03-27 22:27:01 +0800
committerGuanhao Yin <sopium@mysterious.site>2017-03-27 22:27:01 +0800
commitb15c04948d3a32892f0ec5307186fc572dd06e10 (patch)
tree1d682dedc71f962d3c359fa5aa9d24168145742f
parentWork with (super) jumbo frames (diff)
downloadwireguard-rs-b15c04948d3a32892f0ec5307186fc572dd06e10.tar.xz
wireguard-rs-b15c04948d3a32892f0ec5307186fc572dd06e10.zip
Implement handshake load monitoring
-rw-r--r--src/protocol/controller.rs31
-rw-r--r--src/protocol/load_monitor.rs108
-rw-r--r--src/protocol/mod.rs3
3 files changed, 133 insertions, 9 deletions
diff --git a/src/protocol/controller.rs b/src/protocol/controller.rs
index c177c08..bb9f3e9 100644
--- a/src/protocol/controller.rs
+++ b/src/protocol/controller.rs
@@ -54,8 +54,12 @@ const REJECT_AFTER_TIME: u64 = 180;
const REKEY_TIMEOUT: u64 = 5;
const KEEPALIVE_TIMEOUT: u64 = 10;
+
const BUFSIZE: usize = 65536;
+// How many handshake messages per second is considered normal load.
+const HANDSHAKES_PER_SEC: u32 = 250;
+
// How many packets to queue.
const QUEUE_SIZE: usize = 16;
@@ -76,6 +80,7 @@ pub struct WgState {
rt4: RwLock<IpLookupTable<Ipv4Addr, SharedPeerState>>,
rt6: RwLock<IpLookupTable<Ipv6Addr, SharedPeerState>>,
+ load_monitor: Mutex<LoadMonitor>,
// The secret used to calc cookie.
cookie_secret: Mutex<([u8; 32], Instant)>,
}
@@ -167,11 +172,6 @@ struct Transport {
reject_after_time: Mutex<TimerHandle>,
}
-// TODO determine / detect load.
-fn is_under_load() -> bool {
- false
-}
-
fn udp_process_handshake_init(wg: Arc<WgState>, sock: &UdpSocket, p: &[u8], addr: SocketAddr) {
if p.len() != 148 {
return;
@@ -180,8 +180,8 @@ fn udp_process_handshake_init(wg: Arc<WgState>, sock: &UdpSocket, p: &[u8], addr
// Lock info.
let info = wg.info.read().unwrap();
- if is_under_load() {
- let cookie = calc_cookie(&wg.get_cookie_secret(), addr.to_string().as_bytes());
+ if wg.check_handshake_load() {
+ let cookie = calc_cookie(&wg.get_cookie_secret(), &socket_addr_to_bytes(&addr));
if !cookie_verify(p, &cookie) {
debug!("Mac2 verify failed, send cookie reply.");
let peer_id = Id::from_slice(&p[4..8]);
@@ -247,8 +247,8 @@ fn udp_process_handshake_resp(wg: &WgState, sock: &UdpSocket, p: &[u8], addr: So
// Lock info.
let info = wg.info.read().unwrap();
- if is_under_load() {
- let cookie = calc_cookie(&wg.get_cookie_secret(), addr.to_string().as_bytes());
+ if wg.check_handshake_load() {
+ let cookie = calc_cookie(&wg.get_cookie_secret(), &socket_addr_to_bytes(&addr));
if !cookie_verify(p, &cookie) {
debug!("Mac2 verify failed, send cookie reply.");
let peer_id = Id::from_slice(&p[4..8]);
@@ -318,6 +318,14 @@ fn udp_process_handshake_resp(wg: &WgState, sock: &UdpSocket, p: &[u8], addr: So
}
}
+/// Maps a `SocketAddr` to bytes.
+fn socket_addr_to_bytes(a: &SocketAddr) -> [u8; 16] {
+ match a.ip() {
+ IpAddr::V4(a) => a.to_ipv6_mapped().octets(),
+ IpAddr::V6(a) => a.octets(),
+ }
+}
+
fn udp_process_cookie_reply(wg: &WgState, p: &[u8]) {
let self_id = Id::from_slice(&p[4..8]);
@@ -864,6 +872,7 @@ impl WgState {
id_map: RwLock::new(HashMap::with_capacity(4)),
rt4: RwLock::new(IpLookupTable::new()),
rt6: RwLock::new(IpLookupTable::new()),
+ load_monitor: Mutex::new(LoadMonitor::new(HANDSHAKES_PER_SEC)),
cookie_secret: Mutex::new((cookie, Instant::now())),
}
}
@@ -901,6 +910,10 @@ impl WgState {
}
}
+ fn check_handshake_load(&self) -> bool {
+ self.load_monitor.lock().unwrap().check()
+ }
+
fn get_cookie_secret(&self) -> [u8; 32] {
let mut cs = self.cookie_secret.lock().unwrap();
let now = Instant::now();
diff --git a/src/protocol/load_monitor.rs b/src/protocol/load_monitor.rs
new file mode 100644
index 0000000..5cb5475
--- /dev/null
+++ b/src/protocol/load_monitor.rs
@@ -0,0 +1,108 @@
+// Copyright 2017 Guanhao Yin <sopium@mysterious.site>
+
+// This file is part of WireGuard.rs.
+
+// WireGuard.rs is free software: you can redistribute it and/or
+// modify it under the terms of the GNU General Public License as
+// published by the Free Software Foundation, either version 3 of the
+// License, or (at your option) any later version.
+
+// WireGuard.rs is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// General Public License for more details.
+
+// You should have received a copy of the GNU General Public License
+// along with WireGuard.rs. If not, see <https://www.gnu.org/licenses/>.
+
+use std::time::Instant;
+
+/// Monitors the frequency of handshake messages and determine whether
+/// they are arriving too quickly.
+///
+/// Implemented with a (deep) token bucket.
+pub struct LoadMonitor {
+ // Constant. How many messages are allowed every seconds.
+ freq: u32,
+ // Scaled to 10^9, to match timer precision.
+ bucket: u64,
+ last_check: Instant,
+ under_load: bool,
+}
+
+const CAP_RATIO: u64 = 4;
+
+impl LoadMonitor {
+ /// Create a new load monitor that allows `freq` messages per second.
+ ///
+ /// If there are more than `freq` messages per second, it will soon indicate
+ /// that we are under load.
+ ///
+ /// If there are less than `freq` messages per second, it will slowly
+ /// but eventually determine that we are no longer under load.
+ pub fn new(freq: u32) -> Self {
+ LoadMonitor {
+ freq: freq,
+ bucket: CAP_RATIO * freq as u64 * NANOS_PER_SEC,
+ last_check: Instant::now(),
+ under_load: false,
+ }
+ }
+
+ /// Call this when receiving a message.
+ ///
+ /// Returns whether we are under load.
+ pub fn check(&mut self) -> bool {
+ let freq = self.freq as u64;
+ let cap = CAP_RATIO * freq * NANOS_PER_SEC;
+
+ let now = Instant::now();
+ let passed = now.duration_since(self.last_check);
+ let bucket_add = (passed.as_secs() * freq * NANOS_PER_SEC) +
+ passed.subsec_nanos() as u64 * freq;
+ self.last_check = now;
+ self.bucket = ::std::cmp::min(cap, self.bucket + bucket_add);
+
+ self.bucket = self.bucket.saturating_sub(NANOS_PER_SEC);
+
+ // println!("bucket: {}", self.bucket as f64 / NANOS_PER_SEC as f64);
+
+ if self.under_load {
+ if self.bucket >= 7 * cap / 8 {
+ self.under_load = false;
+ debug!("No longer under load.");
+ }
+ } else {
+ if self.bucket <= 3 * cap / 4 {
+ self.under_load = true;
+ debug!("Under load!");
+ }
+ }
+
+ self.under_load
+ }
+}
+
+const NANOS_PER_SEC: u64 = 1_000_000_000;
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::thread::sleep;
+ use std::time::Duration;
+
+ #[test]
+ fn load_monitor() {
+ let mut u = LoadMonitor::new(100);
+
+ for _ in 0..110 {
+ u.check();
+ }
+
+ assert!(u.check());
+
+ sleep(Duration::from_secs(1));
+
+ assert!(!u.check());
+ }
+}
diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs
index ae1d851..ee7e511 100644
--- a/src/protocol/mod.rs
+++ b/src/protocol/mod.rs
@@ -31,6 +31,8 @@ mod ip;
mod controller;
/// A generic timer, but optimised for operations mostly used in WG.
mod timer;
+/// Determine load.
+mod load_monitor;
/// Re-export some types and functions from other crates, so users
/// of this module won't have to manually pull in all these crates.
@@ -44,3 +46,4 @@ use self::ip::*;
use self::timer::*;
pub use self::types::{WgInfo, PeerInfo, WgStateOut, PeerStateOut};
use self::types::*;
+use self::load_monitor::*;