summaryrefslogtreecommitdiffstats
path: root/src/handshake/ratelimiter.rs
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-08-07 22:51:58 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-08-07 22:51:58 +0200
commitb33381331f5c33037892bbc3c197376f3b8d945d (patch)
treee574fe1201a3107a12fd38e15ad8957613c28444 /src/handshake/ratelimiter.rs
parentAdded initiation flood protection (diff)
downloadwireguard-rs-b33381331f5c33037892bbc3c197376f3b8d945d.tar.xz
wireguard-rs-b33381331f5c33037892bbc3c197376f3b8d945d.zip
Concurrent rate limiter
The new rate limiter allows multiple simultaneous .allow calls. Also delegated GC to tokio.
Diffstat (limited to '')
-rw-r--r--src/handshake/ratelimiter.rs250
1 files changed, 143 insertions, 107 deletions
diff --git a/src/handshake/ratelimiter.rs b/src/handshake/ratelimiter.rs
index ce09c16..adf9667 100644
--- a/src/handshake/ratelimiter.rs
+++ b/src/handshake/ratelimiter.rs
@@ -1,6 +1,13 @@
use std::collections::HashMap;
use std::net::IpAddr;
+use std::sync::atomic::{AtomicBool, Ordering};
+use std::sync::Arc;
use std::time::{Duration, Instant};
+use spin::{RwLock, Mutex};
+
+use tokio::prelude::future;
+use future::{loop_fn, Future, Loop, lazy};
+use tokio::timer::Delay;
use lazy_static::lazy_static;
@@ -18,57 +25,82 @@ struct Entry {
pub tokens: u64,
}
-pub struct RateLimiter {
- garbage_collect: Instant,
- table: HashMap<IpAddr, Entry>,
+pub struct RateLimiter(Arc<RateLimiterInner>);
+
+struct RateLimiterInner{
+ gc_running: AtomicBool,
+ table: RwLock<HashMap<IpAddr, Mutex<Entry>>>,
}
impl RateLimiter {
pub fn new() -> Self {
- RateLimiter {
- garbage_collect: Instant::now(),
- table: HashMap::new(),
- }
+ RateLimiter (
+ Arc::new(RateLimiterInner {
+ gc_running: AtomicBool::from(false),
+ table: RwLock::new(HashMap::new()),
+ })
+ )
}
- pub fn allow(&mut self, addr: &IpAddr) -> bool {
- // check for garbage collection
- if self.garbage_collect.elapsed() > *GC_INTERVAL {
- self.handle_gc();
- }
-
- // update existing entry
- if let Some(entry) = self.table.get_mut(addr) {
- // add tokens earned since last time
- entry.tokens =
- MAX_TOKENS.min(entry.tokens + u64::from(entry.last_time.elapsed().subsec_nanos()));
- entry.last_time = Instant::now();
-
- // subtract cost of packet
- if entry.tokens > PACKET_COST {
- entry.tokens -= PACKET_COST;
- return true;
- } else {
- return false;
+ pub fn allow(&self, addr: &IpAddr) -> bool {
+ // check if allowed
+ let allowed = {
+ // check for existing entry (required read lock)
+ if let Some(entry) = self.0.table.read().get(addr) {
+ // update existing entry
+ let mut entry = entry.lock();
+
+ // add tokens earned since last time
+ entry.tokens =
+ MAX_TOKENS.min(entry.tokens + u64::from(entry.last_time.elapsed().subsec_nanos()));
+ entry.last_time = Instant::now();
+
+ // subtract cost of packet
+ if entry.tokens > PACKET_COST {
+ entry.tokens -= PACKET_COST;
+ return true;
+ } else {
+ return false;
+ }
}
+
+ // add new entry (write lock)
+ self.0.table.write().insert(
+ *addr,
+ Mutex::new(Entry {
+ last_time: Instant::now(),
+ tokens: MAX_TOKENS - PACKET_COST,
+ }),
+ );
+ true
+ };
+
+ // check that GC is scheduled
+ if !self.0.gc_running.swap(true, Ordering::Relaxed) {
+ let limiter = self.0.clone();
+ tokio::spawn(
+ loop_fn((), move |_| {
+ let limiter = limiter.clone();
+ let next_gc = Instant::now() + *GC_INTERVAL;
+ Delay::new(next_gc)
+ .map_err(|_| ())
+ .and_then(move |_| {
+ let mut tw = limiter.table.write();
+ tw.retain(|_, ref mut entry| entry.lock().last_time.elapsed() <= *GC_INTERVAL);
+ if tw.len() > 0 {
+ Ok(Loop::Continue(()))
+ } else {
+ limiter.gc_running.store(false, Ordering::Relaxed);
+ Ok(Loop::Break(()))
+ }
+ })
+ })
+ );
}
- // add new entry
- self.table.insert(
- *addr,
- Entry {
- last_time: Instant::now(),
- tokens: MAX_TOKENS - PACKET_COST,
- },
- );
-
- true
+ allowed
}
- fn handle_gc(&mut self) {
- self.table
- .retain(|_, ref mut entry| entry.last_time.elapsed() <= *GC_INTERVAL);
- }
}
#[cfg(test)]
@@ -84,79 +116,83 @@ mod tests {
#[test]
fn test_ratelimiter() {
- let mut ratelimiter = RateLimiter::new();
- let mut expected = vec![];
- let ips = vec![
- "127.0.0.1".parse().unwrap(),
- "192.168.1.1".parse().unwrap(),
- "172.167.2.3".parse().unwrap(),
- "97.231.252.215".parse().unwrap(),
- "248.97.91.167".parse().unwrap(),
- "188.208.233.47".parse().unwrap(),
- "104.2.183.179".parse().unwrap(),
- "72.129.46.120".parse().unwrap(),
- "2001:0db8:0a0b:12f0:0000:0000:0000:0001".parse().unwrap(),
- "f5c2:818f:c052:655a:9860:b136:6894:25f0".parse().unwrap(),
- "b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc".parse().unwrap(),
- "a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918".parse().unwrap(),
- "ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445".parse().unwrap(),
- "3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4".parse().unwrap(),
- ];
-
- for _ in 0..PACKETS_BURSTABLE {
+ tokio::run(lazy(|| {
+ let mut ratelimiter = RateLimiter::new();
+ let mut expected = vec![];
+ let ips = vec![
+ "127.0.0.1".parse().unwrap(),
+ "192.168.1.1".parse().unwrap(),
+ "172.167.2.3".parse().unwrap(),
+ "97.231.252.215".parse().unwrap(),
+ "248.97.91.167".parse().unwrap(),
+ "188.208.233.47".parse().unwrap(),
+ "104.2.183.179".parse().unwrap(),
+ "72.129.46.120".parse().unwrap(),
+ "2001:0db8:0a0b:12f0:0000:0000:0000:0001".parse().unwrap(),
+ "f5c2:818f:c052:655a:9860:b136:6894:25f0".parse().unwrap(),
+ "b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc".parse().unwrap(),
+ "a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918".parse().unwrap(),
+ "ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445".parse().unwrap(),
+ "3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4".parse().unwrap(),
+ ];
+
+ for _ in 0..PACKETS_BURSTABLE {
+ expected.push(Result {
+ allowed: true,
+ wait: Duration::new(0, 0),
+ text: "inital burst",
+ });
+ }
+
+ expected.push(Result {
+ allowed: false,
+ wait: Duration::new(0, 0),
+ text: "after burst",
+ });
+
expected.push(Result {
allowed: true,
+ wait: Duration::new(0, PACKET_COST as u32),
+ text: "filling tokens for single packet",
+ });
+
+ expected.push(Result {
+ allowed: false,
wait: Duration::new(0, 0),
- text: "inital burst",
+ text: "not having refilled enough",
+ });
+
+ expected.push(Result {
+ allowed: true,
+ wait: Duration::new(0, 2 * PACKET_COST as u32),
+ text: "filling tokens for 2 * packet burst",
+ });
+
+ expected.push(Result {
+ allowed: true,
+ wait: Duration::new(0, 0),
+ text: "second packet in 2 packet burst",
+ });
+
+ expected.push(Result {
+ allowed: false,
+ wait: Duration::new(0, 0),
+ text: "packet following 2 packet burst",
});
- }
- expected.push(Result {
- allowed: false,
- wait: Duration::new(0, 0),
- text: "after burst",
- });
-
- expected.push(Result {
- allowed: true,
- wait: Duration::new(0, PACKET_COST as u32),
- text: "filling tokens for single packet",
- });
-
- expected.push(Result {
- allowed: false,
- wait: Duration::new(0, 0),
- text: "not having refilled enough",
- });
-
- expected.push(Result {
- allowed: true,
- wait: Duration::new(0, 2 * PACKET_COST as u32),
- text: "filling tokens for 2 * packet burst",
- });
-
- expected.push(Result {
- allowed: true,
- wait: Duration::new(0, 0),
- text: "second packet in 2 packet burst",
- });
-
- expected.push(Result {
- allowed: false,
- wait: Duration::new(0, 0),
- text: "packet following 2 packet burst",
- });
-
- for item in expected {
- std::thread::sleep(item.wait);
- for ip in ips.iter() {
- if ratelimiter.allow(&ip) != item.allowed {
- panic!(
- "test failed for {} on {}. expected: {}, got: {}",
- ip, item.text, item.allowed, !item.allowed
- )
+ for item in expected {
+ std::thread::sleep(item.wait);
+ for ip in ips.iter() {
+ if ratelimiter.allow(&ip) != item.allowed {
+ panic!(
+ "test failed for {} on {}. expected: {}, got: {}",
+ ip, item.text, item.allowed, !item.allowed
+ )
+ }
}
}
- }
+
+ Ok(())
+ }));
}
}