From a12e6e139c963cfe9c78edc9ae83a71049722f29 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Tue, 6 Aug 2019 13:02:13 +0200 Subject: Add rate limiter check to handshake messages. --- Cargo.lock | 1 + Cargo.toml | 1 + src/handshake/device.rs | 20 +++++- src/handshake/macs.rs | 15 ++-- src/handshake/mod.rs | 1 + src/handshake/ratelimiter.rs | 162 +++++++++++++++++++++++++++++++++++++++++++ src/handshake/types.rs | 2 + 7 files changed, 194 insertions(+), 8 deletions(-) create mode 100644 src/handshake/ratelimiter.rs diff --git a/Cargo.lock b/Cargo.lock index fc8cb59..384a19d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -586,6 +586,7 @@ dependencies = [ "generic-array 0.12.3 (registry+https://github.com/rust-lang/crates.io-index)", "hex 0.3.2 (registry+https://github.com/rust-lang/crates.io-index)", "hmac 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)", + "lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "proptest 0.9.4 (registry+https://github.com/rust-lang/crates.io-index)", "rand 0.6.5 (registry+https://github.com/rust-lang/crates.io-index)", "sodiumoxide 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/Cargo.toml b/Cargo.toml index 2f9d38d..2909dcc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ zerocopy = "0.2.7" byteorder = "1.3.1" digest = "0.8.0" sodiumoxide = "0.2.2" +lazy_static = "^1.3" [dependencies.x25519-dalek] version = "^0.5" diff --git a/src/handshake/device.rs b/src/handshake/device.rs index 809d7a3..4926348 100644 --- a/src/handshake/device.rs +++ b/src/handshake/device.rs @@ -1,6 +1,7 @@ use spin::RwLock; use std::collections::HashMap; use std::net::SocketAddr; +use std::sync::Mutex; use zerocopy::AsBytes; use rand::prelude::*; @@ -13,6 +14,7 @@ use super::messages::{CookieReply, Initiation, Response}; use super::messages::{TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE}; use super::noise; use super::peer::Peer; +use super::ratelimiter::RateLimiter; use super::types::*; pub struct Device { @@ -21,6 +23,7 @@ pub struct Device { macs: macs::Validator, // validator for the mac fields pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state id_map: RwLock>, // receiver ids -> public key + limiter: Mutex, } /* A mutable reference to the device needs to be held during configuration. @@ -43,6 +46,7 @@ where macs: macs::Validator::new(pk), pk_map: HashMap::new(), id_map: RwLock::new(HashMap::new()), + limiter: Mutex::new(RateLimiter::new()), } } @@ -203,8 +207,9 @@ where // check mac1 field self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?; - // check mac2 field + // address validation & DoS mitigation if let Some(src) = src { + // check mac2 field if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) { let mut reply = Default::default(); self.macs.create_cookie_reply( @@ -216,6 +221,11 @@ where ); return Ok((None, Some(reply.as_bytes().to_owned()), None)); } + + // check ratelimiter + if !self.limiter.lock().unwrap().allow(&src.ip()) { + return Err(HandshakeError::RateLimited); + } } // consume the initiation @@ -253,8 +263,9 @@ where // check mac1 field self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?; - // check mac2 field + // address validation & DoS mitigation if let Some(src) = src { + // check mac2 field if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) { let mut reply = Default::default(); self.macs.create_cookie_reply( @@ -266,6 +277,11 @@ where ); return Ok((None, Some(reply.as_bytes().to_owned()), None)); } + + // check ratelimiter + if !self.limiter.lock().unwrap().allow(&src.ip()) { + return Err(HandshakeError::RateLimited); + } } // consume inner playload diff --git a/src/handshake/macs.rs b/src/handshake/macs.rs index d5dd95d..3070da3 100644 --- a/src/handshake/macs.rs +++ b/src/handshake/macs.rs @@ -1,3 +1,4 @@ +use lazy_static::lazy_static; use rand::{CryptoRng, RngCore}; use spin::RwLock; use std::time::{Duration, Instant}; @@ -19,7 +20,9 @@ const SIZE_COOKIE: usize = 16; const SIZE_SECRET: usize = 32; const SIZE_MAC: usize = 16; // blake2s-mac128 -const SECS_COOKIE_UPDATE: u64 = 120; +lazy_static! { + pub static ref COOKIE_UPDATE_INTERVAL: Duration = Duration::new(120, 0); +} macro_rules! HASH { ( $($input:expr),* ) => {{ @@ -172,7 +175,7 @@ impl Generator { macs.f_mac1 = MAC!(&self.mac1_key, inner); macs.f_mac2 = match &self.cookie { Some(cookie) => { - if cookie.birth.elapsed() > Duration::from_secs(SECS_COOKIE_UPDATE) { + if cookie.birth.elapsed() > *COOKIE_UPDATE_INTERVAL { self.cookie = None; [0u8; SIZE_MAC] } else { @@ -203,14 +206,14 @@ impl Validator { cookie_key: HASH!(LABEL_COOKIE, pk.as_bytes()).into(), secret: RwLock::new(Secret { value: [0u8; SIZE_SECRET], - birth: Instant::now() - Duration::from_secs(2 * SECS_COOKIE_UPDATE), + birth: Instant::now() - Duration::new(86400, 0), }), } } fn get_tau(&self, src: &[u8]) -> Option<[u8; SIZE_COOKIE]> { let secret = self.secret.read(); - if secret.birth.elapsed() < Duration::from_secs(SECS_COOKIE_UPDATE) { + if secret.birth.elapsed() < *COOKIE_UPDATE_INTERVAL { Some(MAC!(&secret.value, src)) } else { None @@ -221,7 +224,7 @@ impl Validator { // check if current value is still valid { let secret = self.secret.read(); - if secret.birth.elapsed() < Duration::from_secs(SECS_COOKIE_UPDATE) { + if secret.birth.elapsed() < *COOKIE_UPDATE_INTERVAL { return MAC!(&secret.value, src); }; } @@ -229,7 +232,7 @@ impl Validator { // take write lock, check again { let mut secret = self.secret.write(); - if secret.birth.elapsed() < Duration::from_secs(SECS_COOKIE_UPDATE) { + if secret.birth.elapsed() < *COOKIE_UPDATE_INTERVAL { return MAC!(&secret.value, src); }; diff --git a/src/handshake/mod.rs b/src/handshake/mod.rs index 4314925..8095147 100644 --- a/src/handshake/mod.rs +++ b/src/handshake/mod.rs @@ -11,6 +11,7 @@ mod macs; mod messages; mod noise; mod peer; +mod ratelimiter; mod timestamp; mod types; diff --git a/src/handshake/ratelimiter.rs b/src/handshake/ratelimiter.rs new file mode 100644 index 0000000..ce09c16 --- /dev/null +++ b/src/handshake/ratelimiter.rs @@ -0,0 +1,162 @@ +use std::collections::HashMap; +use std::net::IpAddr; +use std::time::{Duration, Instant}; + +use lazy_static::lazy_static; + +const PACKETS_PER_SECOND: u64 = 20; +const PACKETS_BURSTABLE: u64 = 5; +const PACKET_COST: u64 = 1_000_000_000 / PACKETS_PER_SECOND; +const MAX_TOKENS: u64 = PACKET_COST * PACKETS_BURSTABLE; + +lazy_static! { + pub static ref GC_INTERVAL: Duration = Duration::new(1, 0); +} + +struct Entry { + pub last_time: Instant, + pub tokens: u64, +} + +pub struct RateLimiter { + garbage_collect: Instant, + table: HashMap, +} + +impl RateLimiter { + pub fn new() -> Self { + RateLimiter { + garbage_collect: Instant::now(), + table: HashMap::new(), + } + } + + pub fn allow(&mut self, addr: &IpAddr) -> bool { + // check for garbage collection + if self.garbage_collect.elapsed() > *GC_INTERVAL { + self.handle_gc(); + } + + // update existing entry + if let Some(entry) = self.table.get_mut(addr) { + // add tokens earned since last time + entry.tokens = + MAX_TOKENS.min(entry.tokens + u64::from(entry.last_time.elapsed().subsec_nanos())); + entry.last_time = Instant::now(); + + // subtract cost of packet + if entry.tokens > PACKET_COST { + entry.tokens -= PACKET_COST; + return true; + } else { + return false; + } + } + + // add new entry + self.table.insert( + *addr, + Entry { + last_time: Instant::now(), + tokens: MAX_TOKENS - PACKET_COST, + }, + ); + + true + } + + fn handle_gc(&mut self) { + self.table + .retain(|_, ref mut entry| entry.last_time.elapsed() <= *GC_INTERVAL); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std; + + struct Result { + allowed: bool, + text: &'static str, + wait: Duration, + } + + #[test] + fn test_ratelimiter() { + let mut ratelimiter = RateLimiter::new(); + let mut expected = vec![]; + let ips = vec![ + "127.0.0.1".parse().unwrap(), + "192.168.1.1".parse().unwrap(), + "172.167.2.3".parse().unwrap(), + "97.231.252.215".parse().unwrap(), + "248.97.91.167".parse().unwrap(), + "188.208.233.47".parse().unwrap(), + "104.2.183.179".parse().unwrap(), + "72.129.46.120".parse().unwrap(), + "2001:0db8:0a0b:12f0:0000:0000:0000:0001".parse().unwrap(), + "f5c2:818f:c052:655a:9860:b136:6894:25f0".parse().unwrap(), + "b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc".parse().unwrap(), + "a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918".parse().unwrap(), + "ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445".parse().unwrap(), + "3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4".parse().unwrap(), + ]; + + for _ in 0..PACKETS_BURSTABLE { + expected.push(Result { + allowed: true, + wait: Duration::new(0, 0), + text: "inital burst", + }); + } + + expected.push(Result { + allowed: false, + wait: Duration::new(0, 0), + text: "after burst", + }); + + expected.push(Result { + allowed: true, + wait: Duration::new(0, PACKET_COST as u32), + text: "filling tokens for single packet", + }); + + expected.push(Result { + allowed: false, + wait: Duration::new(0, 0), + text: "not having refilled enough", + }); + + expected.push(Result { + allowed: true, + wait: Duration::new(0, 2 * PACKET_COST as u32), + text: "filling tokens for 2 * packet burst", + }); + + expected.push(Result { + allowed: true, + wait: Duration::new(0, 0), + text: "second packet in 2 packet burst", + }); + + expected.push(Result { + allowed: false, + wait: Duration::new(0, 0), + text: "packet following 2 packet burst", + }); + + for item in expected { + std::thread::sleep(item.wait); + for ip in ips.iter() { + if ratelimiter.allow(&ip) != item.allowed { + panic!( + "test failed for {} on {}. expected: {}, got: {}", + ip, item.text, item.allowed, !item.allowed + ) + } + } + } + } +} diff --git a/src/handshake/types.rs b/src/handshake/types.rs index 08c43d0..967704e 100644 --- a/src/handshake/types.rs +++ b/src/handshake/types.rs @@ -43,6 +43,7 @@ pub enum HandshakeError { OldTimestamp, InvalidState, InvalidMac1, + RateLimited } impl fmt::Display for HandshakeError { @@ -57,6 +58,7 @@ impl fmt::Display for HandshakeError { HandshakeError::OldTimestamp => write!(f, "Timestamp is less/equal to the newest"), HandshakeError::InvalidState => write!(f, "Message does not apply to handshake state"), HandshakeError::InvalidMac1 => write!(f, "Message has invalid mac1 field"), + HandshakeError::RateLimited => write!(f, "Message was dropped by rate limiter") } } } -- cgit v1.2.3-59-g8ed1b