aboutsummaryrefslogtreecommitdiffstats
path: root/src/wireguard/handshake
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-11-18 12:04:20 +0100
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-11-18 12:04:20 +0100
commitb1fbd7fbbaa92dde20d292307f4f4347e4c01450 (patch)
tree3ad79a99ff36568aa801121fad4b065cb819b1ea /src/wireguard/handshake
parentUpdate configuration API (diff)
downloadwireguard-rs-b1fbd7fbbaa92dde20d292307f4f4347e4c01450.tar.xz
wireguard-rs-b1fbd7fbbaa92dde20d292307f4f4347e4c01450.zip
Bug fixes from compliance tests with WireGuard
Diffstat (limited to 'src/wireguard/handshake')
-rw-r--r--src/wireguard/handshake/device.rs202
-rw-r--r--src/wireguard/handshake/mod.rs3
-rw-r--r--src/wireguard/handshake/noise.rs27
-rw-r--r--src/wireguard/handshake/peer.rs6
-rw-r--r--src/wireguard/handshake/tests.rs197
5 files changed, 224 insertions, 211 deletions
diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs
index 030c0f8..8e16248 100644
--- a/src/wireguard/handshake/device.rs
+++ b/src/wireguard/handshake/device.rs
@@ -231,11 +231,11 @@ impl Device {
(_, None) => Err(HandshakeError::UnknownPublicKey),
(None, _) => Err(HandshakeError::UnknownPublicKey),
(Some(keyst), Some(peer)) => {
- let sender = self.allocate(rng, peer);
+ let local = self.allocate(rng, peer);
let mut msg = Initiation::default();
// create noise part of initation
- noise::create_initiation(rng, keyst, peer, sender, &mut msg.noise)?;
+ noise::create_initiation(rng, keyst, peer, local, &mut msg.noise)?;
// add macs to initation
peer.macs
@@ -312,18 +312,17 @@ impl Device {
let (peer, st) = noise::consume_initiation(self, keyst, &msg.noise)?;
// allocate new index for response
- let sender = self.allocate(rng, peer);
+ let local = self.allocate(rng, peer);
// prepare memory for response, TODO: take slice for zero allocation
let mut resp = Response::default();
// create response (release id on error)
- let keys = noise::create_response(rng, peer, sender, st, &mut resp.noise).map_err(
- |e| {
- self.release(sender);
+ let keys =
+ noise::create_response(rng, peer, local, st, &mut resp.noise).map_err(|e| {
+ self.release(local);
e
- },
- )?;
+ })?;
// add macs to response
peer.macs
@@ -425,190 +424,3 @@ impl Device {
}
}
}
-
-#[cfg(test)]
-mod tests {
- use super::super::messages::*;
- use super::*;
- use hex;
- use rand::rngs::OsRng;
- use std::net::SocketAddr;
- use std::thread;
- use std::time::Duration;
-
- fn setup_devices<R: RngCore + CryptoRng>(
- rng: &mut R,
- ) -> (PublicKey, Device, PublicKey, Device) {
- // generate new keypairs
-
- let sk1 = StaticSecret::new(rng);
- let pk1 = PublicKey::from(&sk1);
-
- let sk2 = StaticSecret::new(rng);
- let pk2 = PublicKey::from(&sk2);
-
- // pick random psk
-
- let mut psk = [0u8; 32];
- rng.fill_bytes(&mut psk[..]);
-
- // intialize devices on both ends
-
- let mut dev1 = Device::new();
- let mut dev2 = Device::new();
-
- dev1.set_sk(Some(sk1));
- dev2.set_sk(Some(sk2));
-
- dev1.add(pk2).unwrap();
- dev2.add(pk1).unwrap();
-
- dev1.set_psk(pk2, psk).unwrap();
- dev2.set_psk(pk1, psk).unwrap();
-
- (pk1, dev1, pk2, dev2)
- }
-
- fn wait() {
- thread::sleep(Duration::from_millis(20));
- }
-
- /* Test longest possible handshake interaction (7 messages):
- *
- * 1. I -> R (initation)
- * 2. I <- R (cookie reply)
- * 3. I -> R (initation)
- * 4. I <- R (response)
- * 5. I -> R (cookie reply)
- * 6. I -> R (initation)
- * 7. I <- R (response)
- */
- #[test]
- fn handshake_under_load() {
- let mut rng = OsRng::new().unwrap();
- let (_pk1, dev1, pk2, dev2) = setup_devices(&mut rng);
-
- let src1: SocketAddr = "172.16.0.1:8080".parse().unwrap();
- let src2: SocketAddr = "172.16.0.2:7070".parse().unwrap();
-
- // 1. device-1 : create first initation
- let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
-
- // 2. device-2 : responds with CookieReply
- let msg_cookie = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
- (None, Some(msg), None) => msg,
- _ => panic!("unexpected response"),
- };
-
- // device-1 : processes CookieReply (no response)
- match dev1.process(&mut rng, &msg_cookie, Some(&src2)).unwrap() {
- (None, None, None) => (),
- _ => panic!("unexpected response"),
- }
-
- // avoid initation flood detection
- wait();
-
- // 3. device-1 : create second initation
- let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
-
- // 4. device-2 : responds with noise response
- let msg_response = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
- (Some(_), Some(msg), Some(kp)) => {
- assert_eq!(kp.initiator, false);
- msg
- }
- _ => panic!("unexpected response"),
- };
-
- // 5. device-1 : responds with CookieReply
- let msg_cookie = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() {
- (None, Some(msg), None) => msg,
- _ => panic!("unexpected response"),
- };
-
- // device-2 : processes CookieReply (no response)
- match dev2.process(&mut rng, &msg_cookie, Some(&src1)).unwrap() {
- (None, None, None) => (),
- _ => panic!("unexpected response"),
- }
-
- // avoid initation flood detection
- wait();
-
- // 6. device-1 : create third initation
- let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
-
- // 7. device-2 : responds with noise response
- let (msg_response, kp1) = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
- (Some(_), Some(msg), Some(kp)) => {
- assert_eq!(kp.initiator, false);
- (msg, kp)
- }
- _ => panic!("unexpected response"),
- };
-
- // device-1 : process noise response
- let kp2 = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() {
- (Some(_), None, Some(kp)) => {
- assert_eq!(kp.initiator, true);
- kp
- }
- _ => panic!("unexpected response"),
- };
-
- assert_eq!(kp1.send, kp2.recv);
- assert_eq!(kp1.recv, kp2.send);
- }
-
- #[test]
- fn handshake_no_load() {
- let mut rng = OsRng::new().unwrap();
- let (pk1, mut dev1, pk2, mut dev2) = setup_devices(&mut rng);
-
- // do a few handshakes (every handshake should succeed)
-
- for i in 0..10 {
- println!("handshake : {}", i);
-
- // create initiation
-
- let msg1 = dev1.begin(&mut rng, &pk2).unwrap();
-
- println!("msg1 = {} : {} bytes", hex::encode(&msg1[..]), msg1.len());
- println!("msg1 = {:?}", Initiation::parse(&msg1[..]).unwrap());
-
- // process initiation and create response
-
- let (_, msg2, ks_r) = dev2.process(&mut rng, &msg1, None).unwrap();
-
- let ks_r = ks_r.unwrap();
- let msg2 = msg2.unwrap();
-
- println!("msg2 = {} : {} bytes", hex::encode(&msg2[..]), msg2.len());
- println!("msg2 = {:?}", Response::parse(&msg2[..]).unwrap());
-
- assert!(!ks_r.initiator, "Responders key-pair is confirmed");
-
- // process response and obtain confirmed key-pair
-
- let (_, msg3, ks_i) = dev1.process(&mut rng, &msg2, None).unwrap();
- let ks_i = ks_i.unwrap();
-
- assert!(msg3.is_none(), "Returned message after response");
- assert!(ks_i.initiator, "Initiators key-pair is not confirmed");
-
- assert_eq!(ks_i.send, ks_r.recv, "KeyI.send != KeyR.recv");
- assert_eq!(ks_i.recv, ks_r.send, "KeyI.recv != KeyR.send");
-
- dev1.release(ks_i.send.id);
- dev2.release(ks_r.send.id);
-
- // avoid initation flood detection
- wait();
- }
-
- dev1.remove(pk2).unwrap();
- dev2.remove(pk1).unwrap();
- }
-}
diff --git a/src/wireguard/handshake/mod.rs b/src/wireguard/handshake/mod.rs
index 071a41f..3a95817 100644
--- a/src/wireguard/handshake/mod.rs
+++ b/src/wireguard/handshake/mod.rs
@@ -15,6 +15,9 @@ mod ratelimiter;
mod timestamp;
mod types;
+#[cfg(test)]
+mod tests;
+
// publicly exposed interface
pub use device::Device;
diff --git a/src/wireguard/handshake/noise.rs b/src/wireguard/handshake/noise.rs
index 6db300a..46188b4 100644
--- a/src/wireguard/handshake/noise.rs
+++ b/src/wireguard/handshake/noise.rs
@@ -221,7 +221,7 @@ pub fn create_initiation<R: RngCore + CryptoRng>(
rng: &mut R,
keyst: &KeyState,
peer: &Peer,
- sender: u32,
+ local: u32,
msg: &mut NoiseInitiation,
) -> Result<(), HandshakeError> {
debug!("create initation");
@@ -233,7 +233,7 @@ pub fn create_initiation<R: RngCore + CryptoRng>(
let hs = HASH!(&hs, peer.pk.as_bytes());
msg.f_type.set(TYPE_INITIATION as u32);
- msg.f_sender.set(sender);
+ msg.f_sender.set(local); // from us
// (E_priv, E_pub) := DH-Generate()
@@ -292,7 +292,7 @@ pub fn create_initiation<R: RngCore + CryptoRng>(
hs,
ck,
eph_sk,
- sender,
+ local,
};
Ok(())
@@ -378,7 +378,7 @@ pub fn consume_initiation<'a>(
pub fn create_response<R: RngCore + CryptoRng>(
rng: &mut R,
peer: &Peer,
- sender: u32, // sending identifier
+ local: u32, // sending identifier
state: TemporaryState, // state from "consume_initiation"
msg: &mut NoiseResponse, // resulting response
) -> Result<KeyPair, HandshakeError> {
@@ -389,8 +389,8 @@ pub fn create_response<R: RngCore + CryptoRng>(
let (receiver, eph_r_pk, hs, ck) = state;
msg.f_type.set(TYPE_RESPONSE as u32);
- msg.f_sender.set(sender);
- msg.f_receiver.set(receiver);
+ msg.f_sender.set(local); // from us
+ msg.f_receiver.set(receiver); // to the sender of the initation
// (E_priv, E_pub) := DH-Generate()
@@ -447,11 +447,11 @@ pub fn create_response<R: RngCore + CryptoRng>(
birth: Instant::now(),
initiator: false,
send: Key {
- id: sender,
+ id: receiver,
key: key_send.into(),
},
recv: Key {
- id: receiver,
+ id: local,
key: key_recv.into(),
},
})
@@ -472,13 +472,13 @@ pub fn consume_response(
// retrieve peer and copy initiation state
let peer = device.lookup_id(msg.f_receiver.get())?;
- let (hs, ck, sender, eph_sk) = match *peer.state.lock() {
+ let (hs, ck, local, eph_sk) = match *peer.state.lock() {
State::InitiationSent {
hs,
ck,
- sender,
+ local,
ref eph_sk,
- } => Ok((hs, ck, sender, StaticSecret::from(eph_sk.to_bytes()))),
+ } => Ok((hs, ck, local, StaticSecret::from(eph_sk.to_bytes()))),
_ => Err(HandshakeError::InvalidState),
}?;
@@ -535,6 +535,7 @@ pub fn consume_response(
// null the initiation state
// (to avoid replay of this response message)
*state = State::Reset;
+ let remote = msg.f_sender.get();
// return confirmed key-pair
Ok((
@@ -544,11 +545,11 @@ pub fn consume_response(
birth,
initiator: true,
send: Key {
- id: sender,
+ id: remote,
key: key_send.into(),
},
recv: Key {
- id: msg.f_sender.get(),
+ id: local,
key: key_recv.into(),
},
}),
diff --git a/src/wireguard/handshake/peer.rs b/src/wireguard/handshake/peer.rs
index 2d69244..b7d8740 100644
--- a/src/wireguard/handshake/peer.rs
+++ b/src/wireguard/handshake/peer.rs
@@ -40,7 +40,7 @@ pub struct Peer {
pub enum State {
Reset,
InitiationSent {
- sender: u32, // assigned sender id
+ local: u32, // local id assigned
eph_sk: StaticSecret,
hs: GenericArray<u8, U32>,
ck: GenericArray<u8, U32>,
@@ -83,7 +83,7 @@ impl Peer {
pub fn reset_state(&self) -> Option<u32> {
match mem::replace(&mut *self.state.lock(), State::Reset) {
- State::InitiationSent { sender, .. } => Some(sender),
+ State::InitiationSent { local, .. } => Some(local),
_ => None,
}
}
@@ -125,7 +125,7 @@ impl Peer {
// reset state
match *state {
- State::InitiationSent { sender, .. } => device.release(sender),
+ State::InitiationSent { local, .. } => device.release(local),
_ => (),
}
diff --git a/src/wireguard/handshake/tests.rs b/src/wireguard/handshake/tests.rs
new file mode 100644
index 0000000..6be7b51
--- /dev/null
+++ b/src/wireguard/handshake/tests.rs
@@ -0,0 +1,197 @@
+use super::*;
+use hex;
+use rand::rngs::OsRng;
+use std::net::SocketAddr;
+use std::thread;
+use std::time::Duration;
+
+use rand::prelude::*;
+
+use x25519_dalek::PublicKey;
+use x25519_dalek::StaticSecret;
+
+use super::messages::{Initiation, Response};
+
+fn setup_devices<R: RngCore + CryptoRng>(rng: &mut R) -> (PublicKey, Device, PublicKey, Device) {
+ // generate new keypairs
+
+ let sk1 = StaticSecret::new(rng);
+ let pk1 = PublicKey::from(&sk1);
+
+ let sk2 = StaticSecret::new(rng);
+ let pk2 = PublicKey::from(&sk2);
+
+ // pick random psk
+
+ let mut psk = [0u8; 32];
+ rng.fill_bytes(&mut psk[..]);
+
+ // intialize devices on both ends
+
+ let mut dev1 = Device::new();
+ let mut dev2 = Device::new();
+
+ dev1.set_sk(Some(sk1));
+ dev2.set_sk(Some(sk2));
+
+ dev1.add(pk2).unwrap();
+ dev2.add(pk1).unwrap();
+
+ dev1.set_psk(pk2, psk).unwrap();
+ dev2.set_psk(pk1, psk).unwrap();
+
+ (pk1, dev1, pk2, dev2)
+}
+
+fn wait() {
+ thread::sleep(Duration::from_millis(20));
+}
+
+/* Test longest possible handshake interaction (7 messages):
+ *
+ * 1. I -> R (initation)
+ * 2. I <- R (cookie reply)
+ * 3. I -> R (initation)
+ * 4. I <- R (response)
+ * 5. I -> R (cookie reply)
+ * 6. I -> R (initation)
+ * 7. I <- R (response)
+ */
+#[test]
+fn handshake_under_load() {
+ let mut rng = OsRng::new().unwrap();
+ let (_pk1, dev1, pk2, dev2) = setup_devices(&mut rng);
+
+ let src1: SocketAddr = "172.16.0.1:8080".parse().unwrap();
+ let src2: SocketAddr = "172.16.0.2:7070".parse().unwrap();
+
+ // 1. device-1 : create first initation
+ let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
+
+ // 2. device-2 : responds with CookieReply
+ let msg_cookie = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
+ (None, Some(msg), None) => msg,
+ _ => panic!("unexpected response"),
+ };
+
+ // device-1 : processes CookieReply (no response)
+ match dev1.process(&mut rng, &msg_cookie, Some(&src2)).unwrap() {
+ (None, None, None) => (),
+ _ => panic!("unexpected response"),
+ }
+
+ // avoid initation flood detection
+ wait();
+
+ // 3. device-1 : create second initation
+ let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
+
+ // 4. device-2 : responds with noise response
+ let msg_response = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
+ (Some(_), Some(msg), Some(kp)) => {
+ assert_eq!(kp.initiator, false);
+ msg
+ }
+ _ => panic!("unexpected response"),
+ };
+
+ // 5. device-1 : responds with CookieReply
+ let msg_cookie = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() {
+ (None, Some(msg), None) => msg,
+ _ => panic!("unexpected response"),
+ };
+
+ // device-2 : processes CookieReply (no response)
+ match dev2.process(&mut rng, &msg_cookie, Some(&src1)).unwrap() {
+ (None, None, None) => (),
+ _ => panic!("unexpected response"),
+ }
+
+ // avoid initation flood detection
+ wait();
+
+ // 6. device-1 : create third initation
+ let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
+
+ // 7. device-2 : responds with noise response
+ let (msg_response, kp1) = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
+ (Some(_), Some(msg), Some(kp)) => {
+ assert_eq!(kp.initiator, false);
+ (msg, kp)
+ }
+ _ => panic!("unexpected response"),
+ };
+
+ // device-1 : process noise response
+ let kp2 = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() {
+ (Some(_), None, Some(kp)) => {
+ assert_eq!(kp.initiator, true);
+ kp
+ }
+ _ => panic!("unexpected response"),
+ };
+
+ assert_eq!(kp1.send, kp2.recv);
+ assert_eq!(kp1.recv, kp2.send);
+}
+
+#[test]
+fn handshake_no_load() {
+ let mut rng = OsRng::new().unwrap();
+ let (pk1, mut dev1, pk2, mut dev2) = setup_devices(&mut rng);
+
+ // do a few handshakes (every handshake should succeed)
+
+ for i in 0..10 {
+ println!("handshake : {}", i);
+
+ // create initiation
+
+ let msg1 = dev1.begin(&mut rng, &pk2).unwrap();
+
+ println!("msg1 = {} : {} bytes", hex::encode(&msg1[..]), msg1.len());
+ println!(
+ "msg1 = {:?}",
+ Initiation::parse(&msg1[..]).expect("failed to parse initiation")
+ );
+
+ // process initiation and create response
+
+ let (_, msg2, ks_r) = dev2
+ .process(&mut rng, &msg1, None)
+ .expect("failed to process initiation");
+
+ let ks_r = ks_r.unwrap();
+ let msg2 = msg2.unwrap();
+
+ println!("msg2 = {} : {} bytes", hex::encode(&msg2[..]), msg2.len());
+ println!(
+ "msg2 = {:?}",
+ Response::parse(&msg2[..]).expect("failed to parse response")
+ );
+
+ assert!(!ks_r.initiator, "Responders key-pair is confirmed");
+
+ // process response and obtain confirmed key-pair
+
+ let (_, msg3, ks_i) = dev1
+ .process(&mut rng, &msg2, None)
+ .expect("failed to process response");
+ let ks_i = ks_i.unwrap();
+
+ assert!(msg3.is_none(), "Returned message after response");
+ assert!(ks_i.initiator, "Initiators key-pair is not confirmed");
+
+ assert_eq!(ks_i.send, ks_r.recv, "KeyI.send != KeyR.recv");
+ assert_eq!(ks_i.recv, ks_r.send, "KeyI.recv != KeyR.send");
+
+ dev1.release(ks_i.local_id());
+ dev2.release(ks_r.local_id());
+
+ // avoid initation flood detection
+ wait();
+ }
+
+ dev1.remove(pk2).unwrap();
+ dev2.remove(pk1).unwrap();
+}