diff options
Diffstat (limited to 'src/wireguard/handshake/device.rs')
-rw-r--r-- | src/wireguard/handshake/device.rs | 198 |
1 files changed, 127 insertions, 71 deletions
diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs index edd1a07..4b5d8f6 100644 --- a/src/wireguard/handshake/device.rs +++ b/src/wireguard/handshake/device.rs @@ -1,4 +1,5 @@ use spin::RwLock; +use std::collections::hash_map; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Mutex; @@ -6,7 +7,10 @@ use zerocopy::AsBytes; use byteorder::{ByteOrder, LittleEndian}; -use rand::prelude::*; +use rand::Rng; +use rand_core::{CryptoRng, RngCore}; + +use clear_on_drop::clear::Clear; use x25519_dalek::PublicKey; use x25519_dalek::StaticSecret; @@ -22,42 +26,101 @@ use super::types::*; const MAX_PEER_PER_DEVICE: usize = 1 << 20; pub struct KeyState { - pub sk: StaticSecret, // static secret key - pub pk: PublicKey, // static public key - macs: macs::Validator, // validator for the mac fields + pub(super) sk: StaticSecret, // static secret key + pub(super) pk: PublicKey, // static public key + macs: macs::Validator, // validator for the mac fields } -pub struct Device { - keyst: Option<KeyState>, // secret/public key - pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state - id_map: RwLock<HashMap<u32, [u8; 32]>>, // receiver ids -> public key +/// The device is generic over an "opaque" type +/// which can be used to associate the public key with this value. +/// (the instance is a Peer object in the parent module) +pub struct Device<O> { + keyst: Option<KeyState>, + id_map: RwLock<HashMap<u32, [u8; 32]>>, + pk_map: HashMap<[u8; 32], Peer<O>>, limiter: Mutex<RateLimiter>, } +pub struct Iter<'a, O> { + iter: hash_map::Iter<'a, [u8; 32], Peer<O>>, +} + +impl<'a, O> Iterator for Iter<'a, O> { + type Item = (PublicKey, &'a O); + + fn next(&mut self) -> Option<Self::Item> { + self.iter + .next() + .map(|(pk, peer)| (PublicKey::from(*pk), &peer.opaque)) + } +} + +/* These methods enable the Device to act as a map + * from public keys to the set of contained opaque values. + * + * It also abstracts away the problem of PublicKey not being hashable. + */ +impl<O> Device<O> { + pub fn clear(&mut self) { + self.id_map.write().clear(); + self.pk_map.clear(); + } + + pub fn len(&self) -> usize { + self.pk_map.len() + } + + /// Enables enumeration of (public key, opaque) pairs + /// without exposing internal peer type. + pub fn iter(&self) -> Iter<O> { + Iter { + iter: self.pk_map.iter(), + } + } + + /// Enables lookup by public key without exposing internal peer type. + pub fn get(&self, pk: &PublicKey) -> Option<&O> { + self.pk_map.get(pk.as_bytes()).map(|peer| &peer.opaque) + } + + pub fn contains_key(&self, pk: &PublicKey) -> bool { + self.pk_map.contains_key(pk.as_bytes()) + } +} + /* A mutable reference to the device needs to be held during configuration. * Wrapping the device in a RwLock enables peer config after "configuration time" */ -impl Device { +impl<O> Device<O> { /// Initialize a new handshake state machine - pub fn new() -> Device { + pub fn new() -> Device<O> { Device { keyst: None, - pk_map: HashMap::new(), id_map: RwLock::new(HashMap::new()), + pk_map: HashMap::new(), limiter: Mutex::new(RateLimiter::new()), } } - fn update_ss(&self, peer: &mut Peer) -> Option<PublicKey> { - if let Some(key) = self.keyst.as_ref() { - if *peer.pk.as_bytes() == *key.pk.as_bytes() { - return Some(peer.pk); + fn update_ss(&mut self) -> (Vec<u32>, Option<PublicKey>) { + let mut same = None; + let mut ids = Vec::with_capacity(self.pk_map.len()); + for (pk, peer) in self.pk_map.iter_mut() { + if let Some(key) = self.keyst.as_ref() { + if key.pk.as_bytes() == pk { + same = Some(PublicKey::from(*pk)); + peer.ss.clear() + } else { + let pk = PublicKey::from(*pk); + peer.ss = *key.sk.diffie_hellman(&pk).as_bytes(); + } + } else { + peer.ss.clear(); } - peer.ss = *key.sk.diffie_hellman(&peer.pk).as_bytes(); - } else { - peer.ss = [0u8; 32]; - }; - None + peer.reset_state().map(|id| ids.push(id)); + } + + (ids, same) } /// Update the secret key of the device @@ -74,29 +137,15 @@ impl Device { }); // recalculate / erase the shared secrets for every peer - let mut ids = vec![]; - let mut same = None; - for mut peer in self.pk_map.values_mut() { - // clear any existing handshake state - peer.reset_state().map(|id| ids.push(id)); - - // update precomputed shared secret - if let Some(key) = self.keyst.as_ref() { - peer.ss = *key.sk.diffie_hellman(&peer.pk).as_bytes(); - if *peer.pk.as_bytes() == *key.pk.as_bytes() { - same = Some(peer.pk) - } - } else { - peer.ss = [0u8; 32]; - }; - } + let (ids, same) = self.update_ss(); // release ids from aborted handshakes for id in ids { self.release(id) } - // if we found a peer matching the device public key, remove it. + // if we found a peer matching the device public key + // remove it and return its value to the caller same.map(|pk| { self.pk_map.remove(pk.as_bytes()); pk @@ -119,29 +168,32 @@ impl Device { /// /// * `pk` - The public key to add /// * `identifier` - Associated identifier which can be used to distinguish the peers - pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> { + pub fn add(&mut self, pk: PublicKey, opaque: O) -> Result<(), ConfigError> { // ensure less than 2^20 peers if self.pk_map.len() > MAX_PEER_PER_DEVICE { return Err(ConfigError::new("Too many peers for device")); } - // create peer and precompute static secret - let mut peer = Peer::new( - pk, - self.keyst - .as_ref() - .map(|key| *key.sk.diffie_hellman(&pk).as_bytes()) - .unwrap_or([0u8; 32]), - ); - - // add peer to device - match self.update_ss(&mut peer) { - Some(_) => Err(ConfigError::new("Public key of peer matches the device")), - None => { - self.pk_map.insert(*pk.as_bytes(), peer); - Ok(()) + // error if public key matches device + if let Some(key) = self.keyst.as_ref() { + if pk.as_bytes() == key.pk.as_bytes() { + return Err(ConfigError::new("Public key of peer matches the device")); } } + + // pre-compute shared secret and add to pk_map + self.pk_map.insert( + *pk.as_bytes(), + Peer::new( + pk, + self.keyst + .as_ref() + .map(|key| *key.sk.diffie_hellman(&pk).as_bytes()) + .unwrap_or([0u8; 32]), + opaque, + ), + ); + Ok(()) } /// Remove a peer by public key @@ -163,7 +215,7 @@ impl Device { .remove(pk.as_bytes()) .ok_or(ConfigError::new("Public key not in device"))?; - // pruge the id map (linear scan) + // purge the id map (linear scan) id_map.retain(|_, v| v != pk.as_bytes()); Ok(()) } @@ -231,11 +283,11 @@ impl Device { (_, None) => Err(HandshakeError::UnknownPublicKey), (None, _) => Err(HandshakeError::UnknownPublicKey), (Some(keyst), Some(peer)) => { - let local = self.allocate(rng, peer); + let local = self.allocate(rng, pk); let mut msg = Initiation::default(); // create noise part of initation - noise::create_initiation(rng, keyst, peer, local, &mut msg.noise)?; + noise::create_initiation(rng, keyst, peer, pk, local, &mut msg.noise)?; // add macs to initation peer.macs @@ -253,11 +305,11 @@ impl Device { /// /// * `msg` - Byte slice containing the message (untrusted input) pub fn process<'a, R: RngCore + CryptoRng>( - &self, - rng: &mut R, // rng instance to sample randomness from - msg: &[u8], // message buffer + &'a self, + rng: &mut R, // rng instance to sample randomness from + msg: &[u8], // message buffer src: Option<SocketAddr>, // optional source endpoint, set when "under load" - ) -> Result<Output, HandshakeError> { + ) -> Result<Output<'a, O>, HandshakeError> { // ensure type read in-range if msg.len() < 4 { return Err(HandshakeError::InvalidMessageFormat); @@ -303,17 +355,17 @@ impl Device { } // consume the initiation - let (peer, st) = noise::consume_initiation(self, keyst, &msg.noise)?; + let (peer, pk, st) = noise::consume_initiation(self, keyst, &msg.noise)?; // allocate new index for response - let local = self.allocate(rng, peer); + let local = self.allocate(rng, &pk); // prepare memory for response, TODO: take slice for zero allocation let mut resp = Response::default(); // create response (release id on error) - let keys = - noise::create_response(rng, peer, local, st, &mut resp.noise).map_err(|e| { + let keys = noise::create_response(rng, peer, &pk, local, st, &mut resp.noise) + .map_err(|e| { self.release(local); e })?; @@ -324,7 +376,11 @@ impl Device { .generate(resp.noise.as_bytes(), &mut resp.macs); // return unconfirmed keypair and the response as vector - Ok((Some(peer.pk), Some(resp.as_bytes().to_owned()), Some(keys))) + Ok(( + Some(&peer.opaque), + Some(resp.as_bytes().to_owned()), + Some(keys), + )) } TYPE_RESPONSE => { let msg = Response::parse(msg)?; @@ -363,7 +419,7 @@ impl Device { let msg = CookieReply::parse(msg)?; // lookup peer - let peer = self.lookup_id(msg.f_receiver.get())?; + let (peer, _) = self.lookup_id(msg.f_receiver.get())?; // validate cookie reply peer.macs.lock().process(&msg)?; @@ -379,7 +435,7 @@ impl Device { // Internal function // // Return the peer associated with the public key - pub(crate) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer, HandshakeError> { + pub(super) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer<O>, HandshakeError> { self.pk_map .get(pk.as_bytes()) .ok_or(HandshakeError::UnknownPublicKey) @@ -388,11 +444,11 @@ impl Device { // Internal function // // Return the peer currently associated with the receiver identifier - pub(crate) fn lookup_id(&self, id: u32) -> Result<&Peer, HandshakeError> { + pub(super) fn lookup_id(&self, id: u32) -> Result<(&Peer<O>, PublicKey), HandshakeError> { let im = self.id_map.read(); let pk = im.get(&id).ok_or(HandshakeError::UnknownReceiverId)?; match self.pk_map.get(pk) { - Some(peer) => Ok(peer), + Some(peer) => Ok((peer, PublicKey::from(*pk))), _ => unreachable!(), // if the id-lookup succeeded, the peer should exist } } @@ -400,7 +456,7 @@ impl Device { // Internal function // // Allocated a new receiver identifier for the peer - fn allocate<R: RngCore + CryptoRng>(&self, rng: &mut R, peer: &Peer) -> u32 { + fn allocate<R: RngCore + CryptoRng>(&self, rng: &mut R, pk: &PublicKey) -> u32 { loop { let id = rng.gen(); @@ -412,7 +468,7 @@ impl Device { // take write lock and add index let mut m = self.id_map.write(); if !m.contains_key(&id) { - m.insert(id, *peer.pk.as_bytes()); + m.insert(id, *pk.as_bytes()); return id; } } |