From a08fd4002bfae92072f64f8d5e0084e6f248f139 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sun, 13 Oct 2019 22:26:12 +0200 Subject: Work on Linux platform code --- src/wireguard/handshake/device.rs | 574 +++++++++++++++++++++++++++++++++ src/wireguard/handshake/macs.rs | 327 +++++++++++++++++++ src/wireguard/handshake/messages.rs | 363 +++++++++++++++++++++ src/wireguard/handshake/mod.rs | 21 ++ src/wireguard/handshake/noise.rs | 549 +++++++++++++++++++++++++++++++ src/wireguard/handshake/peer.rs | 142 ++++++++ src/wireguard/handshake/ratelimiter.rs | 199 ++++++++++++ src/wireguard/handshake/timestamp.rs | 32 ++ src/wireguard/handshake/types.rs | 90 ++++++ 9 files changed, 2297 insertions(+) create mode 100644 src/wireguard/handshake/device.rs create mode 100644 src/wireguard/handshake/macs.rs create mode 100644 src/wireguard/handshake/messages.rs create mode 100644 src/wireguard/handshake/mod.rs create mode 100644 src/wireguard/handshake/noise.rs create mode 100644 src/wireguard/handshake/peer.rs create mode 100644 src/wireguard/handshake/ratelimiter.rs create mode 100644 src/wireguard/handshake/timestamp.rs create mode 100644 src/wireguard/handshake/types.rs (limited to 'src/wireguard/handshake') diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs new file mode 100644 index 0000000..6a55f6e --- /dev/null +++ b/src/wireguard/handshake/device.rs @@ -0,0 +1,574 @@ +use spin::RwLock; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::Mutex; +use zerocopy::AsBytes; + +use byteorder::{ByteOrder, LittleEndian}; + +use rand::prelude::*; + +use x25519_dalek::PublicKey; +use x25519_dalek::StaticSecret; + +use super::macs; +use super::messages::{CookieReply, Initiation, Response}; +use super::messages::{TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE}; +use super::noise; +use super::peer::Peer; +use super::ratelimiter::RateLimiter; +use super::types::*; + +const MAX_PEER_PER_DEVICE: usize = 1 << 20; + +pub struct Device { + pub sk: StaticSecret, // static secret key + pub pk: PublicKey, // static public key + macs: macs::Validator, // validator for the mac fields + pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state + id_map: RwLock>, // receiver ids -> public key + limiter: Mutex, +} + +/* A mutable reference to the device needs to be held during configuration. + * Wrapping the device in a RwLock enables peer config after "configuration time" + */ +impl Device { + /// Initialize a new handshake state machine + /// + /// # Arguments + /// + /// * `sk` - x25519 scalar representing the local private key + pub fn new(sk: StaticSecret) -> Device { + let pk = PublicKey::from(&sk); + Device { + pk, + sk, + macs: macs::Validator::new(pk), + pk_map: HashMap::new(), + id_map: RwLock::new(HashMap::new()), + limiter: Mutex::new(RateLimiter::new()), + } + } + + /// Update the secret key of the device + /// + /// # Arguments + /// + /// * `sk` - x25519 scalar representing the local private key + pub fn set_sk(&mut self, sk: StaticSecret) { + // update secret and public key + let pk = PublicKey::from(&sk); + self.sk = sk; + self.pk = pk; + self.macs = macs::Validator::new(pk); + + // recalculate the shared secrets for every peer + let mut ids = vec![]; + for mut peer in self.pk_map.values_mut() { + peer.reset_state().map(|id| ids.push(id)); + peer.ss = self.sk.diffie_hellman(&peer.pk) + } + + // release ids from aborted handshakes + for id in ids { + self.release(id) + } + } + + /// Return the secret key of the device + /// + /// # Returns + /// + /// A secret key (x25519 scalar) + pub fn get_sk(&self) -> StaticSecret { + StaticSecret::from(self.sk.to_bytes()) + } + + /// Add a new public key to the state machine + /// To remove public keys, you must create a new machine instance + /// + /// # Arguments + /// + /// * `pk` - The public key to add + /// * `identifier` - Associated identifier which can be used to distinguish the peers + pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> { + // check that the pk is not added twice + if let Some(_) = self.pk_map.get(pk.as_bytes()) { + return Err(ConfigError::new("Duplicate public key")); + }; + + // check that the pk is not that of the device + if *self.pk.as_bytes() == *pk.as_bytes() { + return Err(ConfigError::new( + "Public key corresponds to secret key of interface", + )); + } + + // ensure less than 2^20 peers + if self.pk_map.len() > MAX_PEER_PER_DEVICE { + return Err(ConfigError::new("Too many peers for device")); + } + + // map the public key to the peer state + self.pk_map + .insert(*pk.as_bytes(), Peer::new(pk, self.sk.diffie_hellman(&pk))); + + Ok(()) + } + + /// Remove a peer by public key + /// To remove public keys, you must create a new machine instance + /// + /// # Arguments + /// + /// * `pk` - The public key of the peer to remove + /// + /// # Returns + /// + /// The call might fail if the public key is not found + pub fn remove(&mut self, pk: PublicKey) -> Result<(), ConfigError> { + // take write-lock on receive id table + let mut id_map = self.id_map.write(); + + // remove the peer + self.pk_map + .remove(pk.as_bytes()) + .ok_or(ConfigError::new("Public key not in device"))?; + + // pruge the id map (linear scan) + id_map.retain(|_, v| v != pk.as_bytes()); + Ok(()) + } + + /// Add a psk to the peer + /// + /// # Arguments + /// + /// * `pk` - The public key of the peer + /// * `psk` - The psk to set / unset + /// + /// # Returns + /// + /// The call might fail if the public key is not found + pub fn set_psk(&mut self, pk: PublicKey, psk: Option) -> Result<(), ConfigError> { + match self.pk_map.get_mut(pk.as_bytes()) { + Some(mut peer) => { + peer.psk = match psk { + Some(v) => v, + None => [0u8; 32], + }; + Ok(()) + } + _ => Err(ConfigError::new("No such public key")), + } + } + + /// Return the psk for the peer + /// + /// # Arguments + /// + /// * `pk` - The public key of the peer + /// + /// # Returns + /// + /// A 32 byte array holding the PSK + /// + /// The call might fail if the public key is not found + pub fn get_psk(&self, pk: PublicKey) -> Result { + match self.pk_map.get(pk.as_bytes()) { + Some(peer) => Ok(peer.psk), + _ => Err(ConfigError::new("No such public key")), + } + } + + /// Release an id back to the pool + /// + /// # Arguments + /// + /// * `id` - The (sender) id to release + pub fn release(&self, id: u32) { + let mut m = self.id_map.write(); + debug_assert!(m.contains_key(&id), "Releasing id not allocated"); + m.remove(&id); + } + + /// Begin a new handshake + /// + /// # Arguments + /// + /// * `pk` - Public key of peer to initiate handshake for + pub fn begin( + &self, + rng: &mut R, + pk: &PublicKey, + ) -> Result, HandshakeError> { + match self.pk_map.get(pk.as_bytes()) { + None => Err(HandshakeError::UnknownPublicKey), + Some(peer) => { + let sender = self.allocate(rng, peer); + + let mut msg = Initiation::default(); + + noise::create_initiation(rng, self, peer, sender, &mut msg.noise)?; + + // add macs to initation + + peer.macs + .lock() + .generate(msg.noise.as_bytes(), &mut msg.macs); + + Ok(msg.as_bytes().to_owned()) + } + } + } + + /// Process a handshake message. + /// + /// # Arguments + /// + /// * `msg` - Byte slice containing the message (untrusted input) + pub fn process<'a, R: RngCore + CryptoRng, S>( + &self, + rng: &mut R, // rng instance to sample randomness from + msg: &[u8], // message buffer + src: Option<&'a S>, // optional source endpoint, set when "under load" + ) -> Result + where + &'a S: Into<&'a SocketAddr>, + { + // ensure type read in-range + if msg.len() < 4 { + return Err(HandshakeError::InvalidMessageFormat); + } + + // de-multiplex the message type field + match LittleEndian::read_u32(msg) { + TYPE_INITIATION => { + // parse message + let msg = Initiation::parse(msg)?; + + // check mac1 field + self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?; + + // address validation & DoS mitigation + if let Some(src) = src { + // obtain ref to socket addr + let src = src.into(); + + // check mac2 field + if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) { + let mut reply = Default::default(); + self.macs.create_cookie_reply( + rng, + msg.noise.f_sender.get(), + src, + &msg.macs, + &mut reply, + ); + return Ok((None, Some(reply.as_bytes().to_owned()), None)); + } + + // check ratelimiter + if !self.limiter.lock().unwrap().allow(&src.ip()) { + return Err(HandshakeError::RateLimited); + } + } + + // consume the initiation + let (peer, st) = noise::consume_initiation(self, &msg.noise)?; + + // allocate new index for response + let sender = self.allocate(rng, peer); + + // prepare memory for response, TODO: take slice for zero allocation + let mut resp = Response::default(); + + // create response (release id on error) + let keys = noise::create_response(rng, peer, sender, st, &mut resp.noise).map_err( + |e| { + self.release(sender); + e + }, + )?; + + // add macs to response + peer.macs + .lock() + .generate(resp.noise.as_bytes(), &mut resp.macs); + + // return unconfirmed keypair and the response as vector + Ok((Some(peer.pk), Some(resp.as_bytes().to_owned()), Some(keys))) + } + TYPE_RESPONSE => { + let msg = Response::parse(msg)?; + + // check mac1 field + self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?; + + // address validation & DoS mitigation + if let Some(src) = src { + // obtain ref to socket addr + let src = src.into(); + + // check mac2 field + if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) { + let mut reply = Default::default(); + self.macs.create_cookie_reply( + rng, + msg.noise.f_sender.get(), + src, + &msg.macs, + &mut reply, + ); + return Ok((None, Some(reply.as_bytes().to_owned()), None)); + } + + // check ratelimiter + if !self.limiter.lock().unwrap().allow(&src.ip()) { + return Err(HandshakeError::RateLimited); + } + } + + // consume inner playload + noise::consume_response(self, &msg.noise) + } + TYPE_COOKIE_REPLY => { + let msg = CookieReply::parse(msg)?; + + // lookup peer + let peer = self.lookup_id(msg.f_receiver.get())?; + + // validate cookie reply + peer.macs.lock().process(&msg)?; + + // this prompts no new message and + // DOES NOT cryptographically verify the peer + Ok((None, None, None)) + } + _ => Err(HandshakeError::InvalidMessageFormat), + } + } + + // Internal function + // + // Return the peer associated with the public key + pub(crate) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer, HandshakeError> { + self.pk_map + .get(pk.as_bytes()) + .ok_or(HandshakeError::UnknownPublicKey) + } + + // Internal function + // + // Return the peer currently associated with the receiver identifier + pub(crate) fn lookup_id(&self, id: u32) -> Result<&Peer, HandshakeError> { + let im = self.id_map.read(); + let pk = im.get(&id).ok_or(HandshakeError::UnknownReceiverId)?; + match self.pk_map.get(pk) { + Some(peer) => Ok(peer), + _ => unreachable!(), // if the id-lookup succeeded, the peer should exist + } + } + + // Internal function + // + // Allocated a new receiver identifier for the peer + fn allocate(&self, rng: &mut R, peer: &Peer) -> u32 { + loop { + let id = rng.gen(); + + // check membership with read lock + if self.id_map.read().contains_key(&id) { + continue; + } + + // take write lock and add index + let mut m = self.id_map.write(); + if !m.contains_key(&id) { + m.insert(id, *peer.pk.as_bytes()); + return id; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::super::messages::*; + use super::*; + use hex; + use rand::rngs::OsRng; + use std::net::SocketAddr; + use std::thread; + use std::time::Duration; + + fn setup_devices( + rng: &mut R, + ) -> (PublicKey, Device, PublicKey, Device) { + // generate new keypairs + + let sk1 = StaticSecret::new(rng); + let pk1 = PublicKey::from(&sk1); + + let sk2 = StaticSecret::new(rng); + let pk2 = PublicKey::from(&sk2); + + // pick random psk + + let mut psk = [0u8; 32]; + rng.fill_bytes(&mut psk[..]); + + // intialize devices on both ends + + let mut dev1 = Device::new(sk1); + let mut dev2 = Device::new(sk2); + + dev1.add(pk2).unwrap(); + dev2.add(pk1).unwrap(); + + dev1.set_psk(pk2, Some(psk)).unwrap(); + dev2.set_psk(pk1, Some(psk)).unwrap(); + + (pk1, dev1, pk2, dev2) + } + + /* Test longest possible handshake interaction (7 messages): + * + * 1. I -> R (initation) + * 2. I <- R (cookie reply) + * 3. I -> R (initation) + * 4. I <- R (response) + * 5. I -> R (cookie reply) + * 6. I -> R (initation) + * 7. I <- R (response) + */ + #[test] + fn handshake_under_load() { + let mut rng = OsRng::new().unwrap(); + let (_pk1, dev1, pk2, dev2) = setup_devices(&mut rng); + + let src1: SocketAddr = "172.16.0.1:8080".parse().unwrap(); + let src2: SocketAddr = "172.16.0.2:7070".parse().unwrap(); + + // 1. device-1 : create first initation + let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); + + // 2. device-2 : responds with CookieReply + let msg_cookie = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() { + (None, Some(msg), None) => msg, + _ => panic!("unexpected response"), + }; + + // device-1 : processes CookieReply (no response) + match dev1.process(&mut rng, &msg_cookie, Some(&src2)).unwrap() { + (None, None, None) => (), + _ => panic!("unexpected response"), + } + + // avoid initation flood + thread::sleep(Duration::from_millis(20)); + + // 3. device-1 : create second initation + let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); + + // 4. device-2 : responds with noise response + let msg_response = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() { + (Some(_), Some(msg), Some(kp)) => { + assert_eq!(kp.initiator, false); + msg + } + _ => panic!("unexpected response"), + }; + + // 5. device-1 : responds with CookieReply + let msg_cookie = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() { + (None, Some(msg), None) => msg, + _ => panic!("unexpected response"), + }; + + // device-2 : processes CookieReply (no response) + match dev2.process(&mut rng, &msg_cookie, Some(&src1)).unwrap() { + (None, None, None) => (), + _ => panic!("unexpected response"), + } + + // avoid initation flood + thread::sleep(Duration::from_millis(20)); + + // 6. device-1 : create third initation + let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); + + // 7. device-2 : responds with noise response + let (msg_response, kp1) = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() { + (Some(_), Some(msg), Some(kp)) => { + assert_eq!(kp.initiator, false); + (msg, kp) + } + _ => panic!("unexpected response"), + }; + + // device-1 : process noise response + let kp2 = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() { + (Some(_), None, Some(kp)) => { + assert_eq!(kp.initiator, true); + kp + } + _ => panic!("unexpected response"), + }; + + assert_eq!(kp1.send, kp2.recv); + assert_eq!(kp1.recv, kp2.send); + } + + #[test] + fn handshake_no_load() { + let mut rng = OsRng::new().unwrap(); + let (pk1, mut dev1, pk2, mut dev2) = setup_devices(&mut rng); + + // do a few handshakes (every handshake should succeed) + + for i in 0..10 { + println!("handshake : {}", i); + + // create initiation + + let msg1 = dev1.begin(&mut rng, &pk2).unwrap(); + + println!("msg1 = {} : {} bytes", hex::encode(&msg1[..]), msg1.len()); + println!("msg1 = {:?}", Initiation::parse(&msg1[..]).unwrap()); + + // process initiation and create response + + let (_, msg2, ks_r) = dev2.process(&mut rng, &msg1, None).unwrap(); + + let ks_r = ks_r.unwrap(); + let msg2 = msg2.unwrap(); + + println!("msg2 = {} : {} bytes", hex::encode(&msg2[..]), msg2.len()); + println!("msg2 = {:?}", Response::parse(&msg2[..]).unwrap()); + + assert!(!ks_r.initiator, "Responders key-pair is confirmed"); + + // process response and obtain confirmed key-pair + + let (_, msg3, ks_i) = dev1.process(&mut rng, &msg2, None).unwrap(); + let ks_i = ks_i.unwrap(); + + assert!(msg3.is_none(), "Returned message after response"); + assert!(ks_i.initiator, "Initiators key-pair is not confirmed"); + + assert_eq!(ks_i.send, ks_r.recv, "KeyI.send != KeyR.recv"); + assert_eq!(ks_i.recv, ks_r.send, "KeyI.recv != KeyR.send"); + + dev1.release(ks_i.send.id); + dev2.release(ks_r.send.id); + + // to avoid flood detection + thread::sleep(Duration::from_millis(20)); + } + + dev1.remove(pk2).unwrap(); + dev2.remove(pk1).unwrap(); + } +} diff --git a/src/wireguard/handshake/macs.rs b/src/wireguard/handshake/macs.rs new file mode 100644 index 0000000..689826b --- /dev/null +++ b/src/wireguard/handshake/macs.rs @@ -0,0 +1,327 @@ +use generic_array::GenericArray; +use rand::{CryptoRng, RngCore}; +use spin::RwLock; +use std::time::{Duration, Instant}; + +// types to coalesce into bytes +use std::net::SocketAddr; +use x25519_dalek::PublicKey; + +// AEAD +use aead::{Aead, NewAead, Payload}; +use chacha20poly1305::XChaCha20Poly1305; + +// MAC +use blake2::Blake2s; +use subtle::ConstantTimeEq; + +use super::messages::{CookieReply, MacsFooter, TYPE_COOKIE_REPLY}; +use super::types::HandshakeError; + +const LABEL_MAC1: &[u8] = b"mac1----"; +const LABEL_COOKIE: &[u8] = b"cookie--"; + +const SIZE_COOKIE: usize = 16; +const SIZE_SECRET: usize = 32; +const SIZE_MAC: usize = 16; // blake2s-mac128 +const SIZE_TAG: usize = 16; // xchacha20poly1305 tag + +const COOKIE_UPDATE_INTERVAL: Duration = Duration::from_secs(120); + +macro_rules! HASH { + ( $($input:expr),* ) => {{ + use blake2::Digest; + let mut hsh = Blake2s::new(); + $( + hsh.input($input); + )* + hsh.result() + }}; +} + +macro_rules! MAC { + ( $key:expr, $($input:expr),* ) => {{ + use blake2::VarBlake2s; + use digest::Input; + use digest::VariableOutput; + let mut tag = [0u8; SIZE_MAC]; + let mut mac = VarBlake2s::new_keyed($key, SIZE_MAC); + $( + mac.input($input); + )* + mac.variable_result(|buf| tag.copy_from_slice(buf)); + tag + }}; +} + +macro_rules! XSEAL { + ($key:expr, $nonce:expr, $ad:expr, $pt:expr, $ct:expr) => {{ + let ct = XChaCha20Poly1305::new(*GenericArray::from_slice($key)) + .encrypt( + GenericArray::from_slice($nonce), + Payload { msg: $pt, aad: $ad }, + ) + .unwrap(); + debug_assert_eq!(ct.len(), $pt.len() + SIZE_TAG); + $ct.copy_from_slice(&ct); + }}; +} + +macro_rules! XOPEN { + ($key:expr, $nonce:expr, $ad:expr, $pt:expr, $ct:expr) => {{ + debug_assert_eq!($ct.len(), $pt.len() + SIZE_TAG); + XChaCha20Poly1305::new(*GenericArray::from_slice($key)) + .decrypt( + GenericArray::from_slice($nonce), + Payload { msg: $ct, aad: $ad }, + ) + .map_err(|_| HandshakeError::DecryptionFailure) + .map(|pt| $pt.copy_from_slice(&pt)) + }}; +} + +struct Cookie { + value: [u8; 16], + birth: Instant, +} + +pub struct Generator { + mac1_key: [u8; 32], + cookie_key: [u8; 32], // xchacha20poly key for opening cookie response + last_mac1: Option<[u8; 16]>, + cookie: Option, +} + +fn addr_to_mac_bytes(addr: &SocketAddr) -> Vec { + match addr { + SocketAddr::V4(addr) => { + let mut res = Vec::with_capacity(4 + 2); + res.extend(&addr.ip().octets()); + res.extend(&addr.port().to_le_bytes()); + res + } + SocketAddr::V6(addr) => { + let mut res = Vec::with_capacity(16 + 2); + res.extend(&addr.ip().octets()); + res.extend(&addr.port().to_le_bytes()); + res + } + } +} + +impl Generator { + /// Initalize a new mac field generator + /// + /// # Arguments + /// + /// - pk: The public key of the peer to which the generator is associated + /// + /// # Returns + /// + /// A freshly initated generator + pub fn new(pk: PublicKey) -> Generator { + Generator { + mac1_key: HASH!(LABEL_MAC1, pk.as_bytes()).into(), + cookie_key: HASH!(LABEL_COOKIE, pk.as_bytes()).into(), + last_mac1: None, + cookie: None, + } + } + + /// Process a CookieReply message + /// + /// # Arguments + /// + /// - reply: CookieReply to process + /// + /// # Returns + /// + /// Can fail if the cookie reply fails to validate + /// (either indicating that it is outdated or malformed) + pub fn process(&mut self, reply: &CookieReply) -> Result<(), HandshakeError> { + let mac1 = self.last_mac1.ok_or(HandshakeError::InvalidState)?; + let mut tau = [0u8; SIZE_COOKIE]; + XOPEN!( + &self.cookie_key, // key + &reply.f_nonce, // nonce + &mac1, // ad + &mut tau, // pt + &reply.f_cookie // ct || tag + )?; + self.cookie = Some(Cookie { + birth: Instant::now(), + value: tau, + }); + Ok(()) + } + + /// Generate both mac fields for an inner message + /// + /// # Arguments + /// + /// - inner: A byteslice representing the inner message to be covered + /// - macs: The destination mac footer for the resulting macs + pub fn generate(&mut self, inner: &[u8], macs: &mut MacsFooter) { + macs.f_mac1 = MAC!(&self.mac1_key, inner); + macs.f_mac2 = match &self.cookie { + Some(cookie) => { + if cookie.birth.elapsed() > COOKIE_UPDATE_INTERVAL { + self.cookie = None; + [0u8; SIZE_MAC] + } else { + MAC!(&cookie.value, inner, macs.f_mac1) + } + } + None => [0u8; SIZE_MAC], + }; + self.last_mac1 = Some(macs.f_mac1); + } +} + +struct Secret { + value: [u8; 32], + birth: Instant, +} + +pub struct Validator { + mac1_key: [u8; 32], // mac1 key, derived from device public key + cookie_key: [u8; 32], // xchacha20poly key for sealing cookie response + secret: RwLock, +} + +impl Validator { + pub fn new(pk: PublicKey) -> Validator { + Validator { + mac1_key: HASH!(LABEL_MAC1, pk.as_bytes()).into(), + cookie_key: HASH!(LABEL_COOKIE, pk.as_bytes()).into(), + secret: RwLock::new(Secret { + value: [0u8; SIZE_SECRET], + birth: Instant::now() - Duration::new(86400, 0), + }), + } + } + + fn get_tau(&self, src: &[u8]) -> Option<[u8; SIZE_COOKIE]> { + let secret = self.secret.read(); + if secret.birth.elapsed() < COOKIE_UPDATE_INTERVAL { + Some(MAC!(&secret.value, src)) + } else { + None + } + } + + fn get_set_tau(&self, rng: &mut R, src: &[u8]) -> [u8; SIZE_COOKIE] { + // check if current value is still valid + { + let secret = self.secret.read(); + if secret.birth.elapsed() < COOKIE_UPDATE_INTERVAL { + return MAC!(&secret.value, src); + }; + } + + // take write lock, check again + { + let mut secret = self.secret.write(); + if secret.birth.elapsed() < COOKIE_UPDATE_INTERVAL { + return MAC!(&secret.value, src); + }; + + // set new random cookie secret + rng.fill_bytes(&mut secret.value); + secret.birth = Instant::now(); + MAC!(&secret.value, src) + } + } + + pub fn create_cookie_reply( + &self, + rng: &mut R, + receiver: u32, // receiver id of incoming message + src: &SocketAddr, // source address of incoming message + macs: &MacsFooter, // footer of incoming message + msg: &mut CookieReply, // resulting cookie reply + ) { + let src = addr_to_mac_bytes(src); + msg.f_type.set(TYPE_COOKIE_REPLY as u32); + msg.f_receiver.set(receiver); + rng.fill_bytes(&mut msg.f_nonce); + XSEAL!( + &self.cookie_key, // key + &msg.f_nonce, // nonce + &macs.f_mac1, // ad + &self.get_set_tau(rng, &src), // pt + &mut msg.f_cookie // ct || tag + ); + } + + /// Check the mac1 field against the inner message + /// + /// # Arguments + /// + /// - inner: The inner message covered by the mac1 field + /// - macs: The mac footer + pub fn check_mac1(&self, inner: &[u8], macs: &MacsFooter) -> Result<(), HandshakeError> { + let valid_mac1: bool = MAC!(&self.mac1_key, inner).ct_eq(&macs.f_mac1).into(); + if !valid_mac1 { + Err(HandshakeError::InvalidMac1) + } else { + Ok(()) + } + } + + pub fn check_mac2(&self, inner: &[u8], src: &SocketAddr, macs: &MacsFooter) -> bool { + let src = addr_to_mac_bytes(src); + match self.get_tau(&src) { + Some(tau) => MAC!(&tau, inner, macs.f_mac1).ct_eq(&macs.f_mac2).into(), + None => false, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use proptest::prelude::*; + use rand::rngs::OsRng; + use x25519_dalek::StaticSecret; + + fn new_validator_generator() -> (Validator, Generator) { + let mut rng = OsRng::new().unwrap(); + let sk = StaticSecret::new(&mut rng); + let pk = PublicKey::from(&sk); + (Validator::new(pk), Generator::new(pk)) + } + + proptest! { + #[test] + fn test_cookie_reply(inner1 : Vec, inner2 : Vec, receiver : u32) { + let mut msg = CookieReply::default(); + let mut rng = OsRng::new().expect("failed to create rng"); + let mut macs = MacsFooter::default(); + let src = "192.0.2.16:8080".parse().unwrap(); + let (validator, mut generator) = new_validator_generator(); + + // generate mac1 for first message + generator.generate(&inner1[..], &mut macs); + assert_ne!(macs.f_mac1, [0u8; SIZE_MAC], "mac1 should be set"); + assert_eq!(macs.f_mac2, [0u8; SIZE_MAC], "mac2 should not be set"); + + // check validity of mac1 + validator.check_mac1(&inner1[..], &macs).expect("mac1 of inner1 did not validate"); + assert_eq!(validator.check_mac2(&inner1[..], &src, &macs), false, "mac2 of inner2 did not validate"); + validator.create_cookie_reply(&mut rng, receiver, &src, &macs, &mut msg); + + // consume cookie reply + generator.process(&msg).expect("failed to process CookieReply"); + + // generate mac2 & mac2 for second message + generator.generate(&inner2[..], &mut macs); + assert_ne!(macs.f_mac1, [0u8; SIZE_MAC], "mac1 should be set"); + assert_ne!(macs.f_mac2, [0u8; SIZE_MAC], "mac2 should be set"); + + // check validity of mac1 and mac2 + validator.check_mac1(&inner2[..], &macs).expect("mac1 of inner2 did not validate"); + assert!(validator.check_mac2(&inner2[..], &src, &macs), "mac2 of inner2 did not validate"); + } + } +} diff --git a/src/wireguard/handshake/messages.rs b/src/wireguard/handshake/messages.rs new file mode 100644 index 0000000..29d80af --- /dev/null +++ b/src/wireguard/handshake/messages.rs @@ -0,0 +1,363 @@ +#[cfg(test)] +use hex; + +#[cfg(test)] +use std::fmt; + +use std::mem; + +use byteorder::LittleEndian; +use zerocopy::byteorder::U32; +use zerocopy::{AsBytes, ByteSlice, FromBytes, LayoutVerified}; + +use super::types::*; + +const SIZE_MAC: usize = 16; +const SIZE_TAG: usize = 16; // poly1305 tag +const SIZE_XNONCE: usize = 24; // xchacha20 nonce +const SIZE_COOKIE: usize = 16; // +const SIZE_X25519_POINT: usize = 32; // x25519 public key +const SIZE_TIMESTAMP: usize = 12; + +pub const TYPE_INITIATION: u32 = 1; +pub const TYPE_RESPONSE: u32 = 2; +pub const TYPE_COOKIE_REPLY: u32 = 3; + +const fn max(a: usize, b: usize) -> usize { + let m: usize = (a > b) as usize; + m * a + (1 - m) * b +} + +pub const MAX_HANDSHAKE_MSG_SIZE: usize = max( + max(mem::size_of::(), mem::size_of::()), + mem::size_of::(), +); + +/* Handshake messsages */ + +#[repr(packed)] +#[derive(Copy, Clone, FromBytes, AsBytes)] +pub struct Response { + pub noise: NoiseResponse, // inner message covered by macs + pub macs: MacsFooter, +} + +#[repr(packed)] +#[derive(Copy, Clone, FromBytes, AsBytes)] +pub struct Initiation { + pub noise: NoiseInitiation, // inner message covered by macs + pub macs: MacsFooter, +} + +#[repr(packed)] +#[derive(Copy, Clone, FromBytes, AsBytes)] +pub struct CookieReply { + pub f_type: U32, + pub f_receiver: U32, + pub f_nonce: [u8; SIZE_XNONCE], + pub f_cookie: [u8; SIZE_COOKIE + SIZE_TAG], +} + +/* Inner sub-messages */ + +#[repr(packed)] +#[derive(Copy, Clone, FromBytes, AsBytes)] +pub struct MacsFooter { + pub f_mac1: [u8; SIZE_MAC], + pub f_mac2: [u8; SIZE_MAC], +} + +#[repr(packed)] +#[derive(Copy, Clone, FromBytes, AsBytes)] +pub struct NoiseInitiation { + pub f_type: U32, + pub f_sender: U32, + pub f_ephemeral: [u8; SIZE_X25519_POINT], + pub f_static: [u8; SIZE_X25519_POINT + SIZE_TAG], + pub f_timestamp: [u8; SIZE_TIMESTAMP + SIZE_TAG], +} + +#[repr(packed)] +#[derive(Copy, Clone, FromBytes, AsBytes)] +pub struct NoiseResponse { + pub f_type: U32, + pub f_sender: U32, + pub f_receiver: U32, + pub f_ephemeral: [u8; SIZE_X25519_POINT], + pub f_empty: [u8; SIZE_TAG], +} + +/* Zero copy parsing of handshake messages */ + +impl Initiation { + pub fn parse(bytes: B) -> Result, HandshakeError> { + let msg: LayoutVerified = + LayoutVerified::new(bytes).ok_or(HandshakeError::InvalidMessageFormat)?; + + if msg.noise.f_type.get() != (TYPE_INITIATION as u32) { + return Err(HandshakeError::InvalidMessageFormat); + } + + Ok(msg) + } +} + +impl Response { + pub fn parse(bytes: B) -> Result, HandshakeError> { + let msg: LayoutVerified = + LayoutVerified::new(bytes).ok_or(HandshakeError::InvalidMessageFormat)?; + + if msg.noise.f_type.get() != (TYPE_RESPONSE as u32) { + return Err(HandshakeError::InvalidMessageFormat); + } + + Ok(msg) + } +} + +impl CookieReply { + pub fn parse(bytes: B) -> Result, HandshakeError> { + let msg: LayoutVerified = + LayoutVerified::new(bytes).ok_or(HandshakeError::InvalidMessageFormat)?; + + if msg.f_type.get() != (TYPE_COOKIE_REPLY as u32) { + return Err(HandshakeError::InvalidMessageFormat); + } + + Ok(msg) + } +} + +/* Default values */ + +impl Default for Response { + fn default() -> Self { + Self { + noise: Default::default(), + macs: Default::default(), + } + } +} + +impl Default for Initiation { + fn default() -> Self { + Self { + noise: Default::default(), + macs: Default::default(), + } + } +} + +impl Default for CookieReply { + fn default() -> Self { + Self { + f_type: >::new(TYPE_COOKIE_REPLY as u32), + f_receiver: >::ZERO, + f_nonce: [0u8; SIZE_XNONCE], + f_cookie: [0u8; SIZE_COOKIE + SIZE_TAG], + } + } +} + +impl Default for MacsFooter { + fn default() -> Self { + Self { + f_mac1: [0u8; SIZE_MAC], + f_mac2: [0u8; SIZE_MAC], + } + } +} + +impl Default for NoiseInitiation { + fn default() -> Self { + Self { + f_type: >::new(TYPE_INITIATION as u32), + f_sender: >::ZERO, + f_ephemeral: [0u8; SIZE_X25519_POINT], + f_static: [0u8; SIZE_X25519_POINT + SIZE_TAG], + f_timestamp: [0u8; SIZE_TIMESTAMP + SIZE_TAG], + } + } +} + +impl Default for NoiseResponse { + fn default() -> Self { + Self { + f_type: >::new(TYPE_RESPONSE as u32), + f_sender: >::ZERO, + f_receiver: >::ZERO, + f_ephemeral: [0u8; SIZE_X25519_POINT], + f_empty: [0u8; SIZE_TAG], + } + } +} + +/* Debug formatting (for testing purposes) */ + +#[cfg(test)] +impl fmt::Debug for Initiation { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Initiation {{ {:?} || {:?} }}", self.noise, self.macs) + } +} + +#[cfg(test)] +impl fmt::Debug for Response { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Response {{ {:?} || {:?} }}", self.noise, self.macs) + } +} + +#[cfg(test)] +impl fmt::Debug for CookieReply { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "CookieReply {{ type = {}, receiver = {}, nonce = {}, cookie = {} }}", + self.f_type, + self.f_receiver, + hex::encode(&self.f_nonce[..]), + hex::encode(&self.f_cookie[..]), + ) + } +} + +#[cfg(test)] +impl fmt::Debug for NoiseInitiation { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, + "NoiseInitiation {{ type = {}, sender = {}, ephemeral = {}, static = {}, timestamp = {} }}", + self.f_type.get(), + self.f_sender.get(), + hex::encode(&self.f_ephemeral[..]), + hex::encode(&self.f_static[..]), + hex::encode(&self.f_timestamp[..]), + ) + } +} + +#[cfg(test)] +impl fmt::Debug for NoiseResponse { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, + "NoiseResponse {{ type = {}, sender = {}, receiver = {}, ephemeral = {}, empty = |{} }}", + self.f_type, + self.f_sender, + self.f_receiver, + hex::encode(&self.f_ephemeral[..]), + hex::encode(&self.f_empty[..]) + ) + } +} + +#[cfg(test)] +impl fmt::Debug for MacsFooter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Macs {{ mac1 = {}, mac2 = {} }}", + hex::encode(&self.f_mac1[..]), + hex::encode(&self.f_mac2[..]) + ) + } +} + +/* Equality (for testing purposes) */ + +#[cfg(test)] +macro_rules! eq_as_bytes { + ($type:path) => { + impl PartialEq for $type { + fn eq(&self, other: &Self) -> bool { + self.as_bytes() == other.as_bytes() + } + } + impl Eq for $type {} + }; +} + +#[cfg(test)] +eq_as_bytes!(Initiation); + +#[cfg(test)] +eq_as_bytes!(Response); + +#[cfg(test)] +eq_as_bytes!(CookieReply); + +#[cfg(test)] +eq_as_bytes!(MacsFooter); + +#[cfg(test)] +eq_as_bytes!(NoiseInitiation); + +#[cfg(test)] +eq_as_bytes!(NoiseResponse); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn message_response_identity() { + let mut msg: Response = Default::default(); + + msg.noise.f_sender.set(146252); + msg.noise.f_receiver.set(554442); + msg.noise.f_ephemeral = [ + 0xc1, 0x66, 0x0a, 0x0c, 0xdc, 0x0f, 0x6c, 0x51, 0x0f, 0xc2, 0xcc, 0x51, 0x52, 0x0c, + 0xde, 0x1e, 0xf7, 0xf1, 0xca, 0x90, 0x86, 0x72, 0xad, 0x67, 0xea, 0x89, 0x45, 0x44, + 0x13, 0x56, 0x52, 0x1f, + ]; + msg.noise.f_empty = [ + 0x60, 0x0e, 0x1e, 0x95, 0x41, 0x6b, 0x52, 0x05, 0xa2, 0x09, 0xe1, 0xbf, 0x40, 0x05, + 0x2f, 0xde, + ]; + msg.macs.f_mac1 = [ + 0xf2, 0xad, 0x40, 0xb5, 0xf7, 0xde, 0x77, 0x35, 0x89, 0x19, 0xb7, 0x5c, 0xf9, 0x54, + 0x69, 0x29, + ]; + msg.macs.f_mac2 = [ + 0x4f, 0xd2, 0x1b, 0xfe, 0x77, 0xe6, 0x2e, 0xc9, 0x07, 0xe2, 0x87, 0x17, 0xbb, 0xe5, + 0xdf, 0xbb, + ]; + + let buf: Vec = msg.as_bytes().to_vec(); + let msg_p = Response::parse(&buf[..]).unwrap(); + assert_eq!(msg, *msg_p.into_ref()); + } + + #[test] + fn message_initiate_identity() { + let mut msg: Initiation = Default::default(); + + msg.noise.f_sender.set(575757); + msg.noise.f_ephemeral = [ + 0xc1, 0x66, 0x0a, 0x0c, 0xdc, 0x0f, 0x6c, 0x51, 0x0f, 0xc2, 0xcc, 0x51, 0x52, 0x0c, + 0xde, 0x1e, 0xf7, 0xf1, 0xca, 0x90, 0x86, 0x72, 0xad, 0x67, 0xea, 0x89, 0x45, 0x44, + 0x13, 0x56, 0x52, 0x1f, + ]; + msg.noise.f_static = [ + 0xdc, 0x33, 0x90, 0x15, 0x8f, 0x82, 0x3e, 0x06, 0x44, 0xa0, 0xde, 0x4c, 0x15, 0x6c, + 0x5d, 0xa4, 0x65, 0x99, 0xf6, 0x6c, 0xa1, 0x14, 0x77, 0xf9, 0xeb, 0x6a, 0xec, 0xc3, + 0x3c, 0xda, 0x47, 0xe1, 0x45, 0xac, 0x8d, 0x43, 0xea, 0x1b, 0x2f, 0x02, 0x45, 0x5d, + 0x86, 0x37, 0xee, 0x83, 0x6b, 0x42, + ]; + msg.noise.f_timestamp = [ + 0x4f, 0x1c, 0x60, 0xec, 0x0e, 0xf6, 0x36, 0xf0, 0x78, 0x28, 0x57, 0x42, 0x60, 0x0e, + 0x1e, 0x95, 0x41, 0x6b, 0x52, 0x05, 0xa2, 0x09, 0xe1, 0xbf, 0x40, 0x05, 0x2f, 0xde, + ]; + msg.macs.f_mac1 = [ + 0xf2, 0xad, 0x40, 0xb5, 0xf7, 0xde, 0x77, 0x35, 0x89, 0x19, 0xb7, 0x5c, 0xf9, 0x54, + 0x69, 0x29, + ]; + msg.macs.f_mac2 = [ + 0x4f, 0xd2, 0x1b, 0xfe, 0x77, 0xe6, 0x2e, 0xc9, 0x07, 0xe2, 0x87, 0x17, 0xbb, 0xe5, + 0xdf, 0xbb, + ]; + + let buf: Vec = msg.as_bytes().to_vec(); + let msg_p = Initiation::parse(&buf[..]).unwrap(); + assert_eq!(msg, *msg_p.into_ref()); + } +} diff --git a/src/wireguard/handshake/mod.rs b/src/wireguard/handshake/mod.rs new file mode 100644 index 0000000..071a41f --- /dev/null +++ b/src/wireguard/handshake/mod.rs @@ -0,0 +1,21 @@ +/* Implementation of the: + * + * Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s + * + * Protocol pattern, see: http://www.noiseprotocol.org/noise.html. + * For documentation. + */ + +mod device; +mod macs; +mod messages; +mod noise; +mod peer; +mod ratelimiter; +mod timestamp; +mod types; + +// publicly exposed interface + +pub use device::Device; +pub use messages::{MAX_HANDSHAKE_MSG_SIZE, TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE}; diff --git a/src/wireguard/handshake/noise.rs b/src/wireguard/handshake/noise.rs new file mode 100644 index 0000000..a2a84b0 --- /dev/null +++ b/src/wireguard/handshake/noise.rs @@ -0,0 +1,549 @@ +// DH +use x25519_dalek::PublicKey; +use x25519_dalek::StaticSecret; + +// HASH & MAC +use blake2::Blake2s; +use hmac::Hmac; + +// AEAD +use aead::{Aead, NewAead, Payload}; +use chacha20poly1305::ChaCha20Poly1305; + +use rand::{CryptoRng, RngCore}; + +use generic_array::typenum::*; +use generic_array::*; + +use clear_on_drop::clear::Clear; +use clear_on_drop::clear_stack_on_return; + +use subtle::ConstantTimeEq; + +use super::device::Device; +use super::messages::{NoiseInitiation, NoiseResponse}; +use super::messages::{TYPE_INITIATION, TYPE_RESPONSE}; +use super::peer::{Peer, State}; +use super::timestamp; +use super::types::*; + +use super::super::types::{KeyPair, Key}; + +use std::time::Instant; + +// HMAC hasher (generic construction) + +type HMACBlake2s = Hmac; + +// convenient alias to pass state temporarily into device.rs and back + +type TemporaryState = (u32, PublicKey, GenericArray, GenericArray); + +const SIZE_CK: usize = 32; +const SIZE_HS: usize = 32; +const SIZE_NONCE: usize = 8; +const SIZE_TAG: usize = 16; + +// number of pages to clear after sensitive call +const CLEAR_PAGES: usize = 1; + +// C := Hash(Construction) +const INITIAL_CK: [u8; SIZE_CK] = [ + 0x60, 0xe2, 0x6d, 0xae, 0xf3, 0x27, 0xef, 0xc0, 0x2e, 0xc3, 0x35, 0xe2, 0xa0, 0x25, 0xd2, 0xd0, + 0x16, 0xeb, 0x42, 0x06, 0xf8, 0x72, 0x77, 0xf5, 0x2d, 0x38, 0xd1, 0x98, 0x8b, 0x78, 0xcd, 0x36, +]; + +// H := Hash(C || Identifier) +const INITIAL_HS: [u8; SIZE_HS] = [ + 0x22, 0x11, 0xb3, 0x61, 0x08, 0x1a, 0xc5, 0x66, 0x69, 0x12, 0x43, 0xdb, 0x45, 0x8a, 0xd5, 0x32, + 0x2d, 0x9c, 0x6c, 0x66, 0x22, 0x93, 0xe8, 0xb7, 0x0e, 0xe1, 0x9c, 0x65, 0xba, 0x07, 0x9e, 0xf3, +]; + +const ZERO_NONCE: [u8; 12] = [0u8; 12]; + +macro_rules! HASH { + ( $($input:expr),* ) => {{ + use blake2::Digest; + let mut hsh = Blake2s::new(); + $( + hsh.input($input); + )* + hsh.result() + }}; +} + +macro_rules! HMAC { + ($key:expr, $($input:expr),*) => {{ + use hmac::Mac; + let mut mac = HMACBlake2s::new_varkey($key).unwrap(); + $( + mac.input($input); + )* + mac.result().code() + }}; +} + +macro_rules! KDF1 { + ($ck:expr, $input:expr) => {{ + let mut t0 = HMAC!($ck, $input); + let t1 = HMAC!(&t0, &[0x1]); + t0.clear(); + t1 + }}; +} + +macro_rules! KDF2 { + ($ck:expr, $input:expr) => {{ + let mut t0 = HMAC!($ck, $input); + let t1 = HMAC!(&t0, &[0x1]); + let t2 = HMAC!(&t0, &t1, &[0x2]); + t0.clear(); + (t1, t2) + }}; +} + +macro_rules! KDF3 { + ($ck:expr, $input:expr) => {{ + let mut t0 = HMAC!($ck, $input); + let t1 = HMAC!(&t0, &[0x1]); + let t2 = HMAC!(&t0, &t1, &[0x2]); + let t3 = HMAC!(&t0, &t2, &[0x3]); + t0.clear(); + (t1, t2, t3) + }}; +} + +macro_rules! SEAL { + ($key:expr, $ad:expr, $pt:expr, $ct:expr) => { + ChaCha20Poly1305::new(*GenericArray::from_slice($key)) + .encrypt(&ZERO_NONCE.into(), Payload { msg: $pt, aad: $ad }) + .map(|ct| $ct.copy_from_slice(&ct)) + .unwrap() + }; +} + +macro_rules! OPEN { + ($key:expr, $ad:expr, $pt:expr, $ct:expr) => { + ChaCha20Poly1305::new(*GenericArray::from_slice($key)) + .decrypt(&ZERO_NONCE.into(), Payload { msg: $ct, aad: $ad }) + .map_err(|_| HandshakeError::DecryptionFailure) + .map(|pt| $pt.copy_from_slice(&pt)) + }; +} + +#[cfg(test)] +mod tests { + use super::*; + + const IDENTIFIER: &[u8] = b"WireGuard v1 zx2c4 Jason@zx2c4.com"; + const CONSTRUCTION: &[u8] = b"Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"; + + /* Sanity check precomputed initial chain key + */ + #[test] + fn precomputed_chain_key() { + assert_eq!(INITIAL_CK[..], HASH!(CONSTRUCTION)[..]); + } + + /* Sanity check precomputed initial hash transcript + */ + #[test] + fn precomputed_hash() { + assert_eq!(INITIAL_HS[..], HASH!(INITIAL_CK, IDENTIFIER)[..]); + } + + /* Sanity check the HKDF macro + * + * Test vectors generated using WireGuard-Go + */ + #[test] + fn hkdf() { + let tests: Vec<(Vec, Vec, [u8; 32], [u8; 32], [u8; 32])> = vec![ + ( + vec![], + vec![], + [ + 0x83, 0x87, 0xb4, 0x6b, 0xf4, 0x3e, 0xcc, 0xfc, 0xf3, 0x49, 0x55, 0x2a, 0x09, + 0x5d, 0x83, 0x15, 0xc4, 0x05, 0x5b, 0xeb, 0x90, 0x20, 0x8f, 0xb1, 0xbe, 0x23, + 0xb8, 0x94, 0xbc, 0x2e, 0xd5, 0xd0, + ], + [ + 0x58, 0xa0, 0xe5, 0xf6, 0xfa, 0xef, 0xcc, 0xf4, 0x80, 0x7b, 0xff, 0x1f, 0x05, + 0xfa, 0x8a, 0x92, 0x17, 0x94, 0x57, 0x62, 0x04, 0x0b, 0xce, 0xc2, 0xf4, 0xb4, + 0xa6, 0x2b, 0xdf, 0xe0, 0xe8, 0x6e, + ], + [ + 0x0c, 0xe6, 0xea, 0x98, 0xec, 0x54, 0x8f, 0x8e, 0x28, 0x1e, 0x93, 0xe3, 0x2d, + 0xb6, 0x56, 0x21, 0xc4, 0x5e, 0xb1, 0x8d, 0xc6, 0xf0, 0xa7, 0xad, 0x94, 0x17, + 0x86, 0x10, 0xa2, 0xf7, 0x33, 0x8e, + ], + ), + ( + vec![0xde, 0xad, 0xbe, 0xef], + vec![], + [ + 0x55, 0x32, 0x9d, 0xc8, 0x0e, 0x69, 0x0f, 0xd8, 0x6b, 0xd9, 0x66, 0x1f, 0x08, + 0x51, 0xc9, 0xb3, 0x68, 0x6d, 0xf2, 0xb1, 0xfd, 0xa0, 0x34, 0x7b, 0xc3, 0xd2, + 0x79, 0x58, 0x25, 0x4b, 0x32, 0xc6, + ], + [ + 0x8d, 0xfc, 0x6d, 0x33, 0xa8, 0x11, 0x8f, 0xfe, 0x40, 0x8b, 0x31, 0xdd, 0xac, + 0x25, 0xf7, 0x2a, 0xee, 0x91, 0x15, 0xa4, 0x5b, 0x69, 0xba, 0x17, 0x6a, 0xd0, + 0x12, 0xb2, 0x43, 0x83, 0x4f, 0xee, + ], + [ + 0xd6, 0x9e, 0x85, 0x2a, 0x28, 0x96, 0x56, 0x9e, 0xa5, 0x4a, 0x67, 0x96, 0x9a, + 0xa1, 0x80, 0x02, 0x87, 0x92, 0x1d, 0xac, 0x53, 0xce, 0x6d, 0xb4, 0xb4, 0xe1, + 0x21, 0x92, 0xf2, 0x63, 0xc4, 0xc4, + ], + ), + ]; + + for (key, input, t0, t1, t2) in &tests { + let tt0 = KDF1!(key, input); + debug_assert_eq!(tt0[..], t0[..]); + + let (tt0, tt1) = KDF2!(key, input); + debug_assert_eq!(tt0[..], t0[..]); + debug_assert_eq!(tt1[..], t1[..]); + + let (tt0, tt1, tt2) = KDF3!(key, input); + debug_assert_eq!(tt0[..], t0[..]); + debug_assert_eq!(tt1[..], t1[..]); + debug_assert_eq!(tt2[..], t2[..]); + } + } +} + +pub fn create_initiation( + rng: &mut R, + device: &Device, + peer: &Peer, + sender: u32, + msg: &mut NoiseInitiation, +) -> Result<(), HandshakeError> { + clear_stack_on_return(CLEAR_PAGES, || { + // initialize state + + let ck = INITIAL_CK; + let hs = INITIAL_HS; + let hs = HASH!(&hs, peer.pk.as_bytes()); + + msg.f_type.set(TYPE_INITIATION as u32); + msg.f_sender.set(sender); + + // (E_priv, E_pub) := DH-Generate() + + let eph_sk = StaticSecret::new(rng); + let eph_pk = PublicKey::from(&eph_sk); + + // C := Kdf(C, E_pub) + + let ck = KDF1!(&ck, eph_pk.as_bytes()); + + // msg.ephemeral := E_pub + + msg.f_ephemeral = *eph_pk.as_bytes(); + + // H := HASH(H, msg.ephemeral) + + let hs = HASH!(&hs, msg.f_ephemeral); + + // (C, k) := Kdf2(C, DH(E_priv, S_pub)) + + let (ck, key) = KDF2!(&ck, eph_sk.diffie_hellman(&peer.pk).as_bytes()); + + // msg.static := Aead(k, 0, S_pub, H) + + SEAL!( + &key, + &hs, // ad + device.pk.as_bytes(), // pt + &mut msg.f_static // ct || tag + ); + + // H := Hash(H || msg.static) + + let hs = HASH!(&hs, &msg.f_static[..]); + + // (C, k) := Kdf2(C, DH(S_priv, S_pub)) + + let (ck, key) = KDF2!(&ck, peer.ss.as_bytes()); + + // msg.timestamp := Aead(k, 0, Timestamp(), H) + + SEAL!( + &key, + &hs, // ad + ×tamp::now(), // pt + &mut msg.f_timestamp // ct || tag + ); + + // H := Hash(H || msg.timestamp) + + let hs = HASH!(&hs, &msg.f_timestamp); + + // update state of peer + + *peer.state.lock() = State::InitiationSent { + hs, + ck, + eph_sk, + sender, + }; + + Ok(()) + }) +} + +pub fn consume_initiation<'a>( + device: &'a Device, + msg: &NoiseInitiation, +) -> Result<(&'a Peer, TemporaryState), HandshakeError> { + clear_stack_on_return(CLEAR_PAGES, || { + // initialize new state + + let ck = INITIAL_CK; + let hs = INITIAL_HS; + let hs = HASH!(&hs, device.pk.as_bytes()); + + // C := Kdf(C, E_pub) + + let ck = KDF1!(&ck, &msg.f_ephemeral); + + // H := HASH(H, msg.ephemeral) + + let hs = HASH!(&hs, &msg.f_ephemeral); + + // (C, k) := Kdf2(C, DH(E_priv, S_pub)) + + let eph_r_pk = PublicKey::from(msg.f_ephemeral); + let (ck, key) = KDF2!(&ck, device.sk.diffie_hellman(&eph_r_pk).as_bytes()); + + // msg.static := Aead(k, 0, S_pub, H) + + let mut pk = [0u8; 32]; + + OPEN!( + &key, + &hs, // ad + &mut pk, // pt + &msg.f_static // ct || tag + )?; + + let peer = device.lookup_pk(&PublicKey::from(pk))?; + + // reset initiation state + + *peer.state.lock() = State::Reset; + + // H := Hash(H || msg.static) + + let hs = HASH!(&hs, &msg.f_static[..]); + + // (C, k) := Kdf2(C, DH(S_priv, S_pub)) + + let (ck, key) = KDF2!(&ck, peer.ss.as_bytes()); + + // msg.timestamp := Aead(k, 0, Timestamp(), H) + + let mut ts = timestamp::ZERO; + + OPEN!( + &key, + &hs, // ad + &mut ts, // pt + &msg.f_timestamp // ct || tag + )?; + + // check and update timestamp + + peer.check_replay_flood(device, &ts)?; + + // H := Hash(H || msg.timestamp) + + let hs = HASH!(&hs, &msg.f_timestamp); + + // return state (to create response) + + Ok((peer, (msg.f_sender.get(), eph_r_pk, hs, ck))) + }) +} + +pub fn create_response( + rng: &mut R, + peer: &Peer, + sender: u32, // sending identifier + state: TemporaryState, // state from "consume_initiation" + msg: &mut NoiseResponse, // resulting response +) -> Result { + clear_stack_on_return(CLEAR_PAGES, || { + // unpack state + + let (receiver, eph_r_pk, hs, ck) = state; + + msg.f_type.set(TYPE_RESPONSE as u32); + msg.f_sender.set(sender); + msg.f_receiver.set(receiver); + + // (E_priv, E_pub) := DH-Generate() + + let eph_sk = StaticSecret::new(rng); + let eph_pk = PublicKey::from(&eph_sk); + + // C := Kdf1(C, E_pub) + + let ck = KDF1!(&ck, eph_pk.as_bytes()); + + // msg.ephemeral := E_pub + + msg.f_ephemeral = *eph_pk.as_bytes(); + + // H := Hash(H || msg.ephemeral) + + let hs = HASH!(&hs, &msg.f_ephemeral); + + // C := Kdf1(C, DH(E_priv, E_pub)) + + let ck = KDF1!(&ck, eph_sk.diffie_hellman(&eph_r_pk).as_bytes()); + + // C := Kdf1(C, DH(E_priv, S_pub)) + + let ck = KDF1!(&ck, eph_sk.diffie_hellman(&peer.pk).as_bytes()); + + // (C, tau, k) := Kdf3(C, Q) + + let (ck, tau, key) = KDF3!(&ck, &peer.psk); + + // H := Hash(H || tau) + + let hs = HASH!(&hs, tau); + + // msg.empty := Aead(k, 0, [], H) + + SEAL!( + &key, + &hs, // ad + &[], // pt + &mut msg.f_empty // \epsilon || tag + ); + + // Not strictly needed + // let hs = HASH!(&hs, &msg.f_empty_tag); + + // derive key-pair + + let (key_recv, key_send) = KDF2!(&ck, &[]); + + // return unconfirmed key-pair + + Ok(KeyPair { + birth: Instant::now(), + initiator: false, + send: Key { + id: sender, + key: key_send.into(), + }, + recv: Key { + id: receiver, + key: key_recv.into(), + }, + }) + }) +} + +/* The state lock is released while processing the message to + * allow concurrent processing of potential responses to the initiation, + * in order to better mitigate DoS from malformed response messages. + */ +pub fn consume_response(device: &Device, msg: &NoiseResponse) -> Result { + clear_stack_on_return(CLEAR_PAGES, || { + // retrieve peer and copy initiation state + let peer = device.lookup_id(msg.f_receiver.get())?; + + let (hs, ck, sender, eph_sk) = match *peer.state.lock() { + State::InitiationSent { + hs, + ck, + sender, + ref eph_sk, + } => Ok((hs, ck, sender, StaticSecret::from(eph_sk.to_bytes()))), + _ => Err(HandshakeError::InvalidState), + }?; + + // C := Kdf1(C, E_pub) + + let ck = KDF1!(&ck, &msg.f_ephemeral); + + // H := Hash(H || msg.ephemeral) + + let hs = HASH!(&hs, &msg.f_ephemeral); + + // C := Kdf1(C, DH(E_priv, E_pub)) + + let eph_r_pk = PublicKey::from(msg.f_ephemeral); + let ck = KDF1!(&ck, eph_sk.diffie_hellman(&eph_r_pk).as_bytes()); + + // C := Kdf1(C, DH(E_priv, S_pub)) + + let ck = KDF1!(&ck, device.sk.diffie_hellman(&eph_r_pk).as_bytes()); + + // (C, tau, k) := Kdf3(C, Q) + + let (ck, tau, key) = KDF3!(&ck, &peer.psk); + + // H := Hash(H || tau) + + let hs = HASH!(&hs, tau); + + // msg.empty := Aead(k, 0, [], H) + + OPEN!( + &key, + &hs, // ad + &mut [], // pt + &msg.f_empty // \epsilon || tag + )?; + + // derive key-pair + + let birth = Instant::now(); + let (key_send, key_recv) = KDF2!(&ck, &[]); + + // check for new initiation sent while lock released + + let mut state = peer.state.lock(); + let update = match *state { + State::InitiationSent { + eph_sk: ref old, .. + } => old.to_bytes().ct_eq(&eph_sk.to_bytes()).into(), + _ => false, + }; + + if update { + // null the initiation state + // (to avoid replay of this response message) + *state = State::Reset; + + // return confirmed key-pair + Ok(( + Some(peer.pk), + None, + Some(KeyPair { + birth, + initiator: true, + send: Key { + id: sender, + key: key_send.into(), + }, + recv: Key { + id: msg.f_sender.get(), + key: key_recv.into(), + }, + }), + )) + } else { + Err(HandshakeError::InvalidState) + } + }) +} diff --git a/src/wireguard/handshake/peer.rs b/src/wireguard/handshake/peer.rs new file mode 100644 index 0000000..c9e1c40 --- /dev/null +++ b/src/wireguard/handshake/peer.rs @@ -0,0 +1,142 @@ +use spin::Mutex; + +use std::mem; +use std::time::{Duration, Instant}; + +use generic_array::typenum::U32; +use generic_array::GenericArray; + +use x25519_dalek::PublicKey; +use x25519_dalek::SharedSecret; +use x25519_dalek::StaticSecret; + +use clear_on_drop::clear::Clear; + +use super::device::Device; +use super::macs; +use super::timestamp; +use super::types::*; + +const TIME_BETWEEN_INITIATIONS: Duration = Duration::from_millis(20); + +/* Represents the recomputation and state of a peer. + * + * This type is only for internal use and not exposed. + */ +pub struct Peer { + // mutable state + pub(crate) state: Mutex, + pub(crate) timestamp: Mutex>, + pub(crate) last_initiation_consumption: Mutex>, + + // state related to DoS mitigation fields + pub(crate) macs: Mutex, + + // constant state + pub(crate) pk: PublicKey, // public key of peer + pub(crate) ss: SharedSecret, // precomputed DH(static, static) + pub(crate) psk: Psk, // psk of peer +} + +pub enum State { + Reset, + InitiationSent { + sender: u32, // assigned sender id + eph_sk: StaticSecret, + hs: GenericArray, + ck: GenericArray, + }, +} + +impl Drop for State { + fn drop(&mut self) { + match self { + State::InitiationSent { hs, ck, .. } => { + // eph_sk already cleared by dalek-x25519 + hs.clear(); + ck.clear(); + } + _ => (), + } + } +} + +impl Peer { + pub fn new( + pk: PublicKey, // public key of peer + ss: SharedSecret, // precomputed DH(static, static) + ) -> Self { + Self { + macs: Mutex::new(macs::Generator::new(pk)), + state: Mutex::new(State::Reset), + timestamp: Mutex::new(None), + last_initiation_consumption: Mutex::new(None), + pk: pk, + ss: ss, + psk: [0u8; 32], + } + } + + /// Set the state of the peer unconditionally + /// + /// # Arguments + /// + pub fn set_state(&self, state_new: State) { + *self.state.lock() = state_new; + } + + pub fn reset_state(&self) -> Option { + match mem::replace(&mut *self.state.lock(), State::Reset) { + State::InitiationSent { sender, .. } => Some(sender), + _ => None, + } + } + + /// Set the mutable state of the peer conditioned on the timestamp being newer + /// + /// # Arguments + /// + /// * st_new - The updated state of the peer + /// * ts_new - The associated timestamp + pub fn check_replay_flood( + &self, + device: &Device, + timestamp_new: ×tamp::TAI64N, + ) -> Result<(), HandshakeError> { + let mut state = self.state.lock(); + let mut timestamp = self.timestamp.lock(); + let mut last_initiation_consumption = self.last_initiation_consumption.lock(); + + // check replay attack + match *timestamp { + Some(timestamp_old) => { + if !timestamp::compare(×tamp_old, ×tamp_new) { + return Err(HandshakeError::OldTimestamp); + } + } + _ => (), + }; + + // check flood attack + match *last_initiation_consumption { + Some(last) => { + if last.elapsed() < TIME_BETWEEN_INITIATIONS { + return Err(HandshakeError::InitiationFlood); + } + } + _ => (), + } + + // reset state + match *state { + State::InitiationSent { sender, .. } => device.release(sender), + _ => (), + } + + // update replay & flood protection + *state = State::Reset; + *timestamp = Some(*timestamp_new); + *last_initiation_consumption = Some(Instant::now()); + Ok(()) + } +} diff --git a/src/wireguard/handshake/ratelimiter.rs b/src/wireguard/handshake/ratelimiter.rs new file mode 100644 index 0000000..63d728c --- /dev/null +++ b/src/wireguard/handshake/ratelimiter.rs @@ -0,0 +1,199 @@ +use spin; +use std::collections::HashMap; +use std::net::IpAddr; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Condvar, Mutex}; +use std::thread; +use std::time::{Duration, Instant}; + +const PACKETS_PER_SECOND: u64 = 20; +const PACKETS_BURSTABLE: u64 = 5; +const PACKET_COST: u64 = 1_000_000_000 / PACKETS_PER_SECOND; +const MAX_TOKENS: u64 = PACKET_COST * PACKETS_BURSTABLE; + +const GC_INTERVAL: Duration = Duration::from_secs(1); + +struct Entry { + pub last_time: Instant, + pub tokens: u64, +} + +pub struct RateLimiter(Arc); + +struct RateLimiterInner { + gc_running: AtomicBool, + gc_dropped: (Mutex, Condvar), + table: spin::RwLock>>, +} + +impl Drop for RateLimiter { + fn drop(&mut self) { + // wake up & terminate any lingering GC thread + let &(ref lock, ref cvar) = &self.0.gc_dropped; + let mut dropped = lock.lock().unwrap(); + *dropped = true; + cvar.notify_all(); + } +} + +impl RateLimiter { + pub fn new() -> Self { + RateLimiter(Arc::new(RateLimiterInner { + gc_dropped: (Mutex::new(false), Condvar::new()), + gc_running: AtomicBool::from(false), + table: spin::RwLock::new(HashMap::new()), + })) + } + + pub fn allow(&self, addr: &IpAddr) -> bool { + // check if allowed + let allowed = { + // check for existing entry (only requires read lock) + if let Some(entry) = self.0.table.read().get(addr) { + // update existing entry + let mut entry = entry.lock(); + + // add tokens earned since last time + entry.tokens = MAX_TOKENS + .min(entry.tokens + u64::from(entry.last_time.elapsed().subsec_nanos())); + entry.last_time = Instant::now(); + + // subtract cost of packet + if entry.tokens > PACKET_COST { + entry.tokens -= PACKET_COST; + return true; + } else { + return false; + } + } + + // add new entry (write lock) + self.0.table.write().insert( + *addr, + spin::Mutex::new(Entry { + last_time: Instant::now(), + tokens: MAX_TOKENS - PACKET_COST, + }), + ); + true + }; + + // check that GC thread is scheduled + if !self.0.gc_running.swap(true, Ordering::Relaxed) { + let limiter = self.0.clone(); + thread::spawn(move || { + let &(ref lock, ref cvar) = &limiter.gc_dropped; + let mut dropped = lock.lock().unwrap(); + while !*dropped { + // garbage collect + { + let mut tw = limiter.table.write(); + tw.retain(|_, ref mut entry| { + entry.lock().last_time.elapsed() <= GC_INTERVAL + }); + if tw.len() == 0 { + limiter.gc_running.store(false, Ordering::Relaxed); + return; + } + } + + // wait until stopped or new GC (~1 every sec) + let res = cvar.wait_timeout(dropped, GC_INTERVAL).unwrap(); + dropped = res.0; + } + }); + } + + allowed + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std; + + struct Result { + allowed: bool, + text: &'static str, + wait: Duration, + } + + #[test] + fn test_ratelimiter() { + let ratelimiter = RateLimiter::new(); + let mut expected = vec![]; + let ips = vec![ + "127.0.0.1".parse().unwrap(), + "192.168.1.1".parse().unwrap(), + "172.167.2.3".parse().unwrap(), + "97.231.252.215".parse().unwrap(), + "248.97.91.167".parse().unwrap(), + "188.208.233.47".parse().unwrap(), + "104.2.183.179".parse().unwrap(), + "72.129.46.120".parse().unwrap(), + "2001:0db8:0a0b:12f0:0000:0000:0000:0001".parse().unwrap(), + "f5c2:818f:c052:655a:9860:b136:6894:25f0".parse().unwrap(), + "b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc".parse().unwrap(), + "a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918".parse().unwrap(), + "ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445".parse().unwrap(), + "3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4".parse().unwrap(), + ]; + + for _ in 0..PACKETS_BURSTABLE { + expected.push(Result { + allowed: true, + wait: Duration::new(0, 0), + text: "inital burst", + }); + } + + expected.push(Result { + allowed: false, + wait: Duration::new(0, 0), + text: "after burst", + }); + + expected.push(Result { + allowed: true, + wait: Duration::new(0, PACKET_COST as u32), + text: "filling tokens for single packet", + }); + + expected.push(Result { + allowed: false, + wait: Duration::new(0, 0), + text: "not having refilled enough", + }); + + expected.push(Result { + allowed: true, + wait: Duration::new(0, 2 * PACKET_COST as u32), + text: "filling tokens for 2 * packet burst", + }); + + expected.push(Result { + allowed: true, + wait: Duration::new(0, 0), + text: "second packet in 2 packet burst", + }); + + expected.push(Result { + allowed: false, + wait: Duration::new(0, 0), + text: "packet following 2 packet burst", + }); + + for item in expected { + std::thread::sleep(item.wait); + for ip in ips.iter() { + if ratelimiter.allow(&ip) != item.allowed { + panic!( + "test failed for {} on {}. expected: {}, got: {}", + ip, item.text, item.allowed, !item.allowed + ) + } + } + } + } +} diff --git a/src/wireguard/handshake/timestamp.rs b/src/wireguard/handshake/timestamp.rs new file mode 100644 index 0000000..b5bd9f0 --- /dev/null +++ b/src/wireguard/handshake/timestamp.rs @@ -0,0 +1,32 @@ +use std::time::{SystemTime, UNIX_EPOCH}; + +pub type TAI64N = [u8; 12]; + +const TAI64_EPOCH: u64 = 0x400000000000000a; + +pub const ZERO: TAI64N = [0u8; 12]; + +pub fn now() -> TAI64N { + // get system time as duration + let sysnow = SystemTime::now(); + let delta = sysnow.duration_since(UNIX_EPOCH).unwrap(); + + // convert to tai64n + let tai64_secs = delta.as_secs() + TAI64_EPOCH; + let tai64_nano = delta.subsec_nanos(); + + // serialize + let mut res = [0u8; 12]; + res[..8].copy_from_slice(&tai64_secs.to_be_bytes()[..]); + res[8..].copy_from_slice(&tai64_nano.to_be_bytes()[..]); + res +} + +pub fn compare(old: &TAI64N, new: &TAI64N) -> bool { + for i in 0..12 { + if new[i] > old[i] { + return true; + } + } + return false; +} diff --git a/src/wireguard/handshake/types.rs b/src/wireguard/handshake/types.rs new file mode 100644 index 0000000..5f984cc --- /dev/null +++ b/src/wireguard/handshake/types.rs @@ -0,0 +1,90 @@ +use std::error::Error; +use std::fmt; + +use x25519_dalek::PublicKey; + +use super::super::types::KeyPair; + +/* Internal types for the noise IKpsk2 implementation */ + +// config error + +#[derive(Debug)] +pub struct ConfigError(String); + +impl ConfigError { + pub fn new(s: &str) -> Self { + ConfigError(s.to_string()) + } +} + +impl fmt::Display for ConfigError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "ConfigError({})", self.0) + } +} + +impl Error for ConfigError { + fn description(&self) -> &str { + &self.0 + } + + fn source(&self) -> Option<&(dyn Error + 'static)> { + None + } +} + +// handshake error + +#[derive(Debug)] +pub enum HandshakeError { + DecryptionFailure, + UnknownPublicKey, + UnknownReceiverId, + InvalidMessageFormat, + OldTimestamp, + InvalidState, + InvalidMac1, + RateLimited, + InitiationFlood, +} + +impl fmt::Display for HandshakeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + HandshakeError::DecryptionFailure => write!(f, "Failed to AEAD:OPEN"), + HandshakeError::UnknownPublicKey => write!(f, "Unknown public key"), + HandshakeError::UnknownReceiverId => { + write!(f, "Receiver id not allocated to any handshake") + } + HandshakeError::InvalidMessageFormat => write!(f, "Invalid handshake message format"), + HandshakeError::OldTimestamp => write!(f, "Timestamp is less/equal to the newest"), + HandshakeError::InvalidState => write!(f, "Message does not apply to handshake state"), + HandshakeError::InvalidMac1 => write!(f, "Message has invalid mac1 field"), + HandshakeError::RateLimited => write!(f, "Message was dropped by rate limiter"), + HandshakeError::InitiationFlood => { + write!(f, "Message was dropped because of initiation flood") + } + } + } +} + +impl Error for HandshakeError { + fn description(&self) -> &str { + "Generic Handshake Error" + } + + fn source(&self) -> Option<&(dyn Error + 'static)> { + None + } +} + +pub type Output = ( + Option, // external identifier associated with peer + Option>, // message to send + Option, // resulting key-pair of successful handshake +); + +// preshared key + +pub type Psk = [u8; 32]; -- cgit v1.2.3-59-g8ed1b