aboutsummaryrefslogtreecommitdiffstats
path: root/src/wireguard/handshake
diff options
context:
space:
mode:
Diffstat (limited to 'src/wireguard/handshake')
-rw-r--r--src/wireguard/handshake/device.rs64
-rw-r--r--src/wireguard/handshake/macs.rs19
-rw-r--r--src/wireguard/handshake/noise.rs35
-rw-r--r--src/wireguard/handshake/peer.rs34
-rw-r--r--src/wireguard/handshake/ratelimiter.rs5
-rw-r--r--src/wireguard/handshake/tests.rs18
-rw-r--r--src/wireguard/handshake/timestamp.rs2
7 files changed, 85 insertions, 92 deletions
diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs
index 3a3d023..47ca401 100644
--- a/src/wireguard/handshake/device.rs
+++ b/src/wireguard/handshake/device.rs
@@ -1,14 +1,15 @@
-use spin::RwLock;
use std::collections::hash_map;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Mutex;
-use zerocopy::AsBytes;
use byteorder::{ByteOrder, LittleEndian};
+use dashmap::mapref::entry::Entry;
+use dashmap::DashMap;
+use zerocopy::AsBytes;
-use rand::prelude::{CryptoRng, RngCore};
use rand::Rng;
+use rand_core::{CryptoRng, RngCore};
use clear_on_drop::clear::Clear;
@@ -36,7 +37,7 @@ pub struct KeyState {
/// (the instance is a Peer object in the parent module)
pub struct Device<O> {
keyst: Option<KeyState>,
- id_map: RwLock<HashMap<u32, [u8; 32]>>,
+ id_map: DashMap<u32, [u8; 32]>, // concurrent map
pk_map: HashMap<[u8; 32], Peer<O>>,
limiter: Mutex<RateLimiter>,
}
@@ -62,7 +63,7 @@ impl<'a, O> Iterator for Iter<'a, O> {
*/
impl<O> Device<O> {
pub fn clear(&mut self) {
- self.id_map.write().clear();
+ self.id_map.clear();
self.pk_map.clear();
}
@@ -96,7 +97,7 @@ impl<O> Device<O> {
pub fn new() -> Device<O> {
Device {
keyst: None,
- id_map: RwLock::new(HashMap::new()),
+ id_map: DashMap::new(),
pk_map: HashMap::new(),
limiter: Mutex::new(RateLimiter::new()),
}
@@ -117,7 +118,9 @@ impl<O> Device<O> {
} else {
peer.ss.clear();
}
- peer.reset_state().map(|id| ids.push(id));
+ if let Some(id) = peer.reset_state() {
+ ids.push(id)
+ }
}
(ids, same)
@@ -208,16 +211,14 @@ impl<O> Device<O> {
///
/// The call might fail if the public key is not found
pub fn remove(&mut self, pk: &PublicKey) -> Result<(), ConfigError> {
- // take write-lock on receive id table
- let mut id_map = self.id_map.write();
-
// remove the peer
self.pk_map
.remove(pk.as_bytes())
- .ok_or(ConfigError::new("Public key not in device"))?;
+ .ok_or_else(|| ConfigError::new("Public key not in device"))?;
- // purge the id map (linear scan)
- id_map.retain(|_, v| v != pk.as_bytes());
+ // remove every id entry for the peer in the public key map
+ // O(n) operations, however it is rare: only when removing peers.
+ self.id_map.retain(|_, v| v != pk.as_bytes());
Ok(())
}
@@ -265,9 +266,8 @@ impl<O> Device<O> {
///
/// * `id` - The (sender) id to release
pub fn release(&self, id: u32) {
- let mut m = self.id_map.write();
- debug_assert!(m.contains_key(&id), "Releasing id not allocated");
- m.remove(&id);
+ let old = self.id_map.remove(&id);
+ assert!(old.is_some(), "released id not allocated");
}
/// Begin a new handshake
@@ -391,9 +391,6 @@ impl<O> Device<O> {
// address validation & DoS mitigation
if let Some(src) = src {
- // obtain ref to socket addr
- let src = src.into();
-
// check mac2 field
if !keyst.macs.check_mac2(msg.noise.as_bytes(), &src, &msg.macs) {
let mut reply = Default::default();
@@ -446,32 +443,37 @@ impl<O> Device<O> {
//
// Return the peer currently associated with the receiver identifier
pub(super) fn lookup_id(&self, id: u32) -> Result<(&Peer<O>, PublicKey), HandshakeError> {
- let im = self.id_map.read();
- let pk = im.get(&id).ok_or(HandshakeError::UnknownReceiverId)?;
- match self.pk_map.get(pk) {
+ // obtain a read reference to entry in the id_map
+ let pk = self
+ .id_map
+ .get(&id)
+ .ok_or(HandshakeError::UnknownReceiverId)?;
+
+ // lookup the public key from the pk map
+ match self.pk_map.get(&*pk) {
Some(peer) => Ok((peer, PublicKey::from(*pk))),
- _ => unreachable!(), // if the id-lookup succeeded, the peer should exist
+ _ => unreachable!(),
}
}
// Internal function
//
- // Allocated a new receiver identifier for the peer
+ // Allocated a new receiver identifier for the peer.
+ // Implemented via rejection sampling.
fn allocate<R: RngCore + CryptoRng>(&self, rng: &mut R, pk: &PublicKey) -> u32 {
loop {
let id = rng.gen();
- // check membership with read lock
- if self.id_map.read().contains_key(&id) {
+ // read lock the shard and do quick check
+ if self.id_map.contains_key(&id) {
continue;
}
- // take write lock and add index
- let mut m = self.id_map.write();
- if !m.contains_key(&id) {
- m.insert(id, *pk.as_bytes());
+ // write lock the shard and insert
+ if let Entry::Vacant(entry) = self.id_map.entry(id) {
+ entry.insert(*pk.as_bytes());
return id;
- }
+ };
}
}
}
diff --git a/src/wireguard/handshake/macs.rs b/src/wireguard/handshake/macs.rs
index cb5d7d4..f4f5586 100644
--- a/src/wireguard/handshake/macs.rs
+++ b/src/wireguard/handshake/macs.rs
@@ -1,5 +1,5 @@
use generic_array::GenericArray;
-use rand::{CryptoRng, RngCore};
+use rand_core::{CryptoRng, RngCore};
use spin::RwLock;
use std::time::{Duration, Instant};
@@ -8,6 +8,7 @@ use std::net::SocketAddr;
use x25519_dalek::PublicKey;
// AEAD
+
use aead::{Aead, NewAead, Payload};
use chacha20poly1305::XChaCha20Poly1305;
@@ -33,30 +34,29 @@ macro_rules! HASH {
use blake2::Digest;
let mut hsh = Blake2s::new();
$(
- hsh.input($input);
+ hsh.update($input);
)*
- hsh.result()
+ hsh.finalize()
}};
}
macro_rules! MAC {
( $key:expr, $($input:expr),* ) => {{
use blake2::VarBlake2s;
- use digest::Input;
- use digest::VariableOutput;
+ use blake2::digest::{Update, VariableOutput};
let mut tag = [0u8; SIZE_MAC];
let mut mac = VarBlake2s::new_keyed($key, SIZE_MAC);
$(
- mac.input($input);
+ mac.update($input);
)*
- mac.variable_result(|buf| tag.copy_from_slice(buf));
+ mac.finalize_variable(|buf| tag.copy_from_slice(buf));
tag
}};
}
macro_rules! XSEAL {
($key:expr, $nonce:expr, $ad:expr, $pt:expr, $ct:expr) => {{
- let ct = XChaCha20Poly1305::new(*GenericArray::from_slice($key))
+ let ct = XChaCha20Poly1305::new(GenericArray::from_slice($key))
.encrypt(
GenericArray::from_slice($nonce),
Payload { msg: $pt, aad: $ad },
@@ -70,7 +70,7 @@ macro_rules! XSEAL {
macro_rules! XOPEN {
($key:expr, $nonce:expr, $ad:expr, $pt:expr, $ct:expr) => {{
debug_assert_eq!($ct.len(), $pt.len() + SIZE_TAG);
- XChaCha20Poly1305::new(*GenericArray::from_slice($key))
+ XChaCha20Poly1305::new(GenericArray::from_slice($key))
.decrypt(
GenericArray::from_slice($nonce),
Payload { msg: $ct, aad: $ad },
@@ -141,6 +141,7 @@ impl Generator {
pub fn process(&mut self, reply: &CookieReply) -> Result<(), HandshakeError> {
let mac1 = self.last_mac1.ok_or(HandshakeError::InvalidState)?;
let mut tau = [0u8; SIZE_COOKIE];
+ #[allow(clippy::unnecessary_mut_passed)]
XOPEN!(
&self.cookie_key, // key
&reply.f_nonce, // nonce
diff --git a/src/wireguard/handshake/noise.rs b/src/wireguard/handshake/noise.rs
index beb99c2..92c8c5f 100644
--- a/src/wireguard/handshake/noise.rs
+++ b/src/wireguard/handshake/noise.rs
@@ -1,7 +1,7 @@
use std::time::Instant;
// DH
-use x25519_dalek::{PublicKey, StaticSecret, SharedSecret};
+use x25519_dalek::{PublicKey, SharedSecret, StaticSecret};
// HASH & MAC
use blake2::Blake2s;
@@ -11,15 +11,13 @@ use hmac::Hmac;
use aead::{Aead, NewAead, Payload};
use chacha20poly1305::ChaCha20Poly1305;
-use log;
-
-use rand::prelude::{CryptoRng, RngCore};
+use rand_core::{CryptoRng, RngCore};
use generic_array::typenum::*;
use generic_array::*;
use clear_on_drop::clear::Clear;
-use clear_on_drop::clear_stack_on_return;
+use clear_on_drop::clear_stack_on_return_fnonce;
use subtle::ConstantTimeEq;
@@ -65,20 +63,20 @@ macro_rules! HASH {
use blake2::Digest;
let mut hsh = Blake2s::new();
$(
- hsh.input($input);
+ hsh.update($input);
)*
- hsh.result()
+ hsh.finalize()
}};
}
macro_rules! HMAC {
($key:expr, $($input:expr),*) => {{
- use hmac::Mac;
+ use hmac::{Mac, NewMac};
let mut mac = HMACBlake2s::new_varkey($key).unwrap();
$(
- mac.input($input);
+ mac.update($input);
)*
- mac.result().code()
+ mac.finalize().into_bytes()
}};
}
@@ -114,7 +112,7 @@ macro_rules! KDF3 {
macro_rules! SEAL {
($key:expr, $ad:expr, $pt:expr, $ct:expr) => {
- ChaCha20Poly1305::new(*GenericArray::from_slice($key))
+ ChaCha20Poly1305::new(GenericArray::from_slice($key))
.encrypt(&ZERO_NONCE.into(), Payload { msg: $pt, aad: $ad })
.map(|ct| $ct.copy_from_slice(&ct))
.unwrap()
@@ -123,7 +121,7 @@ macro_rules! SEAL {
macro_rules! OPEN {
($key:expr, $ad:expr, $pt:expr, $ct:expr) => {
- ChaCha20Poly1305::new(*GenericArray::from_slice($key))
+ ChaCha20Poly1305::new(GenericArray::from_slice($key))
.decrypt(&ZERO_NONCE.into(), Payload { msg: $ct, aad: $ad })
.map_err(|_| HandshakeError::DecryptionFailure)
.map(|pt| $pt.copy_from_slice(&pt))
@@ -215,7 +213,7 @@ mod tests {
}
// Computes an X25519 shared secret.
-//
+//
// This function wraps dalek to add a zero-check.
// This is not recommended by the Noise specification,
// but implemented in the kernel with which we strive for absolute equivalent behavior.
@@ -244,7 +242,7 @@ pub(super) fn create_initiation<R: RngCore + CryptoRng, O>(
return Err(HandshakeError::InvalidSharedSecret);
}
- clear_stack_on_return(CLEAR_PAGES, || {
+ clear_stack_on_return_fnonce(CLEAR_PAGES, || {
// initialize state
let ck = INITIAL_CK;
@@ -290,7 +288,6 @@ pub(super) fn create_initiation<R: RngCore + CryptoRng, O>(
// (C, k) := Kdf2(C, DH(S_priv, S_pub))
-
let (ck, key) = KDF2!(&ck, &peer.ss);
// msg.timestamp := Aead(k, 0, Timestamp(), H)
@@ -326,7 +323,7 @@ pub(super) fn consume_initiation<'a, O>(
) -> Result<(&'a Peer<O>, PublicKey, TemporaryState), HandshakeError> {
log::debug!("consume initiation");
- clear_stack_on_return(CLEAR_PAGES, || {
+ clear_stack_on_return_fnonce(CLEAR_PAGES, || {
// initialize new state
let ck = INITIAL_CK;
@@ -360,7 +357,7 @@ pub(super) fn consume_initiation<'a, O>(
let peer = device.lookup_pk(&PublicKey::from(pk))?;
// check for zero shared-secret (see "shared_secret" note).
-
+
if peer.ss.ct_eq(&[0u8; 32]).into() {
return Err(HandshakeError::InvalidSharedSecret);
}
@@ -415,7 +412,7 @@ pub(super) fn create_response<R: RngCore + CryptoRng, O>(
msg: &mut NoiseResponse, // resulting response
) -> Result<KeyPair, HandshakeError> {
log::debug!("create response");
- clear_stack_on_return(CLEAR_PAGES, || {
+ clear_stack_on_return_fnonce(CLEAR_PAGES, || {
// unpack state
let (receiver, eph_r_pk, hs, ck) = state;
@@ -500,7 +497,7 @@ pub(super) fn consume_response<'a, O>(
msg: &NoiseResponse,
) -> Result<Output<'a, O>, HandshakeError> {
log::debug!("consume response");
- clear_stack_on_return(CLEAR_PAGES, || {
+ clear_stack_on_return_fnonce(CLEAR_PAGES, || {
// retrieve peer and copy initiation state
let (peer, _) = device.lookup_id(msg.f_receiver.get())?;
diff --git a/src/wireguard/handshake/peer.rs b/src/wireguard/handshake/peer.rs
index 1636e62..f847725 100644
--- a/src/wireguard/handshake/peer.rs
+++ b/src/wireguard/handshake/peer.rs
@@ -50,13 +50,10 @@ pub enum State {
impl Drop for State {
fn drop(&mut self) {
- match self {
- State::InitiationSent { hs, ck, .. } => {
- // eph_sk already cleared by dalek-x25519
- hs.clear();
- ck.clear();
- }
- _ => (),
+ if let State::InitiationSent { hs, ck, .. } = self {
+ // eph_sk already cleared by dalek-x25519
+ hs.clear();
+ ck.clear();
}
}
}
@@ -97,29 +94,22 @@ impl<O> Peer<O> {
let mut last_initiation_consumption = self.last_initiation_consumption.lock();
// check replay attack
- match *timestamp {
- Some(timestamp_old) => {
- if !timestamp::compare(&timestamp_old, &timestamp_new) {
- return Err(HandshakeError::OldTimestamp);
- }
+ if let Some(timestamp_old) = *timestamp {
+ if !timestamp::compare(&timestamp_old, &timestamp_new) {
+ return Err(HandshakeError::OldTimestamp);
}
- _ => (),
};
// check flood attack
- match *last_initiation_consumption {
- Some(last) => {
- if last.elapsed() < TIME_BETWEEN_INITIATIONS {
- return Err(HandshakeError::InitiationFlood);
- }
+ if let Some(last) = *last_initiation_consumption {
+ if last.elapsed() < TIME_BETWEEN_INITIATIONS {
+ return Err(HandshakeError::InitiationFlood);
}
- _ => (),
}
// reset state
- match *state {
- State::InitiationSent { local, .. } => device.release(local),
- _ => (),
+ if let State::InitiationSent { local, .. } = *state {
+ device.release(local)
}
// update replay & flood protection
diff --git a/src/wireguard/handshake/ratelimiter.rs b/src/wireguard/handshake/ratelimiter.rs
index 89109e9..9e796a0 100644
--- a/src/wireguard/handshake/ratelimiter.rs
+++ b/src/wireguard/handshake/ratelimiter.rs
@@ -5,8 +5,6 @@ use std::sync::{Arc, Condvar, Mutex};
use std::thread;
use std::time::{Duration, Instant};
-use spin;
-
const PACKETS_PER_SECOND: u64 = 20;
const PACKETS_BURSTABLE: u64 = 5;
const PACKET_COST: u64 = 1_000_000_000 / PACKETS_PER_SECOND;
@@ -39,6 +37,7 @@ impl Drop for RateLimiter {
impl RateLimiter {
pub fn new() -> Self {
+ #[allow(clippy::mutex_atomic)]
RateLimiter(Arc::new(RateLimiterInner {
gc_dropped: (Mutex::new(false), Condvar::new()),
gc_running: AtomicBool::from(false),
@@ -145,7 +144,7 @@ mod tests {
expected.push(Result {
allowed: true,
wait: Duration::new(0, 0),
- text: "inital burst",
+ text: "initial burst",
});
}
diff --git a/src/wireguard/handshake/tests.rs b/src/wireguard/handshake/tests.rs
index 5174d2e..35ff152 100644
--- a/src/wireguard/handshake/tests.rs
+++ b/src/wireguard/handshake/tests.rs
@@ -6,8 +6,8 @@ use std::time::Duration;
use hex;
-use rand::prelude::{CryptoRng, RngCore};
use rand::rngs::OsRng;
+use rand_core::{CryptoRng, RngCore};
use x25519_dalek::PublicKey;
use x25519_dalek::StaticSecret;
@@ -15,20 +15,22 @@ use x25519_dalek::StaticSecret;
use super::messages::{Initiation, Response};
fn setup_devices<R: RngCore + CryptoRng, O: Default>(
- rng: &mut R,
+ rng1: &mut R,
+ rng2: &mut R,
+ rng3: &mut R,
) -> (PublicKey, Device<O>, PublicKey, Device<O>) {
// generate new key pairs
- let sk1 = StaticSecret::new(rng);
+ let sk1 = StaticSecret::new(rng1);
let pk1 = PublicKey::from(&sk1);
- let sk2 = StaticSecret::new(rng);
+ let sk2 = StaticSecret::new(rng2);
let pk2 = PublicKey::from(&sk2);
// pick random psk
let mut psk = [0u8; 32];
- rng.fill_bytes(&mut psk[..]);
+ rng3.fill_bytes(&mut psk[..]);
// initialize devices on both ends
@@ -63,7 +65,8 @@ fn wait() {
*/
#[test]
fn handshake_under_load() {
- let (_pk1, dev1, pk2, dev2): (_, Device<usize>, _, _) = setup_devices(&mut OsRng);
+ let (_pk1, dev1, pk2, dev2): (_, Device<usize>, _, _) =
+ setup_devices(&mut OsRng, &mut OsRng, &mut OsRng);
let src1: SocketAddr = "172.16.0.1:8080".parse().unwrap();
let src2: SocketAddr = "172.16.0.2:7070".parse().unwrap();
@@ -140,7 +143,8 @@ fn handshake_under_load() {
#[test]
fn handshake_no_load() {
- let (pk1, mut dev1, pk2, mut dev2): (_, Device<usize>, _, _) = setup_devices(&mut OsRng);
+ let (pk1, mut dev1, pk2, mut dev2): (_, Device<usize>, _, _) =
+ setup_devices(&mut OsRng, &mut OsRng, &mut OsRng);
// do a few handshakes (every handshake should succeed)
diff --git a/src/wireguard/handshake/timestamp.rs b/src/wireguard/handshake/timestamp.rs
index b5bd9f0..485bb8d 100644
--- a/src/wireguard/handshake/timestamp.rs
+++ b/src/wireguard/handshake/timestamp.rs
@@ -28,5 +28,5 @@ pub fn compare(old: &TAI64N, new: &TAI64N) -> bool {
return true;
}
}
- return false;
+ false
}