summaryrefslogtreecommitdiffstats
path: root/src/wireguard/wireguard.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/wireguard/wireguard.rs')
-rw-r--r--src/wireguard/wireguard.rs19
1 files changed, 13 insertions, 6 deletions
diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs
index c0a8d9d..613c0a8 100644
--- a/src/wireguard/wireguard.rs
+++ b/src/wireguard/wireguard.rs
@@ -18,6 +18,7 @@ use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant, SystemTime};
+use std::collections::hash_map::Entry;
use std::collections::HashMap;
use log::debug;
@@ -208,9 +209,9 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
self.state.handshake.read().get_psk(pk).ok()
}
- pub fn add_peer(&self, pk: PublicKey) {
+ pub fn add_peer(&self, pk: PublicKey) -> bool {
if self.state.peers.read().contains_key(pk.as_bytes()) {
- return;
+ return false;
}
let mut rng = OsRng::new().unwrap();
@@ -243,10 +244,16 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
// finally, add the peer to the wireguard device
let mut peers = self.state.peers.write();
- peers.entry(*pk.as_bytes()).or_insert(peer);
-
- // add to the handshake device
- self.state.handshake.write().add(pk).unwrap(); // TODO: handle adding of public key for interface
+ match peers.entry(*pk.as_bytes()) {
+ Entry::Occupied(_) => false,
+ Entry::Vacant(vacancy) => {
+ let ok_pk = self.state.handshake.write().add(pk).is_ok();
+ if ok_pk {
+ vacancy.insert(peer);
+ }
+ ok_pk
+ }
+ }
}
/// Begin consuming messages from the reader.