From a50079552ac5148ef4d30a30948ffe4095d3d0ba Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sat, 10 Aug 2019 16:01:56 +0200 Subject: Kill GC thread on Ratelimiter drop --- src/handshake/device.rs | 113 +++++++++++++++++++++--- src/handshake/macs.rs | 2 +- src/handshake/ratelimiter.rs | 206 ++++++++++++++++++++++--------------------- 3 files changed, 209 insertions(+), 112 deletions(-) (limited to 'src') diff --git a/src/handshake/device.rs b/src/handshake/device.rs index 3ec161f..86a832a 100644 --- a/src/handshake/device.rs +++ b/src/handshake/device.rs @@ -356,20 +356,18 @@ mod tests { use super::super::messages::*; use super::*; use hex; - use rand::rngs::OsRng; use std::thread; + use rand::rngs::OsRng; use std::time::Duration; + use std::net::SocketAddr; - #[test] - fn handshake() { + fn setup_devices(rng : &mut R) -> (PublicKey, Device, PublicKey, Device) { // generate new keypairs - let mut rng = OsRng::new().unwrap(); - - let sk1 = StaticSecret::new(&mut rng); + let sk1 = StaticSecret::new(rng); let pk1 = PublicKey::from(&sk1); - let sk2 = StaticSecret::new(&mut rng); + let sk2 = StaticSecret::new(rng); let pk2 = PublicKey::from(&sk2); // pick random psk @@ -388,7 +386,103 @@ mod tests { dev1.set_psk(pk2, Some(psk)).unwrap(); dev2.set_psk(pk1, Some(psk)).unwrap(); - // do a few handshakes + (pk1, dev1, pk2, dev2) + } + + /* Test longest possible handshake interaction (7 messages): + * + * 1. I -> R (initation) + * 2. I <- R (cookie reply) + * 3. I -> R (initation) + * 4. I <- R (response) + * 5. I -> R (cookie reply) + * 6. I -> R (initation) + * 7. I <- R (response) + */ + #[test] + fn handshake_under_load() { + let mut rng = OsRng::new().unwrap(); + let (_pk1, dev1, pk2, dev2) = setup_devices(&mut rng); + + let src1 : SocketAddr = "172.16.0.1:8080".parse().unwrap(); + let src2 : SocketAddr = "172.16.0.2:7070".parse().unwrap(); + + // 1. device-1 : create first initation + let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); + + // 2. device-2 : responds with CookieReply + let msg_cookie = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() { + (None, Some(msg), None) => msg, + _ => panic!("unexpected response") + }; + + // device-1 : processes CookieReply (no response) + match dev1.process(&mut rng, &msg_cookie, Some(&src2)).unwrap() { + (None, None, None) => (), + _ => panic!("unexpected response") + } + + // avoid initation flood + thread::sleep(Duration::from_millis(20)); + + // 3. device-1 : create second initation + let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); + + // 4. device-2 : responds with noise response + let msg_response = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() { + (Some(_), Some(msg), Some(kp)) => { + assert_eq!(kp.confirmed, false); + msg + }, + _ => panic!("unexpected response") + }; + + // 5. device-1 : responds with CookieReply + let msg_cookie = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() { + (None, Some(msg), None) => msg, + _ => panic!("unexpected response") + }; + + // device-2 : processes CookieReply (no response) + match dev2.process(&mut rng, &msg_cookie, Some(&src1)).unwrap() { + (None, None, None) => (), + _ => panic!("unexpected response") + } + + // avoid initation flood + thread::sleep(Duration::from_millis(20)); + + // 6. device-1 : create third initation + let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); + + // 7. device-2 : responds with noise response + let (msg_response, kp1) = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() { + (Some(_), Some(msg), Some(kp)) => { + assert_eq!(kp.confirmed, false); + (msg, kp) + }, + _ => panic!("unexpected response") + }; + + // device-1 : process noise response + let kp2 = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() { + (Some(_), None, Some(kp)) => { + assert_eq!(kp.confirmed, true); + kp + }, + _ => panic!("unexpected response") + }; + + assert_eq!(kp1.send, kp2.recv); + assert_eq!(kp1.recv, kp2.send); + } + + #[test] + fn handshake_no_load() { + let mut rng = OsRng::new().unwrap(); + let (pk1, mut dev1, pk2, mut dev2) = setup_devices(&mut rng); + + // do a few handshakes (every handshake should succeed) for i in 0..10 { println!("handshake : {}", i); @@ -430,9 +524,6 @@ mod tests { thread::sleep(Duration::from_millis(20)); } - assert_eq!(dev1.get_psk(pk2).unwrap(), psk); - assert_eq!(dev2.get_psk(pk1).unwrap(), psk); - dev1.remove(pk2).unwrap(); dev2.remove(pk1).unwrap(); } diff --git a/src/handshake/macs.rs b/src/handshake/macs.rs index 3070da3..721fc88 100644 --- a/src/handshake/macs.rs +++ b/src/handshake/macs.rs @@ -309,7 +309,7 @@ mod tests { let mut msg = CookieReply::default(); let mut rng = OsRng::new().expect("failed to create rng"); let mut macs = MacsFooter::default(); - let src = "127.0.0.1:8080".parse().unwrap(); + let src = "192.0.2.16:8080".parse().unwrap(); let (validator, mut generator) = new_validator_generator(); // generate mac1 for first message diff --git a/src/handshake/ratelimiter.rs b/src/handshake/ratelimiter.rs index adf9667..02b82e7 100644 --- a/src/handshake/ratelimiter.rs +++ b/src/handshake/ratelimiter.rs @@ -1,13 +1,10 @@ use std::collections::HashMap; use std::net::IpAddr; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; +use std::sync::{Condvar, Mutex, Arc}; +use std::thread; use std::time::{Duration, Instant}; -use spin::{RwLock, Mutex}; - -use tokio::prelude::future; -use future::{loop_fn, Future, Loop, lazy}; -use tokio::timer::Delay; +use spin; use lazy_static::lazy_static; @@ -29,15 +26,27 @@ pub struct RateLimiter(Arc); struct RateLimiterInner{ gc_running: AtomicBool, - table: RwLock>>, + gc_dropped: (Mutex, Condvar), + table: spin::RwLock>>, +} + +impl Drop for RateLimiter { + fn drop(&mut self) { + // wake up & terminate any lingering GC thread + let &(ref lock, ref cvar) = &self.0.gc_dropped; + let mut dropped = lock.lock().unwrap(); + *dropped = true; + cvar.notify_all(); + } } impl RateLimiter { pub fn new() -> Self { RateLimiter ( Arc::new(RateLimiterInner { + gc_dropped: (Mutex::new(false), Condvar::new()), gc_running: AtomicBool::from(false), - table: RwLock::new(HashMap::new()), + table: spin::RwLock::new(HashMap::new()), }) ) } @@ -45,7 +54,7 @@ impl RateLimiter { pub fn allow(&self, addr: &IpAddr) -> bool { // check if allowed let allowed = { - // check for existing entry (required read lock) + // check for existing entry (only requires read lock) if let Some(entry) = self.0.table.read().get(addr) { // update existing entry let mut entry = entry.lock(); @@ -67,7 +76,7 @@ impl RateLimiter { // add new entry (write lock) self.0.table.write().insert( *addr, - Mutex::new(Entry { + spin::Mutex::new(Entry { last_time: Instant::now(), tokens: MAX_TOKENS - PACKET_COST, }), @@ -75,27 +84,28 @@ impl RateLimiter { true }; - // check that GC is scheduled + // check that GC thread is scheduled if !self.0.gc_running.swap(true, Ordering::Relaxed) { let limiter = self.0.clone(); - tokio::spawn( - loop_fn((), move |_| { - let limiter = limiter.clone(); - let next_gc = Instant::now() + *GC_INTERVAL; - Delay::new(next_gc) - .map_err(|_| ()) - .and_then(move |_| { - let mut tw = limiter.table.write(); - tw.retain(|_, ref mut entry| entry.lock().last_time.elapsed() <= *GC_INTERVAL); - if tw.len() > 0 { - Ok(Loop::Continue(())) - } else { - limiter.gc_running.store(false, Ordering::Relaxed); - Ok(Loop::Break(())) - } - }) - }) - ); + thread::spawn(move || { + let &(ref lock, ref cvar) = &limiter.gc_dropped; + let mut dropped = lock.lock().unwrap(); + while !*dropped { + // garbage collect + { + let mut tw = limiter.table.write(); + tw.retain(|_, ref mut entry| entry.lock().last_time.elapsed() <= *GC_INTERVAL); + if tw.len() == 0 { + limiter.gc_running.store(false, Ordering::Relaxed); + return; + } + } + + // wait until stopped or new GC (~1 every sec) + let res = cvar.wait_timeout(dropped,*GC_INTERVAL).unwrap(); + dropped = res.0; + } + }); } allowed @@ -116,83 +126,79 @@ mod tests { #[test] fn test_ratelimiter() { - tokio::run(lazy(|| { - let mut ratelimiter = RateLimiter::new(); - let mut expected = vec![]; - let ips = vec![ - "127.0.0.1".parse().unwrap(), - "192.168.1.1".parse().unwrap(), - "172.167.2.3".parse().unwrap(), - "97.231.252.215".parse().unwrap(), - "248.97.91.167".parse().unwrap(), - "188.208.233.47".parse().unwrap(), - "104.2.183.179".parse().unwrap(), - "72.129.46.120".parse().unwrap(), - "2001:0db8:0a0b:12f0:0000:0000:0000:0001".parse().unwrap(), - "f5c2:818f:c052:655a:9860:b136:6894:25f0".parse().unwrap(), - "b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc".parse().unwrap(), - "a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918".parse().unwrap(), - "ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445".parse().unwrap(), - "3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4".parse().unwrap(), - ]; - - for _ in 0..PACKETS_BURSTABLE { - expected.push(Result { - allowed: true, - wait: Duration::new(0, 0), - text: "inital burst", - }); - } - - expected.push(Result { - allowed: false, - wait: Duration::new(0, 0), - text: "after burst", - }); - + let ratelimiter = RateLimiter::new(); + let mut expected = vec![]; + let ips = vec![ + "127.0.0.1".parse().unwrap(), + "192.168.1.1".parse().unwrap(), + "172.167.2.3".parse().unwrap(), + "97.231.252.215".parse().unwrap(), + "248.97.91.167".parse().unwrap(), + "188.208.233.47".parse().unwrap(), + "104.2.183.179".parse().unwrap(), + "72.129.46.120".parse().unwrap(), + "2001:0db8:0a0b:12f0:0000:0000:0000:0001".parse().unwrap(), + "f5c2:818f:c052:655a:9860:b136:6894:25f0".parse().unwrap(), + "b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc".parse().unwrap(), + "a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918".parse().unwrap(), + "ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445".parse().unwrap(), + "3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4".parse().unwrap(), + ]; + + for _ in 0..PACKETS_BURSTABLE { expected.push(Result { allowed: true, - wait: Duration::new(0, PACKET_COST as u32), - text: "filling tokens for single packet", - }); - - expected.push(Result { - allowed: false, wait: Duration::new(0, 0), - text: "not having refilled enough", - }); - - expected.push(Result { - allowed: true, - wait: Duration::new(0, 2 * PACKET_COST as u32), - text: "filling tokens for 2 * packet burst", - }); - - expected.push(Result { - allowed: true, - wait: Duration::new(0, 0), - text: "second packet in 2 packet burst", - }); - - expected.push(Result { - allowed: false, - wait: Duration::new(0, 0), - text: "packet following 2 packet burst", + text: "inital burst", }); + } - for item in expected { - std::thread::sleep(item.wait); - for ip in ips.iter() { - if ratelimiter.allow(&ip) != item.allowed { - panic!( - "test failed for {} on {}. expected: {}, got: {}", - ip, item.text, item.allowed, !item.allowed - ) - } + expected.push(Result { + allowed: false, + wait: Duration::new(0, 0), + text: "after burst", + }); + + expected.push(Result { + allowed: true, + wait: Duration::new(0, PACKET_COST as u32), + text: "filling tokens for single packet", + }); + + expected.push(Result { + allowed: false, + wait: Duration::new(0, 0), + text: "not having refilled enough", + }); + + expected.push(Result { + allowed: true, + wait: Duration::new(0, 2 * PACKET_COST as u32), + text: "filling tokens for 2 * packet burst", + }); + + expected.push(Result { + allowed: true, + wait: Duration::new(0, 0), + text: "second packet in 2 packet burst", + }); + + expected.push(Result { + allowed: false, + wait: Duration::new(0, 0), + text: "packet following 2 packet burst", + }); + + for item in expected { + std::thread::sleep(item.wait); + for ip in ips.iter() { + if ratelimiter.allow(&ip) != item.allowed { + panic!( + "test failed for {} on {}. expected: {}, got: {}", + ip, item.text, item.allowed, !item.allowed + ) } } - - Ok(()) - })); + } } } -- cgit v1.2.3-59-g8ed1b