diff options
Diffstat (limited to 'src/wireguard/handshake/device.rs')
-rw-r--r-- | src/wireguard/handshake/device.rs | 147 |
1 files changed, 93 insertions, 54 deletions
diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs index c2e3a6e..85c2e45 100644 --- a/src/wireguard/handshake/device.rs +++ b/src/wireguard/handshake/device.rs @@ -21,10 +21,14 @@ use super::types::*; const MAX_PEER_PER_DEVICE: usize = 1 << 20; +pub struct KeyState { + pub sk: StaticSecret, // static secret key + pub pk: PublicKey, // static public key + macs: macs::Validator, // validator for the mac fields +} + pub struct Device { - pub sk: StaticSecret, // static secret key - pub pk: PublicKey, // static public key - macs: macs::Validator, // validator for the mac fields + keyst: Option<KeyState>, // secret/public key pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state id_map: RwLock<HashMap<u32, [u8; 32]>>, // receiver ids -> public key limiter: Mutex<RateLimiter>, @@ -35,45 +39,68 @@ pub struct Device { */ impl Device { /// Initialize a new handshake state machine - /// - /// # Arguments - /// - /// * `sk` - x25519 scalar representing the local private key - pub fn new(sk: StaticSecret) -> Device { - let pk = PublicKey::from(&sk); + pub fn new() -> Device { Device { - pk, - sk, - macs: macs::Validator::new(pk), + keyst: None, pk_map: HashMap::new(), id_map: RwLock::new(HashMap::new()), limiter: Mutex::new(RateLimiter::new()), } } + fn update_ss(&self, peer: &mut Peer) -> Option<PublicKey> { + if let Some(key) = self.keyst.as_ref() { + if *peer.pk.as_bytes() == *key.pk.as_bytes() { + return Some(peer.pk); + } + peer.ss = *key.sk.diffie_hellman(&peer.pk).as_bytes(); + } else { + peer.ss = [0u8; 32]; + }; + None + } + /// Update the secret key of the device /// /// # Arguments /// /// * `sk` - x25519 scalar representing the local private key - pub fn set_sk(&mut self, sk: StaticSecret) { + pub fn set_sk(&mut self, sk: Option<StaticSecret>) -> Option<PublicKey> { // update secret and public key - let pk = PublicKey::from(&sk); - self.sk = sk; - self.pk = pk; - self.macs = macs::Validator::new(pk); + self.keyst = sk.map(|sk| { + let pk = PublicKey::from(&sk); + let macs = macs::Validator::new(pk); + KeyState { pk, sk, macs } + }); - // recalculate the shared secrets for every peer + // recalculate / erase the shared secrets for every peer let mut ids = vec![]; + let mut same = None; for mut peer in self.pk_map.values_mut() { + // clear any existing handshake state peer.reset_state().map(|id| ids.push(id)); - peer.ss = self.sk.diffie_hellman(&peer.pk) + + // update precomputed shared secret + if let Some(key) = self.keyst.as_ref() { + peer.ss = *key.sk.diffie_hellman(&peer.pk).as_bytes(); + if *peer.pk.as_bytes() == *key.pk.as_bytes() { + same = Some(peer.pk) + } + } else { + peer.ss = [0u8; 32]; + }; } // release ids from aborted handshakes for id in ids { self.release(id) } + + // if we found a peer matching the device public key, remove it. + same.map(|pk| { + self.pk_map.remove(pk.as_bytes()); + pk + }) } /// Return the secret key of the device @@ -81,8 +108,8 @@ impl Device { /// # Returns /// /// A secret key (x25519 scalar) - pub fn get_sk(&self) -> StaticSecret { - StaticSecret::from(self.sk.to_bytes()) + pub fn get_sk(&self) -> Option<&StaticSecret> { + self.keyst.as_ref().map(|key| &key.sk) } /// Add a new public key to the state machine @@ -93,28 +120,28 @@ impl Device { /// * `pk` - The public key to add /// * `identifier` - Associated identifier which can be used to distinguish the peers pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> { - // check that the pk is not added twice - if let Some(_) = self.pk_map.get(pk.as_bytes()) { - return Ok(()); - }; - - // check that the pk is not that of the device - if *self.pk.as_bytes() == *pk.as_bytes() { - return Err(ConfigError::new( - "Public key corresponds to secret key of interface", - )); - } - // ensure less than 2^20 peers if self.pk_map.len() > MAX_PEER_PER_DEVICE { return Err(ConfigError::new("Too many peers for device")); } - // map the public key to the peer state - self.pk_map - .insert(*pk.as_bytes(), Peer::new(pk, self.sk.diffie_hellman(&pk))); - - Ok(()) + // create peer and precompute static secret + let mut peer = Peer::new( + pk, + self.keyst + .as_ref() + .map(|key| *key.sk.diffie_hellman(&pk).as_bytes()) + .unwrap_or([0u8; 32]), + ); + + // add peer to device + match self.update_ss(&mut peer) { + Some(_) => Err(ConfigError::new("Public key of peer matches the device")), + None => { + self.pk_map.insert(*pk.as_bytes(), peer); + Ok(()) + } + } } /// Remove a peer by public key @@ -203,17 +230,17 @@ impl Device { rng: &mut R, pk: &PublicKey, ) -> Result<Vec<u8>, HandshakeError> { - match self.pk_map.get(pk.as_bytes()) { - None => Err(HandshakeError::UnknownPublicKey), - Some(peer) => { + match (self.keyst.as_ref(), self.pk_map.get(pk.as_bytes())) { + (_, None) => Err(HandshakeError::UnknownPublicKey), + (None, _) => Err(HandshakeError::UnknownPublicKey), + (Some(keyst), Some(peer)) => { let sender = self.allocate(rng, peer); - let mut msg = Initiation::default(); - noise::create_initiation(rng, self, peer, sender, &mut msg.noise)?; + // create noise part of initation + noise::create_initiation(rng, keyst, peer, sender, &mut msg.noise)?; // add macs to initation - peer.macs .lock() .generate(msg.noise.as_bytes(), &mut msg.macs); @@ -242,6 +269,15 @@ impl Device { return Err(HandshakeError::InvalidMessageFormat); } + // obtain reference to key state + // if no key is configured return a noop. + let keyst = match self.keyst.as_ref() { + Some(key) => key, + None => { + return Ok((None, None, None)); + } + }; + // de-multiplex the message type field match LittleEndian::read_u32(msg) { TYPE_INITIATION => { @@ -249,7 +285,7 @@ impl Device { let msg = Initiation::parse(msg)?; // check mac1 field - self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?; + keyst.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?; // address validation & DoS mitigation if let Some(src) = src { @@ -257,9 +293,9 @@ impl Device { let src = src.into(); // check mac2 field - if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) { + if !keyst.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) { let mut reply = Default::default(); - self.macs.create_cookie_reply( + keyst.macs.create_cookie_reply( rng, msg.noise.f_sender.get(), src, @@ -276,7 +312,7 @@ impl Device { } // consume the initiation - let (peer, st) = noise::consume_initiation(self, &msg.noise)?; + let (peer, st) = noise::consume_initiation(self, keyst, &msg.noise)?; // allocate new index for response let sender = self.allocate(rng, peer); @@ -304,7 +340,7 @@ impl Device { let msg = Response::parse(msg)?; // check mac1 field - self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?; + keyst.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?; // address validation & DoS mitigation if let Some(src) = src { @@ -312,9 +348,9 @@ impl Device { let src = src.into(); // check mac2 field - if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) { + if !keyst.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) { let mut reply = Default::default(); - self.macs.create_cookie_reply( + keyst.macs.create_cookie_reply( rng, msg.noise.f_sender.get(), src, @@ -331,7 +367,7 @@ impl Device { } // consume inner playload - noise::consume_response(self, &msg.noise) + noise::consume_response(self, keyst, &msg.noise) } TYPE_COOKIE_REPLY => { let msg = CookieReply::parse(msg)?; @@ -421,8 +457,11 @@ mod tests { // intialize devices on both ends - let mut dev1 = Device::new(sk1); - let mut dev2 = Device::new(sk2); + let mut dev1 = Device::new(); + let mut dev2 = Device::new(); + + dev1.set_sk(Some(sk1)); + dev2.set_sk(Some(sk2)); dev1.add(pk2).unwrap(); dev2.add(pk1).unwrap(); |