summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Cargo.lock1
-rw-r--r--Cargo.toml1
-rw-r--r--src/handshake/device.rs20
-rw-r--r--src/handshake/macs.rs15
-rw-r--r--src/handshake/mod.rs1
-rw-r--r--src/handshake/ratelimiter.rs162
-rw-r--r--src/handshake/types.rs2
7 files changed, 194 insertions, 8 deletions
diff --git a/Cargo.lock b/Cargo.lock
index fc8cb59..384a19d 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -586,6 +586,7 @@ dependencies = [
"generic-array 0.12.3 (registry+https://github.com/rust-lang/crates.io-index)",
"hex 0.3.2 (registry+https://github.com/rust-lang/crates.io-index)",
"hmac 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
+ "lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"proptest 0.9.4 (registry+https://github.com/rust-lang/crates.io-index)",
"rand 0.6.5 (registry+https://github.com/rust-lang/crates.io-index)",
"sodiumoxide 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)",
diff --git a/Cargo.toml b/Cargo.toml
index 2f9d38d..2909dcc 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -16,6 +16,7 @@ zerocopy = "0.2.7"
byteorder = "1.3.1"
digest = "0.8.0"
sodiumoxide = "0.2.2"
+lazy_static = "^1.3"
[dependencies.x25519-dalek]
version = "^0.5"
diff --git a/src/handshake/device.rs b/src/handshake/device.rs
index 809d7a3..4926348 100644
--- a/src/handshake/device.rs
+++ b/src/handshake/device.rs
@@ -1,6 +1,7 @@
use spin::RwLock;
use std::collections::HashMap;
use std::net::SocketAddr;
+use std::sync::Mutex;
use zerocopy::AsBytes;
use rand::prelude::*;
@@ -13,6 +14,7 @@ use super::messages::{CookieReply, Initiation, Response};
use super::messages::{TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE};
use super::noise;
use super::peer::Peer;
+use super::ratelimiter::RateLimiter;
use super::types::*;
pub struct Device<T> {
@@ -21,6 +23,7 @@ pub struct Device<T> {
macs: macs::Validator, // validator for the mac fields
pk_map: HashMap<[u8; 32], Peer<T>>, // public key -> peer state
id_map: RwLock<HashMap<u32, [u8; 32]>>, // receiver ids -> public key
+ limiter: Mutex<RateLimiter>,
}
/* A mutable reference to the device needs to be held during configuration.
@@ -43,6 +46,7 @@ where
macs: macs::Validator::new(pk),
pk_map: HashMap::new(),
id_map: RwLock::new(HashMap::new()),
+ limiter: Mutex::new(RateLimiter::new()),
}
}
@@ -203,8 +207,9 @@ where
// check mac1 field
self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
- // check mac2 field
+ // address validation & DoS mitigation
if let Some(src) = src {
+ // check mac2 field
if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
let mut reply = Default::default();
self.macs.create_cookie_reply(
@@ -216,6 +221,11 @@ where
);
return Ok((None, Some(reply.as_bytes().to_owned()), None));
}
+
+ // check ratelimiter
+ if !self.limiter.lock().unwrap().allow(&src.ip()) {
+ return Err(HandshakeError::RateLimited);
+ }
}
// consume the initiation
@@ -253,8 +263,9 @@ where
// check mac1 field
self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
- // check mac2 field
+ // address validation & DoS mitigation
if let Some(src) = src {
+ // check mac2 field
if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
let mut reply = Default::default();
self.macs.create_cookie_reply(
@@ -266,6 +277,11 @@ where
);
return Ok((None, Some(reply.as_bytes().to_owned()), None));
}
+
+ // check ratelimiter
+ if !self.limiter.lock().unwrap().allow(&src.ip()) {
+ return Err(HandshakeError::RateLimited);
+ }
}
// consume inner playload
diff --git a/src/handshake/macs.rs b/src/handshake/macs.rs
index d5dd95d..3070da3 100644
--- a/src/handshake/macs.rs
+++ b/src/handshake/macs.rs
@@ -1,3 +1,4 @@
+use lazy_static::lazy_static;
use rand::{CryptoRng, RngCore};
use spin::RwLock;
use std::time::{Duration, Instant};
@@ -19,7 +20,9 @@ const SIZE_COOKIE: usize = 16;
const SIZE_SECRET: usize = 32;
const SIZE_MAC: usize = 16; // blake2s-mac128
-const SECS_COOKIE_UPDATE: u64 = 120;
+lazy_static! {
+ pub static ref COOKIE_UPDATE_INTERVAL: Duration = Duration::new(120, 0);
+}
macro_rules! HASH {
( $($input:expr),* ) => {{
@@ -172,7 +175,7 @@ impl Generator {
macs.f_mac1 = MAC!(&self.mac1_key, inner);
macs.f_mac2 = match &self.cookie {
Some(cookie) => {
- if cookie.birth.elapsed() > Duration::from_secs(SECS_COOKIE_UPDATE) {
+ if cookie.birth.elapsed() > *COOKIE_UPDATE_INTERVAL {
self.cookie = None;
[0u8; SIZE_MAC]
} else {
@@ -203,14 +206,14 @@ impl Validator {
cookie_key: HASH!(LABEL_COOKIE, pk.as_bytes()).into(),
secret: RwLock::new(Secret {
value: [0u8; SIZE_SECRET],
- birth: Instant::now() - Duration::from_secs(2 * SECS_COOKIE_UPDATE),
+ birth: Instant::now() - Duration::new(86400, 0),
}),
}
}
fn get_tau(&self, src: &[u8]) -> Option<[u8; SIZE_COOKIE]> {
let secret = self.secret.read();
- if secret.birth.elapsed() < Duration::from_secs(SECS_COOKIE_UPDATE) {
+ if secret.birth.elapsed() < *COOKIE_UPDATE_INTERVAL {
Some(MAC!(&secret.value, src))
} else {
None
@@ -221,7 +224,7 @@ impl Validator {
// check if current value is still valid
{
let secret = self.secret.read();
- if secret.birth.elapsed() < Duration::from_secs(SECS_COOKIE_UPDATE) {
+ if secret.birth.elapsed() < *COOKIE_UPDATE_INTERVAL {
return MAC!(&secret.value, src);
};
}
@@ -229,7 +232,7 @@ impl Validator {
// take write lock, check again
{
let mut secret = self.secret.write();
- if secret.birth.elapsed() < Duration::from_secs(SECS_COOKIE_UPDATE) {
+ if secret.birth.elapsed() < *COOKIE_UPDATE_INTERVAL {
return MAC!(&secret.value, src);
};
diff --git a/src/handshake/mod.rs b/src/handshake/mod.rs
index 4314925..8095147 100644
--- a/src/handshake/mod.rs
+++ b/src/handshake/mod.rs
@@ -11,6 +11,7 @@ mod macs;
mod messages;
mod noise;
mod peer;
+mod ratelimiter;
mod timestamp;
mod types;
diff --git a/src/handshake/ratelimiter.rs b/src/handshake/ratelimiter.rs
new file mode 100644
index 0000000..ce09c16
--- /dev/null
+++ b/src/handshake/ratelimiter.rs
@@ -0,0 +1,162 @@
+use std::collections::HashMap;
+use std::net::IpAddr;
+use std::time::{Duration, Instant};
+
+use lazy_static::lazy_static;
+
+const PACKETS_PER_SECOND: u64 = 20;
+const PACKETS_BURSTABLE: u64 = 5;
+const PACKET_COST: u64 = 1_000_000_000 / PACKETS_PER_SECOND;
+const MAX_TOKENS: u64 = PACKET_COST * PACKETS_BURSTABLE;
+
+lazy_static! {
+ pub static ref GC_INTERVAL: Duration = Duration::new(1, 0);
+}
+
+struct Entry {
+ pub last_time: Instant,
+ pub tokens: u64,
+}
+
+pub struct RateLimiter {
+ garbage_collect: Instant,
+ table: HashMap<IpAddr, Entry>,
+}
+
+impl RateLimiter {
+ pub fn new() -> Self {
+ RateLimiter {
+ garbage_collect: Instant::now(),
+ table: HashMap::new(),
+ }
+ }
+
+ pub fn allow(&mut self, addr: &IpAddr) -> bool {
+ // check for garbage collection
+ if self.garbage_collect.elapsed() > *GC_INTERVAL {
+ self.handle_gc();
+ }
+
+ // update existing entry
+ if let Some(entry) = self.table.get_mut(addr) {
+ // add tokens earned since last time
+ entry.tokens =
+ MAX_TOKENS.min(entry.tokens + u64::from(entry.last_time.elapsed().subsec_nanos()));
+ entry.last_time = Instant::now();
+
+ // subtract cost of packet
+ if entry.tokens > PACKET_COST {
+ entry.tokens -= PACKET_COST;
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ // add new entry
+ self.table.insert(
+ *addr,
+ Entry {
+ last_time: Instant::now(),
+ tokens: MAX_TOKENS - PACKET_COST,
+ },
+ );
+
+ true
+ }
+
+ fn handle_gc(&mut self) {
+ self.table
+ .retain(|_, ref mut entry| entry.last_time.elapsed() <= *GC_INTERVAL);
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std;
+
+ struct Result {
+ allowed: bool,
+ text: &'static str,
+ wait: Duration,
+ }
+
+ #[test]
+ fn test_ratelimiter() {
+ let mut ratelimiter = RateLimiter::new();
+ let mut expected = vec![];
+ let ips = vec![
+ "127.0.0.1".parse().unwrap(),
+ "192.168.1.1".parse().unwrap(),
+ "172.167.2.3".parse().unwrap(),
+ "97.231.252.215".parse().unwrap(),
+ "248.97.91.167".parse().unwrap(),
+ "188.208.233.47".parse().unwrap(),
+ "104.2.183.179".parse().unwrap(),
+ "72.129.46.120".parse().unwrap(),
+ "2001:0db8:0a0b:12f0:0000:0000:0000:0001".parse().unwrap(),
+ "f5c2:818f:c052:655a:9860:b136:6894:25f0".parse().unwrap(),
+ "b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc".parse().unwrap(),
+ "a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918".parse().unwrap(),
+ "ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445".parse().unwrap(),
+ "3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4".parse().unwrap(),
+ ];
+
+ for _ in 0..PACKETS_BURSTABLE {
+ expected.push(Result {
+ allowed: true,
+ wait: Duration::new(0, 0),
+ text: "inital burst",
+ });
+ }
+
+ expected.push(Result {
+ allowed: false,
+ wait: Duration::new(0, 0),
+ text: "after burst",
+ });
+
+ expected.push(Result {
+ allowed: true,
+ wait: Duration::new(0, PACKET_COST as u32),
+ text: "filling tokens for single packet",
+ });
+
+ expected.push(Result {
+ allowed: false,
+ wait: Duration::new(0, 0),
+ text: "not having refilled enough",
+ });
+
+ expected.push(Result {
+ allowed: true,
+ wait: Duration::new(0, 2 * PACKET_COST as u32),
+ text: "filling tokens for 2 * packet burst",
+ });
+
+ expected.push(Result {
+ allowed: true,
+ wait: Duration::new(0, 0),
+ text: "second packet in 2 packet burst",
+ });
+
+ expected.push(Result {
+ allowed: false,
+ wait: Duration::new(0, 0),
+ text: "packet following 2 packet burst",
+ });
+
+ for item in expected {
+ std::thread::sleep(item.wait);
+ for ip in ips.iter() {
+ if ratelimiter.allow(&ip) != item.allowed {
+ panic!(
+ "test failed for {} on {}. expected: {}, got: {}",
+ ip, item.text, item.allowed, !item.allowed
+ )
+ }
+ }
+ }
+ }
+}
diff --git a/src/handshake/types.rs b/src/handshake/types.rs
index 08c43d0..967704e 100644
--- a/src/handshake/types.rs
+++ b/src/handshake/types.rs
@@ -43,6 +43,7 @@ pub enum HandshakeError {
OldTimestamp,
InvalidState,
InvalidMac1,
+ RateLimited
}
impl fmt::Display for HandshakeError {
@@ -57,6 +58,7 @@ impl fmt::Display for HandshakeError {
HandshakeError::OldTimestamp => write!(f, "Timestamp is less/equal to the newest"),
HandshakeError::InvalidState => write!(f, "Message does not apply to handshake state"),
HandshakeError::InvalidMac1 => write!(f, "Message has invalid mac1 field"),
+ HandshakeError::RateLimited => write!(f, "Message was dropped by rate limiter")
}
}
}