aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-08-07 11:29:39 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-08-07 11:29:39 +0200
commitf7f10881235e3b9b0c272772a0f4c646f987a8d3 (patch)
tree1fb3c1641dce2315d9f18c2cb8af67ce062c3623 /src
parentAdd rate limiter check to handshake messages. (diff)
downloadwireguard-rs-f7f10881235e3b9b0c272772a0f4c646f987a8d3.tar.xz
wireguard-rs-f7f10881235e3b9b0c272772a0f4c646f987a8d3.zip
Added initiation flood protection
Diffstat (limited to 'src')
-rw-r--r--src/handshake/device.rs15
-rw-r--r--src/handshake/noise.rs2
-rw-r--r--src/handshake/peer.rs54
-rw-r--r--src/handshake/types.rs8
4 files changed, 53 insertions, 26 deletions
diff --git a/src/handshake/device.rs b/src/handshake/device.rs
index 4926348..3ec161f 100644
--- a/src/handshake/device.rs
+++ b/src/handshake/device.rs
@@ -17,6 +17,8 @@ use super::peer::Peer;
use super::ratelimiter::RateLimiter;
use super::types::*;
+const MAX_PEER_PER_DEVICE: usize = 1 << 20;
+
pub struct Device<T> {
pub sk: StaticSecret, // static secret key
pub pk: PublicKey, // static public key
@@ -59,21 +61,23 @@ where
/// * `identifier` - Associated identifier which can be used to distinguish the peers
pub fn add(&mut self, pk: PublicKey, identifier: T) -> Result<(), ConfigError> {
// check that the pk is not added twice
-
if let Some(_) = self.pk_map.get(pk.as_bytes()) {
return Err(ConfigError::new("Duplicate public key"));
};
// check that the pk is not that of the device
-
if *self.pk.as_bytes() == *pk.as_bytes() {
return Err(ConfigError::new(
"Public key corresponds to secret key of interface",
));
}
- // map : pk -> new index
+ // ensure less than 2^20 peers
+ if self.pk_map.len() > MAX_PEER_PER_DEVICE {
+ return Err(ConfigError::new("Too many peers for device"));
+ }
+ // map the public key to the peer state
self.pk_map.insert(
*pk.as_bytes(),
Peer::new(identifier, pk, self.sk.diffie_hellman(&pk)),
@@ -353,6 +357,8 @@ mod tests {
use super::*;
use hex;
use rand::rngs::OsRng;
+ use std::thread;
+ use std::time::Duration;
#[test]
fn handshake() {
@@ -419,6 +425,9 @@ mod tests {
dev1.release(ks_i.send.id);
dev2.release(ks_r.send.id);
+
+ // to avoid flood detection
+ thread::sleep(Duration::from_millis(20));
}
assert_eq!(dev1.get_psk(pk2).unwrap(), psk);
diff --git a/src/handshake/noise.rs b/src/handshake/noise.rs
index 6532f4d..4eea627 100644
--- a/src/handshake/noise.rs
+++ b/src/handshake/noise.rs
@@ -306,7 +306,7 @@ pub fn consume_initiation<'a, T: Copy>(
// check and update timestamp
- peer.check_timestamp(device, &ts)?;
+ peer.check_replay_flood(device, &ts)?;
// H := Hash(H || msg.timestamp)
diff --git a/src/handshake/peer.rs b/src/handshake/peer.rs
index 9645799..9629a7f 100644
--- a/src/handshake/peer.rs
+++ b/src/handshake/peer.rs
@@ -1,4 +1,6 @@
+use lazy_static::lazy_static;
use spin::Mutex;
+use std::time::{Duration, Instant};
use generic_array::typenum::U32;
use generic_array::GenericArray;
@@ -8,15 +10,18 @@ use x25519_dalek::SharedSecret;
use x25519_dalek::StaticSecret;
use super::device::Device;
+use super::macs;
use super::timestamp;
use super::types::*;
-use super::macs;
+
+lazy_static! {
+ pub static ref TIME_BETWEEN_INITIATIONS: Duration = Duration::from_millis(20);
+}
/* Represents the recomputation and state of a peer.
*
* This type is only for internal use and not exposed.
*/
-
pub struct Peer<T> {
// external identifier
pub(crate) identifier: T,
@@ -24,6 +29,7 @@ pub struct Peer<T> {
// mutable state
state: Mutex<State>,
timestamp: Mutex<Option<timestamp::TAI64N>>,
+ last_initiation_consumption: Mutex<Option<Instant>>,
// state related to DoS mitigation fields
pub(crate) macs: Mutex<macs::Generator>,
@@ -77,6 +83,7 @@ where
identifier: identifier,
state: Mutex::new(State::Reset),
timestamp: Mutex::new(None),
+ last_initiation_consumption: Mutex::new(None),
pk: pk,
ss: ss,
psk: [0u8; 32],
@@ -104,38 +111,45 @@ where
///
/// * st_new - The updated state of the peer
/// * ts_new - The associated timestamp
- pub fn check_timestamp(
+ pub fn check_replay_flood(
&self,
device: &Device<T>,
timestamp_new: &timestamp::TAI64N,
) -> Result<(), HandshakeError> {
let mut state = self.state.lock();
let mut timestamp = self.timestamp.lock();
+ let mut last_initiation_consumption = self.last_initiation_consumption.lock();
- let update = match *timestamp {
- None => true,
+ // check replay attack
+ match *timestamp {
Some(timestamp_old) => {
- if timestamp::compare(&timestamp_old, &timestamp_new) {
- true
- } else {
- false
+ if !timestamp::compare(&timestamp_old, &timestamp_new) {
+ return Err(HandshakeError::OldTimestamp);
}
}
+ _ => (),
};
- if update {
- // release existing identifier
- match *state {
- State::InitiationSent { sender, .. } => device.release(sender),
- _ => (),
+ // check flood attack
+ match *last_initiation_consumption {
+ Some(last) => {
+ if last.elapsed() < *TIME_BETWEEN_INITIATIONS {
+ return Err(HandshakeError::InitiationFlood);
+ }
}
+ _ => (),
+ }
- // reset state and update timestamp
- *state = State::Reset;
- *timestamp = Some(*timestamp_new);
- Ok(())
- } else {
- Err(HandshakeError::OldTimestamp)
+ // reset state
+ match *state {
+ State::InitiationSent { sender, .. } => device.release(sender),
+ _ => (),
}
+
+ // update replay & flood protection
+ *state = State::Reset;
+ *timestamp = Some(*timestamp_new);
+ *last_initiation_consumption = Some(Instant::now());
+ Ok(())
}
}
diff --git a/src/handshake/types.rs b/src/handshake/types.rs
index 967704e..7b190ec 100644
--- a/src/handshake/types.rs
+++ b/src/handshake/types.rs
@@ -43,7 +43,8 @@ pub enum HandshakeError {
OldTimestamp,
InvalidState,
InvalidMac1,
- RateLimited
+ RateLimited,
+ InitiationFlood,
}
impl fmt::Display for HandshakeError {
@@ -58,7 +59,10 @@ impl fmt::Display for HandshakeError {
HandshakeError::OldTimestamp => write!(f, "Timestamp is less/equal to the newest"),
HandshakeError::InvalidState => write!(f, "Message does not apply to handshake state"),
HandshakeError::InvalidMac1 => write!(f, "Message has invalid mac1 field"),
- HandshakeError::RateLimited => write!(f, "Message was dropped by rate limiter")
+ HandshakeError::RateLimited => write!(f, "Message was dropped by rate limiter"),
+ HandshakeError::InitiationFlood => {
+ write!(f, "Message was dropped because of initiation flood")
+ }
}
}
}