diff options
Diffstat (limited to 'src/handshake/device.rs')
-rw-r--r-- | src/handshake/device.rs | 20 |
1 files changed, 18 insertions, 2 deletions
diff --git a/src/handshake/device.rs b/src/handshake/device.rs index 809d7a3..4926348 100644 --- a/src/handshake/device.rs +++ b/src/handshake/device.rs @@ -1,6 +1,7 @@ use spin::RwLock; use std::collections::HashMap; use std::net::SocketAddr; +use std::sync::Mutex; use zerocopy::AsBytes; use rand::prelude::*; @@ -13,6 +14,7 @@ use super::messages::{CookieReply, Initiation, Response}; use super::messages::{TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE}; use super::noise; use super::peer::Peer; +use super::ratelimiter::RateLimiter; use super::types::*; pub struct Device<T> { @@ -21,6 +23,7 @@ pub struct Device<T> { macs: macs::Validator, // validator for the mac fields pk_map: HashMap<[u8; 32], Peer<T>>, // public key -> peer state id_map: RwLock<HashMap<u32, [u8; 32]>>, // receiver ids -> public key + limiter: Mutex<RateLimiter>, } /* A mutable reference to the device needs to be held during configuration. @@ -43,6 +46,7 @@ where macs: macs::Validator::new(pk), pk_map: HashMap::new(), id_map: RwLock::new(HashMap::new()), + limiter: Mutex::new(RateLimiter::new()), } } @@ -203,8 +207,9 @@ where // check mac1 field self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?; - // check mac2 field + // address validation & DoS mitigation if let Some(src) = src { + // check mac2 field if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) { let mut reply = Default::default(); self.macs.create_cookie_reply( @@ -216,6 +221,11 @@ where ); return Ok((None, Some(reply.as_bytes().to_owned()), None)); } + + // check ratelimiter + if !self.limiter.lock().unwrap().allow(&src.ip()) { + return Err(HandshakeError::RateLimited); + } } // consume the initiation @@ -253,8 +263,9 @@ where // check mac1 field self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?; - // check mac2 field + // address validation & DoS mitigation if let Some(src) = src { + // check mac2 field if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) { let mut reply = Default::default(); self.macs.create_cookie_reply( @@ -266,6 +277,11 @@ where ); return Ok((None, Some(reply.as_bytes().to_owned()), None)); } + + // check ratelimiter + if !self.limiter.lock().unwrap().allow(&src.ip()) { + return Err(HandshakeError::RateLimited); + } } // consume inner playload |