From b33381331f5c33037892bbc3c197376f3b8d945d Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Wed, 7 Aug 2019 22:51:58 +0200 Subject: Concurrent rate limiter The new rate limiter allows multiple simultaneous .allow calls. Also delegated GC to tokio. --- src/handshake/ratelimiter.rs | 250 +++++++++++++++++++++++++------------------ 1 file changed, 143 insertions(+), 107 deletions(-) (limited to 'src') diff --git a/src/handshake/ratelimiter.rs b/src/handshake/ratelimiter.rs index ce09c16..adf9667 100644 --- a/src/handshake/ratelimiter.rs +++ b/src/handshake/ratelimiter.rs @@ -1,6 +1,13 @@ use std::collections::HashMap; use std::net::IpAddr; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; use std::time::{Duration, Instant}; +use spin::{RwLock, Mutex}; + +use tokio::prelude::future; +use future::{loop_fn, Future, Loop, lazy}; +use tokio::timer::Delay; use lazy_static::lazy_static; @@ -18,57 +25,82 @@ struct Entry { pub tokens: u64, } -pub struct RateLimiter { - garbage_collect: Instant, - table: HashMap, +pub struct RateLimiter(Arc); + +struct RateLimiterInner{ + gc_running: AtomicBool, + table: RwLock>>, } impl RateLimiter { pub fn new() -> Self { - RateLimiter { - garbage_collect: Instant::now(), - table: HashMap::new(), - } + RateLimiter ( + Arc::new(RateLimiterInner { + gc_running: AtomicBool::from(false), + table: RwLock::new(HashMap::new()), + }) + ) } - pub fn allow(&mut self, addr: &IpAddr) -> bool { - // check for garbage collection - if self.garbage_collect.elapsed() > *GC_INTERVAL { - self.handle_gc(); - } - - // update existing entry - if let Some(entry) = self.table.get_mut(addr) { - // add tokens earned since last time - entry.tokens = - MAX_TOKENS.min(entry.tokens + u64::from(entry.last_time.elapsed().subsec_nanos())); - entry.last_time = Instant::now(); - - // subtract cost of packet - if entry.tokens > PACKET_COST { - entry.tokens -= PACKET_COST; - return true; - } else { - return false; + pub fn allow(&self, addr: &IpAddr) -> bool { + // check if allowed + let allowed = { + // check for existing entry (required read lock) + if let Some(entry) = self.0.table.read().get(addr) { + // update existing entry + let mut entry = entry.lock(); + + // add tokens earned since last time + entry.tokens = + MAX_TOKENS.min(entry.tokens + u64::from(entry.last_time.elapsed().subsec_nanos())); + entry.last_time = Instant::now(); + + // subtract cost of packet + if entry.tokens > PACKET_COST { + entry.tokens -= PACKET_COST; + return true; + } else { + return false; + } } + + // add new entry (write lock) + self.0.table.write().insert( + *addr, + Mutex::new(Entry { + last_time: Instant::now(), + tokens: MAX_TOKENS - PACKET_COST, + }), + ); + true + }; + + // check that GC is scheduled + if !self.0.gc_running.swap(true, Ordering::Relaxed) { + let limiter = self.0.clone(); + tokio::spawn( + loop_fn((), move |_| { + let limiter = limiter.clone(); + let next_gc = Instant::now() + *GC_INTERVAL; + Delay::new(next_gc) + .map_err(|_| ()) + .and_then(move |_| { + let mut tw = limiter.table.write(); + tw.retain(|_, ref mut entry| entry.lock().last_time.elapsed() <= *GC_INTERVAL); + if tw.len() > 0 { + Ok(Loop::Continue(())) + } else { + limiter.gc_running.store(false, Ordering::Relaxed); + Ok(Loop::Break(())) + } + }) + }) + ); } - // add new entry - self.table.insert( - *addr, - Entry { - last_time: Instant::now(), - tokens: MAX_TOKENS - PACKET_COST, - }, - ); - - true + allowed } - fn handle_gc(&mut self) { - self.table - .retain(|_, ref mut entry| entry.last_time.elapsed() <= *GC_INTERVAL); - } } #[cfg(test)] @@ -84,79 +116,83 @@ mod tests { #[test] fn test_ratelimiter() { - let mut ratelimiter = RateLimiter::new(); - let mut expected = vec![]; - let ips = vec![ - "127.0.0.1".parse().unwrap(), - "192.168.1.1".parse().unwrap(), - "172.167.2.3".parse().unwrap(), - "97.231.252.215".parse().unwrap(), - "248.97.91.167".parse().unwrap(), - "188.208.233.47".parse().unwrap(), - "104.2.183.179".parse().unwrap(), - "72.129.46.120".parse().unwrap(), - "2001:0db8:0a0b:12f0:0000:0000:0000:0001".parse().unwrap(), - "f5c2:818f:c052:655a:9860:b136:6894:25f0".parse().unwrap(), - "b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc".parse().unwrap(), - "a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918".parse().unwrap(), - "ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445".parse().unwrap(), - "3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4".parse().unwrap(), - ]; - - for _ in 0..PACKETS_BURSTABLE { + tokio::run(lazy(|| { + let mut ratelimiter = RateLimiter::new(); + let mut expected = vec![]; + let ips = vec![ + "127.0.0.1".parse().unwrap(), + "192.168.1.1".parse().unwrap(), + "172.167.2.3".parse().unwrap(), + "97.231.252.215".parse().unwrap(), + "248.97.91.167".parse().unwrap(), + "188.208.233.47".parse().unwrap(), + "104.2.183.179".parse().unwrap(), + "72.129.46.120".parse().unwrap(), + "2001:0db8:0a0b:12f0:0000:0000:0000:0001".parse().unwrap(), + "f5c2:818f:c052:655a:9860:b136:6894:25f0".parse().unwrap(), + "b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc".parse().unwrap(), + "a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918".parse().unwrap(), + "ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445".parse().unwrap(), + "3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4".parse().unwrap(), + ]; + + for _ in 0..PACKETS_BURSTABLE { + expected.push(Result { + allowed: true, + wait: Duration::new(0, 0), + text: "inital burst", + }); + } + + expected.push(Result { + allowed: false, + wait: Duration::new(0, 0), + text: "after burst", + }); + expected.push(Result { allowed: true, + wait: Duration::new(0, PACKET_COST as u32), + text: "filling tokens for single packet", + }); + + expected.push(Result { + allowed: false, wait: Duration::new(0, 0), - text: "inital burst", + text: "not having refilled enough", + }); + + expected.push(Result { + allowed: true, + wait: Duration::new(0, 2 * PACKET_COST as u32), + text: "filling tokens for 2 * packet burst", + }); + + expected.push(Result { + allowed: true, + wait: Duration::new(0, 0), + text: "second packet in 2 packet burst", + }); + + expected.push(Result { + allowed: false, + wait: Duration::new(0, 0), + text: "packet following 2 packet burst", }); - } - expected.push(Result { - allowed: false, - wait: Duration::new(0, 0), - text: "after burst", - }); - - expected.push(Result { - allowed: true, - wait: Duration::new(0, PACKET_COST as u32), - text: "filling tokens for single packet", - }); - - expected.push(Result { - allowed: false, - wait: Duration::new(0, 0), - text: "not having refilled enough", - }); - - expected.push(Result { - allowed: true, - wait: Duration::new(0, 2 * PACKET_COST as u32), - text: "filling tokens for 2 * packet burst", - }); - - expected.push(Result { - allowed: true, - wait: Duration::new(0, 0), - text: "second packet in 2 packet burst", - }); - - expected.push(Result { - allowed: false, - wait: Duration::new(0, 0), - text: "packet following 2 packet burst", - }); - - for item in expected { - std::thread::sleep(item.wait); - for ip in ips.iter() { - if ratelimiter.allow(&ip) != item.allowed { - panic!( - "test failed for {} on {}. expected: {}, got: {}", - ip, item.text, item.allowed, !item.allowed - ) + for item in expected { + std::thread::sleep(item.wait); + for ip in ips.iter() { + if ratelimiter.allow(&ip) != item.allowed { + panic!( + "test failed for {} on {}. expected: {}, got: {}", + ip, item.text, item.allowed, !item.allowed + ) + } } } - } + + Ok(()) + })); } } -- cgit v1.2.3-59-g8ed1b