diff options
Diffstat (limited to 'src/wireguard')
-rw-r--r-- | src/wireguard/handshake/device.rs | 16 | ||||
-rw-r--r-- | src/wireguard/handshake/peer.rs | 1 | ||||
-rw-r--r-- | src/wireguard/peer.rs | 3 | ||||
-rw-r--r-- | src/wireguard/timers.rs | 2 | ||||
-rw-r--r-- | src/wireguard/wireguard.rs | 19 |
5 files changed, 25 insertions, 16 deletions
diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs index 02e6929..030c0f8 100644 --- a/src/wireguard/handshake/device.rs +++ b/src/wireguard/handshake/device.rs @@ -469,6 +469,10 @@ mod tests { (pk1, dev1, pk2, dev2) } + fn wait() { + thread::sleep(Duration::from_millis(20)); + } + /* Test longest possible handshake interaction (7 messages): * * 1. I -> R (initation) @@ -502,8 +506,8 @@ mod tests { _ => panic!("unexpected response"), } - // avoid initation flood - thread::sleep(Duration::from_millis(20)); + // avoid initation flood detection + wait(); // 3. device-1 : create second initation let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); @@ -529,8 +533,8 @@ mod tests { _ => panic!("unexpected response"), } - // avoid initation flood - thread::sleep(Duration::from_millis(20)); + // avoid initation flood detection + wait(); // 6. device-1 : create third initation let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); @@ -600,8 +604,8 @@ mod tests { dev1.release(ks_i.send.id); dev2.release(ks_r.send.id); - // to avoid flood detection - thread::sleep(Duration::from_millis(20)); + // avoid initation flood detection + wait(); } dev1.remove(pk2).unwrap(); diff --git a/src/wireguard/handshake/peer.rs b/src/wireguard/handshake/peer.rs index abb36eb..2d69244 100644 --- a/src/wireguard/handshake/peer.rs +++ b/src/wireguard/handshake/peer.rs @@ -7,7 +7,6 @@ use generic_array::typenum::U32; use generic_array::GenericArray; use x25519_dalek::PublicKey; -use x25519_dalek::SharedSecret; use x25519_dalek::StaticSecret; use clear_on_drop::clear::Clear; diff --git a/src/wireguard/peer.rs b/src/wireguard/peer.rs index 4f9d19f..7d95493 100644 --- a/src/wireguard/peer.rs +++ b/src/wireguard/peer.rs @@ -1,4 +1,3 @@ -use super::constants::*; use super::router; use super::timers::{Events, Timers}; use super::HandshakeJob; @@ -9,7 +8,7 @@ use super::wireguard::WireguardInner; use std::fmt; use std::ops::Deref; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU64}; use std::sync::Arc; use std::time::{Instant, SystemTime}; diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs index 33b089f..8f6b3ee 100644 --- a/src/wireguard/timers.rs +++ b/src/wireguard/timers.rs @@ -63,7 +63,7 @@ impl<T: tun::Tun, B: bind::Bind> PeerInner<T, B> { // take a write lock preventing simultaneous "stop_timers" call let mut timers = self.timers_mut(); - // set flag to renable timer events + // set flag to reenable timer events if timers.enabled { return; } diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index c0a8d9d..613c0a8 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -18,6 +18,7 @@ use std::sync::Arc; use std::thread; use std::time::{Duration, Instant, SystemTime}; +use std::collections::hash_map::Entry; use std::collections::HashMap; use log::debug; @@ -208,9 +209,9 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { self.state.handshake.read().get_psk(pk).ok() } - pub fn add_peer(&self, pk: PublicKey) { + pub fn add_peer(&self, pk: PublicKey) -> bool { if self.state.peers.read().contains_key(pk.as_bytes()) { - return; + return false; } let mut rng = OsRng::new().unwrap(); @@ -243,10 +244,16 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { // finally, add the peer to the wireguard device let mut peers = self.state.peers.write(); - peers.entry(*pk.as_bytes()).or_insert(peer); - - // add to the handshake device - self.state.handshake.write().add(pk).unwrap(); // TODO: handle adding of public key for interface + match peers.entry(*pk.as_bytes()) { + Entry::Occupied(_) => false, + Entry::Vacant(vacancy) => { + let ok_pk = self.state.handshake.write().add(pk).is_ok(); + if ok_pk { + vacancy.insert(peer); + } + ok_pk + } + } } /// Begin consuming messages from the reader. |