aboutsummaryrefslogtreecommitdiffstats
path: root/src/handshake/device.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/handshake/device.rs')
-rw-r--r--src/handshake/device.rs20
1 files changed, 18 insertions, 2 deletions
diff --git a/src/handshake/device.rs b/src/handshake/device.rs
index 809d7a3..4926348 100644
--- a/src/handshake/device.rs
+++ b/src/handshake/device.rs
@@ -1,6 +1,7 @@
use spin::RwLock;
use std::collections::HashMap;
use std::net::SocketAddr;
+use std::sync::Mutex;
use zerocopy::AsBytes;
use rand::prelude::*;
@@ -13,6 +14,7 @@ use super::messages::{CookieReply, Initiation, Response};
use super::messages::{TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE};
use super::noise;
use super::peer::Peer;
+use super::ratelimiter::RateLimiter;
use super::types::*;
pub struct Device<T> {
@@ -21,6 +23,7 @@ pub struct Device<T> {
macs: macs::Validator, // validator for the mac fields
pk_map: HashMap<[u8; 32], Peer<T>>, // public key -> peer state
id_map: RwLock<HashMap<u32, [u8; 32]>>, // receiver ids -> public key
+ limiter: Mutex<RateLimiter>,
}
/* A mutable reference to the device needs to be held during configuration.
@@ -43,6 +46,7 @@ where
macs: macs::Validator::new(pk),
pk_map: HashMap::new(),
id_map: RwLock::new(HashMap::new()),
+ limiter: Mutex::new(RateLimiter::new()),
}
}
@@ -203,8 +207,9 @@ where
// check mac1 field
self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
- // check mac2 field
+ // address validation & DoS mitigation
if let Some(src) = src {
+ // check mac2 field
if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
let mut reply = Default::default();
self.macs.create_cookie_reply(
@@ -216,6 +221,11 @@ where
);
return Ok((None, Some(reply.as_bytes().to_owned()), None));
}
+
+ // check ratelimiter
+ if !self.limiter.lock().unwrap().allow(&src.ip()) {
+ return Err(HandshakeError::RateLimited);
+ }
}
// consume the initiation
@@ -253,8 +263,9 @@ where
// check mac1 field
self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
- // check mac2 field
+ // address validation & DoS mitigation
if let Some(src) = src {
+ // check mac2 field
if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
let mut reply = Default::default();
self.macs.create_cookie_reply(
@@ -266,6 +277,11 @@ where
);
return Ok((None, Some(reply.as_bytes().to_owned()), None));
}
+
+ // check ratelimiter
+ if !self.limiter.lock().unwrap().allow(&src.ip()) {
+ return Err(HandshakeError::RateLimited);
+ }
}
// consume inner playload