diff options
author | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-08-06 13:02:13 +0200 |
---|---|---|
committer | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-08-06 13:02:13 +0200 |
commit | a12e6e139c963cfe9c78edc9ae83a71049722f29 (patch) | |
tree | b6bfd82f28553c84eae5a48f6b43b4cf88cd5bbc /src/handshake/ratelimiter.rs | |
parent | Prepare for resuse of message buffers for response (diff) | |
download | wireguard-rs-a12e6e139c963cfe9c78edc9ae83a71049722f29.tar.xz wireguard-rs-a12e6e139c963cfe9c78edc9ae83a71049722f29.zip |
Add rate limiter check to handshake messages.
Diffstat (limited to 'src/handshake/ratelimiter.rs')
-rw-r--r-- | src/handshake/ratelimiter.rs | 162 |
1 files changed, 162 insertions, 0 deletions
diff --git a/src/handshake/ratelimiter.rs b/src/handshake/ratelimiter.rs new file mode 100644 index 0000000..ce09c16 --- /dev/null +++ b/src/handshake/ratelimiter.rs @@ -0,0 +1,162 @@ +use std::collections::HashMap; +use std::net::IpAddr; +use std::time::{Duration, Instant}; + +use lazy_static::lazy_static; + +const PACKETS_PER_SECOND: u64 = 20; +const PACKETS_BURSTABLE: u64 = 5; +const PACKET_COST: u64 = 1_000_000_000 / PACKETS_PER_SECOND; +const MAX_TOKENS: u64 = PACKET_COST * PACKETS_BURSTABLE; + +lazy_static! { + pub static ref GC_INTERVAL: Duration = Duration::new(1, 0); +} + +struct Entry { + pub last_time: Instant, + pub tokens: u64, +} + +pub struct RateLimiter { + garbage_collect: Instant, + table: HashMap<IpAddr, Entry>, +} + +impl RateLimiter { + pub fn new() -> Self { + RateLimiter { + garbage_collect: Instant::now(), + table: HashMap::new(), + } + } + + pub fn allow(&mut self, addr: &IpAddr) -> bool { + // check for garbage collection + if self.garbage_collect.elapsed() > *GC_INTERVAL { + self.handle_gc(); + } + + // update existing entry + if let Some(entry) = self.table.get_mut(addr) { + // add tokens earned since last time + entry.tokens = + MAX_TOKENS.min(entry.tokens + u64::from(entry.last_time.elapsed().subsec_nanos())); + entry.last_time = Instant::now(); + + // subtract cost of packet + if entry.tokens > PACKET_COST { + entry.tokens -= PACKET_COST; + return true; + } else { + return false; + } + } + + // add new entry + self.table.insert( + *addr, + Entry { + last_time: Instant::now(), + tokens: MAX_TOKENS - PACKET_COST, + }, + ); + + true + } + + fn handle_gc(&mut self) { + self.table + .retain(|_, ref mut entry| entry.last_time.elapsed() <= *GC_INTERVAL); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std; + + struct Result { + allowed: bool, + text: &'static str, + wait: Duration, + } + + #[test] + fn test_ratelimiter() { + let mut ratelimiter = RateLimiter::new(); + let mut expected = vec![]; + let ips = vec![ + "127.0.0.1".parse().unwrap(), + "192.168.1.1".parse().unwrap(), + "172.167.2.3".parse().unwrap(), + "97.231.252.215".parse().unwrap(), + "248.97.91.167".parse().unwrap(), + "188.208.233.47".parse().unwrap(), + "104.2.183.179".parse().unwrap(), + "72.129.46.120".parse().unwrap(), + "2001:0db8:0a0b:12f0:0000:0000:0000:0001".parse().unwrap(), + "f5c2:818f:c052:655a:9860:b136:6894:25f0".parse().unwrap(), + "b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc".parse().unwrap(), + "a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918".parse().unwrap(), + "ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445".parse().unwrap(), + "3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4".parse().unwrap(), + ]; + + for _ in 0..PACKETS_BURSTABLE { + expected.push(Result { + allowed: true, + wait: Duration::new(0, 0), + text: "inital burst", + }); + } + + expected.push(Result { + allowed: false, + wait: Duration::new(0, 0), + text: "after burst", + }); + + expected.push(Result { + allowed: true, + wait: Duration::new(0, PACKET_COST as u32), + text: "filling tokens for single packet", + }); + + expected.push(Result { + allowed: false, + wait: Duration::new(0, 0), + text: "not having refilled enough", + }); + + expected.push(Result { + allowed: true, + wait: Duration::new(0, 2 * PACKET_COST as u32), + text: "filling tokens for 2 * packet burst", + }); + + expected.push(Result { + allowed: true, + wait: Duration::new(0, 0), + text: "second packet in 2 packet burst", + }); + + expected.push(Result { + allowed: false, + wait: Duration::new(0, 0), + text: "packet following 2 packet burst", + }); + + for item in expected { + std::thread::sleep(item.wait); + for ip in ips.iter() { + if ratelimiter.allow(&ip) != item.allowed { + panic!( + "test failed for {} on {}. expected: {}, got: {}", + ip, item.text, item.allowed, !item.allowed + ) + } + } + } + } +} |