summaryrefslogtreecommitdiffstats
path: root/src/handshake
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-08-10 16:01:56 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-08-10 16:01:56 +0200
commita50079552ac5148ef4d30a30948ffe4095d3d0ba (patch)
tree267470bef8498fa8b7131ff001df37d799f627f6 /src/handshake
parentConcurrent rate limiter (diff)
downloadwireguard-rs-a50079552ac5148ef4d30a30948ffe4095d3d0ba.tar.xz
wireguard-rs-a50079552ac5148ef4d30a30948ffe4095d3d0ba.zip
Kill GC thread on Ratelimiter drop
Diffstat (limited to 'src/handshake')
-rw-r--r--src/handshake/device.rs113
-rw-r--r--src/handshake/macs.rs2
-rw-r--r--src/handshake/ratelimiter.rs206
3 files changed, 209 insertions, 112 deletions
diff --git a/src/handshake/device.rs b/src/handshake/device.rs
index 3ec161f..86a832a 100644
--- a/src/handshake/device.rs
+++ b/src/handshake/device.rs
@@ -356,20 +356,18 @@ mod tests {
use super::super::messages::*;
use super::*;
use hex;
- use rand::rngs::OsRng;
use std::thread;
+ use rand::rngs::OsRng;
use std::time::Duration;
+ use std::net::SocketAddr;
- #[test]
- fn handshake() {
+ fn setup_devices<R: RngCore + CryptoRng>(rng : &mut R) -> (PublicKey, Device<usize>, PublicKey, Device<usize>) {
// generate new keypairs
- let mut rng = OsRng::new().unwrap();
-
- let sk1 = StaticSecret::new(&mut rng);
+ let sk1 = StaticSecret::new(rng);
let pk1 = PublicKey::from(&sk1);
- let sk2 = StaticSecret::new(&mut rng);
+ let sk2 = StaticSecret::new(rng);
let pk2 = PublicKey::from(&sk2);
// pick random psk
@@ -388,7 +386,103 @@ mod tests {
dev1.set_psk(pk2, Some(psk)).unwrap();
dev2.set_psk(pk1, Some(psk)).unwrap();
- // do a few handshakes
+ (pk1, dev1, pk2, dev2)
+ }
+
+ /* Test longest possible handshake interaction (7 messages):
+ *
+ * 1. I -> R (initation)
+ * 2. I <- R (cookie reply)
+ * 3. I -> R (initation)
+ * 4. I <- R (response)
+ * 5. I -> R (cookie reply)
+ * 6. I -> R (initation)
+ * 7. I <- R (response)
+ */
+ #[test]
+ fn handshake_under_load() {
+ let mut rng = OsRng::new().unwrap();
+ let (_pk1, dev1, pk2, dev2) = setup_devices(&mut rng);
+
+ let src1 : SocketAddr = "172.16.0.1:8080".parse().unwrap();
+ let src2 : SocketAddr = "172.16.0.2:7070".parse().unwrap();
+
+ // 1. device-1 : create first initation
+ let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
+
+ // 2. device-2 : responds with CookieReply
+ let msg_cookie = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
+ (None, Some(msg), None) => msg,
+ _ => panic!("unexpected response")
+ };
+
+ // device-1 : processes CookieReply (no response)
+ match dev1.process(&mut rng, &msg_cookie, Some(&src2)).unwrap() {
+ (None, None, None) => (),
+ _ => panic!("unexpected response")
+ }
+
+ // avoid initation flood
+ thread::sleep(Duration::from_millis(20));
+
+ // 3. device-1 : create second initation
+ let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
+
+ // 4. device-2 : responds with noise response
+ let msg_response = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
+ (Some(_), Some(msg), Some(kp)) => {
+ assert_eq!(kp.confirmed, false);
+ msg
+ },
+ _ => panic!("unexpected response")
+ };
+
+ // 5. device-1 : responds with CookieReply
+ let msg_cookie = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() {
+ (None, Some(msg), None) => msg,
+ _ => panic!("unexpected response")
+ };
+
+ // device-2 : processes CookieReply (no response)
+ match dev2.process(&mut rng, &msg_cookie, Some(&src1)).unwrap() {
+ (None, None, None) => (),
+ _ => panic!("unexpected response")
+ }
+
+ // avoid initation flood
+ thread::sleep(Duration::from_millis(20));
+
+ // 6. device-1 : create third initation
+ let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
+
+ // 7. device-2 : responds with noise response
+ let (msg_response, kp1) = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
+ (Some(_), Some(msg), Some(kp)) => {
+ assert_eq!(kp.confirmed, false);
+ (msg, kp)
+ },
+ _ => panic!("unexpected response")
+ };
+
+ // device-1 : process noise response
+ let kp2 = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() {
+ (Some(_), None, Some(kp)) => {
+ assert_eq!(kp.confirmed, true);
+ kp
+ },
+ _ => panic!("unexpected response")
+ };
+
+ assert_eq!(kp1.send, kp2.recv);
+ assert_eq!(kp1.recv, kp2.send);
+ }
+
+ #[test]
+ fn handshake_no_load() {
+ let mut rng = OsRng::new().unwrap();
+ let (pk1, mut dev1, pk2, mut dev2) = setup_devices(&mut rng);
+
+ // do a few handshakes (every handshake should succeed)
for i in 0..10 {
println!("handshake : {}", i);
@@ -430,9 +524,6 @@ mod tests {
thread::sleep(Duration::from_millis(20));
}
- assert_eq!(dev1.get_psk(pk2).unwrap(), psk);
- assert_eq!(dev2.get_psk(pk1).unwrap(), psk);
-
dev1.remove(pk2).unwrap();
dev2.remove(pk1).unwrap();
}
diff --git a/src/handshake/macs.rs b/src/handshake/macs.rs
index 3070da3..721fc88 100644
--- a/src/handshake/macs.rs
+++ b/src/handshake/macs.rs
@@ -309,7 +309,7 @@ mod tests {
let mut msg = CookieReply::default();
let mut rng = OsRng::new().expect("failed to create rng");
let mut macs = MacsFooter::default();
- let src = "127.0.0.1:8080".parse().unwrap();
+ let src = "192.0.2.16:8080".parse().unwrap();
let (validator, mut generator) = new_validator_generator();
// generate mac1 for first message
diff --git a/src/handshake/ratelimiter.rs b/src/handshake/ratelimiter.rs
index adf9667..02b82e7 100644
--- a/src/handshake/ratelimiter.rs
+++ b/src/handshake/ratelimiter.rs
@@ -1,13 +1,10 @@
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::atomic::{AtomicBool, Ordering};
-use std::sync::Arc;
+use std::sync::{Condvar, Mutex, Arc};
+use std::thread;
use std::time::{Duration, Instant};
-use spin::{RwLock, Mutex};
-
-use tokio::prelude::future;
-use future::{loop_fn, Future, Loop, lazy};
-use tokio::timer::Delay;
+use spin;
use lazy_static::lazy_static;
@@ -29,15 +26,27 @@ pub struct RateLimiter(Arc<RateLimiterInner>);
struct RateLimiterInner{
gc_running: AtomicBool,
- table: RwLock<HashMap<IpAddr, Mutex<Entry>>>,
+ gc_dropped: (Mutex<bool>, Condvar),
+ table: spin::RwLock<HashMap<IpAddr, spin::Mutex<Entry>>>,
+}
+
+impl Drop for RateLimiter {
+ fn drop(&mut self) {
+ // wake up & terminate any lingering GC thread
+ let &(ref lock, ref cvar) = &self.0.gc_dropped;
+ let mut dropped = lock.lock().unwrap();
+ *dropped = true;
+ cvar.notify_all();
+ }
}
impl RateLimiter {
pub fn new() -> Self {
RateLimiter (
Arc::new(RateLimiterInner {
+ gc_dropped: (Mutex::new(false), Condvar::new()),
gc_running: AtomicBool::from(false),
- table: RwLock::new(HashMap::new()),
+ table: spin::RwLock::new(HashMap::new()),
})
)
}
@@ -45,7 +54,7 @@ impl RateLimiter {
pub fn allow(&self, addr: &IpAddr) -> bool {
// check if allowed
let allowed = {
- // check for existing entry (required read lock)
+ // check for existing entry (only requires read lock)
if let Some(entry) = self.0.table.read().get(addr) {
// update existing entry
let mut entry = entry.lock();
@@ -67,7 +76,7 @@ impl RateLimiter {
// add new entry (write lock)
self.0.table.write().insert(
*addr,
- Mutex::new(Entry {
+ spin::Mutex::new(Entry {
last_time: Instant::now(),
tokens: MAX_TOKENS - PACKET_COST,
}),
@@ -75,27 +84,28 @@ impl RateLimiter {
true
};
- // check that GC is scheduled
+ // check that GC thread is scheduled
if !self.0.gc_running.swap(true, Ordering::Relaxed) {
let limiter = self.0.clone();
- tokio::spawn(
- loop_fn((), move |_| {
- let limiter = limiter.clone();
- let next_gc = Instant::now() + *GC_INTERVAL;
- Delay::new(next_gc)
- .map_err(|_| ())
- .and_then(move |_| {
- let mut tw = limiter.table.write();
- tw.retain(|_, ref mut entry| entry.lock().last_time.elapsed() <= *GC_INTERVAL);
- if tw.len() > 0 {
- Ok(Loop::Continue(()))
- } else {
- limiter.gc_running.store(false, Ordering::Relaxed);
- Ok(Loop::Break(()))
- }
- })
- })
- );
+ thread::spawn(move || {
+ let &(ref lock, ref cvar) = &limiter.gc_dropped;
+ let mut dropped = lock.lock().unwrap();
+ while !*dropped {
+ // garbage collect
+ {
+ let mut tw = limiter.table.write();
+ tw.retain(|_, ref mut entry| entry.lock().last_time.elapsed() <= *GC_INTERVAL);
+ if tw.len() == 0 {
+ limiter.gc_running.store(false, Ordering::Relaxed);
+ return;
+ }
+ }
+
+ // wait until stopped or new GC (~1 every sec)
+ let res = cvar.wait_timeout(dropped,*GC_INTERVAL).unwrap();
+ dropped = res.0;
+ }
+ });
}
allowed
@@ -116,83 +126,79 @@ mod tests {
#[test]
fn test_ratelimiter() {
- tokio::run(lazy(|| {
- let mut ratelimiter = RateLimiter::new();
- let mut expected = vec![];
- let ips = vec![
- "127.0.0.1".parse().unwrap(),
- "192.168.1.1".parse().unwrap(),
- "172.167.2.3".parse().unwrap(),
- "97.231.252.215".parse().unwrap(),
- "248.97.91.167".parse().unwrap(),
- "188.208.233.47".parse().unwrap(),
- "104.2.183.179".parse().unwrap(),
- "72.129.46.120".parse().unwrap(),
- "2001:0db8:0a0b:12f0:0000:0000:0000:0001".parse().unwrap(),
- "f5c2:818f:c052:655a:9860:b136:6894:25f0".parse().unwrap(),
- "b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc".parse().unwrap(),
- "a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918".parse().unwrap(),
- "ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445".parse().unwrap(),
- "3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4".parse().unwrap(),
- ];
-
- for _ in 0..PACKETS_BURSTABLE {
- expected.push(Result {
- allowed: true,
- wait: Duration::new(0, 0),
- text: "inital burst",
- });
- }
-
- expected.push(Result {
- allowed: false,
- wait: Duration::new(0, 0),
- text: "after burst",
- });
-
+ let ratelimiter = RateLimiter::new();
+ let mut expected = vec![];
+ let ips = vec![
+ "127.0.0.1".parse().unwrap(),
+ "192.168.1.1".parse().unwrap(),
+ "172.167.2.3".parse().unwrap(),
+ "97.231.252.215".parse().unwrap(),
+ "248.97.91.167".parse().unwrap(),
+ "188.208.233.47".parse().unwrap(),
+ "104.2.183.179".parse().unwrap(),
+ "72.129.46.120".parse().unwrap(),
+ "2001:0db8:0a0b:12f0:0000:0000:0000:0001".parse().unwrap(),
+ "f5c2:818f:c052:655a:9860:b136:6894:25f0".parse().unwrap(),
+ "b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc".parse().unwrap(),
+ "a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918".parse().unwrap(),
+ "ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445".parse().unwrap(),
+ "3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4".parse().unwrap(),
+ ];
+
+ for _ in 0..PACKETS_BURSTABLE {
expected.push(Result {
allowed: true,
- wait: Duration::new(0, PACKET_COST as u32),
- text: "filling tokens for single packet",
- });
-
- expected.push(Result {
- allowed: false,
wait: Duration::new(0, 0),
- text: "not having refilled enough",
- });
-
- expected.push(Result {
- allowed: true,
- wait: Duration::new(0, 2 * PACKET_COST as u32),
- text: "filling tokens for 2 * packet burst",
- });
-
- expected.push(Result {
- allowed: true,
- wait: Duration::new(0, 0),
- text: "second packet in 2 packet burst",
- });
-
- expected.push(Result {
- allowed: false,
- wait: Duration::new(0, 0),
- text: "packet following 2 packet burst",
+ text: "inital burst",
});
+ }
- for item in expected {
- std::thread::sleep(item.wait);
- for ip in ips.iter() {
- if ratelimiter.allow(&ip) != item.allowed {
- panic!(
- "test failed for {} on {}. expected: {}, got: {}",
- ip, item.text, item.allowed, !item.allowed
- )
- }
+ expected.push(Result {
+ allowed: false,
+ wait: Duration::new(0, 0),
+ text: "after burst",
+ });
+
+ expected.push(Result {
+ allowed: true,
+ wait: Duration::new(0, PACKET_COST as u32),
+ text: "filling tokens for single packet",
+ });
+
+ expected.push(Result {
+ allowed: false,
+ wait: Duration::new(0, 0),
+ text: "not having refilled enough",
+ });
+
+ expected.push(Result {
+ allowed: true,
+ wait: Duration::new(0, 2 * PACKET_COST as u32),
+ text: "filling tokens for 2 * packet burst",
+ });
+
+ expected.push(Result {
+ allowed: true,
+ wait: Duration::new(0, 0),
+ text: "second packet in 2 packet burst",
+ });
+
+ expected.push(Result {
+ allowed: false,
+ wait: Duration::new(0, 0),
+ text: "packet following 2 packet burst",
+ });
+
+ for item in expected {
+ std::thread::sleep(item.wait);
+ for ip in ips.iter() {
+ if ratelimiter.allow(&ip) != item.allowed {
+ panic!(
+ "test failed for {} on {}. expected: {}, got: {}",
+ ip, item.text, item.allowed, !item.allowed
+ )
}
}
-
- Ok(())
- }));
+ }
}
}