diff options
Diffstat (limited to 'src/wireguard/wireguard.rs')
-rw-r--r-- | src/wireguard/wireguard.rs | 85 |
1 files changed, 30 insertions, 55 deletions
diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index bf550ef..ecbb9c1 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -21,9 +21,6 @@ use std::sync::Mutex as StdMutex; use std::thread; use std::time::Instant; -use std::collections::hash_map::Entry; -use std::collections::HashMap; - use hjul::Runner; use rand::rngs::OsRng; use rand::Rng; @@ -50,14 +47,13 @@ pub struct WireguardInner<T: Tun, B: UDP> { // outbound writer pub send: RwLock<Option<B::Writer>>, - // identity and configuration map - pub peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>, + // peer map + pub peers: RwLock<handshake::Device<Peer<T, B>>>, // cryptokey router pub router: router::Device<B::Endpoint, Events<T, B>, T::Writer, B::Writer>, // handshake related state - pub handshake: RwLock<handshake::Device>, pub last_under_load: Mutex<Instant>, pub pending: AtomicUsize, // number of pending handshake packets in queue pub queue: ParallelQueue<HandshakeJob<B::Endpoint>>, @@ -142,7 +138,7 @@ impl<T: Tun, B: UDP> WireGuard<T, B> { self.router.down(); // set all peers down (stops timers) - for peer in self.peers.write().values() { + for (_, peer) in self.peers.write().iter() { peer.down(); } @@ -163,11 +159,11 @@ impl<T: Tun, B: UDP> WireGuard<T, B> { return; } - // enable tranmission from router + // enable transmission from router self.router.up(); // set all peers up (restarts timers) - for peer in self.peers.write().values() { + for (_, peer) in self.peers.write().iter() { peer.up(); } @@ -179,54 +175,51 @@ impl<T: Tun, B: UDP> WireGuard<T, B> { } pub fn remove_peer(&self, pk: &PublicKey) { - if self.handshake.write().remove(pk).is_ok() { - self.peers.write().remove(pk.as_bytes()); - } + let _ = self.peers.write().remove(pk); } pub fn lookup_peer(&self, pk: &PublicKey) -> Option<Peer<T, B>> { - self.peers.read().get(pk.as_bytes()).map(|p| p.clone()) + self.peers.read().get(pk).map(|p| p.clone()) } pub fn list_peers(&self) -> Vec<Peer<T, B>> { let peers = self.peers.read(); let mut list = Vec::with_capacity(peers.len()); for (k, v) in peers.iter() { - debug_assert!(k == v.pk.as_bytes()); + debug_assert!(k.as_bytes() == v.pk.as_bytes()); list.push(v.clone()); } list } pub fn set_key(&self, sk: Option<StaticSecret>) { - let mut handshake = self.handshake.write(); - handshake.set_sk(sk); + let mut peers = self.peers.write(); + peers.set_sk(sk); self.router.clear_sending_keys(); - // handshake lock is released and new handshakes can be initated } pub fn get_sk(&self) -> Option<StaticSecret> { - self.handshake + self.peers .read() .get_sk() .map(|sk| StaticSecret::from(sk.to_bytes())) } pub fn set_psk(&self, pk: PublicKey, psk: [u8; 32]) -> bool { - self.handshake.write().set_psk(pk, psk).is_ok() + self.peers.write().set_psk(pk, psk).is_ok() } pub fn get_psk(&self, pk: &PublicKey) -> Option<[u8; 32]> { - self.handshake.read().get_psk(pk).ok() + self.peers.read().get_psk(pk).ok() } pub fn add_peer(&self, pk: PublicKey) -> bool { - if self.peers.read().contains_key(pk.as_bytes()) { + let mut peers = self.peers.write(); + if peers.contains_key(&pk) { return false; } - let mut rng = OsRng::new().unwrap(); let state = Arc::new(PeerInner { - id: rng.gen(), + id: OsRng.gen(), pk, wg: self.clone(), walltime_last_handshake: Mutex::new(None), @@ -243,33 +236,19 @@ impl<T: Tun, B: UDP> WireGuard<T, B> { // form WireGuard peer let peer = Peer { router, state }; + // prevent up/down while inserting + let enabled = self.enabled.read(); + + /* The need for dummy timers arises from the chicken-egg + * problem of the timer callbacks being able to set timers themselves. + * + * This is in fact the only place where the write lock is ever taken. + * TODO: Consider the ease of using atomic pointers instead. + */ + *peer.timers.write() = Timers::new(&*self.runner.lock(), *enabled, peer.clone()); + // finally, add the peer to the wireguard device - let mut peers = self.peers.write(); - match peers.entry(*pk.as_bytes()) { - Entry::Occupied(_) => false, - Entry::Vacant(vacancy) => { - // check that the public key does not cause conflict with the private key of the device - let ok_pk = self.handshake.write().add(pk).is_ok(); - if !ok_pk { - return false; - } - - // prevent up/down while inserting - let enabled = self.enabled.read(); - - /* The need for dummy timers arises from the chicken-egg - * problem of the timer callbacks being able to set timers themselves. - * - * This is in fact the only place where the write lock is ever taken. - * TODO: Consider the ease of using atomic pointers instead. - */ - *peer.timers.write() = Timers::new(&*self.runner.lock(), *enabled, peer.clone()); - - // insert into peer map (takes ownership and ensures that the peer is not dropped) - vacancy.insert(peer); - true - } - } + peers.add(pk, peer).is_ok() } /// Begin consuming messages from the reader. @@ -311,9 +290,6 @@ impl<T: Tun, B: UDP> WireGuard<T, B> { // workers equal to number of physical cores let cpus = num_cpus::get(); - // create device state - let mut rng = OsRng::new().unwrap(); - // create handshake queue let (tx, mut rxs) = ParallelQueue::new(cpus, 128); @@ -322,14 +298,13 @@ impl<T: Tun, B: UDP> WireGuard<T, B> { inner: Arc::new(WireguardInner { enabled: RwLock::new(false), tun_readers: WaitCounter::new(), - id: rng.gen(), + id: OsRng.gen(), mtu: AtomicUsize::new(0), - peers: RwLock::new(HashMap::new()), last_under_load: Mutex::new(Instant::now() - TIME_HORIZON), send: RwLock::new(None), router: router::Device::new(num_cpus::get(), writer), // router owns the writing half pending: AtomicUsize::new(0), - handshake: RwLock::new(handshake::Device::new()), + peers: RwLock::new(handshake::Device::new()), runner: Mutex::new(Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY)), queue: tx, }), |