aboutsummaryrefslogtreecommitdiffstats
path: root/src/wireguard/handshake/device.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/wireguard/handshake/device.rs')
-rw-r--r--src/wireguard/handshake/device.rs198
1 files changed, 127 insertions, 71 deletions
diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs
index edd1a07..4b5d8f6 100644
--- a/src/wireguard/handshake/device.rs
+++ b/src/wireguard/handshake/device.rs
@@ -1,4 +1,5 @@
use spin::RwLock;
+use std::collections::hash_map;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Mutex;
@@ -6,7 +7,10 @@ use zerocopy::AsBytes;
use byteorder::{ByteOrder, LittleEndian};
-use rand::prelude::*;
+use rand::Rng;
+use rand_core::{CryptoRng, RngCore};
+
+use clear_on_drop::clear::Clear;
use x25519_dalek::PublicKey;
use x25519_dalek::StaticSecret;
@@ -22,42 +26,101 @@ use super::types::*;
const MAX_PEER_PER_DEVICE: usize = 1 << 20;
pub struct KeyState {
- pub sk: StaticSecret, // static secret key
- pub pk: PublicKey, // static public key
- macs: macs::Validator, // validator for the mac fields
+ pub(super) sk: StaticSecret, // static secret key
+ pub(super) pk: PublicKey, // static public key
+ macs: macs::Validator, // validator for the mac fields
}
-pub struct Device {
- keyst: Option<KeyState>, // secret/public key
- pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state
- id_map: RwLock<HashMap<u32, [u8; 32]>>, // receiver ids -> public key
+/// The device is generic over an "opaque" type
+/// which can be used to associate the public key with this value.
+/// (the instance is a Peer object in the parent module)
+pub struct Device<O> {
+ keyst: Option<KeyState>,
+ id_map: RwLock<HashMap<u32, [u8; 32]>>,
+ pk_map: HashMap<[u8; 32], Peer<O>>,
limiter: Mutex<RateLimiter>,
}
+pub struct Iter<'a, O> {
+ iter: hash_map::Iter<'a, [u8; 32], Peer<O>>,
+}
+
+impl<'a, O> Iterator for Iter<'a, O> {
+ type Item = (PublicKey, &'a O);
+
+ fn next(&mut self) -> Option<Self::Item> {
+ self.iter
+ .next()
+ .map(|(pk, peer)| (PublicKey::from(*pk), &peer.opaque))
+ }
+}
+
+/* These methods enable the Device to act as a map
+ * from public keys to the set of contained opaque values.
+ *
+ * It also abstracts away the problem of PublicKey not being hashable.
+ */
+impl<O> Device<O> {
+ pub fn clear(&mut self) {
+ self.id_map.write().clear();
+ self.pk_map.clear();
+ }
+
+ pub fn len(&self) -> usize {
+ self.pk_map.len()
+ }
+
+ /// Enables enumeration of (public key, opaque) pairs
+ /// without exposing internal peer type.
+ pub fn iter(&self) -> Iter<O> {
+ Iter {
+ iter: self.pk_map.iter(),
+ }
+ }
+
+ /// Enables lookup by public key without exposing internal peer type.
+ pub fn get(&self, pk: &PublicKey) -> Option<&O> {
+ self.pk_map.get(pk.as_bytes()).map(|peer| &peer.opaque)
+ }
+
+ pub fn contains_key(&self, pk: &PublicKey) -> bool {
+ self.pk_map.contains_key(pk.as_bytes())
+ }
+}
+
/* A mutable reference to the device needs to be held during configuration.
* Wrapping the device in a RwLock enables peer config after "configuration time"
*/
-impl Device {
+impl<O> Device<O> {
/// Initialize a new handshake state machine
- pub fn new() -> Device {
+ pub fn new() -> Device<O> {
Device {
keyst: None,
- pk_map: HashMap::new(),
id_map: RwLock::new(HashMap::new()),
+ pk_map: HashMap::new(),
limiter: Mutex::new(RateLimiter::new()),
}
}
- fn update_ss(&self, peer: &mut Peer) -> Option<PublicKey> {
- if let Some(key) = self.keyst.as_ref() {
- if *peer.pk.as_bytes() == *key.pk.as_bytes() {
- return Some(peer.pk);
+ fn update_ss(&mut self) -> (Vec<u32>, Option<PublicKey>) {
+ let mut same = None;
+ let mut ids = Vec::with_capacity(self.pk_map.len());
+ for (pk, peer) in self.pk_map.iter_mut() {
+ if let Some(key) = self.keyst.as_ref() {
+ if key.pk.as_bytes() == pk {
+ same = Some(PublicKey::from(*pk));
+ peer.ss.clear()
+ } else {
+ let pk = PublicKey::from(*pk);
+ peer.ss = *key.sk.diffie_hellman(&pk).as_bytes();
+ }
+ } else {
+ peer.ss.clear();
}
- peer.ss = *key.sk.diffie_hellman(&peer.pk).as_bytes();
- } else {
- peer.ss = [0u8; 32];
- };
- None
+ peer.reset_state().map(|id| ids.push(id));
+ }
+
+ (ids, same)
}
/// Update the secret key of the device
@@ -74,29 +137,15 @@ impl Device {
});
// recalculate / erase the shared secrets for every peer
- let mut ids = vec![];
- let mut same = None;
- for mut peer in self.pk_map.values_mut() {
- // clear any existing handshake state
- peer.reset_state().map(|id| ids.push(id));
-
- // update precomputed shared secret
- if let Some(key) = self.keyst.as_ref() {
- peer.ss = *key.sk.diffie_hellman(&peer.pk).as_bytes();
- if *peer.pk.as_bytes() == *key.pk.as_bytes() {
- same = Some(peer.pk)
- }
- } else {
- peer.ss = [0u8; 32];
- };
- }
+ let (ids, same) = self.update_ss();
// release ids from aborted handshakes
for id in ids {
self.release(id)
}
- // if we found a peer matching the device public key, remove it.
+ // if we found a peer matching the device public key
+ // remove it and return its value to the caller
same.map(|pk| {
self.pk_map.remove(pk.as_bytes());
pk
@@ -119,29 +168,32 @@ impl Device {
///
/// * `pk` - The public key to add
/// * `identifier` - Associated identifier which can be used to distinguish the peers
- pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> {
+ pub fn add(&mut self, pk: PublicKey, opaque: O) -> Result<(), ConfigError> {
// ensure less than 2^20 peers
if self.pk_map.len() > MAX_PEER_PER_DEVICE {
return Err(ConfigError::new("Too many peers for device"));
}
- // create peer and precompute static secret
- let mut peer = Peer::new(
- pk,
- self.keyst
- .as_ref()
- .map(|key| *key.sk.diffie_hellman(&pk).as_bytes())
- .unwrap_or([0u8; 32]),
- );
-
- // add peer to device
- match self.update_ss(&mut peer) {
- Some(_) => Err(ConfigError::new("Public key of peer matches the device")),
- None => {
- self.pk_map.insert(*pk.as_bytes(), peer);
- Ok(())
+ // error if public key matches device
+ if let Some(key) = self.keyst.as_ref() {
+ if pk.as_bytes() == key.pk.as_bytes() {
+ return Err(ConfigError::new("Public key of peer matches the device"));
}
}
+
+ // pre-compute shared secret and add to pk_map
+ self.pk_map.insert(
+ *pk.as_bytes(),
+ Peer::new(
+ pk,
+ self.keyst
+ .as_ref()
+ .map(|key| *key.sk.diffie_hellman(&pk).as_bytes())
+ .unwrap_or([0u8; 32]),
+ opaque,
+ ),
+ );
+ Ok(())
}
/// Remove a peer by public key
@@ -163,7 +215,7 @@ impl Device {
.remove(pk.as_bytes())
.ok_or(ConfigError::new("Public key not in device"))?;
- // pruge the id map (linear scan)
+ // purge the id map (linear scan)
id_map.retain(|_, v| v != pk.as_bytes());
Ok(())
}
@@ -231,11 +283,11 @@ impl Device {
(_, None) => Err(HandshakeError::UnknownPublicKey),
(None, _) => Err(HandshakeError::UnknownPublicKey),
(Some(keyst), Some(peer)) => {
- let local = self.allocate(rng, peer);
+ let local = self.allocate(rng, pk);
let mut msg = Initiation::default();
// create noise part of initation
- noise::create_initiation(rng, keyst, peer, local, &mut msg.noise)?;
+ noise::create_initiation(rng, keyst, peer, pk, local, &mut msg.noise)?;
// add macs to initation
peer.macs
@@ -253,11 +305,11 @@ impl Device {
///
/// * `msg` - Byte slice containing the message (untrusted input)
pub fn process<'a, R: RngCore + CryptoRng>(
- &self,
- rng: &mut R, // rng instance to sample randomness from
- msg: &[u8], // message buffer
+ &'a self,
+ rng: &mut R, // rng instance to sample randomness from
+ msg: &[u8], // message buffer
src: Option<SocketAddr>, // optional source endpoint, set when "under load"
- ) -> Result<Output, HandshakeError> {
+ ) -> Result<Output<'a, O>, HandshakeError> {
// ensure type read in-range
if msg.len() < 4 {
return Err(HandshakeError::InvalidMessageFormat);
@@ -303,17 +355,17 @@ impl Device {
}
// consume the initiation
- let (peer, st) = noise::consume_initiation(self, keyst, &msg.noise)?;
+ let (peer, pk, st) = noise::consume_initiation(self, keyst, &msg.noise)?;
// allocate new index for response
- let local = self.allocate(rng, peer);
+ let local = self.allocate(rng, &pk);
// prepare memory for response, TODO: take slice for zero allocation
let mut resp = Response::default();
// create response (release id on error)
- let keys =
- noise::create_response(rng, peer, local, st, &mut resp.noise).map_err(|e| {
+ let keys = noise::create_response(rng, peer, &pk, local, st, &mut resp.noise)
+ .map_err(|e| {
self.release(local);
e
})?;
@@ -324,7 +376,11 @@ impl Device {
.generate(resp.noise.as_bytes(), &mut resp.macs);
// return unconfirmed keypair and the response as vector
- Ok((Some(peer.pk), Some(resp.as_bytes().to_owned()), Some(keys)))
+ Ok((
+ Some(&peer.opaque),
+ Some(resp.as_bytes().to_owned()),
+ Some(keys),
+ ))
}
TYPE_RESPONSE => {
let msg = Response::parse(msg)?;
@@ -363,7 +419,7 @@ impl Device {
let msg = CookieReply::parse(msg)?;
// lookup peer
- let peer = self.lookup_id(msg.f_receiver.get())?;
+ let (peer, _) = self.lookup_id(msg.f_receiver.get())?;
// validate cookie reply
peer.macs.lock().process(&msg)?;
@@ -379,7 +435,7 @@ impl Device {
// Internal function
//
// Return the peer associated with the public key
- pub(crate) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer, HandshakeError> {
+ pub(super) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer<O>, HandshakeError> {
self.pk_map
.get(pk.as_bytes())
.ok_or(HandshakeError::UnknownPublicKey)
@@ -388,11 +444,11 @@ impl Device {
// Internal function
//
// Return the peer currently associated with the receiver identifier
- pub(crate) fn lookup_id(&self, id: u32) -> Result<&Peer, HandshakeError> {
+ pub(super) fn lookup_id(&self, id: u32) -> Result<(&Peer<O>, PublicKey), HandshakeError> {
let im = self.id_map.read();
let pk = im.get(&id).ok_or(HandshakeError::UnknownReceiverId)?;
match self.pk_map.get(pk) {
- Some(peer) => Ok(peer),
+ Some(peer) => Ok((peer, PublicKey::from(*pk))),
_ => unreachable!(), // if the id-lookup succeeded, the peer should exist
}
}
@@ -400,7 +456,7 @@ impl Device {
// Internal function
//
// Allocated a new receiver identifier for the peer
- fn allocate<R: RngCore + CryptoRng>(&self, rng: &mut R, peer: &Peer) -> u32 {
+ fn allocate<R: RngCore + CryptoRng>(&self, rng: &mut R, pk: &PublicKey) -> u32 {
loop {
let id = rng.gen();
@@ -412,7 +468,7 @@ impl Device {
// take write lock and add index
let mut m = self.id_map.write();
if !m.contains_key(&id) {
- m.insert(id, *peer.pk.as_bytes());
+ m.insert(id, *pk.as_bytes());
return id;
}
}