From a08fd4002bfae92072f64f8d5e0084e6f248f139 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sun, 13 Oct 2019 22:26:12 +0200 Subject: Work on Linux platform code --- src/wireguard/config.rs | 186 ++++++++++ src/wireguard/constants.rs | 20 ++ src/wireguard/handshake/device.rs | 574 +++++++++++++++++++++++++++++++ src/wireguard/handshake/macs.rs | 327 ++++++++++++++++++ src/wireguard/handshake/messages.rs | 363 ++++++++++++++++++++ src/wireguard/handshake/mod.rs | 21 ++ src/wireguard/handshake/noise.rs | 549 +++++++++++++++++++++++++++++ src/wireguard/handshake/peer.rs | 142 ++++++++ src/wireguard/handshake/ratelimiter.rs | 199 +++++++++++ src/wireguard/handshake/timestamp.rs | 32 ++ src/wireguard/handshake/types.rs | 90 +++++ src/wireguard/mod.rs | 23 ++ src/wireguard/router/anti_replay.rs | 157 +++++++++ src/wireguard/router/constants.rs | 7 + src/wireguard/router/device.rs | 243 +++++++++++++ src/wireguard/router/ip.rs | 26 ++ src/wireguard/router/messages.rs | 13 + src/wireguard/router/mod.rs | 22 ++ src/wireguard/router/peer.rs | 611 +++++++++++++++++++++++++++++++++ src/wireguard/router/tests.rs | 432 +++++++++++++++++++++++ src/wireguard/router/types.rs | 65 ++++ src/wireguard/router/workers.rs | 305 ++++++++++++++++ src/wireguard/tests.rs | 46 +++ src/wireguard/timers.rs | 234 +++++++++++++ src/wireguard/types/bind.rs | 23 ++ src/wireguard/types/dummy.rs | 323 +++++++++++++++++ src/wireguard/types/endpoint.rs | 7 + src/wireguard/types/keys.rs | 36 ++ src/wireguard/types/mod.rs | 10 + src/wireguard/types/tun.rs | 56 +++ src/wireguard/wireguard.rs | 407 ++++++++++++++++++++++ 31 files changed, 5549 insertions(+) create mode 100644 src/wireguard/config.rs create mode 100644 src/wireguard/constants.rs create mode 100644 src/wireguard/handshake/device.rs create mode 100644 src/wireguard/handshake/macs.rs create mode 100644 src/wireguard/handshake/messages.rs create mode 100644 src/wireguard/handshake/mod.rs create mode 100644 src/wireguard/handshake/noise.rs create mode 100644 src/wireguard/handshake/peer.rs create mode 100644 src/wireguard/handshake/ratelimiter.rs create mode 100644 src/wireguard/handshake/timestamp.rs create mode 100644 src/wireguard/handshake/types.rs create mode 100644 src/wireguard/mod.rs create mode 100644 src/wireguard/router/anti_replay.rs create mode 100644 src/wireguard/router/constants.rs create mode 100644 src/wireguard/router/device.rs create mode 100644 src/wireguard/router/ip.rs create mode 100644 src/wireguard/router/messages.rs create mode 100644 src/wireguard/router/mod.rs create mode 100644 src/wireguard/router/peer.rs create mode 100644 src/wireguard/router/tests.rs create mode 100644 src/wireguard/router/types.rs create mode 100644 src/wireguard/router/workers.rs create mode 100644 src/wireguard/tests.rs create mode 100644 src/wireguard/timers.rs create mode 100644 src/wireguard/types/bind.rs create mode 100644 src/wireguard/types/dummy.rs create mode 100644 src/wireguard/types/endpoint.rs create mode 100644 src/wireguard/types/keys.rs create mode 100644 src/wireguard/types/mod.rs create mode 100644 src/wireguard/types/tun.rs create mode 100644 src/wireguard/wireguard.rs (limited to 'src/wireguard') diff --git a/src/wireguard/config.rs b/src/wireguard/config.rs new file mode 100644 index 0000000..0f2953d --- /dev/null +++ b/src/wireguard/config.rs @@ -0,0 +1,186 @@ +use std::net::{IpAddr, SocketAddr}; +use x25519_dalek::{PublicKey, StaticSecret}; + +use super::wireguard::Wireguard; +use super::types::bind::Bind; +use super::types::tun::Tun; + +/// The goal of the configuration interface is, among others, +/// to hide the IO implementations (over which the WG device is generic), +/// from the configuration and UAPI code. + +/// Describes a snapshot of the state of a peer +pub struct PeerState { + rx_bytes: u64, + tx_bytes: u64, + last_handshake_time_sec: u64, + last_handshake_time_nsec: u64, + public_key: PublicKey, + allowed_ips: Vec<(IpAddr, u32)>, +} + +pub enum ConfigError { + NoSuchPeer +} + +impl ConfigError { + + fn errno(&self) -> i32 { + match self { + NoSuchPeer => 1, + } + } +} + +/// Exposed configuration interface +pub trait Configuration { + /// Updates the private key of the device + /// + /// # Arguments + /// + /// - `sk`: The new private key (or None, if the private key should be cleared) + fn set_private_key(&self, sk: Option); + + /// Returns the private key of the device + /// + /// # Returns + /// + /// The private if set, otherwise None. + fn get_private_key(&self) -> Option; + + /// Returns the protocol version of the device + /// + /// # Returns + /// + /// An integer indicating the protocol version + fn get_protocol_version(&self) -> usize; + + fn set_listen_port(&self, port: u16) -> Option; + + /// Set the firewall mark (or similar, depending on platform) + /// + /// # Arguments + /// + /// - `mark`: The fwmark value + /// + /// # Returns + /// + /// An error if this operation is not supported by the underlying + /// "bind" implementation. + fn set_fwmark(&self, mark: Option) -> Option; + + /// Removes all peers from the device + fn replace_peers(&self); + + /// Remove the peer from the + /// + /// # Arguments + /// + /// - `peer`: The public key of the peer to remove + /// + /// # Returns + /// + /// If the peer does not exists this operation is a noop + fn remove_peer(&self, peer: PublicKey); + + /// Adds a new peer to the device + /// + /// # Arguments + /// + /// - `peer`: The public key of the peer to add + /// + /// # Returns + /// + /// A bool indicating if the peer was added. + /// + /// If the peer already exists this operation is a noop + fn add_peer(&self, peer: PublicKey) -> bool; + + /// Update the psk of a peer + /// + /// # Arguments + /// + /// - `peer`: The public key of the peer + /// - `psk`: The new psk or None if the psk should be unset + /// + /// # Returns + /// + /// An error if no such peer exists + fn set_preshared_key(&self, peer: PublicKey, psk: Option<[u8; 32]>) -> Option; + + /// Update the endpoint of the + /// + /// # Arguments + /// + /// - `peer': The public key of the peer + /// - `psk` + fn set_endpoint(&self, peer: PublicKey, addr: SocketAddr) -> Option; + + /// Update the endpoint of the + /// + /// # Arguments + /// + /// - `peer': The public key of the peer + /// - `psk` + fn set_persistent_keepalive_interval(&self, peer: PublicKey) -> Option; + + /// Remove all allowed IPs from the peer + /// + /// # Arguments + /// + /// - `peer': The public key of the peer + /// + /// # Returns + /// + /// An error if no such peer exists + fn replace_allowed_ips(&self, peer: PublicKey) -> Option; + + /// Add a new allowed subnet to the peer + /// + /// # Arguments + /// + /// - `peer`: The public key of the peer + /// - `ip`: Subnet mask + /// - `masklen`: + /// + /// # Returns + /// + /// An error if the peer does not exist + /// + /// # Note: + /// + /// The API must itself sanitize the (ip, masklen) set: + /// The ip should be masked to remove any set bits right of the first "masklen" bits. + fn add_allowed_ip(&self, peer: PublicKey, ip: IpAddr, masklen: u32) -> Option; + + /// Returns the state of all peers + /// + /// # Returns + /// + /// A list of structures describing the state of each peer + fn get_peers(&self) -> Vec; +} + +impl Configuration for Wireguard { + + fn set_private_key(&self, sk : Option) { + self.set_key(sk) + } + + fn get_private_key(&self) -> Option { + self.get_sk() + } + + fn get_protocol_version(&self) -> usize { + 1 + } + + fn set_listen_port(&self, port : u16) -> Option { + None + } + + fn set_fwmark(&self, mark: Option) -> Option { + None + } + +} \ No newline at end of file diff --git a/src/wireguard/constants.rs b/src/wireguard/constants.rs new file mode 100644 index 0000000..72de8d9 --- /dev/null +++ b/src/wireguard/constants.rs @@ -0,0 +1,20 @@ +use std::time::Duration; +use std::u64; + +pub const REKEY_AFTER_MESSAGES: u64 = u64::MAX - (1 << 16); +pub const REJECT_AFTER_MESSAGES: u64 = u64::MAX - (1 << 4); + +pub const REKEY_AFTER_TIME: Duration = Duration::from_secs(120); +pub const REJECT_AFTER_TIME: Duration = Duration::from_secs(180); +pub const REKEY_ATTEMPT_TIME: Duration = Duration::from_secs(90); +pub const REKEY_TIMEOUT: Duration = Duration::from_secs(5); +pub const KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10); + +pub const MAX_TIMER_HANDSHAKES: usize = 18; + +pub const TIMER_MAX_DURATION: Duration = Duration::from_secs(200); +pub const TIMERS_TICK: Duration = Duration::from_millis(100); +pub const TIMERS_SLOTS: usize = (TIMER_MAX_DURATION.as_micros() / TIMERS_TICK.as_micros()) as usize; +pub const TIMERS_CAPACITY: usize = 1024; + +pub const MESSAGE_PADDING_MULTIPLE: usize = 16; diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs new file mode 100644 index 0000000..6a55f6e --- /dev/null +++ b/src/wireguard/handshake/device.rs @@ -0,0 +1,574 @@ +use spin::RwLock; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::Mutex; +use zerocopy::AsBytes; + +use byteorder::{ByteOrder, LittleEndian}; + +use rand::prelude::*; + +use x25519_dalek::PublicKey; +use x25519_dalek::StaticSecret; + +use super::macs; +use super::messages::{CookieReply, Initiation, Response}; +use super::messages::{TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE}; +use super::noise; +use super::peer::Peer; +use super::ratelimiter::RateLimiter; +use super::types::*; + +const MAX_PEER_PER_DEVICE: usize = 1 << 20; + +pub struct Device { + pub sk: StaticSecret, // static secret key + pub pk: PublicKey, // static public key + macs: macs::Validator, // validator for the mac fields + pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state + id_map: RwLock>, // receiver ids -> public key + limiter: Mutex, +} + +/* A mutable reference to the device needs to be held during configuration. + * Wrapping the device in a RwLock enables peer config after "configuration time" + */ +impl Device { + /// Initialize a new handshake state machine + /// + /// # Arguments + /// + /// * `sk` - x25519 scalar representing the local private key + pub fn new(sk: StaticSecret) -> Device { + let pk = PublicKey::from(&sk); + Device { + pk, + sk, + macs: macs::Validator::new(pk), + pk_map: HashMap::new(), + id_map: RwLock::new(HashMap::new()), + limiter: Mutex::new(RateLimiter::new()), + } + } + + /// Update the secret key of the device + /// + /// # Arguments + /// + /// * `sk` - x25519 scalar representing the local private key + pub fn set_sk(&mut self, sk: StaticSecret) { + // update secret and public key + let pk = PublicKey::from(&sk); + self.sk = sk; + self.pk = pk; + self.macs = macs::Validator::new(pk); + + // recalculate the shared secrets for every peer + let mut ids = vec![]; + for mut peer in self.pk_map.values_mut() { + peer.reset_state().map(|id| ids.push(id)); + peer.ss = self.sk.diffie_hellman(&peer.pk) + } + + // release ids from aborted handshakes + for id in ids { + self.release(id) + } + } + + /// Return the secret key of the device + /// + /// # Returns + /// + /// A secret key (x25519 scalar) + pub fn get_sk(&self) -> StaticSecret { + StaticSecret::from(self.sk.to_bytes()) + } + + /// Add a new public key to the state machine + /// To remove public keys, you must create a new machine instance + /// + /// # Arguments + /// + /// * `pk` - The public key to add + /// * `identifier` - Associated identifier which can be used to distinguish the peers + pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> { + // check that the pk is not added twice + if let Some(_) = self.pk_map.get(pk.as_bytes()) { + return Err(ConfigError::new("Duplicate public key")); + }; + + // check that the pk is not that of the device + if *self.pk.as_bytes() == *pk.as_bytes() { + return Err(ConfigError::new( + "Public key corresponds to secret key of interface", + )); + } + + // ensure less than 2^20 peers + if self.pk_map.len() > MAX_PEER_PER_DEVICE { + return Err(ConfigError::new("Too many peers for device")); + } + + // map the public key to the peer state + self.pk_map + .insert(*pk.as_bytes(), Peer::new(pk, self.sk.diffie_hellman(&pk))); + + Ok(()) + } + + /// Remove a peer by public key + /// To remove public keys, you must create a new machine instance + /// + /// # Arguments + /// + /// * `pk` - The public key of the peer to remove + /// + /// # Returns + /// + /// The call might fail if the public key is not found + pub fn remove(&mut self, pk: PublicKey) -> Result<(), ConfigError> { + // take write-lock on receive id table + let mut id_map = self.id_map.write(); + + // remove the peer + self.pk_map + .remove(pk.as_bytes()) + .ok_or(ConfigError::new("Public key not in device"))?; + + // pruge the id map (linear scan) + id_map.retain(|_, v| v != pk.as_bytes()); + Ok(()) + } + + /// Add a psk to the peer + /// + /// # Arguments + /// + /// * `pk` - The public key of the peer + /// * `psk` - The psk to set / unset + /// + /// # Returns + /// + /// The call might fail if the public key is not found + pub fn set_psk(&mut self, pk: PublicKey, psk: Option) -> Result<(), ConfigError> { + match self.pk_map.get_mut(pk.as_bytes()) { + Some(mut peer) => { + peer.psk = match psk { + Some(v) => v, + None => [0u8; 32], + }; + Ok(()) + } + _ => Err(ConfigError::new("No such public key")), + } + } + + /// Return the psk for the peer + /// + /// # Arguments + /// + /// * `pk` - The public key of the peer + /// + /// # Returns + /// + /// A 32 byte array holding the PSK + /// + /// The call might fail if the public key is not found + pub fn get_psk(&self, pk: PublicKey) -> Result { + match self.pk_map.get(pk.as_bytes()) { + Some(peer) => Ok(peer.psk), + _ => Err(ConfigError::new("No such public key")), + } + } + + /// Release an id back to the pool + /// + /// # Arguments + /// + /// * `id` - The (sender) id to release + pub fn release(&self, id: u32) { + let mut m = self.id_map.write(); + debug_assert!(m.contains_key(&id), "Releasing id not allocated"); + m.remove(&id); + } + + /// Begin a new handshake + /// + /// # Arguments + /// + /// * `pk` - Public key of peer to initiate handshake for + pub fn begin( + &self, + rng: &mut R, + pk: &PublicKey, + ) -> Result, HandshakeError> { + match self.pk_map.get(pk.as_bytes()) { + None => Err(HandshakeError::UnknownPublicKey), + Some(peer) => { + let sender = self.allocate(rng, peer); + + let mut msg = Initiation::default(); + + noise::create_initiation(rng, self, peer, sender, &mut msg.noise)?; + + // add macs to initation + + peer.macs + .lock() + .generate(msg.noise.as_bytes(), &mut msg.macs); + + Ok(msg.as_bytes().to_owned()) + } + } + } + + /// Process a handshake message. + /// + /// # Arguments + /// + /// * `msg` - Byte slice containing the message (untrusted input) + pub fn process<'a, R: RngCore + CryptoRng, S>( + &self, + rng: &mut R, // rng instance to sample randomness from + msg: &[u8], // message buffer + src: Option<&'a S>, // optional source endpoint, set when "under load" + ) -> Result + where + &'a S: Into<&'a SocketAddr>, + { + // ensure type read in-range + if msg.len() < 4 { + return Err(HandshakeError::InvalidMessageFormat); + } + + // de-multiplex the message type field + match LittleEndian::read_u32(msg) { + TYPE_INITIATION => { + // parse message + let msg = Initiation::parse(msg)?; + + // check mac1 field + self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?; + + // address validation & DoS mitigation + if let Some(src) = src { + // obtain ref to socket addr + let src = src.into(); + + // check mac2 field + if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) { + let mut reply = Default::default(); + self.macs.create_cookie_reply( + rng, + msg.noise.f_sender.get(), + src, + &msg.macs, + &mut reply, + ); + return Ok((None, Some(reply.as_bytes().to_owned()), None)); + } + + // check ratelimiter + if !self.limiter.lock().unwrap().allow(&src.ip()) { + return Err(HandshakeError::RateLimited); + } + } + + // consume the initiation + let (peer, st) = noise::consume_initiation(self, &msg.noise)?; + + // allocate new index for response + let sender = self.allocate(rng, peer); + + // prepare memory for response, TODO: take slice for zero allocation + let mut resp = Response::default(); + + // create response (release id on error) + let keys = noise::create_response(rng, peer, sender, st, &mut resp.noise).map_err( + |e| { + self.release(sender); + e + }, + )?; + + // add macs to response + peer.macs + .lock() + .generate(resp.noise.as_bytes(), &mut resp.macs); + + // return unconfirmed keypair and the response as vector + Ok((Some(peer.pk), Some(resp.as_bytes().to_owned()), Some(keys))) + } + TYPE_RESPONSE => { + let msg = Response::parse(msg)?; + + // check mac1 field + self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?; + + // address validation & DoS mitigation + if let Some(src) = src { + // obtain ref to socket addr + let src = src.into(); + + // check mac2 field + if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) { + let mut reply = Default::default(); + self.macs.create_cookie_reply( + rng, + msg.noise.f_sender.get(), + src, + &msg.macs, + &mut reply, + ); + return Ok((None, Some(reply.as_bytes().to_owned()), None)); + } + + // check ratelimiter + if !self.limiter.lock().unwrap().allow(&src.ip()) { + return Err(HandshakeError::RateLimited); + } + } + + // consume inner playload + noise::consume_response(self, &msg.noise) + } + TYPE_COOKIE_REPLY => { + let msg = CookieReply::parse(msg)?; + + // lookup peer + let peer = self.lookup_id(msg.f_receiver.get())?; + + // validate cookie reply + peer.macs.lock().process(&msg)?; + + // this prompts no new message and + // DOES NOT cryptographically verify the peer + Ok((None, None, None)) + } + _ => Err(HandshakeError::InvalidMessageFormat), + } + } + + // Internal function + // + // Return the peer associated with the public key + pub(crate) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer, HandshakeError> { + self.pk_map + .get(pk.as_bytes()) + .ok_or(HandshakeError::UnknownPublicKey) + } + + // Internal function + // + // Return the peer currently associated with the receiver identifier + pub(crate) fn lookup_id(&self, id: u32) -> Result<&Peer, HandshakeError> { + let im = self.id_map.read(); + let pk = im.get(&id).ok_or(HandshakeError::UnknownReceiverId)?; + match self.pk_map.get(pk) { + Some(peer) => Ok(peer), + _ => unreachable!(), // if the id-lookup succeeded, the peer should exist + } + } + + // Internal function + // + // Allocated a new receiver identifier for the peer + fn allocate(&self, rng: &mut R, peer: &Peer) -> u32 { + loop { + let id = rng.gen(); + + // check membership with read lock + if self.id_map.read().contains_key(&id) { + continue; + } + + // take write lock and add index + let mut m = self.id_map.write(); + if !m.contains_key(&id) { + m.insert(id, *peer.pk.as_bytes()); + return id; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::super::messages::*; + use super::*; + use hex; + use rand::rngs::OsRng; + use std::net::SocketAddr; + use std::thread; + use std::time::Duration; + + fn setup_devices( + rng: &mut R, + ) -> (PublicKey, Device, PublicKey, Device) { + // generate new keypairs + + let sk1 = StaticSecret::new(rng); + let pk1 = PublicKey::from(&sk1); + + let sk2 = StaticSecret::new(rng); + let pk2 = PublicKey::from(&sk2); + + // pick random psk + + let mut psk = [0u8; 32]; + rng.fill_bytes(&mut psk[..]); + + // intialize devices on both ends + + let mut dev1 = Device::new(sk1); + let mut dev2 = Device::new(sk2); + + dev1.add(pk2).unwrap(); + dev2.add(pk1).unwrap(); + + dev1.set_psk(pk2, Some(psk)).unwrap(); + dev2.set_psk(pk1, Some(psk)).unwrap(); + + (pk1, dev1, pk2, dev2) + } + + /* Test longest possible handshake interaction (7 messages): + * + * 1. I -> R (initation) + * 2. I <- R (cookie reply) + * 3. I -> R (initation) + * 4. I <- R (response) + * 5. I -> R (cookie reply) + * 6. I -> R (initation) + * 7. I <- R (response) + */ + #[test] + fn handshake_under_load() { + let mut rng = OsRng::new().unwrap(); + let (_pk1, dev1, pk2, dev2) = setup_devices(&mut rng); + + let src1: SocketAddr = "172.16.0.1:8080".parse().unwrap(); + let src2: SocketAddr = "172.16.0.2:7070".parse().unwrap(); + + // 1. device-1 : create first initation + let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); + + // 2. device-2 : responds with CookieReply + let msg_cookie = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() { + (None, Some(msg), None) => msg, + _ => panic!("unexpected response"), + }; + + // device-1 : processes CookieReply (no response) + match dev1.process(&mut rng, &msg_cookie, Some(&src2)).unwrap() { + (None, None, None) => (), + _ => panic!("unexpected response"), + } + + // avoid initation flood + thread::sleep(Duration::from_millis(20)); + + // 3. device-1 : create second initation + let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); + + // 4. device-2 : responds with noise response + let msg_response = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() { + (Some(_), Some(msg), Some(kp)) => { + assert_eq!(kp.initiator, false); + msg + } + _ => panic!("unexpected response"), + }; + + // 5. device-1 : responds with CookieReply + let msg_cookie = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() { + (None, Some(msg), None) => msg, + _ => panic!("unexpected response"), + }; + + // device-2 : processes CookieReply (no response) + match dev2.process(&mut rng, &msg_cookie, Some(&src1)).unwrap() { + (None, None, None) => (), + _ => panic!("unexpected response"), + } + + // avoid initation flood + thread::sleep(Duration::from_millis(20)); + + // 6. device-1 : create third initation + let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); + + // 7. device-2 : responds with noise response + let (msg_response, kp1) = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() { + (Some(_), Some(msg), Some(kp)) => { + assert_eq!(kp.initiator, false); + (msg, kp) + } + _ => panic!("unexpected response"), + }; + + // device-1 : process noise response + let kp2 = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() { + (Some(_), None, Some(kp)) => { + assert_eq!(kp.initiator, true); + kp + } + _ => panic!("unexpected response"), + }; + + assert_eq!(kp1.send, kp2.recv); + assert_eq!(kp1.recv, kp2.send); + } + + #[test] + fn handshake_no_load() { + let mut rng = OsRng::new().unwrap(); + let (pk1, mut dev1, pk2, mut dev2) = setup_devices(&mut rng); + + // do a few handshakes (every handshake should succeed) + + for i in 0..10 { + println!("handshake : {}", i); + + // create initiation + + let msg1 = dev1.begin(&mut rng, &pk2).unwrap(); + + println!("msg1 = {} : {} bytes", hex::encode(&msg1[..]), msg1.len()); + println!("msg1 = {:?}", Initiation::parse(&msg1[..]).unwrap()); + + // process initiation and create response + + let (_, msg2, ks_r) = dev2.process(&mut rng, &msg1, None).unwrap(); + + let ks_r = ks_r.unwrap(); + let msg2 = msg2.unwrap(); + + println!("msg2 = {} : {} bytes", hex::encode(&msg2[..]), msg2.len()); + println!("msg2 = {:?}", Response::parse(&msg2[..]).unwrap()); + + assert!(!ks_r.initiator, "Responders key-pair is confirmed"); + + // process response and obtain confirmed key-pair + + let (_, msg3, ks_i) = dev1.process(&mut rng, &msg2, None).unwrap(); + let ks_i = ks_i.unwrap(); + + assert!(msg3.is_none(), "Returned message after response"); + assert!(ks_i.initiator, "Initiators key-pair is not confirmed"); + + assert_eq!(ks_i.send, ks_r.recv, "KeyI.send != KeyR.recv"); + assert_eq!(ks_i.recv, ks_r.send, "KeyI.recv != KeyR.send"); + + dev1.release(ks_i.send.id); + dev2.release(ks_r.send.id); + + // to avoid flood detection + thread::sleep(Duration::from_millis(20)); + } + + dev1.remove(pk2).unwrap(); + dev2.remove(pk1).unwrap(); + } +} diff --git a/src/wireguard/handshake/macs.rs b/src/wireguard/handshake/macs.rs new file mode 100644 index 0000000..689826b --- /dev/null +++ b/src/wireguard/handshake/macs.rs @@ -0,0 +1,327 @@ +use generic_array::GenericArray; +use rand::{CryptoRng, RngCore}; +use spin::RwLock; +use std::time::{Duration, Instant}; + +// types to coalesce into bytes +use std::net::SocketAddr; +use x25519_dalek::PublicKey; + +// AEAD +use aead::{Aead, NewAead, Payload}; +use chacha20poly1305::XChaCha20Poly1305; + +// MAC +use blake2::Blake2s; +use subtle::ConstantTimeEq; + +use super::messages::{CookieReply, MacsFooter, TYPE_COOKIE_REPLY}; +use super::types::HandshakeError; + +const LABEL_MAC1: &[u8] = b"mac1----"; +const LABEL_COOKIE: &[u8] = b"cookie--"; + +const SIZE_COOKIE: usize = 16; +const SIZE_SECRET: usize = 32; +const SIZE_MAC: usize = 16; // blake2s-mac128 +const SIZE_TAG: usize = 16; // xchacha20poly1305 tag + +const COOKIE_UPDATE_INTERVAL: Duration = Duration::from_secs(120); + +macro_rules! HASH { + ( $($input:expr),* ) => {{ + use blake2::Digest; + let mut hsh = Blake2s::new(); + $( + hsh.input($input); + )* + hsh.result() + }}; +} + +macro_rules! MAC { + ( $key:expr, $($input:expr),* ) => {{ + use blake2::VarBlake2s; + use digest::Input; + use digest::VariableOutput; + let mut tag = [0u8; SIZE_MAC]; + let mut mac = VarBlake2s::new_keyed($key, SIZE_MAC); + $( + mac.input($input); + )* + mac.variable_result(|buf| tag.copy_from_slice(buf)); + tag + }}; +} + +macro_rules! XSEAL { + ($key:expr, $nonce:expr, $ad:expr, $pt:expr, $ct:expr) => {{ + let ct = XChaCha20Poly1305::new(*GenericArray::from_slice($key)) + .encrypt( + GenericArray::from_slice($nonce), + Payload { msg: $pt, aad: $ad }, + ) + .unwrap(); + debug_assert_eq!(ct.len(), $pt.len() + SIZE_TAG); + $ct.copy_from_slice(&ct); + }}; +} + +macro_rules! XOPEN { + ($key:expr, $nonce:expr, $ad:expr, $pt:expr, $ct:expr) => {{ + debug_assert_eq!($ct.len(), $pt.len() + SIZE_TAG); + XChaCha20Poly1305::new(*GenericArray::from_slice($key)) + .decrypt( + GenericArray::from_slice($nonce), + Payload { msg: $ct, aad: $ad }, + ) + .map_err(|_| HandshakeError::DecryptionFailure) + .map(|pt| $pt.copy_from_slice(&pt)) + }}; +} + +struct Cookie { + value: [u8; 16], + birth: Instant, +} + +pub struct Generator { + mac1_key: [u8; 32], + cookie_key: [u8; 32], // xchacha20poly key for opening cookie response + last_mac1: Option<[u8; 16]>, + cookie: Option, +} + +fn addr_to_mac_bytes(addr: &SocketAddr) -> Vec { + match addr { + SocketAddr::V4(addr) => { + let mut res = Vec::with_capacity(4 + 2); + res.extend(&addr.ip().octets()); + res.extend(&addr.port().to_le_bytes()); + res + } + SocketAddr::V6(addr) => { + let mut res = Vec::with_capacity(16 + 2); + res.extend(&addr.ip().octets()); + res.extend(&addr.port().to_le_bytes()); + res + } + } +} + +impl Generator { + /// Initalize a new mac field generator + /// + /// # Arguments + /// + /// - pk: The public key of the peer to which the generator is associated + /// + /// # Returns + /// + /// A freshly initated generator + pub fn new(pk: PublicKey) -> Generator { + Generator { + mac1_key: HASH!(LABEL_MAC1, pk.as_bytes()).into(), + cookie_key: HASH!(LABEL_COOKIE, pk.as_bytes()).into(), + last_mac1: None, + cookie: None, + } + } + + /// Process a CookieReply message + /// + /// # Arguments + /// + /// - reply: CookieReply to process + /// + /// # Returns + /// + /// Can fail if the cookie reply fails to validate + /// (either indicating that it is outdated or malformed) + pub fn process(&mut self, reply: &CookieReply) -> Result<(), HandshakeError> { + let mac1 = self.last_mac1.ok_or(HandshakeError::InvalidState)?; + let mut tau = [0u8; SIZE_COOKIE]; + XOPEN!( + &self.cookie_key, // key + &reply.f_nonce, // nonce + &mac1, // ad + &mut tau, // pt + &reply.f_cookie // ct || tag + )?; + self.cookie = Some(Cookie { + birth: Instant::now(), + value: tau, + }); + Ok(()) + } + + /// Generate both mac fields for an inner message + /// + /// # Arguments + /// + /// - inner: A byteslice representing the inner message to be covered + /// - macs: The destination mac footer for the resulting macs + pub fn generate(&mut self, inner: &[u8], macs: &mut MacsFooter) { + macs.f_mac1 = MAC!(&self.mac1_key, inner); + macs.f_mac2 = match &self.cookie { + Some(cookie) => { + if cookie.birth.elapsed() > COOKIE_UPDATE_INTERVAL { + self.cookie = None; + [0u8; SIZE_MAC] + } else { + MAC!(&cookie.value, inner, macs.f_mac1) + } + } + None => [0u8; SIZE_MAC], + }; + self.last_mac1 = Some(macs.f_mac1); + } +} + +struct Secret { + value: [u8; 32], + birth: Instant, +} + +pub struct Validator { + mac1_key: [u8; 32], // mac1 key, derived from device public key + cookie_key: [u8; 32], // xchacha20poly key for sealing cookie response + secret: RwLock, +} + +impl Validator { + pub fn new(pk: PublicKey) -> Validator { + Validator { + mac1_key: HASH!(LABEL_MAC1, pk.as_bytes()).into(), + cookie_key: HASH!(LABEL_COOKIE, pk.as_bytes()).into(), + secret: RwLock::new(Secret { + value: [0u8; SIZE_SECRET], + birth: Instant::now() - Duration::new(86400, 0), + }), + } + } + + fn get_tau(&self, src: &[u8]) -> Option<[u8; SIZE_COOKIE]> { + let secret = self.secret.read(); + if secret.birth.elapsed() < COOKIE_UPDATE_INTERVAL { + Some(MAC!(&secret.value, src)) + } else { + None + } + } + + fn get_set_tau(&self, rng: &mut R, src: &[u8]) -> [u8; SIZE_COOKIE] { + // check if current value is still valid + { + let secret = self.secret.read(); + if secret.birth.elapsed() < COOKIE_UPDATE_INTERVAL { + return MAC!(&secret.value, src); + }; + } + + // take write lock, check again + { + let mut secret = self.secret.write(); + if secret.birth.elapsed() < COOKIE_UPDATE_INTERVAL { + return MAC!(&secret.value, src); + }; + + // set new random cookie secret + rng.fill_bytes(&mut secret.value); + secret.birth = Instant::now(); + MAC!(&secret.value, src) + } + } + + pub fn create_cookie_reply( + &self, + rng: &mut R, + receiver: u32, // receiver id of incoming message + src: &SocketAddr, // source address of incoming message + macs: &MacsFooter, // footer of incoming message + msg: &mut CookieReply, // resulting cookie reply + ) { + let src = addr_to_mac_bytes(src); + msg.f_type.set(TYPE_COOKIE_REPLY as u32); + msg.f_receiver.set(receiver); + rng.fill_bytes(&mut msg.f_nonce); + XSEAL!( + &self.cookie_key, // key + &msg.f_nonce, // nonce + &macs.f_mac1, // ad + &self.get_set_tau(rng, &src), // pt + &mut msg.f_cookie // ct || tag + ); + } + + /// Check the mac1 field against the inner message + /// + /// # Arguments + /// + /// - inner: The inner message covered by the mac1 field + /// - macs: The mac footer + pub fn check_mac1(&self, inner: &[u8], macs: &MacsFooter) -> Result<(), HandshakeError> { + let valid_mac1: bool = MAC!(&self.mac1_key, inner).ct_eq(&macs.f_mac1).into(); + if !valid_mac1 { + Err(HandshakeError::InvalidMac1) + } else { + Ok(()) + } + } + + pub fn check_mac2(&self, inner: &[u8], src: &SocketAddr, macs: &MacsFooter) -> bool { + let src = addr_to_mac_bytes(src); + match self.get_tau(&src) { + Some(tau) => MAC!(&tau, inner, macs.f_mac1).ct_eq(&macs.f_mac2).into(), + None => false, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use proptest::prelude::*; + use rand::rngs::OsRng; + use x25519_dalek::StaticSecret; + + fn new_validator_generator() -> (Validator, Generator) { + let mut rng = OsRng::new().unwrap(); + let sk = StaticSecret::new(&mut rng); + let pk = PublicKey::from(&sk); + (Validator::new(pk), Generator::new(pk)) + } + + proptest! { + #[test] + fn test_cookie_reply(inner1 : Vec, inner2 : Vec, receiver : u32) { + let mut msg = CookieReply::default(); + let mut rng = OsRng::new().expect("failed to create rng"); + let mut macs = MacsFooter::default(); + let src = "192.0.2.16:8080".parse().unwrap(); + let (validator, mut generator) = new_validator_generator(); + + // generate mac1 for first message + generator.generate(&inner1[..], &mut macs); + assert_ne!(macs.f_mac1, [0u8; SIZE_MAC], "mac1 should be set"); + assert_eq!(macs.f_mac2, [0u8; SIZE_MAC], "mac2 should not be set"); + + // check validity of mac1 + validator.check_mac1(&inner1[..], &macs).expect("mac1 of inner1 did not validate"); + assert_eq!(validator.check_mac2(&inner1[..], &src, &macs), false, "mac2 of inner2 did not validate"); + validator.create_cookie_reply(&mut rng, receiver, &src, &macs, &mut msg); + + // consume cookie reply + generator.process(&msg).expect("failed to process CookieReply"); + + // generate mac2 & mac2 for second message + generator.generate(&inner2[..], &mut macs); + assert_ne!(macs.f_mac1, [0u8; SIZE_MAC], "mac1 should be set"); + assert_ne!(macs.f_mac2, [0u8; SIZE_MAC], "mac2 should be set"); + + // check validity of mac1 and mac2 + validator.check_mac1(&inner2[..], &macs).expect("mac1 of inner2 did not validate"); + assert!(validator.check_mac2(&inner2[..], &src, &macs), "mac2 of inner2 did not validate"); + } + } +} diff --git a/src/wireguard/handshake/messages.rs b/src/wireguard/handshake/messages.rs new file mode 100644 index 0000000..29d80af --- /dev/null +++ b/src/wireguard/handshake/messages.rs @@ -0,0 +1,363 @@ +#[cfg(test)] +use hex; + +#[cfg(test)] +use std::fmt; + +use std::mem; + +use byteorder::LittleEndian; +use zerocopy::byteorder::U32; +use zerocopy::{AsBytes, ByteSlice, FromBytes, LayoutVerified}; + +use super::types::*; + +const SIZE_MAC: usize = 16; +const SIZE_TAG: usize = 16; // poly1305 tag +const SIZE_XNONCE: usize = 24; // xchacha20 nonce +const SIZE_COOKIE: usize = 16; // +const SIZE_X25519_POINT: usize = 32; // x25519 public key +const SIZE_TIMESTAMP: usize = 12; + +pub const TYPE_INITIATION: u32 = 1; +pub const TYPE_RESPONSE: u32 = 2; +pub const TYPE_COOKIE_REPLY: u32 = 3; + +const fn max(a: usize, b: usize) -> usize { + let m: usize = (a > b) as usize; + m * a + (1 - m) * b +} + +pub const MAX_HANDSHAKE_MSG_SIZE: usize = max( + max(mem::size_of::(), mem::size_of::()), + mem::size_of::(), +); + +/* Handshake messsages */ + +#[repr(packed)] +#[derive(Copy, Clone, FromBytes, AsBytes)] +pub struct Response { + pub noise: NoiseResponse, // inner message covered by macs + pub macs: MacsFooter, +} + +#[repr(packed)] +#[derive(Copy, Clone, FromBytes, AsBytes)] +pub struct Initiation { + pub noise: NoiseInitiation, // inner message covered by macs + pub macs: MacsFooter, +} + +#[repr(packed)] +#[derive(Copy, Clone, FromBytes, AsBytes)] +pub struct CookieReply { + pub f_type: U32, + pub f_receiver: U32, + pub f_nonce: [u8; SIZE_XNONCE], + pub f_cookie: [u8; SIZE_COOKIE + SIZE_TAG], +} + +/* Inner sub-messages */ + +#[repr(packed)] +#[derive(Copy, Clone, FromBytes, AsBytes)] +pub struct MacsFooter { + pub f_mac1: [u8; SIZE_MAC], + pub f_mac2: [u8; SIZE_MAC], +} + +#[repr(packed)] +#[derive(Copy, Clone, FromBytes, AsBytes)] +pub struct NoiseInitiation { + pub f_type: U32, + pub f_sender: U32, + pub f_ephemeral: [u8; SIZE_X25519_POINT], + pub f_static: [u8; SIZE_X25519_POINT + SIZE_TAG], + pub f_timestamp: [u8; SIZE_TIMESTAMP + SIZE_TAG], +} + +#[repr(packed)] +#[derive(Copy, Clone, FromBytes, AsBytes)] +pub struct NoiseResponse { + pub f_type: U32, + pub f_sender: U32, + pub f_receiver: U32, + pub f_ephemeral: [u8; SIZE_X25519_POINT], + pub f_empty: [u8; SIZE_TAG], +} + +/* Zero copy parsing of handshake messages */ + +impl Initiation { + pub fn parse(bytes: B) -> Result, HandshakeError> { + let msg: LayoutVerified = + LayoutVerified::new(bytes).ok_or(HandshakeError::InvalidMessageFormat)?; + + if msg.noise.f_type.get() != (TYPE_INITIATION as u32) { + return Err(HandshakeError::InvalidMessageFormat); + } + + Ok(msg) + } +} + +impl Response { + pub fn parse(bytes: B) -> Result, HandshakeError> { + let msg: LayoutVerified = + LayoutVerified::new(bytes).ok_or(HandshakeError::InvalidMessageFormat)?; + + if msg.noise.f_type.get() != (TYPE_RESPONSE as u32) { + return Err(HandshakeError::InvalidMessageFormat); + } + + Ok(msg) + } +} + +impl CookieReply { + pub fn parse(bytes: B) -> Result, HandshakeError> { + let msg: LayoutVerified = + LayoutVerified::new(bytes).ok_or(HandshakeError::InvalidMessageFormat)?; + + if msg.f_type.get() != (TYPE_COOKIE_REPLY as u32) { + return Err(HandshakeError::InvalidMessageFormat); + } + + Ok(msg) + } +} + +/* Default values */ + +impl Default for Response { + fn default() -> Self { + Self { + noise: Default::default(), + macs: Default::default(), + } + } +} + +impl Default for Initiation { + fn default() -> Self { + Self { + noise: Default::default(), + macs: Default::default(), + } + } +} + +impl Default for CookieReply { + fn default() -> Self { + Self { + f_type: >::new(TYPE_COOKIE_REPLY as u32), + f_receiver: >::ZERO, + f_nonce: [0u8; SIZE_XNONCE], + f_cookie: [0u8; SIZE_COOKIE + SIZE_TAG], + } + } +} + +impl Default for MacsFooter { + fn default() -> Self { + Self { + f_mac1: [0u8; SIZE_MAC], + f_mac2: [0u8; SIZE_MAC], + } + } +} + +impl Default for NoiseInitiation { + fn default() -> Self { + Self { + f_type: >::new(TYPE_INITIATION as u32), + f_sender: >::ZERO, + f_ephemeral: [0u8; SIZE_X25519_POINT], + f_static: [0u8; SIZE_X25519_POINT + SIZE_TAG], + f_timestamp: [0u8; SIZE_TIMESTAMP + SIZE_TAG], + } + } +} + +impl Default for NoiseResponse { + fn default() -> Self { + Self { + f_type: >::new(TYPE_RESPONSE as u32), + f_sender: >::ZERO, + f_receiver: >::ZERO, + f_ephemeral: [0u8; SIZE_X25519_POINT], + f_empty: [0u8; SIZE_TAG], + } + } +} + +/* Debug formatting (for testing purposes) */ + +#[cfg(test)] +impl fmt::Debug for Initiation { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Initiation {{ {:?} || {:?} }}", self.noise, self.macs) + } +} + +#[cfg(test)] +impl fmt::Debug for Response { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Response {{ {:?} || {:?} }}", self.noise, self.macs) + } +} + +#[cfg(test)] +impl fmt::Debug for CookieReply { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "CookieReply {{ type = {}, receiver = {}, nonce = {}, cookie = {} }}", + self.f_type, + self.f_receiver, + hex::encode(&self.f_nonce[..]), + hex::encode(&self.f_cookie[..]), + ) + } +} + +#[cfg(test)] +impl fmt::Debug for NoiseInitiation { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, + "NoiseInitiation {{ type = {}, sender = {}, ephemeral = {}, static = {}, timestamp = {} }}", + self.f_type.get(), + self.f_sender.get(), + hex::encode(&self.f_ephemeral[..]), + hex::encode(&self.f_static[..]), + hex::encode(&self.f_timestamp[..]), + ) + } +} + +#[cfg(test)] +impl fmt::Debug for NoiseResponse { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, + "NoiseResponse {{ type = {}, sender = {}, receiver = {}, ephemeral = {}, empty = |{} }}", + self.f_type, + self.f_sender, + self.f_receiver, + hex::encode(&self.f_ephemeral[..]), + hex::encode(&self.f_empty[..]) + ) + } +} + +#[cfg(test)] +impl fmt::Debug for MacsFooter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Macs {{ mac1 = {}, mac2 = {} }}", + hex::encode(&self.f_mac1[..]), + hex::encode(&self.f_mac2[..]) + ) + } +} + +/* Equality (for testing purposes) */ + +#[cfg(test)] +macro_rules! eq_as_bytes { + ($type:path) => { + impl PartialEq for $type { + fn eq(&self, other: &Self) -> bool { + self.as_bytes() == other.as_bytes() + } + } + impl Eq for $type {} + }; +} + +#[cfg(test)] +eq_as_bytes!(Initiation); + +#[cfg(test)] +eq_as_bytes!(Response); + +#[cfg(test)] +eq_as_bytes!(CookieReply); + +#[cfg(test)] +eq_as_bytes!(MacsFooter); + +#[cfg(test)] +eq_as_bytes!(NoiseInitiation); + +#[cfg(test)] +eq_as_bytes!(NoiseResponse); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn message_response_identity() { + let mut msg: Response = Default::default(); + + msg.noise.f_sender.set(146252); + msg.noise.f_receiver.set(554442); + msg.noise.f_ephemeral = [ + 0xc1, 0x66, 0x0a, 0x0c, 0xdc, 0x0f, 0x6c, 0x51, 0x0f, 0xc2, 0xcc, 0x51, 0x52, 0x0c, + 0xde, 0x1e, 0xf7, 0xf1, 0xca, 0x90, 0x86, 0x72, 0xad, 0x67, 0xea, 0x89, 0x45, 0x44, + 0x13, 0x56, 0x52, 0x1f, + ]; + msg.noise.f_empty = [ + 0x60, 0x0e, 0x1e, 0x95, 0x41, 0x6b, 0x52, 0x05, 0xa2, 0x09, 0xe1, 0xbf, 0x40, 0x05, + 0x2f, 0xde, + ]; + msg.macs.f_mac1 = [ + 0xf2, 0xad, 0x40, 0xb5, 0xf7, 0xde, 0x77, 0x35, 0x89, 0x19, 0xb7, 0x5c, 0xf9, 0x54, + 0x69, 0x29, + ]; + msg.macs.f_mac2 = [ + 0x4f, 0xd2, 0x1b, 0xfe, 0x77, 0xe6, 0x2e, 0xc9, 0x07, 0xe2, 0x87, 0x17, 0xbb, 0xe5, + 0xdf, 0xbb, + ]; + + let buf: Vec = msg.as_bytes().to_vec(); + let msg_p = Response::parse(&buf[..]).unwrap(); + assert_eq!(msg, *msg_p.into_ref()); + } + + #[test] + fn message_initiate_identity() { + let mut msg: Initiation = Default::default(); + + msg.noise.f_sender.set(575757); + msg.noise.f_ephemeral = [ + 0xc1, 0x66, 0x0a, 0x0c, 0xdc, 0x0f, 0x6c, 0x51, 0x0f, 0xc2, 0xcc, 0x51, 0x52, 0x0c, + 0xde, 0x1e, 0xf7, 0xf1, 0xca, 0x90, 0x86, 0x72, 0xad, 0x67, 0xea, 0x89, 0x45, 0x44, + 0x13, 0x56, 0x52, 0x1f, + ]; + msg.noise.f_static = [ + 0xdc, 0x33, 0x90, 0x15, 0x8f, 0x82, 0x3e, 0x06, 0x44, 0xa0, 0xde, 0x4c, 0x15, 0x6c, + 0x5d, 0xa4, 0x65, 0x99, 0xf6, 0x6c, 0xa1, 0x14, 0x77, 0xf9, 0xeb, 0x6a, 0xec, 0xc3, + 0x3c, 0xda, 0x47, 0xe1, 0x45, 0xac, 0x8d, 0x43, 0xea, 0x1b, 0x2f, 0x02, 0x45, 0x5d, + 0x86, 0x37, 0xee, 0x83, 0x6b, 0x42, + ]; + msg.noise.f_timestamp = [ + 0x4f, 0x1c, 0x60, 0xec, 0x0e, 0xf6, 0x36, 0xf0, 0x78, 0x28, 0x57, 0x42, 0x60, 0x0e, + 0x1e, 0x95, 0x41, 0x6b, 0x52, 0x05, 0xa2, 0x09, 0xe1, 0xbf, 0x40, 0x05, 0x2f, 0xde, + ]; + msg.macs.f_mac1 = [ + 0xf2, 0xad, 0x40, 0xb5, 0xf7, 0xde, 0x77, 0x35, 0x89, 0x19, 0xb7, 0x5c, 0xf9, 0x54, + 0x69, 0x29, + ]; + msg.macs.f_mac2 = [ + 0x4f, 0xd2, 0x1b, 0xfe, 0x77, 0xe6, 0x2e, 0xc9, 0x07, 0xe2, 0x87, 0x17, 0xbb, 0xe5, + 0xdf, 0xbb, + ]; + + let buf: Vec = msg.as_bytes().to_vec(); + let msg_p = Initiation::parse(&buf[..]).unwrap(); + assert_eq!(msg, *msg_p.into_ref()); + } +} diff --git a/src/wireguard/handshake/mod.rs b/src/wireguard/handshake/mod.rs new file mode 100644 index 0000000..071a41f --- /dev/null +++ b/src/wireguard/handshake/mod.rs @@ -0,0 +1,21 @@ +/* Implementation of the: + * + * Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s + * + * Protocol pattern, see: http://www.noiseprotocol.org/noise.html. + * For documentation. + */ + +mod device; +mod macs; +mod messages; +mod noise; +mod peer; +mod ratelimiter; +mod timestamp; +mod types; + +// publicly exposed interface + +pub use device::Device; +pub use messages::{MAX_HANDSHAKE_MSG_SIZE, TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE}; diff --git a/src/wireguard/handshake/noise.rs b/src/wireguard/handshake/noise.rs new file mode 100644 index 0000000..a2a84b0 --- /dev/null +++ b/src/wireguard/handshake/noise.rs @@ -0,0 +1,549 @@ +// DH +use x25519_dalek::PublicKey; +use x25519_dalek::StaticSecret; + +// HASH & MAC +use blake2::Blake2s; +use hmac::Hmac; + +// AEAD +use aead::{Aead, NewAead, Payload}; +use chacha20poly1305::ChaCha20Poly1305; + +use rand::{CryptoRng, RngCore}; + +use generic_array::typenum::*; +use generic_array::*; + +use clear_on_drop::clear::Clear; +use clear_on_drop::clear_stack_on_return; + +use subtle::ConstantTimeEq; + +use super::device::Device; +use super::messages::{NoiseInitiation, NoiseResponse}; +use super::messages::{TYPE_INITIATION, TYPE_RESPONSE}; +use super::peer::{Peer, State}; +use super::timestamp; +use super::types::*; + +use super::super::types::{KeyPair, Key}; + +use std::time::Instant; + +// HMAC hasher (generic construction) + +type HMACBlake2s = Hmac; + +// convenient alias to pass state temporarily into device.rs and back + +type TemporaryState = (u32, PublicKey, GenericArray, GenericArray); + +const SIZE_CK: usize = 32; +const SIZE_HS: usize = 32; +const SIZE_NONCE: usize = 8; +const SIZE_TAG: usize = 16; + +// number of pages to clear after sensitive call +const CLEAR_PAGES: usize = 1; + +// C := Hash(Construction) +const INITIAL_CK: [u8; SIZE_CK] = [ + 0x60, 0xe2, 0x6d, 0xae, 0xf3, 0x27, 0xef, 0xc0, 0x2e, 0xc3, 0x35, 0xe2, 0xa0, 0x25, 0xd2, 0xd0, + 0x16, 0xeb, 0x42, 0x06, 0xf8, 0x72, 0x77, 0xf5, 0x2d, 0x38, 0xd1, 0x98, 0x8b, 0x78, 0xcd, 0x36, +]; + +// H := Hash(C || Identifier) +const INITIAL_HS: [u8; SIZE_HS] = [ + 0x22, 0x11, 0xb3, 0x61, 0x08, 0x1a, 0xc5, 0x66, 0x69, 0x12, 0x43, 0xdb, 0x45, 0x8a, 0xd5, 0x32, + 0x2d, 0x9c, 0x6c, 0x66, 0x22, 0x93, 0xe8, 0xb7, 0x0e, 0xe1, 0x9c, 0x65, 0xba, 0x07, 0x9e, 0xf3, +]; + +const ZERO_NONCE: [u8; 12] = [0u8; 12]; + +macro_rules! HASH { + ( $($input:expr),* ) => {{ + use blake2::Digest; + let mut hsh = Blake2s::new(); + $( + hsh.input($input); + )* + hsh.result() + }}; +} + +macro_rules! HMAC { + ($key:expr, $($input:expr),*) => {{ + use hmac::Mac; + let mut mac = HMACBlake2s::new_varkey($key).unwrap(); + $( + mac.input($input); + )* + mac.result().code() + }}; +} + +macro_rules! KDF1 { + ($ck:expr, $input:expr) => {{ + let mut t0 = HMAC!($ck, $input); + let t1 = HMAC!(&t0, &[0x1]); + t0.clear(); + t1 + }}; +} + +macro_rules! KDF2 { + ($ck:expr, $input:expr) => {{ + let mut t0 = HMAC!($ck, $input); + let t1 = HMAC!(&t0, &[0x1]); + let t2 = HMAC!(&t0, &t1, &[0x2]); + t0.clear(); + (t1, t2) + }}; +} + +macro_rules! KDF3 { + ($ck:expr, $input:expr) => {{ + let mut t0 = HMAC!($ck, $input); + let t1 = HMAC!(&t0, &[0x1]); + let t2 = HMAC!(&t0, &t1, &[0x2]); + let t3 = HMAC!(&t0, &t2, &[0x3]); + t0.clear(); + (t1, t2, t3) + }}; +} + +macro_rules! SEAL { + ($key:expr, $ad:expr, $pt:expr, $ct:expr) => { + ChaCha20Poly1305::new(*GenericArray::from_slice($key)) + .encrypt(&ZERO_NONCE.into(), Payload { msg: $pt, aad: $ad }) + .map(|ct| $ct.copy_from_slice(&ct)) + .unwrap() + }; +} + +macro_rules! OPEN { + ($key:expr, $ad:expr, $pt:expr, $ct:expr) => { + ChaCha20Poly1305::new(*GenericArray::from_slice($key)) + .decrypt(&ZERO_NONCE.into(), Payload { msg: $ct, aad: $ad }) + .map_err(|_| HandshakeError::DecryptionFailure) + .map(|pt| $pt.copy_from_slice(&pt)) + }; +} + +#[cfg(test)] +mod tests { + use super::*; + + const IDENTIFIER: &[u8] = b"WireGuard v1 zx2c4 Jason@zx2c4.com"; + const CONSTRUCTION: &[u8] = b"Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"; + + /* Sanity check precomputed initial chain key + */ + #[test] + fn precomputed_chain_key() { + assert_eq!(INITIAL_CK[..], HASH!(CONSTRUCTION)[..]); + } + + /* Sanity check precomputed initial hash transcript + */ + #[test] + fn precomputed_hash() { + assert_eq!(INITIAL_HS[..], HASH!(INITIAL_CK, IDENTIFIER)[..]); + } + + /* Sanity check the HKDF macro + * + * Test vectors generated using WireGuard-Go + */ + #[test] + fn hkdf() { + let tests: Vec<(Vec, Vec, [u8; 32], [u8; 32], [u8; 32])> = vec![ + ( + vec![], + vec![], + [ + 0x83, 0x87, 0xb4, 0x6b, 0xf4, 0x3e, 0xcc, 0xfc, 0xf3, 0x49, 0x55, 0x2a, 0x09, + 0x5d, 0x83, 0x15, 0xc4, 0x05, 0x5b, 0xeb, 0x90, 0x20, 0x8f, 0xb1, 0xbe, 0x23, + 0xb8, 0x94, 0xbc, 0x2e, 0xd5, 0xd0, + ], + [ + 0x58, 0xa0, 0xe5, 0xf6, 0xfa, 0xef, 0xcc, 0xf4, 0x80, 0x7b, 0xff, 0x1f, 0x05, + 0xfa, 0x8a, 0x92, 0x17, 0x94, 0x57, 0x62, 0x04, 0x0b, 0xce, 0xc2, 0xf4, 0xb4, + 0xa6, 0x2b, 0xdf, 0xe0, 0xe8, 0x6e, + ], + [ + 0x0c, 0xe6, 0xea, 0x98, 0xec, 0x54, 0x8f, 0x8e, 0x28, 0x1e, 0x93, 0xe3, 0x2d, + 0xb6, 0x56, 0x21, 0xc4, 0x5e, 0xb1, 0x8d, 0xc6, 0xf0, 0xa7, 0xad, 0x94, 0x17, + 0x86, 0x10, 0xa2, 0xf7, 0x33, 0x8e, + ], + ), + ( + vec![0xde, 0xad, 0xbe, 0xef], + vec![], + [ + 0x55, 0x32, 0x9d, 0xc8, 0x0e, 0x69, 0x0f, 0xd8, 0x6b, 0xd9, 0x66, 0x1f, 0x08, + 0x51, 0xc9, 0xb3, 0x68, 0x6d, 0xf2, 0xb1, 0xfd, 0xa0, 0x34, 0x7b, 0xc3, 0xd2, + 0x79, 0x58, 0x25, 0x4b, 0x32, 0xc6, + ], + [ + 0x8d, 0xfc, 0x6d, 0x33, 0xa8, 0x11, 0x8f, 0xfe, 0x40, 0x8b, 0x31, 0xdd, 0xac, + 0x25, 0xf7, 0x2a, 0xee, 0x91, 0x15, 0xa4, 0x5b, 0x69, 0xba, 0x17, 0x6a, 0xd0, + 0x12, 0xb2, 0x43, 0x83, 0x4f, 0xee, + ], + [ + 0xd6, 0x9e, 0x85, 0x2a, 0x28, 0x96, 0x56, 0x9e, 0xa5, 0x4a, 0x67, 0x96, 0x9a, + 0xa1, 0x80, 0x02, 0x87, 0x92, 0x1d, 0xac, 0x53, 0xce, 0x6d, 0xb4, 0xb4, 0xe1, + 0x21, 0x92, 0xf2, 0x63, 0xc4, 0xc4, + ], + ), + ]; + + for (key, input, t0, t1, t2) in &tests { + let tt0 = KDF1!(key, input); + debug_assert_eq!(tt0[..], t0[..]); + + let (tt0, tt1) = KDF2!(key, input); + debug_assert_eq!(tt0[..], t0[..]); + debug_assert_eq!(tt1[..], t1[..]); + + let (tt0, tt1, tt2) = KDF3!(key, input); + debug_assert_eq!(tt0[..], t0[..]); + debug_assert_eq!(tt1[..], t1[..]); + debug_assert_eq!(tt2[..], t2[..]); + } + } +} + +pub fn create_initiation( + rng: &mut R, + device: &Device, + peer: &Peer, + sender: u32, + msg: &mut NoiseInitiation, +) -> Result<(), HandshakeError> { + clear_stack_on_return(CLEAR_PAGES, || { + // initialize state + + let ck = INITIAL_CK; + let hs = INITIAL_HS; + let hs = HASH!(&hs, peer.pk.as_bytes()); + + msg.f_type.set(TYPE_INITIATION as u32); + msg.f_sender.set(sender); + + // (E_priv, E_pub) := DH-Generate() + + let eph_sk = StaticSecret::new(rng); + let eph_pk = PublicKey::from(&eph_sk); + + // C := Kdf(C, E_pub) + + let ck = KDF1!(&ck, eph_pk.as_bytes()); + + // msg.ephemeral := E_pub + + msg.f_ephemeral = *eph_pk.as_bytes(); + + // H := HASH(H, msg.ephemeral) + + let hs = HASH!(&hs, msg.f_ephemeral); + + // (C, k) := Kdf2(C, DH(E_priv, S_pub)) + + let (ck, key) = KDF2!(&ck, eph_sk.diffie_hellman(&peer.pk).as_bytes()); + + // msg.static := Aead(k, 0, S_pub, H) + + SEAL!( + &key, + &hs, // ad + device.pk.as_bytes(), // pt + &mut msg.f_static // ct || tag + ); + + // H := Hash(H || msg.static) + + let hs = HASH!(&hs, &msg.f_static[..]); + + // (C, k) := Kdf2(C, DH(S_priv, S_pub)) + + let (ck, key) = KDF2!(&ck, peer.ss.as_bytes()); + + // msg.timestamp := Aead(k, 0, Timestamp(), H) + + SEAL!( + &key, + &hs, // ad + ×tamp::now(), // pt + &mut msg.f_timestamp // ct || tag + ); + + // H := Hash(H || msg.timestamp) + + let hs = HASH!(&hs, &msg.f_timestamp); + + // update state of peer + + *peer.state.lock() = State::InitiationSent { + hs, + ck, + eph_sk, + sender, + }; + + Ok(()) + }) +} + +pub fn consume_initiation<'a>( + device: &'a Device, + msg: &NoiseInitiation, +) -> Result<(&'a Peer, TemporaryState), HandshakeError> { + clear_stack_on_return(CLEAR_PAGES, || { + // initialize new state + + let ck = INITIAL_CK; + let hs = INITIAL_HS; + let hs = HASH!(&hs, device.pk.as_bytes()); + + // C := Kdf(C, E_pub) + + let ck = KDF1!(&ck, &msg.f_ephemeral); + + // H := HASH(H, msg.ephemeral) + + let hs = HASH!(&hs, &msg.f_ephemeral); + + // (C, k) := Kdf2(C, DH(E_priv, S_pub)) + + let eph_r_pk = PublicKey::from(msg.f_ephemeral); + let (ck, key) = KDF2!(&ck, device.sk.diffie_hellman(&eph_r_pk).as_bytes()); + + // msg.static := Aead(k, 0, S_pub, H) + + let mut pk = [0u8; 32]; + + OPEN!( + &key, + &hs, // ad + &mut pk, // pt + &msg.f_static // ct || tag + )?; + + let peer = device.lookup_pk(&PublicKey::from(pk))?; + + // reset initiation state + + *peer.state.lock() = State::Reset; + + // H := Hash(H || msg.static) + + let hs = HASH!(&hs, &msg.f_static[..]); + + // (C, k) := Kdf2(C, DH(S_priv, S_pub)) + + let (ck, key) = KDF2!(&ck, peer.ss.as_bytes()); + + // msg.timestamp := Aead(k, 0, Timestamp(), H) + + let mut ts = timestamp::ZERO; + + OPEN!( + &key, + &hs, // ad + &mut ts, // pt + &msg.f_timestamp // ct || tag + )?; + + // check and update timestamp + + peer.check_replay_flood(device, &ts)?; + + // H := Hash(H || msg.timestamp) + + let hs = HASH!(&hs, &msg.f_timestamp); + + // return state (to create response) + + Ok((peer, (msg.f_sender.get(), eph_r_pk, hs, ck))) + }) +} + +pub fn create_response( + rng: &mut R, + peer: &Peer, + sender: u32, // sending identifier + state: TemporaryState, // state from "consume_initiation" + msg: &mut NoiseResponse, // resulting response +) -> Result { + clear_stack_on_return(CLEAR_PAGES, || { + // unpack state + + let (receiver, eph_r_pk, hs, ck) = state; + + msg.f_type.set(TYPE_RESPONSE as u32); + msg.f_sender.set(sender); + msg.f_receiver.set(receiver); + + // (E_priv, E_pub) := DH-Generate() + + let eph_sk = StaticSecret::new(rng); + let eph_pk = PublicKey::from(&eph_sk); + + // C := Kdf1(C, E_pub) + + let ck = KDF1!(&ck, eph_pk.as_bytes()); + + // msg.ephemeral := E_pub + + msg.f_ephemeral = *eph_pk.as_bytes(); + + // H := Hash(H || msg.ephemeral) + + let hs = HASH!(&hs, &msg.f_ephemeral); + + // C := Kdf1(C, DH(E_priv, E_pub)) + + let ck = KDF1!(&ck, eph_sk.diffie_hellman(&eph_r_pk).as_bytes()); + + // C := Kdf1(C, DH(E_priv, S_pub)) + + let ck = KDF1!(&ck, eph_sk.diffie_hellman(&peer.pk).as_bytes()); + + // (C, tau, k) := Kdf3(C, Q) + + let (ck, tau, key) = KDF3!(&ck, &peer.psk); + + // H := Hash(H || tau) + + let hs = HASH!(&hs, tau); + + // msg.empty := Aead(k, 0, [], H) + + SEAL!( + &key, + &hs, // ad + &[], // pt + &mut msg.f_empty // \epsilon || tag + ); + + // Not strictly needed + // let hs = HASH!(&hs, &msg.f_empty_tag); + + // derive key-pair + + let (key_recv, key_send) = KDF2!(&ck, &[]); + + // return unconfirmed key-pair + + Ok(KeyPair { + birth: Instant::now(), + initiator: false, + send: Key { + id: sender, + key: key_send.into(), + }, + recv: Key { + id: receiver, + key: key_recv.into(), + }, + }) + }) +} + +/* The state lock is released while processing the message to + * allow concurrent processing of potential responses to the initiation, + * in order to better mitigate DoS from malformed response messages. + */ +pub fn consume_response(device: &Device, msg: &NoiseResponse) -> Result { + clear_stack_on_return(CLEAR_PAGES, || { + // retrieve peer and copy initiation state + let peer = device.lookup_id(msg.f_receiver.get())?; + + let (hs, ck, sender, eph_sk) = match *peer.state.lock() { + State::InitiationSent { + hs, + ck, + sender, + ref eph_sk, + } => Ok((hs, ck, sender, StaticSecret::from(eph_sk.to_bytes()))), + _ => Err(HandshakeError::InvalidState), + }?; + + // C := Kdf1(C, E_pub) + + let ck = KDF1!(&ck, &msg.f_ephemeral); + + // H := Hash(H || msg.ephemeral) + + let hs = HASH!(&hs, &msg.f_ephemeral); + + // C := Kdf1(C, DH(E_priv, E_pub)) + + let eph_r_pk = PublicKey::from(msg.f_ephemeral); + let ck = KDF1!(&ck, eph_sk.diffie_hellman(&eph_r_pk).as_bytes()); + + // C := Kdf1(C, DH(E_priv, S_pub)) + + let ck = KDF1!(&ck, device.sk.diffie_hellman(&eph_r_pk).as_bytes()); + + // (C, tau, k) := Kdf3(C, Q) + + let (ck, tau, key) = KDF3!(&ck, &peer.psk); + + // H := Hash(H || tau) + + let hs = HASH!(&hs, tau); + + // msg.empty := Aead(k, 0, [], H) + + OPEN!( + &key, + &hs, // ad + &mut [], // pt + &msg.f_empty // \epsilon || tag + )?; + + // derive key-pair + + let birth = Instant::now(); + let (key_send, key_recv) = KDF2!(&ck, &[]); + + // check for new initiation sent while lock released + + let mut state = peer.state.lock(); + let update = match *state { + State::InitiationSent { + eph_sk: ref old, .. + } => old.to_bytes().ct_eq(&eph_sk.to_bytes()).into(), + _ => false, + }; + + if update { + // null the initiation state + // (to avoid replay of this response message) + *state = State::Reset; + + // return confirmed key-pair + Ok(( + Some(peer.pk), + None, + Some(KeyPair { + birth, + initiator: true, + send: Key { + id: sender, + key: key_send.into(), + }, + recv: Key { + id: msg.f_sender.get(), + key: key_recv.into(), + }, + }), + )) + } else { + Err(HandshakeError::InvalidState) + } + }) +} diff --git a/src/wireguard/handshake/peer.rs b/src/wireguard/handshake/peer.rs new file mode 100644 index 0000000..c9e1c40 --- /dev/null +++ b/src/wireguard/handshake/peer.rs @@ -0,0 +1,142 @@ +use spin::Mutex; + +use std::mem; +use std::time::{Duration, Instant}; + +use generic_array::typenum::U32; +use generic_array::GenericArray; + +use x25519_dalek::PublicKey; +use x25519_dalek::SharedSecret; +use x25519_dalek::StaticSecret; + +use clear_on_drop::clear::Clear; + +use super::device::Device; +use super::macs; +use super::timestamp; +use super::types::*; + +const TIME_BETWEEN_INITIATIONS: Duration = Duration::from_millis(20); + +/* Represents the recomputation and state of a peer. + * + * This type is only for internal use and not exposed. + */ +pub struct Peer { + // mutable state + pub(crate) state: Mutex, + pub(crate) timestamp: Mutex>, + pub(crate) last_initiation_consumption: Mutex>, + + // state related to DoS mitigation fields + pub(crate) macs: Mutex, + + // constant state + pub(crate) pk: PublicKey, // public key of peer + pub(crate) ss: SharedSecret, // precomputed DH(static, static) + pub(crate) psk: Psk, // psk of peer +} + +pub enum State { + Reset, + InitiationSent { + sender: u32, // assigned sender id + eph_sk: StaticSecret, + hs: GenericArray, + ck: GenericArray, + }, +} + +impl Drop for State { + fn drop(&mut self) { + match self { + State::InitiationSent { hs, ck, .. } => { + // eph_sk already cleared by dalek-x25519 + hs.clear(); + ck.clear(); + } + _ => (), + } + } +} + +impl Peer { + pub fn new( + pk: PublicKey, // public key of peer + ss: SharedSecret, // precomputed DH(static, static) + ) -> Self { + Self { + macs: Mutex::new(macs::Generator::new(pk)), + state: Mutex::new(State::Reset), + timestamp: Mutex::new(None), + last_initiation_consumption: Mutex::new(None), + pk: pk, + ss: ss, + psk: [0u8; 32], + } + } + + /// Set the state of the peer unconditionally + /// + /// # Arguments + /// + pub fn set_state(&self, state_new: State) { + *self.state.lock() = state_new; + } + + pub fn reset_state(&self) -> Option { + match mem::replace(&mut *self.state.lock(), State::Reset) { + State::InitiationSent { sender, .. } => Some(sender), + _ => None, + } + } + + /// Set the mutable state of the peer conditioned on the timestamp being newer + /// + /// # Arguments + /// + /// * st_new - The updated state of the peer + /// * ts_new - The associated timestamp + pub fn check_replay_flood( + &self, + device: &Device, + timestamp_new: ×tamp::TAI64N, + ) -> Result<(), HandshakeError> { + let mut state = self.state.lock(); + let mut timestamp = self.timestamp.lock(); + let mut last_initiation_consumption = self.last_initiation_consumption.lock(); + + // check replay attack + match *timestamp { + Some(timestamp_old) => { + if !timestamp::compare(×tamp_old, ×tamp_new) { + return Err(HandshakeError::OldTimestamp); + } + } + _ => (), + }; + + // check flood attack + match *last_initiation_consumption { + Some(last) => { + if last.elapsed() < TIME_BETWEEN_INITIATIONS { + return Err(HandshakeError::InitiationFlood); + } + } + _ => (), + } + + // reset state + match *state { + State::InitiationSent { sender, .. } => device.release(sender), + _ => (), + } + + // update replay & flood protection + *state = State::Reset; + *timestamp = Some(*timestamp_new); + *last_initiation_consumption = Some(Instant::now()); + Ok(()) + } +} diff --git a/src/wireguard/handshake/ratelimiter.rs b/src/wireguard/handshake/ratelimiter.rs new file mode 100644 index 0000000..63d728c --- /dev/null +++ b/src/wireguard/handshake/ratelimiter.rs @@ -0,0 +1,199 @@ +use spin; +use std::collections::HashMap; +use std::net::IpAddr; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Condvar, Mutex}; +use std::thread; +use std::time::{Duration, Instant}; + +const PACKETS_PER_SECOND: u64 = 20; +const PACKETS_BURSTABLE: u64 = 5; +const PACKET_COST: u64 = 1_000_000_000 / PACKETS_PER_SECOND; +const MAX_TOKENS: u64 = PACKET_COST * PACKETS_BURSTABLE; + +const GC_INTERVAL: Duration = Duration::from_secs(1); + +struct Entry { + pub last_time: Instant, + pub tokens: u64, +} + +pub struct RateLimiter(Arc); + +struct RateLimiterInner { + gc_running: AtomicBool, + gc_dropped: (Mutex, Condvar), + table: spin::RwLock>>, +} + +impl Drop for RateLimiter { + fn drop(&mut self) { + // wake up & terminate any lingering GC thread + let &(ref lock, ref cvar) = &self.0.gc_dropped; + let mut dropped = lock.lock().unwrap(); + *dropped = true; + cvar.notify_all(); + } +} + +impl RateLimiter { + pub fn new() -> Self { + RateLimiter(Arc::new(RateLimiterInner { + gc_dropped: (Mutex::new(false), Condvar::new()), + gc_running: AtomicBool::from(false), + table: spin::RwLock::new(HashMap::new()), + })) + } + + pub fn allow(&self, addr: &IpAddr) -> bool { + // check if allowed + let allowed = { + // check for existing entry (only requires read lock) + if let Some(entry) = self.0.table.read().get(addr) { + // update existing entry + let mut entry = entry.lock(); + + // add tokens earned since last time + entry.tokens = MAX_TOKENS + .min(entry.tokens + u64::from(entry.last_time.elapsed().subsec_nanos())); + entry.last_time = Instant::now(); + + // subtract cost of packet + if entry.tokens > PACKET_COST { + entry.tokens -= PACKET_COST; + return true; + } else { + return false; + } + } + + // add new entry (write lock) + self.0.table.write().insert( + *addr, + spin::Mutex::new(Entry { + last_time: Instant::now(), + tokens: MAX_TOKENS - PACKET_COST, + }), + ); + true + }; + + // check that GC thread is scheduled + if !self.0.gc_running.swap(true, Ordering::Relaxed) { + let limiter = self.0.clone(); + thread::spawn(move || { + let &(ref lock, ref cvar) = &limiter.gc_dropped; + let mut dropped = lock.lock().unwrap(); + while !*dropped { + // garbage collect + { + let mut tw = limiter.table.write(); + tw.retain(|_, ref mut entry| { + entry.lock().last_time.elapsed() <= GC_INTERVAL + }); + if tw.len() == 0 { + limiter.gc_running.store(false, Ordering::Relaxed); + return; + } + } + + // wait until stopped or new GC (~1 every sec) + let res = cvar.wait_timeout(dropped, GC_INTERVAL).unwrap(); + dropped = res.0; + } + }); + } + + allowed + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std; + + struct Result { + allowed: bool, + text: &'static str, + wait: Duration, + } + + #[test] + fn test_ratelimiter() { + let ratelimiter = RateLimiter::new(); + let mut expected = vec![]; + let ips = vec![ + "127.0.0.1".parse().unwrap(), + "192.168.1.1".parse().unwrap(), + "172.167.2.3".parse().unwrap(), + "97.231.252.215".parse().unwrap(), + "248.97.91.167".parse().unwrap(), + "188.208.233.47".parse().unwrap(), + "104.2.183.179".parse().unwrap(), + "72.129.46.120".parse().unwrap(), + "2001:0db8:0a0b:12f0:0000:0000:0000:0001".parse().unwrap(), + "f5c2:818f:c052:655a:9860:b136:6894:25f0".parse().unwrap(), + "b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc".parse().unwrap(), + "a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918".parse().unwrap(), + "ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445".parse().unwrap(), + "3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4".parse().unwrap(), + ]; + + for _ in 0..PACKETS_BURSTABLE { + expected.push(Result { + allowed: true, + wait: Duration::new(0, 0), + text: "inital burst", + }); + } + + expected.push(Result { + allowed: false, + wait: Duration::new(0, 0), + text: "after burst", + }); + + expected.push(Result { + allowed: true, + wait: Duration::new(0, PACKET_COST as u32), + text: "filling tokens for single packet", + }); + + expected.push(Result { + allowed: false, + wait: Duration::new(0, 0), + text: "not having refilled enough", + }); + + expected.push(Result { + allowed: true, + wait: Duration::new(0, 2 * PACKET_COST as u32), + text: "filling tokens for 2 * packet burst", + }); + + expected.push(Result { + allowed: true, + wait: Duration::new(0, 0), + text: "second packet in 2 packet burst", + }); + + expected.push(Result { + allowed: false, + wait: Duration::new(0, 0), + text: "packet following 2 packet burst", + }); + + for item in expected { + std::thread::sleep(item.wait); + for ip in ips.iter() { + if ratelimiter.allow(&ip) != item.allowed { + panic!( + "test failed for {} on {}. expected: {}, got: {}", + ip, item.text, item.allowed, !item.allowed + ) + } + } + } + } +} diff --git a/src/wireguard/handshake/timestamp.rs b/src/wireguard/handshake/timestamp.rs new file mode 100644 index 0000000..b5bd9f0 --- /dev/null +++ b/src/wireguard/handshake/timestamp.rs @@ -0,0 +1,32 @@ +use std::time::{SystemTime, UNIX_EPOCH}; + +pub type TAI64N = [u8; 12]; + +const TAI64_EPOCH: u64 = 0x400000000000000a; + +pub const ZERO: TAI64N = [0u8; 12]; + +pub fn now() -> TAI64N { + // get system time as duration + let sysnow = SystemTime::now(); + let delta = sysnow.duration_since(UNIX_EPOCH).unwrap(); + + // convert to tai64n + let tai64_secs = delta.as_secs() + TAI64_EPOCH; + let tai64_nano = delta.subsec_nanos(); + + // serialize + let mut res = [0u8; 12]; + res[..8].copy_from_slice(&tai64_secs.to_be_bytes()[..]); + res[8..].copy_from_slice(&tai64_nano.to_be_bytes()[..]); + res +} + +pub fn compare(old: &TAI64N, new: &TAI64N) -> bool { + for i in 0..12 { + if new[i] > old[i] { + return true; + } + } + return false; +} diff --git a/src/wireguard/handshake/types.rs b/src/wireguard/handshake/types.rs new file mode 100644 index 0000000..5f984cc --- /dev/null +++ b/src/wireguard/handshake/types.rs @@ -0,0 +1,90 @@ +use std::error::Error; +use std::fmt; + +use x25519_dalek::PublicKey; + +use super::super::types::KeyPair; + +/* Internal types for the noise IKpsk2 implementation */ + +// config error + +#[derive(Debug)] +pub struct ConfigError(String); + +impl ConfigError { + pub fn new(s: &str) -> Self { + ConfigError(s.to_string()) + } +} + +impl fmt::Display for ConfigError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "ConfigError({})", self.0) + } +} + +impl Error for ConfigError { + fn description(&self) -> &str { + &self.0 + } + + fn source(&self) -> Option<&(dyn Error + 'static)> { + None + } +} + +// handshake error + +#[derive(Debug)] +pub enum HandshakeError { + DecryptionFailure, + UnknownPublicKey, + UnknownReceiverId, + InvalidMessageFormat, + OldTimestamp, + InvalidState, + InvalidMac1, + RateLimited, + InitiationFlood, +} + +impl fmt::Display for HandshakeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + HandshakeError::DecryptionFailure => write!(f, "Failed to AEAD:OPEN"), + HandshakeError::UnknownPublicKey => write!(f, "Unknown public key"), + HandshakeError::UnknownReceiverId => { + write!(f, "Receiver id not allocated to any handshake") + } + HandshakeError::InvalidMessageFormat => write!(f, "Invalid handshake message format"), + HandshakeError::OldTimestamp => write!(f, "Timestamp is less/equal to the newest"), + HandshakeError::InvalidState => write!(f, "Message does not apply to handshake state"), + HandshakeError::InvalidMac1 => write!(f, "Message has invalid mac1 field"), + HandshakeError::RateLimited => write!(f, "Message was dropped by rate limiter"), + HandshakeError::InitiationFlood => { + write!(f, "Message was dropped because of initiation flood") + } + } + } +} + +impl Error for HandshakeError { + fn description(&self) -> &str { + "Generic Handshake Error" + } + + fn source(&self) -> Option<&(dyn Error + 'static)> { + None + } +} + +pub type Output = ( + Option, // external identifier associated with peer + Option>, // message to send + Option, // resulting key-pair of successful handshake +); + +// preshared key + +pub type Psk = [u8; 32]; diff --git a/src/wireguard/mod.rs b/src/wireguard/mod.rs new file mode 100644 index 0000000..9417e57 --- /dev/null +++ b/src/wireguard/mod.rs @@ -0,0 +1,23 @@ +mod wireguard; +// mod config; +mod constants; +mod timers; + +mod handshake; +mod router; +mod types; + +#[cfg(test)] +mod tests; + +/// The WireGuard sub-module contains a pure, configurable implementation of WireGuard. +/// The implementation is generic over: +/// +/// - TUN type, specifying how packets are received on the interface side: a reader/writer and MTU reporting interface. +/// - Bind type, specifying how WireGuard messages are sent/received from the internet and what constitutes an "endpoint" + +pub use wireguard::{Wireguard, Peer}; + +pub use types::bind; +pub use types::tun; +pub use types::Endpoint; \ No newline at end of file diff --git a/src/wireguard/router/anti_replay.rs b/src/wireguard/router/anti_replay.rs new file mode 100644 index 0000000..b0838bd --- /dev/null +++ b/src/wireguard/router/anti_replay.rs @@ -0,0 +1,157 @@ +use std::mem; + +// Implementation of RFC 6479. +// https://tools.ietf.org/html/rfc6479 + +#[cfg(target_pointer_width = "64")] +type Word = u64; + +#[cfg(target_pointer_width = "64")] +const REDUNDANT_BIT_SHIFTS: usize = 6; + +#[cfg(target_pointer_width = "32")] +type Word = u32; + +#[cfg(target_pointer_width = "32")] +const REDUNDANT_BIT_SHIFTS: usize = 5; + +const SIZE_OF_WORD: usize = mem::size_of::() * 8; + +const BITMAP_BITLEN: usize = 2048; +const BITMAP_LEN: usize = (BITMAP_BITLEN / SIZE_OF_WORD); +const BITMAP_INDEX_MASK: u64 = BITMAP_LEN as u64 - 1; +const BITMAP_LOC_MASK: u64 = (SIZE_OF_WORD - 1) as u64; +const WINDOW_SIZE: u64 = (BITMAP_BITLEN - SIZE_OF_WORD) as u64; + +pub struct AntiReplay { + bitmap: [Word; BITMAP_LEN], + last: u64, +} + +impl Default for AntiReplay { + fn default() -> Self { + AntiReplay::new() + } +} + +impl AntiReplay { + pub fn new() -> Self { + debug_assert_eq!(1 << REDUNDANT_BIT_SHIFTS, SIZE_OF_WORD); + debug_assert_eq!(BITMAP_BITLEN % SIZE_OF_WORD, 0); + AntiReplay { + last: 0, + bitmap: [0; BITMAP_LEN], + } + } + + // Returns true if check is passed, i.e., not a replay or too old. + // + // Unlike RFC 6479, zero is allowed. + fn check(&self, seq: u64) -> bool { + // Larger is always good. + if seq > self.last { + return true; + } + + if self.last - seq > WINDOW_SIZE { + return false; + } + + let bit_location = seq & BITMAP_LOC_MASK; + let index = (seq >> REDUNDANT_BIT_SHIFTS) & BITMAP_INDEX_MASK; + + self.bitmap[index as usize] & (1 << bit_location) == 0 + } + + // Should only be called if check returns true. + fn update_store(&mut self, seq: u64) { + debug_assert!(self.check(seq)); + + let index = seq >> REDUNDANT_BIT_SHIFTS; + + if seq > self.last { + let index_cur = self.last >> REDUNDANT_BIT_SHIFTS; + let diff = index - index_cur; + + if diff >= BITMAP_LEN as u64 { + self.bitmap = [0; BITMAP_LEN]; + } else { + for i in 0..diff { + let real_index = (index_cur + i + 1) & BITMAP_INDEX_MASK; + self.bitmap[real_index as usize] = 0; + } + } + + self.last = seq; + } + + let index = index & BITMAP_INDEX_MASK; + let bit_location = seq & BITMAP_LOC_MASK; + self.bitmap[index as usize] |= 1 << bit_location; + } + + /// Checks and marks a sequence number in the replay filter + /// + /// # Arguments + /// + /// - seq: Sequence number check for replay and add to filter + /// + /// # Returns + /// + /// Ok(()) if sequence number is valid (not marked and not behind the moving window). + /// Err if the sequence number is invalid (already marked or "too old"). + pub fn update(&mut self, seq: u64) -> bool { + if self.check(seq) { + self.update_store(seq); + true + } else { + false + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn anti_replay() { + let mut ar = AntiReplay::new(); + + for i in 0..20000 { + assert!(ar.update(i)); + } + + for i in (0..20000).rev() { + assert!(!ar.check(i)); + } + + assert!(ar.update(65536)); + for i in (65536 - WINDOW_SIZE)..65535 { + assert!(ar.update(i)); + } + + for i in (65536 - 10 * WINDOW_SIZE)..65535 { + assert!(!ar.check(i)); + } + + assert!(ar.update(66000)); + for i in 65537..66000 { + assert!(ar.update(i)); + } + for i in 65537..66000 { + assert_eq!(ar.update(i), false); + } + + // Test max u64. + let next = u64::max_value(); + assert!(ar.update(next)); + assert!(!ar.check(next)); + for i in (next - WINDOW_SIZE)..next { + assert!(ar.update(i)); + } + for i in (next - 20 * WINDOW_SIZE)..next { + assert!(!ar.check(i)); + } + } +} diff --git a/src/wireguard/router/constants.rs b/src/wireguard/router/constants.rs new file mode 100644 index 0000000..0ca824a --- /dev/null +++ b/src/wireguard/router/constants.rs @@ -0,0 +1,7 @@ +// WireGuard semantics constants + +pub const MAX_STAGED_PACKETS: usize = 128; + +// performance constants + +pub const WORKER_QUEUE_SIZE: usize = MAX_STAGED_PACKETS; diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs new file mode 100644 index 0000000..455020c --- /dev/null +++ b/src/wireguard/router/device.rs @@ -0,0 +1,243 @@ +use std::collections::HashMap; +use std::net::{Ipv4Addr, Ipv6Addr}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::mpsc::sync_channel; +use std::sync::mpsc::SyncSender; +use std::sync::Arc; +use std::thread; +use std::time::Instant; + +use log::debug; +use spin::{Mutex, RwLock}; +use treebitmap::IpLookupTable; +use zerocopy::LayoutVerified; + +use super::anti_replay::AntiReplay; +use super::constants::*; +use super::ip::*; +use super::messages::{TransportHeader, TYPE_TRANSPORT}; +use super::peer::{new_peer, Peer, PeerInner}; +use super::types::{Callbacks, RouterError}; +use super::workers::{worker_parallel, JobParallel, Operation}; +use super::SIZE_MESSAGE_PREFIX; + +use super::super::types::{bind, tun, Endpoint, KeyPair}; + +pub struct DeviceInner> { + // inbound writer (TUN) + pub inbound: T, + + // outbound writer (Bind) + pub outbound: RwLock>, + + // routing + pub recv: RwLock>>>, // receiver id -> decryption state + pub ipv4: RwLock>>>, // ipv4 cryptkey routing + pub ipv6: RwLock>>>, // ipv6 cryptkey routing + + // work queues + pub queue_next: AtomicUsize, // next round-robin index + pub queues: Mutex>>, // work queues (1 per thread) +} + +pub struct EncryptionState { + pub key: [u8; 32], // encryption key + pub id: u32, // receiver id + pub nonce: u64, // next available nonce + pub death: Instant, // (birth + reject-after-time - keepalive-timeout - rekey-timeout) +} + +pub struct DecryptionState> { + pub keypair: Arc, + pub confirmed: AtomicBool, + pub protector: Mutex, + pub peer: Arc>, + pub death: Instant, // time when the key can no longer be used for decryption +} + +pub struct Device> { + state: Arc>, // reference to device state + handles: Vec>, // join handles for workers +} + +impl> Drop for Device { + fn drop(&mut self) { + debug!("router: dropping device"); + + // drop all queues + { + let mut queues = self.state.queues.lock(); + while queues.pop().is_some() {} + } + + // join all worker threads + while match self.handles.pop() { + Some(handle) => { + handle.thread().unpark(); + handle.join().unwrap(); + true + } + _ => false, + } {} + + debug!("router: device dropped"); + } +} + +#[inline(always)] +fn get_route>( + device: &Arc>, + packet: &[u8], +) -> Option>> { + // ensure version access within bounds + if packet.len() < 1 { + return None; + }; + + // cast to correct IP header + match packet[0] >> 4 { + VERSION_IP4 => { + // check length and cast to IPv4 header + let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = + LayoutVerified::new_from_prefix(packet)?; + + // lookup destination address + device + .ipv4 + .read() + .longest_match(Ipv4Addr::from(header.f_destination)) + .and_then(|(_, _, p)| Some(p.clone())) + } + VERSION_IP6 => { + // check length and cast to IPv6 header + let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) = + LayoutVerified::new_from_prefix(packet)?; + + // lookup destination address + device + .ipv6 + .read() + .longest_match(Ipv6Addr::from(header.f_destination)) + .and_then(|(_, _, p)| Some(p.clone())) + } + _ => None, + } +} + +impl> Device { + pub fn new(num_workers: usize, tun: T) -> Device { + // allocate shared device state + let inner = DeviceInner { + inbound: tun, + outbound: RwLock::new(None), + queues: Mutex::new(Vec::with_capacity(num_workers)), + queue_next: AtomicUsize::new(0), + recv: RwLock::new(HashMap::new()), + ipv4: RwLock::new(IpLookupTable::new()), + ipv6: RwLock::new(IpLookupTable::new()), + }; + + // start worker threads + let mut threads = Vec::with_capacity(num_workers); + for _ in 0..num_workers { + let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE); + inner.queues.lock().push(tx); + threads.push(thread::spawn(move || worker_parallel(rx))); + } + + // return exported device handle + Device { + state: Arc::new(inner), + handles: threads, + } + } + + /// A new secret key has been set for the device. + /// According to WireGuard semantics, this should cause all "sending" keys to be discarded. + pub fn new_sk(&self) {} + + /// Adds a new peer to the device + /// + /// # Returns + /// + /// A atomic ref. counted peer (with liftime matching the device) + pub fn new_peer(&self, opaque: C::Opaque) -> Peer { + new_peer(self.state.clone(), opaque) + } + + /// Cryptkey routes and sends a plaintext message (IP packet) + /// + /// # Arguments + /// + /// - msg: IP packet to crypt-key route + /// + pub fn send(&self, msg: Vec) -> Result<(), RouterError> { + // ignore header prefix (for in-place transport message construction) + let packet = &msg[SIZE_MESSAGE_PREFIX..]; + + // lookup peer based on IP packet destination address + let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptKeyRoute)?; + + // schedule for encryption and transmission to peer + if let Some(job) = peer.send_job(msg, true) { + debug_assert_eq!(job.1.op, Operation::Encryption); + + // add job to worker queue + let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst); + let queues = self.state.queues.lock(); + queues[idx % queues.len()].send(job).unwrap(); + } + + Ok(()) + } + + /// Receive an encrypted transport message + /// + /// # Arguments + /// + /// - src: Source address of the packet + /// - msg: Encrypted transport message + /// + /// # Returns + /// + /// + pub fn recv(&self, src: E, msg: Vec) -> Result<(), RouterError> { + // parse / cast + let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) { + Some(v) => v, + None => { + return Err(RouterError::MalformedTransportMessage); + } + }; + let header: LayoutVerified<&[u8], TransportHeader> = header; + debug_assert!( + header.f_type.get() == TYPE_TRANSPORT as u32, + "this should be checked by the message type multiplexer" + ); + + // lookup peer based on receiver id + let dec = self.state.recv.read(); + let dec = dec + .get(&header.f_receiver.get()) + .ok_or(RouterError::UnknownReceiverId)?; + + // schedule for decryption and TUN write + if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) { + debug_assert_eq!(job.1.op, Operation::Decryption); + + // add job to worker queue + let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst); + let queues = self.state.queues.lock(); + queues[idx % queues.len()].send(job).unwrap(); + } + + Ok(()) + } + + /// Set outbound writer + /// + /// + pub fn set_outbound_writer(&self, new: B) { + *self.state.outbound.write() = Some(new); + } +} diff --git a/src/wireguard/router/ip.rs b/src/wireguard/router/ip.rs new file mode 100644 index 0000000..e66144f --- /dev/null +++ b/src/wireguard/router/ip.rs @@ -0,0 +1,26 @@ +use byteorder::BigEndian; +use zerocopy::byteorder::U16; +use zerocopy::{AsBytes, FromBytes}; + +pub const VERSION_IP4: u8 = 4; +pub const VERSION_IP6: u8 = 6; + +#[repr(packed)] +#[derive(Copy, Clone, FromBytes, AsBytes)] +pub struct IPv4Header { + _f_space1: [u8; 2], + pub f_total_len: U16, + _f_space2: [u8; 8], + pub f_source: [u8; 4], + pub f_destination: [u8; 4], +} + +#[repr(packed)] +#[derive(Copy, Clone, FromBytes, AsBytes)] +pub struct IPv6Header { + _f_space1: [u8; 4], + pub f_len: U16, + _f_space2: [u8; 2], + pub f_source: [u8; 16], + pub f_destination: [u8; 16], +} diff --git a/src/wireguard/router/messages.rs b/src/wireguard/router/messages.rs new file mode 100644 index 0000000..bf4d13b --- /dev/null +++ b/src/wireguard/router/messages.rs @@ -0,0 +1,13 @@ +use byteorder::LittleEndian; +use zerocopy::byteorder::{U32, U64}; +use zerocopy::{AsBytes, FromBytes}; + +pub const TYPE_TRANSPORT: u32 = 4; + +#[repr(packed)] +#[derive(Copy, Clone, FromBytes, AsBytes)] +pub struct TransportHeader { + pub f_type: U32, + pub f_receiver: U32, + pub f_counter: U64, +} diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs new file mode 100644 index 0000000..7a29cd9 --- /dev/null +++ b/src/wireguard/router/mod.rs @@ -0,0 +1,22 @@ +mod anti_replay; +mod constants; +mod device; +mod ip; +mod messages; +mod peer; +mod types; +mod workers; + +#[cfg(test)] +mod tests; + +use messages::TransportHeader; +use std::mem; + +pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::(); +pub const CAPACITY_MESSAGE_POSTFIX: usize = 16; + +pub use messages::TYPE_TRANSPORT; +pub use device::Device; +pub use peer::Peer; +pub use types::Callbacks; diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs new file mode 100644 index 0000000..4f47604 --- /dev/null +++ b/src/wireguard/router/peer.rs @@ -0,0 +1,611 @@ +use std::mem; +use std::net::{IpAddr, SocketAddr}; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::sync::mpsc::{sync_channel, SyncSender}; +use std::sync::Arc; +use std::thread; + +use arraydeque::{ArrayDeque, Wrapping}; +use log::debug; +use spin::Mutex; +use treebitmap::address::Address; +use treebitmap::IpLookupTable; +use zerocopy::LayoutVerified; + +use super::super::constants::*; +use super::super::types::{bind, tun, Endpoint, KeyPair}; + +use super::anti_replay::AntiReplay; +use super::device::DecryptionState; +use super::device::DeviceInner; +use super::device::EncryptionState; +use super::messages::TransportHeader; + +use futures::*; + +use super::workers::Operation; +use super::workers::{worker_inbound, worker_outbound}; +use super::workers::{JobBuffer, JobInbound, JobOutbound, JobParallel}; +use super::SIZE_MESSAGE_PREFIX; + +use super::constants::*; +use super::types::{Callbacks, RouterError}; + +pub struct KeyWheel { + next: Option>, // next key state (unconfirmed) + current: Option>, // current key state (used for encryption) + previous: Option>, // old key state (used for decryption) + retired: Vec, // retired ids +} + +pub struct PeerInner> { + pub device: Arc>, + pub opaque: C::Opaque, + pub outbound: Mutex>, + pub inbound: Mutex>>, + pub staged_packets: Mutex; MAX_STAGED_PACKETS], Wrapping>>, + pub keys: Mutex, + pub ekey: Mutex>, + pub endpoint: Mutex>, +} + +pub struct Peer> { + state: Arc>, + thread_outbound: Option>, + thread_inbound: Option>, +} + +fn treebit_list>( + peer: &Arc>, + table: &spin::RwLock>>>, + callback: Box R>, +) -> Vec +where + A: Address, +{ + let mut res = Vec::new(); + for subnet in table.read().iter() { + let (ip, masklen, p) = subnet; + if Arc::ptr_eq(&p, &peer) { + res.push(callback(ip, masklen)) + } + } + res +} + +fn treebit_remove>( + peer: &Peer, + table: &spin::RwLock>>>, +) { + let mut m = table.write(); + + // collect keys for value + let mut subnets = vec![]; + for subnet in m.iter() { + let (ip, masklen, p) = subnet; + if Arc::ptr_eq(&p, &peer.state) { + subnets.push((ip, masklen)) + } + } + + // remove all key mappings + for (ip, masklen) in subnets { + let r = m.remove(ip, masklen); + debug_assert!(r.is_some()); + } +} + +impl EncryptionState { + fn new(keypair: &Arc) -> EncryptionState { + EncryptionState { + id: keypair.send.id, + key: keypair.send.key, + nonce: 0, + death: keypair.birth + REJECT_AFTER_TIME, + } + } +} + +impl> DecryptionState { + fn new( + peer: &Arc>, + keypair: &Arc, + ) -> DecryptionState { + DecryptionState { + confirmed: AtomicBool::new(keypair.initiator), + keypair: keypair.clone(), + protector: spin::Mutex::new(AntiReplay::new()), + peer: peer.clone(), + death: keypair.birth + REJECT_AFTER_TIME, + } + } +} + +impl> Drop for Peer { + fn drop(&mut self) { + let peer = &self.state; + + // remove from cryptkey router + + treebit_remove(self, &peer.device.ipv4); + treebit_remove(self, &peer.device.ipv6); + + // drop channels + + mem::replace(&mut *peer.inbound.lock(), sync_channel(0).0); + mem::replace(&mut *peer.outbound.lock(), sync_channel(0).0); + + // join with workers + + mem::replace(&mut self.thread_inbound, None).map(|v| v.join()); + mem::replace(&mut self.thread_outbound, None).map(|v| v.join()); + + // release ids from the receiver map + + let mut keys = peer.keys.lock(); + let mut release = Vec::with_capacity(3); + + keys.next.as_ref().map(|k| release.push(k.recv.id)); + keys.current.as_ref().map(|k| release.push(k.recv.id)); + keys.previous.as_ref().map(|k| release.push(k.recv.id)); + + if release.len() > 0 { + let mut recv = peer.device.recv.write(); + for id in &release { + recv.remove(id); + } + } + + // null key-material + + keys.next = None; + keys.current = None; + keys.previous = None; + + *peer.ekey.lock() = None; + *peer.endpoint.lock() = None; + + debug!("peer dropped & removed from device"); + } +} + +pub fn new_peer>( + device: Arc>, + opaque: C::Opaque, +) -> Peer { + let (out_tx, out_rx) = sync_channel(128); + let (in_tx, in_rx) = sync_channel(128); + + // allocate peer object + let peer = { + let device = device.clone(); + Arc::new(PeerInner { + opaque, + device, + inbound: Mutex::new(in_tx), + outbound: Mutex::new(out_tx), + ekey: spin::Mutex::new(None), + endpoint: spin::Mutex::new(None), + keys: spin::Mutex::new(KeyWheel { + next: None, + current: None, + previous: None, + retired: vec![], + }), + staged_packets: spin::Mutex::new(ArrayDeque::new()), + }) + }; + + // spawn outbound thread + let thread_inbound = { + let peer = peer.clone(); + let device = device.clone(); + thread::spawn(move || worker_outbound(device, peer, out_rx)) + }; + + // spawn inbound thread + let thread_outbound = { + let peer = peer.clone(); + let device = device.clone(); + thread::spawn(move || worker_inbound(device, peer, in_rx)) + }; + + Peer { + state: peer, + thread_inbound: Some(thread_inbound), + thread_outbound: Some(thread_outbound), + } +} + +impl> PeerInner { + fn send_staged(&self) -> bool { + debug!("peer.send_staged"); + let mut sent = false; + let mut staged = self.staged_packets.lock(); + loop { + match staged.pop_front() { + Some(msg) => { + sent = true; + self.send_raw(msg); + } + None => break sent, + } + } + } + + // Treat the msg as the payload of a transport message + // Unlike device.send, peer.send_raw does not buffer messages when a key is not available. + fn send_raw(&self, msg: Vec) -> bool { + debug!("peer.send_raw"); + match self.send_job(msg, false) { + Some(job) => { + debug!("send_raw: got obtained send_job"); + let index = self.device.queue_next.fetch_add(1, Ordering::SeqCst); + let queues = self.device.queues.lock(); + match queues[index % queues.len()].send(job) { + Ok(_) => true, + Err(_) => false, + } + } + None => false, + } + } + + pub fn confirm_key(&self, keypair: &Arc) { + debug!("peer.confirm_key"); + { + // take lock and check keypair = keys.next + let mut keys = self.keys.lock(); + let next = match keys.next.as_ref() { + Some(next) => next, + None => { + return; + } + }; + if !Arc::ptr_eq(&next, keypair) { + return; + } + + // allocate new encryption state + let ekey = Some(EncryptionState::new(&next)); + + // rotate key-wheel + let mut swap = None; + mem::swap(&mut keys.next, &mut swap); + mem::swap(&mut keys.current, &mut swap); + mem::swap(&mut keys.previous, &mut swap); + + // tell the world outside the router that a key was confirmed + C::key_confirmed(&self.opaque); + + // set new key for encryption + *self.ekey.lock() = ekey; + } + + // start transmission of staged packets + self.send_staged(); + } + + pub fn recv_job( + &self, + src: E, + dec: Arc>, + msg: Vec, + ) -> Option { + let (tx, rx) = oneshot(); + let key = dec.keypair.recv.key; + match self.inbound.lock().try_send((dec, src, rx)) { + Ok(_) => Some(( + tx, + JobBuffer { + msg, + key: key, + okay: false, + op: Operation::Decryption, + }, + )), + Err(_) => None, + } + } + + pub fn send_job(&self, mut msg: Vec, stage: bool) -> Option { + debug!("peer.send_job"); + debug_assert!( + msg.len() >= mem::size_of::(), + "received message with size: {:}", + msg.len() + ); + + // parse / cast + let (header, _) = LayoutVerified::new_from_prefix(&mut msg[..]).unwrap(); + let mut header: LayoutVerified<&mut [u8], TransportHeader> = header; + + // check if has key + let key = { + let mut ekey = self.ekey.lock(); + let key = match ekey.as_mut() { + None => None, + Some(mut state) => { + // avoid integer overflow in nonce + if state.nonce >= REJECT_AFTER_MESSAGES - 1 { + *ekey = None; + None + } else { + // there should be no stacked packets lingering around + debug!("encryption state available, nonce = {}", state.nonce); + + // set transport message fields + header.f_counter.set(state.nonce); + header.f_receiver.set(state.id); + state.nonce += 1; + Some(state.key) + } + } + }; + + // If not suitable key was found: + // 1. Stage packet for later transmission + // 2. Request new key + if key.is_none() && stage { + self.staged_packets.lock().push_back(msg); + C::need_key(&self.opaque); + return None; + }; + + key + }?; + + // add job to in-order queue and return sendeer to device for inclusion in worker pool + let (tx, rx) = oneshot(); + match self.outbound.lock().try_send(rx) { + Ok(_) => Some(( + tx, + JobBuffer { + msg, + key, + okay: false, + op: Operation::Encryption, + }, + )), + Err(_) => None, + } + } +} + +impl> Peer { + /// Set the endpoint of the peer + /// + /// # Arguments + /// + /// - `endpoint`, socket address converted to bind endpoint + /// + /// # Note + /// + /// This API still permits support for the "sticky socket" behavior, + /// as sockets should be "unsticked" when manually updating the endpoint + pub fn set_endpoint(&self, endpoint: E) { + debug!("peer.set_endpoint"); + *self.state.endpoint.lock() = Some(endpoint); + } + + /// Returns the current endpoint of the peer (for configuration) + /// + /// # Note + /// + /// Does not convey potential "sticky socket" information + pub fn get_endpoint(&self) -> Option { + debug!("peer.get_endpoint"); + self.state + .endpoint + .lock() + .as_ref() + .map(|e| e.into_address()) + } + + /// Zero all key-material related to the peer + pub fn zero_keys(&self) { + debug!("peer.zero_keys"); + + let mut release: Vec = Vec::with_capacity(3); + let mut keys = self.state.keys.lock(); + + // update key-wheel + + mem::replace(&mut keys.next, None).map(|k| release.push(k.local_id())); + mem::replace(&mut keys.current, None).map(|k| release.push(k.local_id())); + mem::replace(&mut keys.previous, None).map(|k| release.push(k.local_id())); + keys.retired.extend(&release[..]); + + // update inbound "recv" map + { + let mut recv = self.state.device.recv.write(); + for id in release { + recv.remove(&id); + } + } + + // clear encryption state + *self.state.ekey.lock() = None; + } + + /// Add a new keypair + /// + /// # Arguments + /// + /// - new: The new confirmed/unconfirmed key pair + /// + /// # Returns + /// + /// A vector of ids which has been released. + /// These should be released in the handshake module. + /// + /// # Note + /// + /// The number of ids to be released can be at most 3, + /// since the only way to add additional keys to the peer is by using this method + /// and a peer can have at most 3 keys allocated in the router at any time. + pub fn add_keypair(&self, new: KeyPair) -> Vec { + debug!("peer.add_keypair"); + + let initiator = new.initiator; + let release = { + let new = Arc::new(new); + let mut keys = self.state.keys.lock(); + let mut release = mem::replace(&mut keys.retired, vec![]); + + // update key-wheel + if new.initiator { + // start using key for encryption + *self.state.ekey.lock() = Some(EncryptionState::new(&new)); + + // move current into previous + keys.previous = keys.current.as_ref().map(|v| v.clone()); + keys.current = Some(new.clone()); + } else { + // store the key and await confirmation + keys.previous = keys.next.as_ref().map(|v| v.clone()); + keys.next = Some(new.clone()); + }; + + // update incoming packet id map + { + debug!("peer.add_keypair: updating inbound id map"); + let mut recv = self.state.device.recv.write(); + + // purge recv map of previous id + keys.previous.as_ref().map(|k| { + recv.remove(&k.local_id()); + release.push(k.local_id()); + }); + + // map new id to decryption state + debug_assert!(!recv.contains_key(&new.recv.id)); + recv.insert( + new.recv.id, + Arc::new(DecryptionState::new(&self.state, &new)), + ); + } + release + }; + + // schedule confirmation + if initiator { + debug_assert!(self.state.ekey.lock().is_some()); + debug!("peer.add_keypair: is initiator, must confirm the key"); + // attempt to confirm using staged packets + if !self.state.send_staged() { + // fall back to keepalive packet + let ok = self.send_keepalive(); + debug!( + "peer.add_keypair: keepalive for confirmation, sent = {}", + ok + ); + } + debug!("peer.add_keypair: key attempted confirmed"); + } + + debug_assert!( + release.len() <= 3, + "since the key-wheel contains at most 3 keys" + ); + release + } + + pub fn send_keepalive(&self) -> bool { + debug!("peer.send_keepalive"); + self.state.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX]) + } + + /// Map a subnet to the peer + /// + /// # Arguments + /// + /// - `ip`, the mask of the subnet + /// - `masklen`, the length of the mask + /// + /// # Note + /// + /// The `ip` must not have any bits set right of `masklen`. + /// e.g. `192.168.1.0/24` is valid, while `192.168.1.128/24` is not. + /// + /// If an identical value already exists as part of a prior peer, + /// the allowed IP entry will be removed from that peer and added to this peer. + pub fn add_subnet(&self, ip: IpAddr, masklen: u32) { + debug!("peer.add_subnet"); + match ip { + IpAddr::V4(v4) => { + self.state + .device + .ipv4 + .write() + .insert(v4, masklen, self.state.clone()) + } + IpAddr::V6(v6) => { + self.state + .device + .ipv6 + .write() + .insert(v6, masklen, self.state.clone()) + } + }; + } + + /// List subnets mapped to the peer + /// + /// # Returns + /// + /// A vector of subnets, represented by as mask/size + pub fn list_subnets(&self) -> Vec<(IpAddr, u32)> { + debug!("peer.list_subnets"); + let mut res = Vec::new(); + res.append(&mut treebit_list( + &self.state, + &self.state.device.ipv4, + Box::new(|ip, masklen| (IpAddr::V4(ip), masklen)), + )); + res.append(&mut treebit_list( + &self.state, + &self.state.device.ipv6, + Box::new(|ip, masklen| (IpAddr::V6(ip), masklen)), + )); + res + } + + /// Clear subnets mapped to the peer. + /// After the call, no subnets will be cryptkey routed to the peer. + /// Used for the UAPI command "replace_allowed_ips=true" + pub fn remove_subnets(&self) { + debug!("peer.remove_subnets"); + treebit_remove(self, &self.state.device.ipv4); + treebit_remove(self, &self.state.device.ipv6); + } + + /// Send a raw message to the peer (used for handshake messages) + /// + /// # Arguments + /// + /// - `msg`, message body to send to peer + /// + /// # Returns + /// + /// Unit if packet was sent, or an error indicating why sending failed + pub fn send(&self, msg: &[u8]) -> Result<(), RouterError> { + debug!("peer.send"); + let inner = &self.state; + match inner.endpoint.lock().as_ref() { + Some(endpoint) => inner + .device + .outbound + .read() + .as_ref() + .ok_or(RouterError::SendError) + .and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError)), + None => Err(RouterError::NoEndpoint), + } + } + + pub fn purge_staged_packets(&self) { + self.state.staged_packets.lock().clear(); + } +} diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs new file mode 100644 index 0000000..fbee39e --- /dev/null +++ b/src/wireguard/router/tests.rs @@ -0,0 +1,432 @@ +use std::net::IpAddr; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use std::sync::Mutex; +use std::thread; +use std::time::Duration; + +use num_cpus; +use pnet::packet::ipv4::MutableIpv4Packet; +use pnet::packet::ipv6::MutableIpv6Packet; + +use super::super::types::bind::*; +use super::super::types::*; + +use super::{Callbacks, Device, SIZE_MESSAGE_PREFIX}; + +extern crate test; + +const SIZE_KEEPALIVE: usize = 32; + +#[cfg(test)] +mod tests { + use super::*; + use env_logger; + use log::debug; + use std::sync::atomic::AtomicUsize; + use test::Bencher; + + // type for tracking events inside the router module + struct Flags { + send: Mutex>, + recv: Mutex>, + need_key: Mutex>, + key_confirmed: Mutex>, + } + + #[derive(Clone)] + struct Opaque(Arc); + + struct TestCallbacks(); + + impl Opaque { + fn new() -> Opaque { + Opaque(Arc::new(Flags { + send: Mutex::new(vec![]), + recv: Mutex::new(vec![]), + need_key: Mutex::new(vec![]), + key_confirmed: Mutex::new(vec![]), + })) + } + + fn reset(&self) { + self.0.send.lock().unwrap().clear(); + self.0.recv.lock().unwrap().clear(); + self.0.need_key.lock().unwrap().clear(); + self.0.key_confirmed.lock().unwrap().clear(); + } + + fn send(&self) -> Option<(usize, bool, bool)> { + self.0.send.lock().unwrap().pop() + } + + fn recv(&self) -> Option<(usize, bool, bool)> { + self.0.recv.lock().unwrap().pop() + } + + fn need_key(&self) -> Option<()> { + self.0.need_key.lock().unwrap().pop() + } + + fn key_confirmed(&self) -> Option<()> { + self.0.key_confirmed.lock().unwrap().pop() + } + + // has all events been accounted for by assertions? + fn is_empty(&self) -> bool { + let send = self.0.send.lock().unwrap(); + let recv = self.0.recv.lock().unwrap(); + let need_key = self.0.need_key.lock().unwrap(); + let key_confirmed = self.0.key_confirmed.lock().unwrap(); + send.is_empty() && recv.is_empty() && need_key.is_empty() & key_confirmed.is_empty() + } + } + + impl Callbacks for TestCallbacks { + type Opaque = Opaque; + + fn send(t: &Self::Opaque, size: usize, data: bool, sent: bool) { + t.0.send.lock().unwrap().push((size, data, sent)) + } + + fn recv(t: &Self::Opaque, size: usize, data: bool, sent: bool) { + t.0.recv.lock().unwrap().push((size, data, sent)) + } + + fn need_key(t: &Self::Opaque) { + t.0.need_key.lock().unwrap().push(()); + } + + fn key_confirmed(t: &Self::Opaque) { + t.0.key_confirmed.lock().unwrap().push(()); + } + } + + // wait for scheduling + fn wait() { + thread::sleep(Duration::from_millis(50)); + } + + fn init() { + let _ = env_logger::builder().is_test(true).try_init(); + } + + fn make_packet(size: usize, ip: IpAddr) -> Vec { + // create "IP packet" + let mut msg = Vec::with_capacity(SIZE_MESSAGE_PREFIX + size + 16); + msg.resize(SIZE_MESSAGE_PREFIX + size, 0); + match ip { + IpAddr::V4(ip) => { + let mut packet = MutableIpv4Packet::new(&mut msg[SIZE_MESSAGE_PREFIX..]).unwrap(); + packet.set_destination(ip); + packet.set_version(4); + } + IpAddr::V6(ip) => { + let mut packet = MutableIpv6Packet::new(&mut msg[SIZE_MESSAGE_PREFIX..]).unwrap(); + packet.set_destination(ip); + packet.set_version(6); + } + } + msg + } + + #[bench] + fn bench_outbound(b: &mut Bencher) { + struct BencherCallbacks {} + impl Callbacks for BencherCallbacks { + type Opaque = Arc; + fn send(t: &Self::Opaque, size: usize, _data: bool, _sent: bool) { + t.fetch_add(size, Ordering::SeqCst); + } + fn recv(_: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {} + fn need_key(_: &Self::Opaque) {} + fn key_confirmed(_: &Self::Opaque) {} + } + + // create device + let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false); + let router: Device<_, BencherCallbacks, dummy::TunWriter, dummy::VoidBind> = + Device::new(num_cpus::get(), tun_writer); + + // add new peer + let opaque = Arc::new(AtomicUsize::new(0)); + let peer = router.new_peer(opaque.clone()); + peer.add_keypair(dummy::keypair(true)); + + // add subnet to peer + let (mask, len, ip) = ("192.168.1.0", 24, "192.168.1.20"); + let mask: IpAddr = mask.parse().unwrap(); + let ip1: IpAddr = ip.parse().unwrap(); + peer.add_subnet(mask, len); + + // every iteration sends 10 GB + b.iter(|| { + opaque.store(0, Ordering::SeqCst); + let msg = make_packet(1024, ip1); + while opaque.load(Ordering::Acquire) < 10 * 1024 * 1024 { + router.send(msg.to_vec()).unwrap(); + } + }); + } + + #[test] + fn test_outbound() { + init(); + + // create device + let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false); + let router: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer); + router.set_outbound_writer(dummy::VoidBind::new()); + + let tests = vec![ + ("192.168.1.0", 24, "192.168.1.20", true), + ("172.133.133.133", 32, "172.133.133.133", true), + ("172.133.133.133", 32, "172.133.133.132", false), + ( + "2001:db8::ff00:42:0000", + 112, + "2001:db8::ff00:42:3242", + true, + ), + ( + "2001:db8::ff00:42:8000", + 113, + "2001:db8::ff00:42:0660", + false, + ), + ( + "2001:db8::ff00:42:8000", + 113, + "2001:db8::ff00:42:ffff", + true, + ), + ]; + + for (num, (mask, len, ip, okay)) in tests.iter().enumerate() { + for set_key in vec![true, false] { + debug!("index = {}, set_key = {}", num, set_key); + + // add new peer + let opaque = Opaque::new(); + let peer = router.new_peer(opaque.clone()); + let mask: IpAddr = mask.parse().unwrap(); + if set_key { + peer.add_keypair(dummy::keypair(true)); + } + + // map subnet to peer + peer.add_subnet(mask, *len); + + // create "IP packet" + let msg = make_packet(1024, ip.parse().unwrap()); + + // cryptkey route the IP packet + let res = router.send(msg); + + // allow some scheduling + wait(); + + if *okay { + // cryptkey routing succeeded + assert!(res.is_ok(), "crypt-key routing should succeed"); + assert_eq!( + opaque.need_key().is_some(), + !set_key, + "should have requested a new key, if no encryption state was set" + ); + assert_eq!( + opaque.send().is_some(), + set_key, + "transmission should have been attempted" + ); + assert!( + opaque.recv().is_none(), + "no messages should have been marked as received" + ); + } else { + // no such cryptkey route + assert!(res.is_err(), "crypt-key routing should fail"); + assert!( + opaque.need_key().is_none(), + "should not request a new-key if crypt-key routing failed" + ); + assert_eq!( + opaque.send(), + if set_key { + Some((SIZE_KEEPALIVE, false, false)) + } else { + None + }, + "transmission should only happen if key was set (keepalive)", + ); + assert!( + opaque.recv().is_none(), + "no messages should have been marked as received", + ); + } + } + } + } + + #[test] + fn test_bidirectional() { + init(); + + let tests = [ + ( + false, // confirm with keepalive + ("192.168.1.0", 24, "192.168.1.20", true), + ("172.133.133.133", 32, "172.133.133.133", true), + ), + ( + true, // confirm with staged packet + ("192.168.1.0", 24, "192.168.1.20", true), + ("172.133.133.133", 32, "172.133.133.133", true), + ), + ( + false, // confirm with keepalive + ( + "2001:db8::ff00:42:8000", + 113, + "2001:db8::ff00:42:ffff", + true, + ), + ( + "2001:db8::ff40:42:8000", + 113, + "2001:db8::ff40:42:ffff", + true, + ), + ), + ( + false, // confirm with staged packet + ( + "2001:db8::ff00:42:8000", + 113, + "2001:db8::ff00:42:ffff", + true, + ), + ( + "2001:db8::ff40:42:8000", + 113, + "2001:db8::ff40:42:ffff", + true, + ), + ), + ]; + + for (stage, p1, p2) in tests.iter() { + let ((bind_reader1, bind_writer1), (bind_reader2, bind_writer2)) = + dummy::PairBind::pair(); + + // create matching device + let (_fake, _, tun_writer1, _) = dummy::TunTest::create(1500, false); + let (_fake, _, tun_writer2, _) = dummy::TunTest::create(1500, false); + + let router1: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer1); + router1.set_outbound_writer(bind_writer1); + + let router2: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer2); + router2.set_outbound_writer(bind_writer2); + + // prepare opaque values for tracing callbacks + + let opaq1 = Opaque::new(); + let opaq2 = Opaque::new(); + + // create peers with matching keypairs and assign subnets + + let (mask, len, _ip, _okay) = p1; + let peer1 = router1.new_peer(opaq1.clone()); + let mask: IpAddr = mask.parse().unwrap(); + peer1.add_subnet(mask, *len); + peer1.add_keypair(dummy::keypair(false)); + + let (mask, len, _ip, _okay) = p2; + let peer2 = router2.new_peer(opaq2.clone()); + let mask: IpAddr = mask.parse().unwrap(); + peer2.add_subnet(mask, *len); + peer2.set_endpoint(dummy::UnitEndpoint::new()); + + if *stage { + // stage a packet which can be used for confirmation (in place of a keepalive) + let (_mask, _len, ip, _okay) = p2; + let msg = make_packet(1024, ip.parse().unwrap()); + router2.send(msg).expect("failed to sent staged packet"); + + wait(); + assert!(opaq2.recv().is_none()); + assert!( + opaq2.send().is_none(), + "sending should fail as not key is set" + ); + assert!( + opaq2.need_key().is_some(), + "a new key should be requested since a packet was attempted transmitted" + ); + assert!(opaq2.is_empty(), "callbacks should only run once"); + } + + // this should cause a key-confirmation packet (keepalive or staged packet) + // this also causes peer1 to learn the "endpoint" for peer2 + assert!(peer1.get_endpoint().is_none()); + peer2.add_keypair(dummy::keypair(true)); + + wait(); + assert!(opaq2.send().is_some()); + assert!(opaq2.is_empty(), "events on peer2 should be 'send'"); + assert!(opaq1.is_empty(), "nothing should happened on peer1"); + + // read confirming message received by the other end ("across the internet") + let mut buf = vec![0u8; 2048]; + let (len, from) = bind_reader1.read(&mut buf).unwrap(); + buf.truncate(len); + router1.recv(from, buf).unwrap(); + + wait(); + assert!(opaq1.recv().is_some()); + assert!(opaq1.key_confirmed().is_some()); + assert!( + opaq1.is_empty(), + "events on peer1 should be 'recv' and 'key_confirmed'" + ); + assert!(peer1.get_endpoint().is_some()); + assert!(opaq2.is_empty(), "nothing should happened on peer2"); + + // now that peer1 has an endpoint + // route packets : peer1 -> peer2 + + for _ in 0..10 { + assert!( + opaq1.is_empty(), + "we should have asserted a value for every callback on peer1" + ); + assert!( + opaq2.is_empty(), + "we should have asserted a value for every callback on peer2" + ); + + // pass IP packet to router + let (_mask, _len, ip, _okay) = p1; + let msg = make_packet(1024, ip.parse().unwrap()); + router1.send(msg).unwrap(); + + wait(); + assert!(opaq1.send().is_some()); + assert!(opaq1.recv().is_none()); + assert!(opaq1.need_key().is_none()); + + // receive ("across the internet") on the other end + let mut buf = vec![0u8; 2048]; + let (len, from) = bind_reader2.read(&mut buf).unwrap(); + buf.truncate(len); + router2.recv(from, buf).unwrap(); + + wait(); + assert!(opaq2.send().is_none()); + assert!(opaq2.recv().is_some()); + assert!(opaq2.need_key().is_none()); + } + } + } +} diff --git a/src/wireguard/router/types.rs b/src/wireguard/router/types.rs new file mode 100644 index 0000000..b7c3ae0 --- /dev/null +++ b/src/wireguard/router/types.rs @@ -0,0 +1,65 @@ +use std::error::Error; +use std::fmt; + +pub trait Opaque: Send + Sync + 'static {} + +impl Opaque for T where T: Send + Sync + 'static {} + +/// A send/recv callback takes 3 arguments: +/// +/// * `0`, a reference to the opaque value assigned to the peer +/// * `1`, a bool indicating whether the message contained data (not just keepalive) +/// * `2`, a bool indicating whether the message was transmitted (i.e. did the peer have an associated endpoint?) +pub trait Callback: Fn(&T, usize, bool, bool) -> () + Sync + Send + 'static {} + +impl Callback for F where F: Fn(&T, usize, bool, bool) -> () + Sync + Send + 'static {} + +/// A key callback takes 1 argument +/// +/// * `0`, a reference to the opaque value assigned to the peer +pub trait KeyCallback: Fn(&T) -> () + Sync + Send + 'static {} + +impl KeyCallback for F where F: Fn(&T) -> () + Sync + Send + 'static {} + +pub trait Callbacks: Send + Sync + 'static { + type Opaque: Opaque; + fn send(opaque: &Self::Opaque, size: usize, data: bool, sent: bool); + fn recv(opaque: &Self::Opaque, size: usize, data: bool, sent: bool); + fn need_key(opaque: &Self::Opaque); + fn key_confirmed(opaque: &Self::Opaque); +} + +#[derive(Debug)] +pub enum RouterError { + NoCryptKeyRoute, + MalformedIPHeader, + MalformedTransportMessage, + UnknownReceiverId, + NoEndpoint, + SendError, +} + +impl fmt::Display for RouterError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + RouterError::NoCryptKeyRoute => write!(f, "No cryptkey route configured for subnet"), + RouterError::MalformedIPHeader => write!(f, "IP header is malformed"), + RouterError::MalformedTransportMessage => write!(f, "IP header is malformed"), + RouterError::UnknownReceiverId => { + write!(f, "No decryption state associated with receiver id") + } + RouterError::NoEndpoint => write!(f, "No endpoint for peer"), + RouterError::SendError => write!(f, "Failed to send packet on bind"), + } + } +} + +impl Error for RouterError { + fn description(&self) -> &str { + "Generic Handshake Error" + } + + fn source(&self) -> Option<&(dyn Error + 'static)> { + None + } +} diff --git a/src/wireguard/router/workers.rs b/src/wireguard/router/workers.rs new file mode 100644 index 0000000..2e89bb0 --- /dev/null +++ b/src/wireguard/router/workers.rs @@ -0,0 +1,305 @@ +use std::mem; +use std::sync::mpsc::Receiver; +use std::sync::Arc; + +use futures::sync::oneshot; +use futures::*; + +use log::debug; + +use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; +use std::net::{Ipv4Addr, Ipv6Addr}; +use std::sync::atomic::Ordering; +use zerocopy::{AsBytes, LayoutVerified}; + +use super::device::{DecryptionState, DeviceInner}; +use super::messages::{TransportHeader, TYPE_TRANSPORT}; +use super::peer::PeerInner; +use super::types::Callbacks; + +use super::super::types::{Endpoint, tun, bind}; +use super::ip::*; + +const SIZE_TAG: usize = 16; + +#[derive(PartialEq, Debug)] +pub enum Operation { + Encryption, + Decryption, +} + +pub struct JobBuffer { + pub msg: Vec, // message buffer (nonce and receiver id set) + pub key: [u8; 32], // chacha20poly1305 key + pub okay: bool, // state of the job + pub op: Operation, // should be buffer be encrypted / decrypted? +} + +pub type JobParallel = (oneshot::Sender, JobBuffer); + +#[allow(type_alias_bounds)] +pub type JobInbound> = ( + Arc>, + E, + oneshot::Receiver, +); + +pub type JobOutbound = oneshot::Receiver; + +#[inline(always)] +fn check_route>( + device: &Arc>, + peer: &Arc>, + packet: &[u8], +) -> Option { + match packet[0] >> 4 { + VERSION_IP4 => { + // check length and cast to IPv4 header + let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = + LayoutVerified::new_from_prefix(packet)?; + + // check IPv4 source address + device + .ipv4 + .read() + .longest_match(Ipv4Addr::from(header.f_source)) + .and_then(|(_, _, p)| { + if Arc::ptr_eq(p, &peer) { + Some(header.f_total_len.get() as usize) + } else { + None + } + }) + } + VERSION_IP6 => { + // check length and cast to IPv6 header + let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) = + LayoutVerified::new_from_prefix(packet)?; + + // check IPv6 source address + device + .ipv6 + .read() + .longest_match(Ipv6Addr::from(header.f_source)) + .and_then(|(_, _, p)| { + if Arc::ptr_eq(p, &peer) { + Some(header.f_len.get() as usize + mem::size_of::()) + } else { + None + } + }) + } + _ => None, + } +} + +pub fn worker_inbound>( + device: Arc>, // related device + peer: Arc>, // related peer + receiver: Receiver>, +) { + loop { + // fetch job + let (state, endpoint, rx) = match receiver.recv() { + Ok(v) => v, + _ => { + return; + } + }; + debug!("inbound worker: obtained job"); + + // wait for job to complete + let _ = rx + .map(|buf| { + debug!("inbound worker: job complete"); + if buf.okay { + // cast transport header + let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) = + match LayoutVerified::new_from_prefix(&buf.msg[..]) { + Some(v) => v, + None => { + debug!("inbound worker: failed to parse message"); + return; + } + }; + + debug_assert!( + packet.len() >= CHACHA20_POLY1305.tag_len(), + "this should be checked earlier in the pipeline (decryption should fail)" + ); + + // check for replay + if !state.protector.lock().update(header.f_counter.get()) { + debug!("inbound worker: replay detected"); + return; + } + + // check for confirms key + if !state.confirmed.swap(true, Ordering::SeqCst) { + debug!("inbound worker: message confirms key"); + peer.confirm_key(&state.keypair); + } + + // update endpoint + *peer.endpoint.lock() = Some(endpoint); + + // calculate length of IP packet + padding + let length = packet.len() - SIZE_TAG; + debug!("inbound worker: plaintext length = {}", length); + + // check if should be written to TUN + let mut sent = false; + if length > 0 { + if let Some(inner_len) = check_route(&device, &peer, &packet[..length]) { + debug_assert!(inner_len <= length, "should be validated"); + if inner_len <= length { + sent = match device.inbound.write(&packet[..inner_len]) { + Err(e) => { + debug!("failed to write inbound packet to TUN: {:?}", e); + false + } + Ok(_) => true, + } + } + } + } else { + debug!("inbound worker: received keepalive") + } + + // trigger callback + C::recv(&peer.opaque, buf.msg.len(), length == 0, sent); + } else { + debug!("inbound worker: authentication failure") + } + }) + .wait(); + } +} + +pub fn worker_outbound>( + device: Arc>, // related device + peer: Arc>, // related peer + receiver: Receiver, +) { + loop { + // fetch job + let rx = match receiver.recv() { + Ok(v) => v, + _ => { + return; + } + }; + debug!("outbound worker: obtained job"); + + // wait for job to complete + let _ = rx + .map(|buf| { + debug!("outbound worker: job complete"); + if buf.okay { + // write to UDP bind + let xmit = if let Some(dst) = peer.endpoint.lock().as_ref() { + let send : &Option = &*device.outbound.read(); + if let Some(writer) = send.as_ref() { + match writer.write(&buf.msg[..], dst) { + Err(e) => { + debug!("failed to send outbound packet: {:?}", e); + false + } + Ok(_) => true, + } + } else { + false + } + } else { + false + }; + + // trigger callback + C::send( + &peer.opaque, + buf.msg.len(), + buf.msg.len() > SIZE_TAG + mem::size_of::(), + xmit, + ); + } + }) + .wait(); + } +} + +pub fn worker_parallel(receiver: Receiver) { + loop { + // fetch next job + let (tx, mut buf) = match receiver.recv() { + Err(_) => { + return; + } + Ok(val) => val, + }; + debug!("parallel worker: obtained job"); + + // make space for tag (TODO: consider moving this out) + if buf.op == Operation::Encryption { + buf.msg.extend([0u8; SIZE_TAG].iter()); + } + + // cast and check size of packet + let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = + match LayoutVerified::new_from_prefix(&mut buf.msg[..]) { + Some(v) => v, + None => { + debug_assert!( + false, + "parallel worker: failed to parse message (insufficient size)" + ); + continue; + } + }; + debug_assert!(packet.len() >= CHACHA20_POLY1305.tag_len()); + + // do the weird ring AEAD dance + let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &buf.key[..]).unwrap()); + + // create a nonce object + let mut nonce = [0u8; 12]; + debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); + nonce[4..].copy_from_slice(header.f_counter.as_bytes()); + let nonce = Nonce::assume_unique_for_key(nonce); + + match buf.op { + Operation::Encryption => { + debug!("parallel worker: process encryption"); + + // set the type field + header.f_type.set(TYPE_TRANSPORT); + + // encrypt content of transport message in-place + let end = packet.len() - SIZE_TAG; + let tag = key + .seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end]) + .unwrap(); + + // append tag + packet[end..].copy_from_slice(tag.as_ref()); + + buf.okay = true; + } + Operation::Decryption => { + debug!("parallel worker: process decryption"); + + // opening failure is signaled by fault state + buf.okay = match key.open_in_place(nonce, Aad::empty(), packet) { + Ok(_) => true, + Err(_) => false, + }; + } + } + + // pass ownership to consumer + let okay = tx.send(buf); + debug!( + "parallel worker: passing ownership to sequential worker: {}", + okay.is_ok() + ); + } +} diff --git a/src/wireguard/tests.rs b/src/wireguard/tests.rs new file mode 100644 index 0000000..0148d5d --- /dev/null +++ b/src/wireguard/tests.rs @@ -0,0 +1,46 @@ +use super::types::tun::Tun; +use super::types::{bind, dummy, tun}; +use super::wireguard::Wireguard; + +use std::thread; +use std::time::Duration; + +fn init() { + let _ = env_logger::builder().is_test(true).try_init(); +} + +/* Create and configure two matching pure instances of WireGuard + * + */ +#[test] +fn test_pure_wireguard() { + init(); + + // create WG instances for fake TUN devices + + let (fake1, tun_reader1, tun_writer1, mtu1) = dummy::TunTest::create(1500, true); + let wg1: Wireguard = + Wireguard::new(vec![tun_reader1], tun_writer1, mtu1); + + let (fake2, tun_reader2, tun_writer2, mtu2) = dummy::TunTest::create(1500, true); + let wg2: Wireguard = + Wireguard::new(vec![tun_reader2], tun_writer2, mtu2); + + // create pair bind to connect the interfaces "over the internet" + + let ((bind_reader1, bind_writer1), (bind_reader2, bind_writer2)) = dummy::PairBind::pair(); + + wg1.set_writer(bind_writer1); + wg2.set_writer(bind_writer2); + + wg1.add_reader(bind_reader1); + wg2.add_reader(bind_reader2); + + // generate (public, pivate) key pairs + + // configure cryptkey router + + // create IP packets + + thread::sleep(Duration::from_millis(500)); +} diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs new file mode 100644 index 0000000..2792c7b --- /dev/null +++ b/src/wireguard/timers.rs @@ -0,0 +1,234 @@ +use std::marker::PhantomData; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +use log::info; + +use hjul::{Runner, Timer}; + +use super::constants::*; +use super::router::Callbacks; +use super::types::{bind, tun}; +use super::wireguard::{Peer, PeerInner}; + +pub struct Timers { + handshake_pending: AtomicBool, + handshake_attempts: AtomicUsize, + + retransmit_handshake: Timer, + send_keepalive: Timer, + send_persistent_keepalive: Timer, + sent_lastminute_handshake: AtomicBool, + zero_key_material: Timer, + new_handshake: Timer, + need_another_keepalive: AtomicBool, +} + +impl Timers { + #[inline(always)] + fn need_another_keepalive(&self) -> bool { + self.need_another_keepalive.swap(false, Ordering::SeqCst) + } +} + +impl Peer { + /* should be called after an authenticated data packet is sent */ + pub fn timers_data_sent(&self) { + self.timers().new_handshake.start(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT); + } + + /* should be called after an authenticated data packet is received */ + pub fn timers_data_received(&self) { + if !self.timers().send_keepalive.start(KEEPALIVE_TIMEOUT) { + self.timers().need_another_keepalive.store(true, Ordering::SeqCst) + } + } + + /* Should be called after any type of authenticated packet is sent, whether: + * - keepalive + * - data + * - handshake + */ + pub fn timers_any_authenticated_packet_sent(&self) { + self.timers().send_keepalive.stop() + } + + /* Should be called after any type of authenticated packet is received, whether: + * - keepalive + * - data + * - handshake + */ + pub fn timers_any_authenticated_packet_received(&self) { + self.timers().new_handshake.stop(); + } + + /* Should be called after a handshake initiation message is sent. */ + pub fn timers_handshake_initiated(&self) { + self.timers().send_keepalive.stop(); + self.timers().retransmit_handshake.reset(REKEY_TIMEOUT); + } + + /* Should be called after a handshake response message is received and processed + * or when getting key confirmation via the first data message. + */ + pub fn timers_handshake_complete(&self) { + self.timers().handshake_attempts.store(0, Ordering::SeqCst); + self.timers().sent_lastminute_handshake.store(false, Ordering::SeqCst); + // TODO: Store time in peer for config + // self.walltime_last_handshake + } + + /* Should be called after an ephemeral key is created, which is before sending a + * handshake response or after receiving a handshake response. + */ + pub fn timers_session_derived(&self) { + self.timers().zero_key_material.reset(REJECT_AFTER_TIME * 3); + } + + /* Should be called before a packet with authentication, whether + * keepalive, data, or handshake is sent, or after one is received. + */ + pub fn timers_any_authenticated_packet_traversal(&self) { + let keepalive = self.state.keepalive.load(Ordering::Acquire); + if keepalive > 0 { + self.timers().send_persistent_keepalive.reset(Duration::from_secs(keepalive as u64)); + } + } +} + +impl Timers { + pub fn new(runner: &Runner, peer: Peer) -> Timers + where + T: tun::Tun, + B: bind::Bind, + { + // create a timer instance for the provided peer + Timers { + handshake_pending: AtomicBool::new(false), + need_another_keepalive: AtomicBool::new(false), + sent_lastminute_handshake: AtomicBool::new(false), + handshake_attempts: AtomicUsize::new(0), + retransmit_handshake: { + let peer = peer.clone(); + runner.timer(move || { + if peer.timers().handshake_retry() { + info!("Retransmit handshake for {}", peer); + peer.new_handshake(); + } else { + info!("Failed to complete handshake for {}", peer); + peer.router.purge_staged_packets(); + peer.timers().send_keepalive.stop(); + peer.timers().zero_key_material.start(REJECT_AFTER_TIME * 3); + } + }) + }, + send_keepalive: { + let peer = peer.clone(); + runner.timer(move || { + peer.router.send_keepalive(); + if peer.timers().need_another_keepalive() { + peer.timers().send_keepalive.start(KEEPALIVE_TIMEOUT); + } + }) + }, + new_handshake: { + let peer = peer.clone(); + runner.timer(move || { + info!( + "Retrying handshake with {}, because we stopped hearing back after {} seconds", + peer, + (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT).as_secs() + ); + peer.new_handshake(); + peer.timers.read().handshake_begun(); + }) + }, + zero_key_material: { + let peer = peer.clone(); + runner.timer(move || { + peer.router.zero_keys(); + }) + }, + send_persistent_keepalive: { + let peer = peer.clone(); + runner.timer(move || { + let keepalive = peer.state.keepalive.load(Ordering::Acquire); + if keepalive > 0 { + peer.router.send_keepalive(); + peer.timers().send_keepalive.stop(); + peer.timers().send_persistent_keepalive.start(Duration::from_secs(keepalive as u64)); + } + }) + } + } + } + + fn handshake_begun(&self) { + self.handshake_pending.store(true, Ordering::SeqCst); + self.handshake_attempts.store(0, Ordering::SeqCst); + self.retransmit_handshake.reset(REKEY_TIMEOUT); + } + + fn handshake_retry(&self) -> bool { + if self.handshake_attempts.fetch_add(1, Ordering::SeqCst) <= MAX_TIMER_HANDSHAKES { + self.retransmit_handshake.reset(REKEY_TIMEOUT); + true + } else { + self.handshake_pending.store(false, Ordering::SeqCst); + false + } + } + + pub fn updated_persistent_keepalive(&self, keepalive: usize) { + if keepalive > 0 { + self.send_persistent_keepalive.reset(Duration::from_secs(keepalive as u64)); + } + } + + pub fn dummy(runner: &Runner) -> Timers { + Timers { + handshake_pending: AtomicBool::new(false), + need_another_keepalive: AtomicBool::new(false), + sent_lastminute_handshake: AtomicBool::new(false), + handshake_attempts: AtomicUsize::new(0), + retransmit_handshake: runner.timer(|| {}), + new_handshake: runner.timer(|| {}), + send_keepalive: runner.timer(|| {}), + send_persistent_keepalive: runner.timer(|| {}), + zero_key_material: runner.timer(|| {}) + } + } + + pub fn handshake_sent(&self) { + self.send_keepalive.stop(); + } +} + +/* Instance of the router callbacks */ + +pub struct Events(PhantomData<(T, B)>); + +impl Callbacks for Events { + type Opaque = Arc>; + + fn send(peer: &Self::Opaque, size: usize, data: bool, sent: bool) { + peer.tx_bytes.fetch_add(size as u64, Ordering::Relaxed); + } + + fn recv(peer: &Self::Opaque, size: usize, data: bool, sent: bool) { + peer.rx_bytes.fetch_add(size as u64, Ordering::Relaxed); + } + + fn need_key(peer: &Self::Opaque) { + let timers = peer.timers(); + if !timers.handshake_pending.swap(true, Ordering::SeqCst) { + timers.handshake_attempts.store(0, Ordering::SeqCst); + timers.new_handshake.fire(); + } + } + + fn key_confirmed(peer: &Self::Opaque) { + peer.timers().retransmit_handshake.stop(); + } +} diff --git a/src/wireguard/types/bind.rs b/src/wireguard/types/bind.rs new file mode 100644 index 0000000..3d3f187 --- /dev/null +++ b/src/wireguard/types/bind.rs @@ -0,0 +1,23 @@ +use super::Endpoint; +use std::error::Error; + +pub trait Reader: Send + Sync { + type Error: Error; + + fn read(&self, buf: &mut [u8]) -> Result<(usize, E), Self::Error>; +} + +pub trait Writer: Send + Sync + Clone + 'static { + type Error: Error; + + fn write(&self, buf: &[u8], dst: &E) -> Result<(), Self::Error>; +} + +pub trait Bind: Send + Sync + 'static { + type Error: Error; + type Endpoint: Endpoint; + + /* Until Rust gets type equality constraints these have to be generic */ + type Writer: Writer; + type Reader: Reader; +} diff --git a/src/wireguard/types/dummy.rs b/src/wireguard/types/dummy.rs new file mode 100644 index 0000000..2403c9b --- /dev/null +++ b/src/wireguard/types/dummy.rs @@ -0,0 +1,323 @@ +use std::error::Error; +use std::fmt; +use std::marker; +use std::net::SocketAddr; +use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; +use std::sync::Arc; +use std::sync::Mutex; +use std::time::Instant; +use std::sync::atomic::{Ordering, AtomicUsize}; + +use super::*; + +/* This submodule provides pure/dummy implementations of the IO interfaces + * for use in unit tests thoughout the project. + */ + +/* Error implementation */ + +#[derive(Debug)] +pub enum BindError { + Disconnected, +} + +impl Error for BindError { + fn description(&self) -> &str { + "Generic Bind Error" + } + + fn source(&self) -> Option<&(dyn Error + 'static)> { + None + } +} + +impl fmt::Display for BindError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + BindError::Disconnected => write!(f, "PairBind disconnected"), + } + } +} + +/* TUN implementation */ + +#[derive(Debug)] +pub enum TunError { + Disconnected +} + +impl Error for TunError { + fn description(&self) -> &str { + "Generic Tun Error" + } + + fn source(&self) -> Option<&(dyn Error + 'static)> { + None + } +} + +impl fmt::Display for TunError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Not Possible") + } +} + +/* Endpoint implementation */ + +#[derive(Clone, Copy)] +pub struct UnitEndpoint {} + +impl Endpoint for UnitEndpoint { + fn from_address(_: SocketAddr) -> UnitEndpoint { + UnitEndpoint {} + } + + fn into_address(&self) -> SocketAddr { + "127.0.0.1:8080".parse().unwrap() + } + + fn clear_src(&self) {} +} + +impl UnitEndpoint { + pub fn new() -> UnitEndpoint { + UnitEndpoint {} + } +} + +/* */ + +pub struct TunTest {} + +pub struct TunFakeIO { + store: bool, + tx: SyncSender>, + rx: Receiver> +} + +pub struct TunReader { + rx: Receiver> +} + +pub struct TunWriter { + store: bool, + tx: Mutex>> +} + +#[derive(Clone)] +pub struct TunMTU { + mtu: Arc +} + +impl tun::Reader for TunReader { + type Error = TunError; + + fn read(&self, buf: &mut [u8], offset: usize) -> Result { + match self.rx.recv() { + Ok(m) => { + buf[offset..].copy_from_slice(&m[..]); + Ok(m.len()) + } + Err(_) => Err(TunError::Disconnected) + } + } +} + +impl tun::Writer for TunWriter { + type Error = TunError; + + fn write(&self, src: &[u8]) -> Result<(), Self::Error> { + if self.store { + let m = src.to_owned(); + match self.tx.lock().unwrap().send(m) { + Ok(_) => Ok(()), + Err(_) => Err(TunError::Disconnected) + } + } else { + Ok(()) + } + } +} + +impl tun::MTU for TunMTU { + fn mtu(&self) -> usize { + self.mtu.load(Ordering::Acquire) + } +} + +impl tun::Tun for TunTest { + type Writer = TunWriter; + type Reader = TunReader; + type MTU = TunMTU; + type Error = TunError; +} + +impl TunFakeIO { + pub fn write(&self, msg : Vec) { + if self.store { + self.tx.send(msg).unwrap(); + } + } + + pub fn read(&self) -> Vec { + self.rx.recv().unwrap() + } +} + +impl TunTest { + pub fn create(mtu : usize, store: bool) -> (TunFakeIO, TunReader, TunWriter, TunMTU) { + + let (tx1, rx1) = if store { sync_channel(32) } else { sync_channel(1) }; + let (tx2, rx2) = if store { sync_channel(32) } else { sync_channel(1) }; + + let fake = TunFakeIO{tx: tx1, rx: rx2, store}; + let reader = TunReader{rx : rx1}; + let writer = TunWriter{tx : Mutex::new(tx2), store}; + let mtu = TunMTU{mtu : Arc::new(AtomicUsize::new(mtu))}; + + (fake, reader, writer, mtu) + } +} + +/* Void Bind */ + +#[derive(Clone, Copy)] +pub struct VoidBind {} + +impl bind::Reader for VoidBind { + type Error = BindError; + + fn read(&self, _buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> { + Ok((0, UnitEndpoint {})) + } +} + +impl bind::Writer for VoidBind { + type Error = BindError; + + fn write(&self, _buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> { + Ok(()) + } +} + +impl bind::Bind for VoidBind { + type Error = BindError; + type Endpoint = UnitEndpoint; + + type Reader = VoidBind; + type Writer = VoidBind; +} + +impl VoidBind { + pub fn new() -> VoidBind { + VoidBind {} + } +} + +/* Pair Bind */ + +#[derive(Clone)] +pub struct PairReader { + recv: Arc>>>, + _marker: marker::PhantomData, +} + +impl bind::Reader for PairReader { + type Error = BindError; + fn read(&self, buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> { + let vec = self + .recv + .lock() + .unwrap() + .recv() + .map_err(|_| BindError::Disconnected)?; + let len = vec.len(); + buf[..len].copy_from_slice(&vec[..]); + Ok((vec.len(), UnitEndpoint {})) + } +} + +impl bind::Writer for PairWriter { + type Error = BindError; + fn write(&self, buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> { + let owned = buf.to_owned(); + match self.send.lock().unwrap().send(owned) { + Err(_) => Err(BindError::Disconnected), + Ok(_) => Ok(()), + } + } +} + +#[derive(Clone)] +pub struct PairWriter { + send: Arc>>>, + _marker: marker::PhantomData, +} + +#[derive(Clone)] +pub struct PairBind {} + +impl PairBind { + pub fn pair() -> ( + (PairReader, PairWriter), + (PairReader, PairWriter), + ) { + let (tx1, rx1) = sync_channel(128); + let (tx2, rx2) = sync_channel(128); + ( + ( + PairReader { + recv: Arc::new(Mutex::new(rx1)), + _marker: marker::PhantomData, + }, + PairWriter { + send: Arc::new(Mutex::new(tx2)), + _marker: marker::PhantomData, + }, + ), + ( + PairReader { + recv: Arc::new(Mutex::new(rx2)), + _marker: marker::PhantomData, + }, + PairWriter { + send: Arc::new(Mutex::new(tx1)), + _marker: marker::PhantomData, + }, + ), + ) + } +} + +impl bind::Bind for PairBind { + type Error = BindError; + type Endpoint = UnitEndpoint; + type Reader = PairReader; + type Writer = PairWriter; +} + +pub fn keypair(initiator: bool) -> KeyPair { + let k1 = Key { + key: [0x53u8; 32], + id: 0x646e6573, + }; + let k2 = Key { + key: [0x52u8; 32], + id: 0x76636572, + }; + if initiator { + KeyPair { + birth: Instant::now(), + initiator: true, + send: k1, + recv: k2, + } + } else { + KeyPair { + birth: Instant::now(), + initiator: false, + send: k2, + recv: k1, + } + } +} diff --git a/src/wireguard/types/endpoint.rs b/src/wireguard/types/endpoint.rs new file mode 100644 index 0000000..f4f93da --- /dev/null +++ b/src/wireguard/types/endpoint.rs @@ -0,0 +1,7 @@ +use std::net::SocketAddr; + +pub trait Endpoint: Send + 'static { + fn from_address(addr: SocketAddr) -> Self; + fn into_address(&self) -> SocketAddr; + fn clear_src(&self); +} diff --git a/src/wireguard/types/keys.rs b/src/wireguard/types/keys.rs new file mode 100644 index 0000000..282c4ae --- /dev/null +++ b/src/wireguard/types/keys.rs @@ -0,0 +1,36 @@ +use clear_on_drop::clear::Clear; +use std::time::Instant; + +#[derive(Debug, Clone)] +pub struct Key { + pub key: [u8; 32], + pub id: u32, +} + +// zero key on drop +impl Drop for Key { + fn drop(&mut self) { + self.key.clear() + } +} + +#[cfg(test)] +impl PartialEq for Key { + fn eq(&self, other: &Self) -> bool { + self.id == other.id && self.key[..] == other.key[..] + } +} + +#[derive(Debug, Clone)] +pub struct KeyPair { + pub birth: Instant, // when was the key-pair created + pub initiator: bool, // has the key-pair been confirmed? + pub send: Key, // key for outbound messages + pub recv: Key, // key for inbound messages +} + +impl KeyPair { + pub fn local_id(&self) -> u32 { + self.recv.id + } +} diff --git a/src/wireguard/types/mod.rs b/src/wireguard/types/mod.rs new file mode 100644 index 0000000..e0725f3 --- /dev/null +++ b/src/wireguard/types/mod.rs @@ -0,0 +1,10 @@ +mod endpoint; +mod keys; +pub mod tun; +pub mod bind; + +#[cfg(test)] +pub mod dummy; + +pub use endpoint::Endpoint; +pub use keys::{Key, KeyPair}; \ No newline at end of file diff --git a/src/wireguard/types/tun.rs b/src/wireguard/types/tun.rs new file mode 100644 index 0000000..2ba16ff --- /dev/null +++ b/src/wireguard/types/tun.rs @@ -0,0 +1,56 @@ +use std::error::Error; + +pub trait Writer: Send + Sync + 'static { + type Error: Error; + + /// Receive a cryptkey routed IP packet + /// + /// # Arguments + /// + /// - src: Buffer containing the IP packet to be written + /// + /// # Returns + /// + /// Unit type or an error + fn write(&self, src: &[u8]) -> Result<(), Self::Error>; +} + +pub trait Reader: Send + 'static { + type Error: Error; + + /// Reads an IP packet into dst[offset:] from the tunnel device + /// + /// The reason for providing space for a prefix + /// is to efficiently accommodate platforms on which the packet is prefaced by a header. + /// This space is later used to construct the transport message inplace. + /// + /// # Arguments + /// + /// - buf: Destination buffer (enough space for MTU bytes + header) + /// - offset: Offset for the beginning of the IP packet + /// + /// # Returns + /// + /// The size of the IP packet (ignoring the header) or an std::error::Error instance: + fn read(&self, buf: &mut [u8], offset: usize) -> Result; +} + +pub trait MTU: Send + Sync + Clone + 'static { + /// Returns the MTU of the device + /// + /// This function needs to be efficient (called for every read). + /// The goto implementation strategy is to .load an atomic variable, + /// then use e.g. netlink to update the variable in a separate thread. + /// + /// # Returns + /// + /// The MTU of the interface in bytes + fn mtu(&self) -> usize; +} + +pub trait Tun: Send + Sync + 'static { + type Writer: Writer; + type Reader: Reader; + type MTU: MTU; + type Error: Error; +} diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs new file mode 100644 index 0000000..7a22280 --- /dev/null +++ b/src/wireguard/wireguard.rs @@ -0,0 +1,407 @@ +use super::constants::*; +use super::handshake; +use super::router; +use super::timers::{Events, Timers}; + +use super::types::bind::Reader as BindReader; +use super::types::bind::{Bind, Writer}; +use super::types::tun::{Reader, Tun, MTU}; +use super::types::Endpoint; + +use hjul::Runner; + +use std::fmt; +use std::ops::Deref; +use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; +use std::sync::Arc; +use std::thread; +use std::time::{Duration, Instant}; + +use std::collections::HashMap; + +use log::debug; +use rand::rngs::OsRng; +use spin::{Mutex, RwLock, RwLockReadGuard}; + +use byteorder::{ByteOrder, LittleEndian}; +use crossbeam_channel::{bounded, Sender}; +use x25519_dalek::{PublicKey, StaticSecret}; + +const SIZE_HANDSHAKE_QUEUE: usize = 128; +const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4; +const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000); + +pub struct Peer { + pub router: Arc, T::Writer, B::Writer>>, + pub state: Arc>, +} + +impl Clone for Peer { + fn clone(&self) -> Peer { + Peer { + router: self.router.clone(), + state: self.state.clone(), + } + } +} + +pub struct PeerInner { + pub keepalive: AtomicUsize, // keepalive interval + pub rx_bytes: AtomicU64, + pub tx_bytes: AtomicU64, + pub queue: Mutex>>, // handshake queue + pub pk: PublicKey, // DISCUSS: Change layout in handshake module (adopt pattern of router), to avoid this. + pub timers: RwLock, // +} + +impl PeerInner { + #[inline(always)] + pub fn timers(&self) -> RwLockReadGuard { + self.timers.read() + } +} + +impl fmt::Display for Peer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "peer()") + } +} + +impl Deref for Peer { + type Target = PeerInner; + fn deref(&self) -> &Self::Target { + &self.state + } +} + +impl PeerInner { + pub fn new_handshake(&self) { + // TODO: clear endpoint source address ("unsticky") + self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap(); + } +} + +struct Handshake { + device: handshake::Device, + active: bool, +} + +pub enum HandshakeJob { + Message(Vec, E), + New(PublicKey), +} + +struct WireguardInner { + // provides access to the MTU value of the tun device + // (otherwise owned solely by the router and a dedicated read IO thread) + mtu: T::MTU, + send: RwLock>, + + // identify and configuration map + peers: RwLock>>, + + // cryptkey router + router: router::Device, T::Writer, B::Writer>, + + // handshake related state + handshake: RwLock, + under_load: AtomicBool, + pending: AtomicUsize, // num of pending handshake packets in queue + queue: Mutex>>, +} + +pub struct Wireguard { + runner: Runner, + state: Arc>, +} + +/* Returns the padded length of a message: + * + * # Arguments + * + * - `size` : Size of unpadded message + * - `mtu` : Maximum transmission unit of the device + * + * # Returns + * + * The padded length (always less than or equal to the MTU) + */ +#[inline(always)] +const fn padding(size: usize, mtu: usize) -> usize { + #[inline(always)] + const fn min(a: usize, b: usize) -> usize { + let m = (a > b) as usize; + a * m + (1 - m) * b + } + let pad = MESSAGE_PADDING_MULTIPLE; + min(mtu, size + (pad - size % pad) % pad) +} + +impl Wireguard { + pub fn set_key(&self, sk: Option) { + let mut handshake = self.state.handshake.write(); + match sk { + None => { + let mut rng = OsRng::new().unwrap(); + handshake.device.set_sk(StaticSecret::new(&mut rng)); + handshake.active = false; + } + Some(sk) => { + handshake.device.set_sk(sk); + handshake.active = true; + } + } + } + + pub fn get_sk(&self) -> Option { + let handshake = self.state.handshake.read(); + if handshake.active { + Some(handshake.device.get_sk()) + } else { + None + } + } + + pub fn new_peer(&self, pk: PublicKey) -> Peer { + let state = Arc::new(PeerInner { + pk, + queue: Mutex::new(self.state.queue.lock().clone()), + keepalive: AtomicUsize::new(0), + rx_bytes: AtomicU64::new(0), + tx_bytes: AtomicU64::new(0), + timers: RwLock::new(Timers::dummy(&self.runner)), + }); + + let router = Arc::new(self.state.router.new_peer(state.clone())); + + let peer = Peer { router, state }; + + /* The need for dummy timers arises from the chicken-egg + * problem of the timer callbacks being able to set timers themselves. + * + * This is in fact the only place where the write lock is ever taken. + */ + *peer.timers.write() = Timers::new(&self.runner, peer.clone()); + peer + } + + /* Begin consuming messages from the reader. + * + * Any previous reader thread is stopped by closing the previous reader, + * which unblocks the thread and causes an error on reader.read + */ + pub fn add_reader(&self, reader: B::Reader) { + let wg = self.state.clone(); + thread::spawn(move || { + let mut last_under_load = + Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000); + + loop { + // create vector big enough for any message given current MTU + let size = wg.mtu.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE; + let mut msg: Vec = Vec::with_capacity(size); + msg.resize(size, 0); + + // read UDP packet into vector + let (size, src) = match reader.read(&mut msg) { + Err(e) => { + debug!("Bind reader closed with {}", e); + return; + } + Ok(v) => v, + }; + msg.truncate(size); + + // message type de-multiplexer + if msg.len() < std::mem::size_of::() { + continue; + } + match LittleEndian::read_u32(&msg[..]) { + handshake::TYPE_COOKIE_REPLY + | handshake::TYPE_INITIATION + | handshake::TYPE_RESPONSE => { + // update under_load flag + if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD { + last_under_load = Instant::now(); + wg.under_load.store(true, Ordering::SeqCst); + } else if last_under_load.elapsed() > DURATION_UNDER_LOAD { + wg.under_load.store(false, Ordering::SeqCst); + } + + wg.queue + .lock() + .send(HandshakeJob::Message(msg, src)) + .unwrap(); + } + router::TYPE_TRANSPORT => { + // transport message + let _ = wg.router.recv(src, msg).map_err(|e| { + debug!("Failed to handle incoming transport message: {}", e); + }); + } + _ => (), + } + } + }); + } + + pub fn set_writer(&self, writer: B::Writer) { + // TODO: Consider unifying these and avoid Clone requirement on writer + *self.state.send.write() = Some(writer.clone()); + self.state.router.set_outbound_writer(writer); + } + + pub fn new(mut readers: Vec, writer: T::Writer, mtu: T::MTU) -> Wireguard { + // create device state + let mut rng = OsRng::new().unwrap(); + let (tx, rx): (Sender>, _) = bounded(SIZE_HANDSHAKE_QUEUE); + let wg = Arc::new(WireguardInner { + mtu: mtu.clone(), + peers: RwLock::new(HashMap::new()), + send: RwLock::new(None), + router: router::Device::new(num_cpus::get(), writer), // router owns the writing half + pending: AtomicUsize::new(0), + handshake: RwLock::new(Handshake { + device: handshake::Device::new(StaticSecret::new(&mut rng)), + active: false, + }), + under_load: AtomicBool::new(false), + queue: Mutex::new(tx), + }); + + // start handshake workers + for _ in 0..num_cpus::get() { + let wg = wg.clone(); + let rx = rx.clone(); + thread::spawn(move || { + // prepare OsRng instance for this thread + let mut rng = OsRng::new().unwrap(); + + // process elements from the handshake queue + for job in rx { + wg.pending.fetch_sub(1, Ordering::SeqCst); + let state = wg.handshake.read(); + if !state.active { + continue; + } + + match job { + HandshakeJob::Message(msg, src) => { + // feed message to handshake device + let src_validate = (&src).into_address(); // TODO avoid + + // process message + match state.device.process( + &mut rng, + &msg[..], + if wg.under_load.load(Ordering::Relaxed) { + Some(&src_validate) + } else { + None + }, + ) { + Ok((pk, resp, keypair)) => { + // send response + let mut resp_len: u64 = 0; + if let Some(msg) = resp { + resp_len = msg.len() as u64; + let send: &Option = &*wg.send.read(); + if let Some(writer) = send.as_ref() { + let _ = writer.write(&msg[..], &src).map_err(|e| { + debug!( + "handshake worker, failed to send response, error = {}", + e + ) + }); + } + } + + // update timers + if let Some(pk) = pk { + // authenticated handshake packet received + if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { + // add to rx_bytes and tx_bytes + let req_len = msg.len() as u64; + peer.rx_bytes.fetch_add(req_len, Ordering::Relaxed); + peer.tx_bytes.fetch_add(resp_len, Ordering::Relaxed); + + // update endpoint + peer.router.set_endpoint(src); + + // add keypair to peer + keypair.map(|kp| { + // free any unused ids + for id in peer.router.add_keypair(kp) { + state.device.release(id); + } + }); + } + } + } + Err(e) => debug!("handshake worker, error = {:?}", e), + } + } + HandshakeJob::New(pk) => { + let _ = state.device.begin(&mut rng, &pk).map(|msg| { + if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { + let _ = peer.router.send(&msg[..]).map_err(|e| { + debug!("handshake worker, failed to send handshake initiation, error = {}", e) + }); + } + }); + } + } + } + }); + } + + // start TUN read IO threads (multiple threads to support multi-queue interfaces) + debug_assert!( + readers.len() > 0, + "attempted to create WG device without TUN readers" + ); + while let Some(reader) = readers.pop() { + let wg = wg.clone(); + let mtu = mtu.clone(); + thread::spawn(move || loop { + // create vector big enough for any transport message (based on MTU) + let mtu = mtu.mtu(); + let size = mtu + router::SIZE_MESSAGE_PREFIX; + let mut msg: Vec = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX); + msg.resize(size, 0); + + // read a new IP packet + let payload = match reader.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX) { + Ok(payload) => payload, + Err(e) => { + debug!("TUN worker, failed to read from tun device: {}", e); + return; + } + }; + debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu); + + // truncate padding + let payload = padding(payload, mtu); + msg.truncate(router::SIZE_MESSAGE_PREFIX + payload); + debug_assert!(payload <= mtu); + debug_assert_eq!( + if payload < mtu { + (msg.len() - router::SIZE_MESSAGE_PREFIX) % MESSAGE_PADDING_MULTIPLE + } else { + 0 + }, + 0 + ); + + // crypt-key route + let e = wg.router.send(msg); + debug!("TUN worker, router returned {:?}", e); + }); + } + + Wireguard { + state: wg, + runner: Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY), + } + } +} -- cgit v1.2.3-59-g8ed1b