diff options
Diffstat (limited to 'src/wireguard/wireguard.rs')
-rw-r--r-- | src/wireguard/wireguard.rs | 19 |
1 files changed, 13 insertions, 6 deletions
diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index c0a8d9d..613c0a8 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -18,6 +18,7 @@ use std::sync::Arc; use std::thread; use std::time::{Duration, Instant, SystemTime}; +use std::collections::hash_map::Entry; use std::collections::HashMap; use log::debug; @@ -208,9 +209,9 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { self.state.handshake.read().get_psk(pk).ok() } - pub fn add_peer(&self, pk: PublicKey) { + pub fn add_peer(&self, pk: PublicKey) -> bool { if self.state.peers.read().contains_key(pk.as_bytes()) { - return; + return false; } let mut rng = OsRng::new().unwrap(); @@ -243,10 +244,16 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { // finally, add the peer to the wireguard device let mut peers = self.state.peers.write(); - peers.entry(*pk.as_bytes()).or_insert(peer); - - // add to the handshake device - self.state.handshake.write().add(pk).unwrap(); // TODO: handle adding of public key for interface + match peers.entry(*pk.as_bytes()) { + Entry::Occupied(_) => false, + Entry::Vacant(vacancy) => { + let ok_pk = self.state.handshake.write().add(pk).is_ok(); + if ok_pk { + vacancy.insert(peer); + } + ok_pk + } + } } /// Begin consuming messages from the reader. |