diff options
Diffstat (limited to 'src/wireguard/handshake')
-rw-r--r-- | src/wireguard/handshake/device.rs | 147 | ||||
-rw-r--r-- | src/wireguard/handshake/noise.rs | 27 | ||||
-rw-r--r-- | src/wireguard/handshake/peer.rs | 15 |
3 files changed, 115 insertions, 74 deletions
diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs index c2e3a6e..85c2e45 100644 --- a/src/wireguard/handshake/device.rs +++ b/src/wireguard/handshake/device.rs @@ -21,10 +21,14 @@ use super::types::*; const MAX_PEER_PER_DEVICE: usize = 1 << 20; +pub struct KeyState { + pub sk: StaticSecret, // static secret key + pub pk: PublicKey, // static public key + macs: macs::Validator, // validator for the mac fields +} + pub struct Device { - pub sk: StaticSecret, // static secret key - pub pk: PublicKey, // static public key - macs: macs::Validator, // validator for the mac fields + keyst: Option<KeyState>, // secret/public key pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state id_map: RwLock<HashMap<u32, [u8; 32]>>, // receiver ids -> public key limiter: Mutex<RateLimiter>, @@ -35,45 +39,68 @@ pub struct Device { */ impl Device { /// Initialize a new handshake state machine - /// - /// # Arguments - /// - /// * `sk` - x25519 scalar representing the local private key - pub fn new(sk: StaticSecret) -> Device { - let pk = PublicKey::from(&sk); + pub fn new() -> Device { Device { - pk, - sk, - macs: macs::Validator::new(pk), + keyst: None, pk_map: HashMap::new(), id_map: RwLock::new(HashMap::new()), limiter: Mutex::new(RateLimiter::new()), } } + fn update_ss(&self, peer: &mut Peer) -> Option<PublicKey> { + if let Some(key) = self.keyst.as_ref() { + if *peer.pk.as_bytes() == *key.pk.as_bytes() { + return Some(peer.pk); + } + peer.ss = *key.sk.diffie_hellman(&peer.pk).as_bytes(); + } else { + peer.ss = [0u8; 32]; + }; + None + } + /// Update the secret key of the device /// /// # Arguments /// /// * `sk` - x25519 scalar representing the local private key - pub fn set_sk(&mut self, sk: StaticSecret) { + pub fn set_sk(&mut self, sk: Option<StaticSecret>) -> Option<PublicKey> { // update secret and public key - let pk = PublicKey::from(&sk); - self.sk = sk; - self.pk = pk; - self.macs = macs::Validator::new(pk); + self.keyst = sk.map(|sk| { + let pk = PublicKey::from(&sk); + let macs = macs::Validator::new(pk); + KeyState { pk, sk, macs } + }); - // recalculate the shared secrets for every peer + // recalculate / erase the shared secrets for every peer let mut ids = vec![]; + let mut same = None; for mut peer in self.pk_map.values_mut() { + // clear any existing handshake state peer.reset_state().map(|id| ids.push(id)); - peer.ss = self.sk.diffie_hellman(&peer.pk) + + // update precomputed shared secret + if let Some(key) = self.keyst.as_ref() { + peer.ss = *key.sk.diffie_hellman(&peer.pk).as_bytes(); + if *peer.pk.as_bytes() == *key.pk.as_bytes() { + same = Some(peer.pk) + } + } else { + peer.ss = [0u8; 32]; + }; } // release ids from aborted handshakes for id in ids { self.release(id) } + + // if we found a peer matching the device public key, remove it. + same.map(|pk| { + self.pk_map.remove(pk.as_bytes()); + pk + }) } /// Return the secret key of the device @@ -81,8 +108,8 @@ impl Device { /// # Returns /// /// A secret key (x25519 scalar) - pub fn get_sk(&self) -> StaticSecret { - StaticSecret::from(self.sk.to_bytes()) + pub fn get_sk(&self) -> Option<&StaticSecret> { + self.keyst.as_ref().map(|key| &key.sk) } /// Add a new public key to the state machine @@ -93,28 +120,28 @@ impl Device { /// * `pk` - The public key to add /// * `identifier` - Associated identifier which can be used to distinguish the peers pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> { - // check that the pk is not added twice - if let Some(_) = self.pk_map.get(pk.as_bytes()) { - return Ok(()); - }; - - // check that the pk is not that of the device - if *self.pk.as_bytes() == *pk.as_bytes() { - return Err(ConfigError::new( - "Public key corresponds to secret key of interface", - )); - } - // ensure less than 2^20 peers if self.pk_map.len() > MAX_PEER_PER_DEVICE { return Err(ConfigError::new("Too many peers for device")); } - // map the public key to the peer state - self.pk_map - .insert(*pk.as_bytes(), Peer::new(pk, self.sk.diffie_hellman(&pk))); - - Ok(()) + // create peer and precompute static secret + let mut peer = Peer::new( + pk, + self.keyst + .as_ref() + .map(|key| *key.sk.diffie_hellman(&pk).as_bytes()) + .unwrap_or([0u8; 32]), + ); + + // add peer to device + match self.update_ss(&mut peer) { + Some(_) => Err(ConfigError::new("Public key of peer matches the device")), + None => { + self.pk_map.insert(*pk.as_bytes(), peer); + Ok(()) + } + } } /// Remove a peer by public key @@ -203,17 +230,17 @@ impl Device { rng: &mut R, pk: &PublicKey, ) -> Result<Vec<u8>, HandshakeError> { - match self.pk_map.get(pk.as_bytes()) { - None => Err(HandshakeError::UnknownPublicKey), - Some(peer) => { + match (self.keyst.as_ref(), self.pk_map.get(pk.as_bytes())) { + (_, None) => Err(HandshakeError::UnknownPublicKey), + (None, _) => Err(HandshakeError::UnknownPublicKey), + (Some(keyst), Some(peer)) => { let sender = self.allocate(rng, peer); - let mut msg = Initiation::default(); - noise::create_initiation(rng, self, peer, sender, &mut msg.noise)?; + // create noise part of initation + noise::create_initiation(rng, keyst, peer, sender, &mut msg.noise)?; // add macs to initation - peer.macs .lock() .generate(msg.noise.as_bytes(), &mut msg.macs); @@ -242,6 +269,15 @@ impl Device { return Err(HandshakeError::InvalidMessageFormat); } + // obtain reference to key state + // if no key is configured return a noop. + let keyst = match self.keyst.as_ref() { + Some(key) => key, + None => { + return Ok((None, None, None)); + } + }; + // de-multiplex the message type field match LittleEndian::read_u32(msg) { TYPE_INITIATION => { @@ -249,7 +285,7 @@ impl Device { let msg = Initiation::parse(msg)?; // check mac1 field - self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?; + keyst.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?; // address validation & DoS mitigation if let Some(src) = src { @@ -257,9 +293,9 @@ impl Device { let src = src.into(); // check mac2 field - if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) { + if !keyst.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) { let mut reply = Default::default(); - self.macs.create_cookie_reply( + keyst.macs.create_cookie_reply( rng, msg.noise.f_sender.get(), src, @@ -276,7 +312,7 @@ impl Device { } // consume the initiation - let (peer, st) = noise::consume_initiation(self, &msg.noise)?; + let (peer, st) = noise::consume_initiation(self, keyst, &msg.noise)?; // allocate new index for response let sender = self.allocate(rng, peer); @@ -304,7 +340,7 @@ impl Device { let msg = Response::parse(msg)?; // check mac1 field - self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?; + keyst.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?; // address validation & DoS mitigation if let Some(src) = src { @@ -312,9 +348,9 @@ impl Device { let src = src.into(); // check mac2 field - if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) { + if !keyst.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) { let mut reply = Default::default(); - self.macs.create_cookie_reply( + keyst.macs.create_cookie_reply( rng, msg.noise.f_sender.get(), src, @@ -331,7 +367,7 @@ impl Device { } // consume inner playload - noise::consume_response(self, &msg.noise) + noise::consume_response(self, keyst, &msg.noise) } TYPE_COOKIE_REPLY => { let msg = CookieReply::parse(msg)?; @@ -421,8 +457,11 @@ mod tests { // intialize devices on both ends - let mut dev1 = Device::new(sk1); - let mut dev2 = Device::new(sk2); + let mut dev1 = Device::new(); + let mut dev2 = Device::new(); + + dev1.set_sk(Some(sk1)); + dev2.set_sk(Some(sk2)); dev1.add(pk2).unwrap(); dev2.add(pk1).unwrap(); diff --git a/src/wireguard/handshake/noise.rs b/src/wireguard/handshake/noise.rs index 68e738d..6db300a 100644 --- a/src/wireguard/handshake/noise.rs +++ b/src/wireguard/handshake/noise.rs @@ -22,7 +22,7 @@ use clear_on_drop::clear_stack_on_return; use subtle::ConstantTimeEq; -use super::device::Device; +use super::device::{Device, KeyState}; use super::messages::{NoiseInitiation, NoiseResponse}; use super::messages::{TYPE_INITIATION, TYPE_RESPONSE}; use super::peer::{Peer, State}; @@ -219,7 +219,7 @@ mod tests { pub fn create_initiation<R: RngCore + CryptoRng>( rng: &mut R, - device: &Device, + keyst: &KeyState, peer: &Peer, sender: u32, msg: &mut NoiseInitiation, @@ -260,9 +260,9 @@ pub fn create_initiation<R: RngCore + CryptoRng>( SEAL!( &key, - &hs, // ad - device.pk.as_bytes(), // pt - &mut msg.f_static // ct || tag + &hs, // ad + keyst.pk.as_bytes(), // pt + &mut msg.f_static // ct || tag ); // H := Hash(H || msg.static) @@ -271,7 +271,7 @@ pub fn create_initiation<R: RngCore + CryptoRng>( // (C, k) := Kdf2(C, DH(S_priv, S_pub)) - let (ck, key) = KDF2!(&ck, peer.ss.as_bytes()); + let (ck, key) = KDF2!(&ck, &peer.ss); // msg.timestamp := Aead(k, 0, Timestamp(), H) @@ -301,6 +301,7 @@ pub fn create_initiation<R: RngCore + CryptoRng>( pub fn consume_initiation<'a>( device: &'a Device, + keyst: &KeyState, msg: &NoiseInitiation, ) -> Result<(&'a Peer, TemporaryState), HandshakeError> { debug!("consume initation"); @@ -309,7 +310,7 @@ pub fn consume_initiation<'a>( let ck = INITIAL_CK; let hs = INITIAL_HS; - let hs = HASH!(&hs, device.pk.as_bytes()); + let hs = HASH!(&hs, keyst.pk.as_bytes()); // C := Kdf(C, E_pub) @@ -322,7 +323,7 @@ pub fn consume_initiation<'a>( // (C, k) := Kdf2(C, DH(E_priv, S_pub)) let eph_r_pk = PublicKey::from(msg.f_ephemeral); - let (ck, key) = KDF2!(&ck, device.sk.diffie_hellman(&eph_r_pk).as_bytes()); + let (ck, key) = KDF2!(&ck, keyst.sk.diffie_hellman(&eph_r_pk).as_bytes()); // msg.static := Aead(k, 0, S_pub, H) @@ -347,7 +348,7 @@ pub fn consume_initiation<'a>( // (C, k) := Kdf2(C, DH(S_priv, S_pub)) - let (ck, key) = KDF2!(&ck, peer.ss.as_bytes()); + let (ck, key) = KDF2!(&ck, &peer.ss); // msg.timestamp := Aead(k, 0, Timestamp(), H) @@ -461,7 +462,11 @@ pub fn create_response<R: RngCore + CryptoRng>( * allow concurrent processing of potential responses to the initiation, * in order to better mitigate DoS from malformed response messages. */ -pub fn consume_response(device: &Device, msg: &NoiseResponse) -> Result<Output, HandshakeError> { +pub fn consume_response( + device: &Device, + keyst: &KeyState, + msg: &NoiseResponse, +) -> Result<Output, HandshakeError> { debug!("consume response"); clear_stack_on_return(CLEAR_PAGES, || { // retrieve peer and copy initiation state @@ -492,7 +497,7 @@ pub fn consume_response(device: &Device, msg: &NoiseResponse) -> Result<Output, // C := Kdf1(C, DH(E_priv, S_pub)) - let ck = KDF1!(&ck, device.sk.diffie_hellman(&eph_r_pk).as_bytes()); + let ck = KDF1!(&ck, keyst.sk.diffie_hellman(&eph_r_pk).as_bytes()); // (C, tau, k) := Kdf3(C, Q) diff --git a/src/wireguard/handshake/peer.rs b/src/wireguard/handshake/peer.rs index c9e1c40..abb36eb 100644 --- a/src/wireguard/handshake/peer.rs +++ b/src/wireguard/handshake/peer.rs @@ -33,9 +33,9 @@ pub struct Peer { pub(crate) macs: Mutex<macs::Generator>, // constant state - pub(crate) pk: PublicKey, // public key of peer - pub(crate) ss: SharedSecret, // precomputed DH(static, static) - pub(crate) psk: Psk, // psk of peer + pub(crate) pk: PublicKey, // public key of peer + pub(crate) ss: [u8; 32], // precomputed DH(static, static) + pub(crate) psk: Psk, // psk of peer } pub enum State { @@ -62,17 +62,14 @@ impl Drop for State { } impl Peer { - pub fn new( - pk: PublicKey, // public key of peer - ss: SharedSecret, // precomputed DH(static, static) - ) -> Self { + pub fn new(pk: PublicKey, ss: [u8; 32]) -> Self { Self { macs: Mutex::new(macs::Generator::new(pk)), state: Mutex::new(State::Reset), timestamp: Mutex::new(None), last_initiation_consumption: Mutex::new(None), - pk: pk, - ss: ss, + pk, + ss, psk: [0u8; 32], } } |