diff options
author | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-11-18 12:04:20 +0100 |
---|---|---|
committer | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-11-18 12:04:20 +0100 |
commit | b1fbd7fbbaa92dde20d292307f4f4347e4c01450 (patch) | |
tree | 3ad79a99ff36568aa801121fad4b065cb819b1ea /src/wireguard/handshake | |
parent | Update configuration API (diff) | |
download | wireguard-rs-b1fbd7fbbaa92dde20d292307f4f4347e4c01450.tar.xz wireguard-rs-b1fbd7fbbaa92dde20d292307f4f4347e4c01450.zip |
Bug fixes from compliance tests with WireGuard
Diffstat (limited to 'src/wireguard/handshake')
-rw-r--r-- | src/wireguard/handshake/device.rs | 202 | ||||
-rw-r--r-- | src/wireguard/handshake/mod.rs | 3 | ||||
-rw-r--r-- | src/wireguard/handshake/noise.rs | 27 | ||||
-rw-r--r-- | src/wireguard/handshake/peer.rs | 6 | ||||
-rw-r--r-- | src/wireguard/handshake/tests.rs | 197 |
5 files changed, 224 insertions, 211 deletions
diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs index 030c0f8..8e16248 100644 --- a/src/wireguard/handshake/device.rs +++ b/src/wireguard/handshake/device.rs @@ -231,11 +231,11 @@ impl Device { (_, None) => Err(HandshakeError::UnknownPublicKey), (None, _) => Err(HandshakeError::UnknownPublicKey), (Some(keyst), Some(peer)) => { - let sender = self.allocate(rng, peer); + let local = self.allocate(rng, peer); let mut msg = Initiation::default(); // create noise part of initation - noise::create_initiation(rng, keyst, peer, sender, &mut msg.noise)?; + noise::create_initiation(rng, keyst, peer, local, &mut msg.noise)?; // add macs to initation peer.macs @@ -312,18 +312,17 @@ impl Device { let (peer, st) = noise::consume_initiation(self, keyst, &msg.noise)?; // allocate new index for response - let sender = self.allocate(rng, peer); + let local = self.allocate(rng, peer); // prepare memory for response, TODO: take slice for zero allocation let mut resp = Response::default(); // create response (release id on error) - let keys = noise::create_response(rng, peer, sender, st, &mut resp.noise).map_err( - |e| { - self.release(sender); + let keys = + noise::create_response(rng, peer, local, st, &mut resp.noise).map_err(|e| { + self.release(local); e - }, - )?; + })?; // add macs to response peer.macs @@ -425,190 +424,3 @@ impl Device { } } } - -#[cfg(test)] -mod tests { - use super::super::messages::*; - use super::*; - use hex; - use rand::rngs::OsRng; - use std::net::SocketAddr; - use std::thread; - use std::time::Duration; - - fn setup_devices<R: RngCore + CryptoRng>( - rng: &mut R, - ) -> (PublicKey, Device, PublicKey, Device) { - // generate new keypairs - - let sk1 = StaticSecret::new(rng); - let pk1 = PublicKey::from(&sk1); - - let sk2 = StaticSecret::new(rng); - let pk2 = PublicKey::from(&sk2); - - // pick random psk - - let mut psk = [0u8; 32]; - rng.fill_bytes(&mut psk[..]); - - // intialize devices on both ends - - let mut dev1 = Device::new(); - let mut dev2 = Device::new(); - - dev1.set_sk(Some(sk1)); - dev2.set_sk(Some(sk2)); - - dev1.add(pk2).unwrap(); - dev2.add(pk1).unwrap(); - - dev1.set_psk(pk2, psk).unwrap(); - dev2.set_psk(pk1, psk).unwrap(); - - (pk1, dev1, pk2, dev2) - } - - fn wait() { - thread::sleep(Duration::from_millis(20)); - } - - /* Test longest possible handshake interaction (7 messages): - * - * 1. I -> R (initation) - * 2. I <- R (cookie reply) - * 3. I -> R (initation) - * 4. I <- R (response) - * 5. I -> R (cookie reply) - * 6. I -> R (initation) - * 7. I <- R (response) - */ - #[test] - fn handshake_under_load() { - let mut rng = OsRng::new().unwrap(); - let (_pk1, dev1, pk2, dev2) = setup_devices(&mut rng); - - let src1: SocketAddr = "172.16.0.1:8080".parse().unwrap(); - let src2: SocketAddr = "172.16.0.2:7070".parse().unwrap(); - - // 1. device-1 : create first initation - let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); - - // 2. device-2 : responds with CookieReply - let msg_cookie = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() { - (None, Some(msg), None) => msg, - _ => panic!("unexpected response"), - }; - - // device-1 : processes CookieReply (no response) - match dev1.process(&mut rng, &msg_cookie, Some(&src2)).unwrap() { - (None, None, None) => (), - _ => panic!("unexpected response"), - } - - // avoid initation flood detection - wait(); - - // 3. device-1 : create second initation - let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); - - // 4. device-2 : responds with noise response - let msg_response = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() { - (Some(_), Some(msg), Some(kp)) => { - assert_eq!(kp.initiator, false); - msg - } - _ => panic!("unexpected response"), - }; - - // 5. device-1 : responds with CookieReply - let msg_cookie = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() { - (None, Some(msg), None) => msg, - _ => panic!("unexpected response"), - }; - - // device-2 : processes CookieReply (no response) - match dev2.process(&mut rng, &msg_cookie, Some(&src1)).unwrap() { - (None, None, None) => (), - _ => panic!("unexpected response"), - } - - // avoid initation flood detection - wait(); - - // 6. device-1 : create third initation - let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); - - // 7. device-2 : responds with noise response - let (msg_response, kp1) = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() { - (Some(_), Some(msg), Some(kp)) => { - assert_eq!(kp.initiator, false); - (msg, kp) - } - _ => panic!("unexpected response"), - }; - - // device-1 : process noise response - let kp2 = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() { - (Some(_), None, Some(kp)) => { - assert_eq!(kp.initiator, true); - kp - } - _ => panic!("unexpected response"), - }; - - assert_eq!(kp1.send, kp2.recv); - assert_eq!(kp1.recv, kp2.send); - } - - #[test] - fn handshake_no_load() { - let mut rng = OsRng::new().unwrap(); - let (pk1, mut dev1, pk2, mut dev2) = setup_devices(&mut rng); - - // do a few handshakes (every handshake should succeed) - - for i in 0..10 { - println!("handshake : {}", i); - - // create initiation - - let msg1 = dev1.begin(&mut rng, &pk2).unwrap(); - - println!("msg1 = {} : {} bytes", hex::encode(&msg1[..]), msg1.len()); - println!("msg1 = {:?}", Initiation::parse(&msg1[..]).unwrap()); - - // process initiation and create response - - let (_, msg2, ks_r) = dev2.process(&mut rng, &msg1, None).unwrap(); - - let ks_r = ks_r.unwrap(); - let msg2 = msg2.unwrap(); - - println!("msg2 = {} : {} bytes", hex::encode(&msg2[..]), msg2.len()); - println!("msg2 = {:?}", Response::parse(&msg2[..]).unwrap()); - - assert!(!ks_r.initiator, "Responders key-pair is confirmed"); - - // process response and obtain confirmed key-pair - - let (_, msg3, ks_i) = dev1.process(&mut rng, &msg2, None).unwrap(); - let ks_i = ks_i.unwrap(); - - assert!(msg3.is_none(), "Returned message after response"); - assert!(ks_i.initiator, "Initiators key-pair is not confirmed"); - - assert_eq!(ks_i.send, ks_r.recv, "KeyI.send != KeyR.recv"); - assert_eq!(ks_i.recv, ks_r.send, "KeyI.recv != KeyR.send"); - - dev1.release(ks_i.send.id); - dev2.release(ks_r.send.id); - - // avoid initation flood detection - wait(); - } - - dev1.remove(pk2).unwrap(); - dev2.remove(pk1).unwrap(); - } -} diff --git a/src/wireguard/handshake/mod.rs b/src/wireguard/handshake/mod.rs index 071a41f..3a95817 100644 --- a/src/wireguard/handshake/mod.rs +++ b/src/wireguard/handshake/mod.rs @@ -15,6 +15,9 @@ mod ratelimiter; mod timestamp; mod types; +#[cfg(test)] +mod tests; + // publicly exposed interface pub use device::Device; diff --git a/src/wireguard/handshake/noise.rs b/src/wireguard/handshake/noise.rs index 6db300a..46188b4 100644 --- a/src/wireguard/handshake/noise.rs +++ b/src/wireguard/handshake/noise.rs @@ -221,7 +221,7 @@ pub fn create_initiation<R: RngCore + CryptoRng>( rng: &mut R, keyst: &KeyState, peer: &Peer, - sender: u32, + local: u32, msg: &mut NoiseInitiation, ) -> Result<(), HandshakeError> { debug!("create initation"); @@ -233,7 +233,7 @@ pub fn create_initiation<R: RngCore + CryptoRng>( let hs = HASH!(&hs, peer.pk.as_bytes()); msg.f_type.set(TYPE_INITIATION as u32); - msg.f_sender.set(sender); + msg.f_sender.set(local); // from us // (E_priv, E_pub) := DH-Generate() @@ -292,7 +292,7 @@ pub fn create_initiation<R: RngCore + CryptoRng>( hs, ck, eph_sk, - sender, + local, }; Ok(()) @@ -378,7 +378,7 @@ pub fn consume_initiation<'a>( pub fn create_response<R: RngCore + CryptoRng>( rng: &mut R, peer: &Peer, - sender: u32, // sending identifier + local: u32, // sending identifier state: TemporaryState, // state from "consume_initiation" msg: &mut NoiseResponse, // resulting response ) -> Result<KeyPair, HandshakeError> { @@ -389,8 +389,8 @@ pub fn create_response<R: RngCore + CryptoRng>( let (receiver, eph_r_pk, hs, ck) = state; msg.f_type.set(TYPE_RESPONSE as u32); - msg.f_sender.set(sender); - msg.f_receiver.set(receiver); + msg.f_sender.set(local); // from us + msg.f_receiver.set(receiver); // to the sender of the initation // (E_priv, E_pub) := DH-Generate() @@ -447,11 +447,11 @@ pub fn create_response<R: RngCore + CryptoRng>( birth: Instant::now(), initiator: false, send: Key { - id: sender, + id: receiver, key: key_send.into(), }, recv: Key { - id: receiver, + id: local, key: key_recv.into(), }, }) @@ -472,13 +472,13 @@ pub fn consume_response( // retrieve peer and copy initiation state let peer = device.lookup_id(msg.f_receiver.get())?; - let (hs, ck, sender, eph_sk) = match *peer.state.lock() { + let (hs, ck, local, eph_sk) = match *peer.state.lock() { State::InitiationSent { hs, ck, - sender, + local, ref eph_sk, - } => Ok((hs, ck, sender, StaticSecret::from(eph_sk.to_bytes()))), + } => Ok((hs, ck, local, StaticSecret::from(eph_sk.to_bytes()))), _ => Err(HandshakeError::InvalidState), }?; @@ -535,6 +535,7 @@ pub fn consume_response( // null the initiation state // (to avoid replay of this response message) *state = State::Reset; + let remote = msg.f_sender.get(); // return confirmed key-pair Ok(( @@ -544,11 +545,11 @@ pub fn consume_response( birth, initiator: true, send: Key { - id: sender, + id: remote, key: key_send.into(), }, recv: Key { - id: msg.f_sender.get(), + id: local, key: key_recv.into(), }, }), diff --git a/src/wireguard/handshake/peer.rs b/src/wireguard/handshake/peer.rs index 2d69244..b7d8740 100644 --- a/src/wireguard/handshake/peer.rs +++ b/src/wireguard/handshake/peer.rs @@ -40,7 +40,7 @@ pub struct Peer { pub enum State { Reset, InitiationSent { - sender: u32, // assigned sender id + local: u32, // local id assigned eph_sk: StaticSecret, hs: GenericArray<u8, U32>, ck: GenericArray<u8, U32>, @@ -83,7 +83,7 @@ impl Peer { pub fn reset_state(&self) -> Option<u32> { match mem::replace(&mut *self.state.lock(), State::Reset) { - State::InitiationSent { sender, .. } => Some(sender), + State::InitiationSent { local, .. } => Some(local), _ => None, } } @@ -125,7 +125,7 @@ impl Peer { // reset state match *state { - State::InitiationSent { sender, .. } => device.release(sender), + State::InitiationSent { local, .. } => device.release(local), _ => (), } diff --git a/src/wireguard/handshake/tests.rs b/src/wireguard/handshake/tests.rs new file mode 100644 index 0000000..6be7b51 --- /dev/null +++ b/src/wireguard/handshake/tests.rs @@ -0,0 +1,197 @@ +use super::*; +use hex; +use rand::rngs::OsRng; +use std::net::SocketAddr; +use std::thread; +use std::time::Duration; + +use rand::prelude::*; + +use x25519_dalek::PublicKey; +use x25519_dalek::StaticSecret; + +use super::messages::{Initiation, Response}; + +fn setup_devices<R: RngCore + CryptoRng>(rng: &mut R) -> (PublicKey, Device, PublicKey, Device) { + // generate new keypairs + + let sk1 = StaticSecret::new(rng); + let pk1 = PublicKey::from(&sk1); + + let sk2 = StaticSecret::new(rng); + let pk2 = PublicKey::from(&sk2); + + // pick random psk + + let mut psk = [0u8; 32]; + rng.fill_bytes(&mut psk[..]); + + // intialize devices on both ends + + let mut dev1 = Device::new(); + let mut dev2 = Device::new(); + + dev1.set_sk(Some(sk1)); + dev2.set_sk(Some(sk2)); + + dev1.add(pk2).unwrap(); + dev2.add(pk1).unwrap(); + + dev1.set_psk(pk2, psk).unwrap(); + dev2.set_psk(pk1, psk).unwrap(); + + (pk1, dev1, pk2, dev2) +} + +fn wait() { + thread::sleep(Duration::from_millis(20)); +} + +/* Test longest possible handshake interaction (7 messages): + * + * 1. I -> R (initation) + * 2. I <- R (cookie reply) + * 3. I -> R (initation) + * 4. I <- R (response) + * 5. I -> R (cookie reply) + * 6. I -> R (initation) + * 7. I <- R (response) + */ +#[test] +fn handshake_under_load() { + let mut rng = OsRng::new().unwrap(); + let (_pk1, dev1, pk2, dev2) = setup_devices(&mut rng); + + let src1: SocketAddr = "172.16.0.1:8080".parse().unwrap(); + let src2: SocketAddr = "172.16.0.2:7070".parse().unwrap(); + + // 1. device-1 : create first initation + let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); + + // 2. device-2 : responds with CookieReply + let msg_cookie = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() { + (None, Some(msg), None) => msg, + _ => panic!("unexpected response"), + }; + + // device-1 : processes CookieReply (no response) + match dev1.process(&mut rng, &msg_cookie, Some(&src2)).unwrap() { + (None, None, None) => (), + _ => panic!("unexpected response"), + } + + // avoid initation flood detection + wait(); + + // 3. device-1 : create second initation + let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); + + // 4. device-2 : responds with noise response + let msg_response = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() { + (Some(_), Some(msg), Some(kp)) => { + assert_eq!(kp.initiator, false); + msg + } + _ => panic!("unexpected response"), + }; + + // 5. device-1 : responds with CookieReply + let msg_cookie = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() { + (None, Some(msg), None) => msg, + _ => panic!("unexpected response"), + }; + + // device-2 : processes CookieReply (no response) + match dev2.process(&mut rng, &msg_cookie, Some(&src1)).unwrap() { + (None, None, None) => (), + _ => panic!("unexpected response"), + } + + // avoid initation flood detection + wait(); + + // 6. device-1 : create third initation + let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); + + // 7. device-2 : responds with noise response + let (msg_response, kp1) = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() { + (Some(_), Some(msg), Some(kp)) => { + assert_eq!(kp.initiator, false); + (msg, kp) + } + _ => panic!("unexpected response"), + }; + + // device-1 : process noise response + let kp2 = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() { + (Some(_), None, Some(kp)) => { + assert_eq!(kp.initiator, true); + kp + } + _ => panic!("unexpected response"), + }; + + assert_eq!(kp1.send, kp2.recv); + assert_eq!(kp1.recv, kp2.send); +} + +#[test] +fn handshake_no_load() { + let mut rng = OsRng::new().unwrap(); + let (pk1, mut dev1, pk2, mut dev2) = setup_devices(&mut rng); + + // do a few handshakes (every handshake should succeed) + + for i in 0..10 { + println!("handshake : {}", i); + + // create initiation + + let msg1 = dev1.begin(&mut rng, &pk2).unwrap(); + + println!("msg1 = {} : {} bytes", hex::encode(&msg1[..]), msg1.len()); + println!( + "msg1 = {:?}", + Initiation::parse(&msg1[..]).expect("failed to parse initiation") + ); + + // process initiation and create response + + let (_, msg2, ks_r) = dev2 + .process(&mut rng, &msg1, None) + .expect("failed to process initiation"); + + let ks_r = ks_r.unwrap(); + let msg2 = msg2.unwrap(); + + println!("msg2 = {} : {} bytes", hex::encode(&msg2[..]), msg2.len()); + println!( + "msg2 = {:?}", + Response::parse(&msg2[..]).expect("failed to parse response") + ); + + assert!(!ks_r.initiator, "Responders key-pair is confirmed"); + + // process response and obtain confirmed key-pair + + let (_, msg3, ks_i) = dev1 + .process(&mut rng, &msg2, None) + .expect("failed to process response"); + let ks_i = ks_i.unwrap(); + + assert!(msg3.is_none(), "Returned message after response"); + assert!(ks_i.initiator, "Initiators key-pair is not confirmed"); + + assert_eq!(ks_i.send, ks_r.recv, "KeyI.send != KeyR.recv"); + assert_eq!(ks_i.recv, ks_r.send, "KeyI.recv != KeyR.send"); + + dev1.release(ks_i.local_id()); + dev2.release(ks_r.local_id()); + + // avoid initation flood detection + wait(); + } + + dev1.remove(pk2).unwrap(); + dev2.remove(pk1).unwrap(); +} |