summaryrefslogtreecommitdiffstats
path: root/src/wireguard/handshake/device.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/wireguard/handshake/device.rs')
-rw-r--r--src/wireguard/handshake/device.rs147
1 files changed, 93 insertions, 54 deletions
diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs
index c2e3a6e..85c2e45 100644
--- a/src/wireguard/handshake/device.rs
+++ b/src/wireguard/handshake/device.rs
@@ -21,10 +21,14 @@ use super::types::*;
const MAX_PEER_PER_DEVICE: usize = 1 << 20;
+pub struct KeyState {
+ pub sk: StaticSecret, // static secret key
+ pub pk: PublicKey, // static public key
+ macs: macs::Validator, // validator for the mac fields
+}
+
pub struct Device {
- pub sk: StaticSecret, // static secret key
- pub pk: PublicKey, // static public key
- macs: macs::Validator, // validator for the mac fields
+ keyst: Option<KeyState>, // secret/public key
pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state
id_map: RwLock<HashMap<u32, [u8; 32]>>, // receiver ids -> public key
limiter: Mutex<RateLimiter>,
@@ -35,45 +39,68 @@ pub struct Device {
*/
impl Device {
/// Initialize a new handshake state machine
- ///
- /// # Arguments
- ///
- /// * `sk` - x25519 scalar representing the local private key
- pub fn new(sk: StaticSecret) -> Device {
- let pk = PublicKey::from(&sk);
+ pub fn new() -> Device {
Device {
- pk,
- sk,
- macs: macs::Validator::new(pk),
+ keyst: None,
pk_map: HashMap::new(),
id_map: RwLock::new(HashMap::new()),
limiter: Mutex::new(RateLimiter::new()),
}
}
+ fn update_ss(&self, peer: &mut Peer) -> Option<PublicKey> {
+ if let Some(key) = self.keyst.as_ref() {
+ if *peer.pk.as_bytes() == *key.pk.as_bytes() {
+ return Some(peer.pk);
+ }
+ peer.ss = *key.sk.diffie_hellman(&peer.pk).as_bytes();
+ } else {
+ peer.ss = [0u8; 32];
+ };
+ None
+ }
+
/// Update the secret key of the device
///
/// # Arguments
///
/// * `sk` - x25519 scalar representing the local private key
- pub fn set_sk(&mut self, sk: StaticSecret) {
+ pub fn set_sk(&mut self, sk: Option<StaticSecret>) -> Option<PublicKey> {
// update secret and public key
- let pk = PublicKey::from(&sk);
- self.sk = sk;
- self.pk = pk;
- self.macs = macs::Validator::new(pk);
+ self.keyst = sk.map(|sk| {
+ let pk = PublicKey::from(&sk);
+ let macs = macs::Validator::new(pk);
+ KeyState { pk, sk, macs }
+ });
- // recalculate the shared secrets for every peer
+ // recalculate / erase the shared secrets for every peer
let mut ids = vec![];
+ let mut same = None;
for mut peer in self.pk_map.values_mut() {
+ // clear any existing handshake state
peer.reset_state().map(|id| ids.push(id));
- peer.ss = self.sk.diffie_hellman(&peer.pk)
+
+ // update precomputed shared secret
+ if let Some(key) = self.keyst.as_ref() {
+ peer.ss = *key.sk.diffie_hellman(&peer.pk).as_bytes();
+ if *peer.pk.as_bytes() == *key.pk.as_bytes() {
+ same = Some(peer.pk)
+ }
+ } else {
+ peer.ss = [0u8; 32];
+ };
}
// release ids from aborted handshakes
for id in ids {
self.release(id)
}
+
+ // if we found a peer matching the device public key, remove it.
+ same.map(|pk| {
+ self.pk_map.remove(pk.as_bytes());
+ pk
+ })
}
/// Return the secret key of the device
@@ -81,8 +108,8 @@ impl Device {
/// # Returns
///
/// A secret key (x25519 scalar)
- pub fn get_sk(&self) -> StaticSecret {
- StaticSecret::from(self.sk.to_bytes())
+ pub fn get_sk(&self) -> Option<&StaticSecret> {
+ self.keyst.as_ref().map(|key| &key.sk)
}
/// Add a new public key to the state machine
@@ -93,28 +120,28 @@ impl Device {
/// * `pk` - The public key to add
/// * `identifier` - Associated identifier which can be used to distinguish the peers
pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> {
- // check that the pk is not added twice
- if let Some(_) = self.pk_map.get(pk.as_bytes()) {
- return Ok(());
- };
-
- // check that the pk is not that of the device
- if *self.pk.as_bytes() == *pk.as_bytes() {
- return Err(ConfigError::new(
- "Public key corresponds to secret key of interface",
- ));
- }
-
// ensure less than 2^20 peers
if self.pk_map.len() > MAX_PEER_PER_DEVICE {
return Err(ConfigError::new("Too many peers for device"));
}
- // map the public key to the peer state
- self.pk_map
- .insert(*pk.as_bytes(), Peer::new(pk, self.sk.diffie_hellman(&pk)));
-
- Ok(())
+ // create peer and precompute static secret
+ let mut peer = Peer::new(
+ pk,
+ self.keyst
+ .as_ref()
+ .map(|key| *key.sk.diffie_hellman(&pk).as_bytes())
+ .unwrap_or([0u8; 32]),
+ );
+
+ // add peer to device
+ match self.update_ss(&mut peer) {
+ Some(_) => Err(ConfigError::new("Public key of peer matches the device")),
+ None => {
+ self.pk_map.insert(*pk.as_bytes(), peer);
+ Ok(())
+ }
+ }
}
/// Remove a peer by public key
@@ -203,17 +230,17 @@ impl Device {
rng: &mut R,
pk: &PublicKey,
) -> Result<Vec<u8>, HandshakeError> {
- match self.pk_map.get(pk.as_bytes()) {
- None => Err(HandshakeError::UnknownPublicKey),
- Some(peer) => {
+ match (self.keyst.as_ref(), self.pk_map.get(pk.as_bytes())) {
+ (_, None) => Err(HandshakeError::UnknownPublicKey),
+ (None, _) => Err(HandshakeError::UnknownPublicKey),
+ (Some(keyst), Some(peer)) => {
let sender = self.allocate(rng, peer);
-
let mut msg = Initiation::default();
- noise::create_initiation(rng, self, peer, sender, &mut msg.noise)?;
+ // create noise part of initation
+ noise::create_initiation(rng, keyst, peer, sender, &mut msg.noise)?;
// add macs to initation
-
peer.macs
.lock()
.generate(msg.noise.as_bytes(), &mut msg.macs);
@@ -242,6 +269,15 @@ impl Device {
return Err(HandshakeError::InvalidMessageFormat);
}
+ // obtain reference to key state
+ // if no key is configured return a noop.
+ let keyst = match self.keyst.as_ref() {
+ Some(key) => key,
+ None => {
+ return Ok((None, None, None));
+ }
+ };
+
// de-multiplex the message type field
match LittleEndian::read_u32(msg) {
TYPE_INITIATION => {
@@ -249,7 +285,7 @@ impl Device {
let msg = Initiation::parse(msg)?;
// check mac1 field
- self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
+ keyst.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
// address validation & DoS mitigation
if let Some(src) = src {
@@ -257,9 +293,9 @@ impl Device {
let src = src.into();
// check mac2 field
- if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
+ if !keyst.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
let mut reply = Default::default();
- self.macs.create_cookie_reply(
+ keyst.macs.create_cookie_reply(
rng,
msg.noise.f_sender.get(),
src,
@@ -276,7 +312,7 @@ impl Device {
}
// consume the initiation
- let (peer, st) = noise::consume_initiation(self, &msg.noise)?;
+ let (peer, st) = noise::consume_initiation(self, keyst, &msg.noise)?;
// allocate new index for response
let sender = self.allocate(rng, peer);
@@ -304,7 +340,7 @@ impl Device {
let msg = Response::parse(msg)?;
// check mac1 field
- self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
+ keyst.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
// address validation & DoS mitigation
if let Some(src) = src {
@@ -312,9 +348,9 @@ impl Device {
let src = src.into();
// check mac2 field
- if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
+ if !keyst.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
let mut reply = Default::default();
- self.macs.create_cookie_reply(
+ keyst.macs.create_cookie_reply(
rng,
msg.noise.f_sender.get(),
src,
@@ -331,7 +367,7 @@ impl Device {
}
// consume inner playload
- noise::consume_response(self, &msg.noise)
+ noise::consume_response(self, keyst, &msg.noise)
}
TYPE_COOKIE_REPLY => {
let msg = CookieReply::parse(msg)?;
@@ -421,8 +457,11 @@ mod tests {
// intialize devices on both ends
- let mut dev1 = Device::new(sk1);
- let mut dev2 = Device::new(sk2);
+ let mut dev1 = Device::new();
+ let mut dev2 = Device::new();
+
+ dev1.set_sk(Some(sk1));
+ dev2.set_sk(Some(sk2));
dev1.add(pk2).unwrap();
dev2.add(pk1).unwrap();