summaryrefslogtreecommitdiffstats
path: root/src/wireguard/handshake
diff options
context:
space:
mode:
Diffstat (limited to 'src/wireguard/handshake')
-rw-r--r--src/wireguard/handshake/device.rs574
-rw-r--r--src/wireguard/handshake/macs.rs327
-rw-r--r--src/wireguard/handshake/messages.rs363
-rw-r--r--src/wireguard/handshake/mod.rs21
-rw-r--r--src/wireguard/handshake/noise.rs549
-rw-r--r--src/wireguard/handshake/peer.rs142
-rw-r--r--src/wireguard/handshake/ratelimiter.rs199
-rw-r--r--src/wireguard/handshake/timestamp.rs32
-rw-r--r--src/wireguard/handshake/types.rs90
9 files changed, 2297 insertions, 0 deletions
diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs
new file mode 100644
index 0000000..6a55f6e
--- /dev/null
+++ b/src/wireguard/handshake/device.rs
@@ -0,0 +1,574 @@
+use spin::RwLock;
+use std::collections::HashMap;
+use std::net::SocketAddr;
+use std::sync::Mutex;
+use zerocopy::AsBytes;
+
+use byteorder::{ByteOrder, LittleEndian};
+
+use rand::prelude::*;
+
+use x25519_dalek::PublicKey;
+use x25519_dalek::StaticSecret;
+
+use super::macs;
+use super::messages::{CookieReply, Initiation, Response};
+use super::messages::{TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE};
+use super::noise;
+use super::peer::Peer;
+use super::ratelimiter::RateLimiter;
+use super::types::*;
+
+const MAX_PEER_PER_DEVICE: usize = 1 << 20;
+
+pub struct Device {
+ pub sk: StaticSecret, // static secret key
+ pub pk: PublicKey, // static public key
+ macs: macs::Validator, // validator for the mac fields
+ pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state
+ id_map: RwLock<HashMap<u32, [u8; 32]>>, // receiver ids -> public key
+ limiter: Mutex<RateLimiter>,
+}
+
+/* A mutable reference to the device needs to be held during configuration.
+ * Wrapping the device in a RwLock enables peer config after "configuration time"
+ */
+impl Device {
+ /// Initialize a new handshake state machine
+ ///
+ /// # Arguments
+ ///
+ /// * `sk` - x25519 scalar representing the local private key
+ pub fn new(sk: StaticSecret) -> Device {
+ let pk = PublicKey::from(&sk);
+ Device {
+ pk,
+ sk,
+ macs: macs::Validator::new(pk),
+ pk_map: HashMap::new(),
+ id_map: RwLock::new(HashMap::new()),
+ limiter: Mutex::new(RateLimiter::new()),
+ }
+ }
+
+ /// Update the secret key of the device
+ ///
+ /// # Arguments
+ ///
+ /// * `sk` - x25519 scalar representing the local private key
+ pub fn set_sk(&mut self, sk: StaticSecret) {
+ // update secret and public key
+ let pk = PublicKey::from(&sk);
+ self.sk = sk;
+ self.pk = pk;
+ self.macs = macs::Validator::new(pk);
+
+ // recalculate the shared secrets for every peer
+ let mut ids = vec![];
+ for mut peer in self.pk_map.values_mut() {
+ peer.reset_state().map(|id| ids.push(id));
+ peer.ss = self.sk.diffie_hellman(&peer.pk)
+ }
+
+ // release ids from aborted handshakes
+ for id in ids {
+ self.release(id)
+ }
+ }
+
+ /// Return the secret key of the device
+ ///
+ /// # Returns
+ ///
+ /// A secret key (x25519 scalar)
+ pub fn get_sk(&self) -> StaticSecret {
+ StaticSecret::from(self.sk.to_bytes())
+ }
+
+ /// Add a new public key to the state machine
+ /// To remove public keys, you must create a new machine instance
+ ///
+ /// # Arguments
+ ///
+ /// * `pk` - The public key to add
+ /// * `identifier` - Associated identifier which can be used to distinguish the peers
+ pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> {
+ // check that the pk is not added twice
+ if let Some(_) = self.pk_map.get(pk.as_bytes()) {
+ return Err(ConfigError::new("Duplicate public key"));
+ };
+
+ // check that the pk is not that of the device
+ if *self.pk.as_bytes() == *pk.as_bytes() {
+ return Err(ConfigError::new(
+ "Public key corresponds to secret key of interface",
+ ));
+ }
+
+ // ensure less than 2^20 peers
+ if self.pk_map.len() > MAX_PEER_PER_DEVICE {
+ return Err(ConfigError::new("Too many peers for device"));
+ }
+
+ // map the public key to the peer state
+ self.pk_map
+ .insert(*pk.as_bytes(), Peer::new(pk, self.sk.diffie_hellman(&pk)));
+
+ Ok(())
+ }
+
+ /// Remove a peer by public key
+ /// To remove public keys, you must create a new machine instance
+ ///
+ /// # Arguments
+ ///
+ /// * `pk` - The public key of the peer to remove
+ ///
+ /// # Returns
+ ///
+ /// The call might fail if the public key is not found
+ pub fn remove(&mut self, pk: PublicKey) -> Result<(), ConfigError> {
+ // take write-lock on receive id table
+ let mut id_map = self.id_map.write();
+
+ // remove the peer
+ self.pk_map
+ .remove(pk.as_bytes())
+ .ok_or(ConfigError::new("Public key not in device"))?;
+
+ // pruge the id map (linear scan)
+ id_map.retain(|_, v| v != pk.as_bytes());
+ Ok(())
+ }
+
+ /// Add a psk to the peer
+ ///
+ /// # Arguments
+ ///
+ /// * `pk` - The public key of the peer
+ /// * `psk` - The psk to set / unset
+ ///
+ /// # Returns
+ ///
+ /// The call might fail if the public key is not found
+ pub fn set_psk(&mut self, pk: PublicKey, psk: Option<Psk>) -> Result<(), ConfigError> {
+ match self.pk_map.get_mut(pk.as_bytes()) {
+ Some(mut peer) => {
+ peer.psk = match psk {
+ Some(v) => v,
+ None => [0u8; 32],
+ };
+ Ok(())
+ }
+ _ => Err(ConfigError::new("No such public key")),
+ }
+ }
+
+ /// Return the psk for the peer
+ ///
+ /// # Arguments
+ ///
+ /// * `pk` - The public key of the peer
+ ///
+ /// # Returns
+ ///
+ /// A 32 byte array holding the PSK
+ ///
+ /// The call might fail if the public key is not found
+ pub fn get_psk(&self, pk: PublicKey) -> Result<Psk, ConfigError> {
+ match self.pk_map.get(pk.as_bytes()) {
+ Some(peer) => Ok(peer.psk),
+ _ => Err(ConfigError::new("No such public key")),
+ }
+ }
+
+ /// Release an id back to the pool
+ ///
+ /// # Arguments
+ ///
+ /// * `id` - The (sender) id to release
+ pub fn release(&self, id: u32) {
+ let mut m = self.id_map.write();
+ debug_assert!(m.contains_key(&id), "Releasing id not allocated");
+ m.remove(&id);
+ }
+
+ /// Begin a new handshake
+ ///
+ /// # Arguments
+ ///
+ /// * `pk` - Public key of peer to initiate handshake for
+ pub fn begin<R: RngCore + CryptoRng>(
+ &self,
+ rng: &mut R,
+ pk: &PublicKey,
+ ) -> Result<Vec<u8>, HandshakeError> {
+ match self.pk_map.get(pk.as_bytes()) {
+ None => Err(HandshakeError::UnknownPublicKey),
+ Some(peer) => {
+ let sender = self.allocate(rng, peer);
+
+ let mut msg = Initiation::default();
+
+ noise::create_initiation(rng, self, peer, sender, &mut msg.noise)?;
+
+ // add macs to initation
+
+ peer.macs
+ .lock()
+ .generate(msg.noise.as_bytes(), &mut msg.macs);
+
+ Ok(msg.as_bytes().to_owned())
+ }
+ }
+ }
+
+ /// Process a handshake message.
+ ///
+ /// # Arguments
+ ///
+ /// * `msg` - Byte slice containing the message (untrusted input)
+ pub fn process<'a, R: RngCore + CryptoRng, S>(
+ &self,
+ rng: &mut R, // rng instance to sample randomness from
+ msg: &[u8], // message buffer
+ src: Option<&'a S>, // optional source endpoint, set when "under load"
+ ) -> Result<Output, HandshakeError>
+ where
+ &'a S: Into<&'a SocketAddr>,
+ {
+ // ensure type read in-range
+ if msg.len() < 4 {
+ return Err(HandshakeError::InvalidMessageFormat);
+ }
+
+ // de-multiplex the message type field
+ match LittleEndian::read_u32(msg) {
+ TYPE_INITIATION => {
+ // parse message
+ let msg = Initiation::parse(msg)?;
+
+ // check mac1 field
+ self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
+
+ // address validation & DoS mitigation
+ if let Some(src) = src {
+ // obtain ref to socket addr
+ let src = src.into();
+
+ // check mac2 field
+ if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
+ let mut reply = Default::default();
+ self.macs.create_cookie_reply(
+ rng,
+ msg.noise.f_sender.get(),
+ src,
+ &msg.macs,
+ &mut reply,
+ );
+ return Ok((None, Some(reply.as_bytes().to_owned()), None));
+ }
+
+ // check ratelimiter
+ if !self.limiter.lock().unwrap().allow(&src.ip()) {
+ return Err(HandshakeError::RateLimited);
+ }
+ }
+
+ // consume the initiation
+ let (peer, st) = noise::consume_initiation(self, &msg.noise)?;
+
+ // allocate new index for response
+ let sender = self.allocate(rng, peer);
+
+ // prepare memory for response, TODO: take slice for zero allocation
+ let mut resp = Response::default();
+
+ // create response (release id on error)
+ let keys = noise::create_response(rng, peer, sender, st, &mut resp.noise).map_err(
+ |e| {
+ self.release(sender);
+ e
+ },
+ )?;
+
+ // add macs to response
+ peer.macs
+ .lock()
+ .generate(resp.noise.as_bytes(), &mut resp.macs);
+
+ // return unconfirmed keypair and the response as vector
+ Ok((Some(peer.pk), Some(resp.as_bytes().to_owned()), Some(keys)))
+ }
+ TYPE_RESPONSE => {
+ let msg = Response::parse(msg)?;
+
+ // check mac1 field
+ self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
+
+ // address validation & DoS mitigation
+ if let Some(src) = src {
+ // obtain ref to socket addr
+ let src = src.into();
+
+ // check mac2 field
+ if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
+ let mut reply = Default::default();
+ self.macs.create_cookie_reply(
+ rng,
+ msg.noise.f_sender.get(),
+ src,
+ &msg.macs,
+ &mut reply,
+ );
+ return Ok((None, Some(reply.as_bytes().to_owned()), None));
+ }
+
+ // check ratelimiter
+ if !self.limiter.lock().unwrap().allow(&src.ip()) {
+ return Err(HandshakeError::RateLimited);
+ }
+ }
+
+ // consume inner playload
+ noise::consume_response(self, &msg.noise)
+ }
+ TYPE_COOKIE_REPLY => {
+ let msg = CookieReply::parse(msg)?;
+
+ // lookup peer
+ let peer = self.lookup_id(msg.f_receiver.get())?;
+
+ // validate cookie reply
+ peer.macs.lock().process(&msg)?;
+
+ // this prompts no new message and
+ // DOES NOT cryptographically verify the peer
+ Ok((None, None, None))
+ }
+ _ => Err(HandshakeError::InvalidMessageFormat),
+ }
+ }
+
+ // Internal function
+ //
+ // Return the peer associated with the public key
+ pub(crate) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer, HandshakeError> {
+ self.pk_map
+ .get(pk.as_bytes())
+ .ok_or(HandshakeError::UnknownPublicKey)
+ }
+
+ // Internal function
+ //
+ // Return the peer currently associated with the receiver identifier
+ pub(crate) fn lookup_id(&self, id: u32) -> Result<&Peer, HandshakeError> {
+ let im = self.id_map.read();
+ let pk = im.get(&id).ok_or(HandshakeError::UnknownReceiverId)?;
+ match self.pk_map.get(pk) {
+ Some(peer) => Ok(peer),
+ _ => unreachable!(), // if the id-lookup succeeded, the peer should exist
+ }
+ }
+
+ // Internal function
+ //
+ // Allocated a new receiver identifier for the peer
+ fn allocate<R: RngCore + CryptoRng>(&self, rng: &mut R, peer: &Peer) -> u32 {
+ loop {
+ let id = rng.gen();
+
+ // check membership with read lock
+ if self.id_map.read().contains_key(&id) {
+ continue;
+ }
+
+ // take write lock and add index
+ let mut m = self.id_map.write();
+ if !m.contains_key(&id) {
+ m.insert(id, *peer.pk.as_bytes());
+ return id;
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::super::messages::*;
+ use super::*;
+ use hex;
+ use rand::rngs::OsRng;
+ use std::net::SocketAddr;
+ use std::thread;
+ use std::time::Duration;
+
+ fn setup_devices<R: RngCore + CryptoRng>(
+ rng: &mut R,
+ ) -> (PublicKey, Device, PublicKey, Device) {
+ // generate new keypairs
+
+ let sk1 = StaticSecret::new(rng);
+ let pk1 = PublicKey::from(&sk1);
+
+ let sk2 = StaticSecret::new(rng);
+ let pk2 = PublicKey::from(&sk2);
+
+ // pick random psk
+
+ let mut psk = [0u8; 32];
+ rng.fill_bytes(&mut psk[..]);
+
+ // intialize devices on both ends
+
+ let mut dev1 = Device::new(sk1);
+ let mut dev2 = Device::new(sk2);
+
+ dev1.add(pk2).unwrap();
+ dev2.add(pk1).unwrap();
+
+ dev1.set_psk(pk2, Some(psk)).unwrap();
+ dev2.set_psk(pk1, Some(psk)).unwrap();
+
+ (pk1, dev1, pk2, dev2)
+ }
+
+ /* Test longest possible handshake interaction (7 messages):
+ *
+ * 1. I -> R (initation)
+ * 2. I <- R (cookie reply)
+ * 3. I -> R (initation)
+ * 4. I <- R (response)
+ * 5. I -> R (cookie reply)
+ * 6. I -> R (initation)
+ * 7. I <- R (response)
+ */
+ #[test]
+ fn handshake_under_load() {
+ let mut rng = OsRng::new().unwrap();
+ let (_pk1, dev1, pk2, dev2) = setup_devices(&mut rng);
+
+ let src1: SocketAddr = "172.16.0.1:8080".parse().unwrap();
+ let src2: SocketAddr = "172.16.0.2:7070".parse().unwrap();
+
+ // 1. device-1 : create first initation
+ let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
+
+ // 2. device-2 : responds with CookieReply
+ let msg_cookie = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
+ (None, Some(msg), None) => msg,
+ _ => panic!("unexpected response"),
+ };
+
+ // device-1 : processes CookieReply (no response)
+ match dev1.process(&mut rng, &msg_cookie, Some(&src2)).unwrap() {
+ (None, None, None) => (),
+ _ => panic!("unexpected response"),
+ }
+
+ // avoid initation flood
+ thread::sleep(Duration::from_millis(20));
+
+ // 3. device-1 : create second initation
+ let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
+
+ // 4. device-2 : responds with noise response
+ let msg_response = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
+ (Some(_), Some(msg), Some(kp)) => {
+ assert_eq!(kp.initiator, false);
+ msg
+ }
+ _ => panic!("unexpected response"),
+ };
+
+ // 5. device-1 : responds with CookieReply
+ let msg_cookie = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() {
+ (None, Some(msg), None) => msg,
+ _ => panic!("unexpected response"),
+ };
+
+ // device-2 : processes CookieReply (no response)
+ match dev2.process(&mut rng, &msg_cookie, Some(&src1)).unwrap() {
+ (None, None, None) => (),
+ _ => panic!("unexpected response"),
+ }
+
+ // avoid initation flood
+ thread::sleep(Duration::from_millis(20));
+
+ // 6. device-1 : create third initation
+ let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
+
+ // 7. device-2 : responds with noise response
+ let (msg_response, kp1) = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
+ (Some(_), Some(msg), Some(kp)) => {
+ assert_eq!(kp.initiator, false);
+ (msg, kp)
+ }
+ _ => panic!("unexpected response"),
+ };
+
+ // device-1 : process noise response
+ let kp2 = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() {
+ (Some(_), None, Some(kp)) => {
+ assert_eq!(kp.initiator, true);
+ kp
+ }
+ _ => panic!("unexpected response"),
+ };
+
+ assert_eq!(kp1.send, kp2.recv);
+ assert_eq!(kp1.recv, kp2.send);
+ }
+
+ #[test]
+ fn handshake_no_load() {
+ let mut rng = OsRng::new().unwrap();
+ let (pk1, mut dev1, pk2, mut dev2) = setup_devices(&mut rng);
+
+ // do a few handshakes (every handshake should succeed)
+
+ for i in 0..10 {
+ println!("handshake : {}", i);
+
+ // create initiation
+
+ let msg1 = dev1.begin(&mut rng, &pk2).unwrap();
+
+ println!("msg1 = {} : {} bytes", hex::encode(&msg1[..]), msg1.len());
+ println!("msg1 = {:?}", Initiation::parse(&msg1[..]).unwrap());
+
+ // process initiation and create response
+
+ let (_, msg2, ks_r) = dev2.process(&mut rng, &msg1, None).unwrap();
+
+ let ks_r = ks_r.unwrap();
+ let msg2 = msg2.unwrap();
+
+ println!("msg2 = {} : {} bytes", hex::encode(&msg2[..]), msg2.len());
+ println!("msg2 = {:?}", Response::parse(&msg2[..]).unwrap());
+
+ assert!(!ks_r.initiator, "Responders key-pair is confirmed");
+
+ // process response and obtain confirmed key-pair
+
+ let (_, msg3, ks_i) = dev1.process(&mut rng, &msg2, None).unwrap();
+ let ks_i = ks_i.unwrap();
+
+ assert!(msg3.is_none(), "Returned message after response");
+ assert!(ks_i.initiator, "Initiators key-pair is not confirmed");
+
+ assert_eq!(ks_i.send, ks_r.recv, "KeyI.send != KeyR.recv");
+ assert_eq!(ks_i.recv, ks_r.send, "KeyI.recv != KeyR.send");
+
+ dev1.release(ks_i.send.id);
+ dev2.release(ks_r.send.id);
+
+ // to avoid flood detection
+ thread::sleep(Duration::from_millis(20));
+ }
+
+ dev1.remove(pk2).unwrap();
+ dev2.remove(pk1).unwrap();
+ }
+}
diff --git a/src/wireguard/handshake/macs.rs b/src/wireguard/handshake/macs.rs
new file mode 100644
index 0000000..689826b
--- /dev/null
+++ b/src/wireguard/handshake/macs.rs
@@ -0,0 +1,327 @@
+use generic_array::GenericArray;
+use rand::{CryptoRng, RngCore};
+use spin::RwLock;
+use std::time::{Duration, Instant};
+
+// types to coalesce into bytes
+use std::net::SocketAddr;
+use x25519_dalek::PublicKey;
+
+// AEAD
+use aead::{Aead, NewAead, Payload};
+use chacha20poly1305::XChaCha20Poly1305;
+
+// MAC
+use blake2::Blake2s;
+use subtle::ConstantTimeEq;
+
+use super::messages::{CookieReply, MacsFooter, TYPE_COOKIE_REPLY};
+use super::types::HandshakeError;
+
+const LABEL_MAC1: &[u8] = b"mac1----";
+const LABEL_COOKIE: &[u8] = b"cookie--";
+
+const SIZE_COOKIE: usize = 16;
+const SIZE_SECRET: usize = 32;
+const SIZE_MAC: usize = 16; // blake2s-mac128
+const SIZE_TAG: usize = 16; // xchacha20poly1305 tag
+
+const COOKIE_UPDATE_INTERVAL: Duration = Duration::from_secs(120);
+
+macro_rules! HASH {
+ ( $($input:expr),* ) => {{
+ use blake2::Digest;
+ let mut hsh = Blake2s::new();
+ $(
+ hsh.input($input);
+ )*
+ hsh.result()
+ }};
+}
+
+macro_rules! MAC {
+ ( $key:expr, $($input:expr),* ) => {{
+ use blake2::VarBlake2s;
+ use digest::Input;
+ use digest::VariableOutput;
+ let mut tag = [0u8; SIZE_MAC];
+ let mut mac = VarBlake2s::new_keyed($key, SIZE_MAC);
+ $(
+ mac.input($input);
+ )*
+ mac.variable_result(|buf| tag.copy_from_slice(buf));
+ tag
+ }};
+}
+
+macro_rules! XSEAL {
+ ($key:expr, $nonce:expr, $ad:expr, $pt:expr, $ct:expr) => {{
+ let ct = XChaCha20Poly1305::new(*GenericArray::from_slice($key))
+ .encrypt(
+ GenericArray::from_slice($nonce),
+ Payload { msg: $pt, aad: $ad },
+ )
+ .unwrap();
+ debug_assert_eq!(ct.len(), $pt.len() + SIZE_TAG);
+ $ct.copy_from_slice(&ct);
+ }};
+}
+
+macro_rules! XOPEN {
+ ($key:expr, $nonce:expr, $ad:expr, $pt:expr, $ct:expr) => {{
+ debug_assert_eq!($ct.len(), $pt.len() + SIZE_TAG);
+ XChaCha20Poly1305::new(*GenericArray::from_slice($key))
+ .decrypt(
+ GenericArray::from_slice($nonce),
+ Payload { msg: $ct, aad: $ad },
+ )
+ .map_err(|_| HandshakeError::DecryptionFailure)
+ .map(|pt| $pt.copy_from_slice(&pt))
+ }};
+}
+
+struct Cookie {
+ value: [u8; 16],
+ birth: Instant,
+}
+
+pub struct Generator {
+ mac1_key: [u8; 32],
+ cookie_key: [u8; 32], // xchacha20poly key for opening cookie response
+ last_mac1: Option<[u8; 16]>,
+ cookie: Option<Cookie>,
+}
+
+fn addr_to_mac_bytes(addr: &SocketAddr) -> Vec<u8> {
+ match addr {
+ SocketAddr::V4(addr) => {
+ let mut res = Vec::with_capacity(4 + 2);
+ res.extend(&addr.ip().octets());
+ res.extend(&addr.port().to_le_bytes());
+ res
+ }
+ SocketAddr::V6(addr) => {
+ let mut res = Vec::with_capacity(16 + 2);
+ res.extend(&addr.ip().octets());
+ res.extend(&addr.port().to_le_bytes());
+ res
+ }
+ }
+}
+
+impl Generator {
+ /// Initalize a new mac field generator
+ ///
+ /// # Arguments
+ ///
+ /// - pk: The public key of the peer to which the generator is associated
+ ///
+ /// # Returns
+ ///
+ /// A freshly initated generator
+ pub fn new(pk: PublicKey) -> Generator {
+ Generator {
+ mac1_key: HASH!(LABEL_MAC1, pk.as_bytes()).into(),
+ cookie_key: HASH!(LABEL_COOKIE, pk.as_bytes()).into(),
+ last_mac1: None,
+ cookie: None,
+ }
+ }
+
+ /// Process a CookieReply message
+ ///
+ /// # Arguments
+ ///
+ /// - reply: CookieReply to process
+ ///
+ /// # Returns
+ ///
+ /// Can fail if the cookie reply fails to validate
+ /// (either indicating that it is outdated or malformed)
+ pub fn process(&mut self, reply: &CookieReply) -> Result<(), HandshakeError> {
+ let mac1 = self.last_mac1.ok_or(HandshakeError::InvalidState)?;
+ let mut tau = [0u8; SIZE_COOKIE];
+ XOPEN!(
+ &self.cookie_key, // key
+ &reply.f_nonce, // nonce
+ &mac1, // ad
+ &mut tau, // pt
+ &reply.f_cookie // ct || tag
+ )?;
+ self.cookie = Some(Cookie {
+ birth: Instant::now(),
+ value: tau,
+ });
+ Ok(())
+ }
+
+ /// Generate both mac fields for an inner message
+ ///
+ /// # Arguments
+ ///
+ /// - inner: A byteslice representing the inner message to be covered
+ /// - macs: The destination mac footer for the resulting macs
+ pub fn generate(&mut self, inner: &[u8], macs: &mut MacsFooter) {
+ macs.f_mac1 = MAC!(&self.mac1_key, inner);
+ macs.f_mac2 = match &self.cookie {
+ Some(cookie) => {
+ if cookie.birth.elapsed() > COOKIE_UPDATE_INTERVAL {
+ self.cookie = None;
+ [0u8; SIZE_MAC]
+ } else {
+ MAC!(&cookie.value, inner, macs.f_mac1)
+ }
+ }
+ None => [0u8; SIZE_MAC],
+ };
+ self.last_mac1 = Some(macs.f_mac1);
+ }
+}
+
+struct Secret {
+ value: [u8; 32],
+ birth: Instant,
+}
+
+pub struct Validator {
+ mac1_key: [u8; 32], // mac1 key, derived from device public key
+ cookie_key: [u8; 32], // xchacha20poly key for sealing cookie response
+ secret: RwLock<Secret>,
+}
+
+impl Validator {
+ pub fn new(pk: PublicKey) -> Validator {
+ Validator {
+ mac1_key: HASH!(LABEL_MAC1, pk.as_bytes()).into(),
+ cookie_key: HASH!(LABEL_COOKIE, pk.as_bytes()).into(),
+ secret: RwLock::new(Secret {
+ value: [0u8; SIZE_SECRET],
+ birth: Instant::now() - Duration::new(86400, 0),
+ }),
+ }
+ }
+
+ fn get_tau(&self, src: &[u8]) -> Option<[u8; SIZE_COOKIE]> {
+ let secret = self.secret.read();
+ if secret.birth.elapsed() < COOKIE_UPDATE_INTERVAL {
+ Some(MAC!(&secret.value, src))
+ } else {
+ None
+ }
+ }
+
+ fn get_set_tau<R: RngCore + CryptoRng>(&self, rng: &mut R, src: &[u8]) -> [u8; SIZE_COOKIE] {
+ // check if current value is still valid
+ {
+ let secret = self.secret.read();
+ if secret.birth.elapsed() < COOKIE_UPDATE_INTERVAL {
+ return MAC!(&secret.value, src);
+ };
+ }
+
+ // take write lock, check again
+ {
+ let mut secret = self.secret.write();
+ if secret.birth.elapsed() < COOKIE_UPDATE_INTERVAL {
+ return MAC!(&secret.value, src);
+ };
+
+ // set new random cookie secret
+ rng.fill_bytes(&mut secret.value);
+ secret.birth = Instant::now();
+ MAC!(&secret.value, src)
+ }
+ }
+
+ pub fn create_cookie_reply<R: RngCore + CryptoRng>(
+ &self,
+ rng: &mut R,
+ receiver: u32, // receiver id of incoming message
+ src: &SocketAddr, // source address of incoming message
+ macs: &MacsFooter, // footer of incoming message
+ msg: &mut CookieReply, // resulting cookie reply
+ ) {
+ let src = addr_to_mac_bytes(src);
+ msg.f_type.set(TYPE_COOKIE_REPLY as u32);
+ msg.f_receiver.set(receiver);
+ rng.fill_bytes(&mut msg.f_nonce);
+ XSEAL!(
+ &self.cookie_key, // key
+ &msg.f_nonce, // nonce
+ &macs.f_mac1, // ad
+ &self.get_set_tau(rng, &src), // pt
+ &mut msg.f_cookie // ct || tag
+ );
+ }
+
+ /// Check the mac1 field against the inner message
+ ///
+ /// # Arguments
+ ///
+ /// - inner: The inner message covered by the mac1 field
+ /// - macs: The mac footer
+ pub fn check_mac1(&self, inner: &[u8], macs: &MacsFooter) -> Result<(), HandshakeError> {
+ let valid_mac1: bool = MAC!(&self.mac1_key, inner).ct_eq(&macs.f_mac1).into();
+ if !valid_mac1 {
+ Err(HandshakeError::InvalidMac1)
+ } else {
+ Ok(())
+ }
+ }
+
+ pub fn check_mac2(&self, inner: &[u8], src: &SocketAddr, macs: &MacsFooter) -> bool {
+ let src = addr_to_mac_bytes(src);
+ match self.get_tau(&src) {
+ Some(tau) => MAC!(&tau, inner, macs.f_mac1).ct_eq(&macs.f_mac2).into(),
+ None => false,
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use proptest::prelude::*;
+ use rand::rngs::OsRng;
+ use x25519_dalek::StaticSecret;
+
+ fn new_validator_generator() -> (Validator, Generator) {
+ let mut rng = OsRng::new().unwrap();
+ let sk = StaticSecret::new(&mut rng);
+ let pk = PublicKey::from(&sk);
+ (Validator::new(pk), Generator::new(pk))
+ }
+
+ proptest! {
+ #[test]
+ fn test_cookie_reply(inner1 : Vec<u8>, inner2 : Vec<u8>, receiver : u32) {
+ let mut msg = CookieReply::default();
+ let mut rng = OsRng::new().expect("failed to create rng");
+ let mut macs = MacsFooter::default();
+ let src = "192.0.2.16:8080".parse().unwrap();
+ let (validator, mut generator) = new_validator_generator();
+
+ // generate mac1 for first message
+ generator.generate(&inner1[..], &mut macs);
+ assert_ne!(macs.f_mac1, [0u8; SIZE_MAC], "mac1 should be set");
+ assert_eq!(macs.f_mac2, [0u8; SIZE_MAC], "mac2 should not be set");
+
+ // check validity of mac1
+ validator.check_mac1(&inner1[..], &macs).expect("mac1 of inner1 did not validate");
+ assert_eq!(validator.check_mac2(&inner1[..], &src, &macs), false, "mac2 of inner2 did not validate");
+ validator.create_cookie_reply(&mut rng, receiver, &src, &macs, &mut msg);
+
+ // consume cookie reply
+ generator.process(&msg).expect("failed to process CookieReply");
+
+ // generate mac2 & mac2 for second message
+ generator.generate(&inner2[..], &mut macs);
+ assert_ne!(macs.f_mac1, [0u8; SIZE_MAC], "mac1 should be set");
+ assert_ne!(macs.f_mac2, [0u8; SIZE_MAC], "mac2 should be set");
+
+ // check validity of mac1 and mac2
+ validator.check_mac1(&inner2[..], &macs).expect("mac1 of inner2 did not validate");
+ assert!(validator.check_mac2(&inner2[..], &src, &macs), "mac2 of inner2 did not validate");
+ }
+ }
+}
diff --git a/src/wireguard/handshake/messages.rs b/src/wireguard/handshake/messages.rs
new file mode 100644
index 0000000..29d80af
--- /dev/null
+++ b/src/wireguard/handshake/messages.rs
@@ -0,0 +1,363 @@
+#[cfg(test)]
+use hex;
+
+#[cfg(test)]
+use std::fmt;
+
+use std::mem;
+
+use byteorder::LittleEndian;
+use zerocopy::byteorder::U32;
+use zerocopy::{AsBytes, ByteSlice, FromBytes, LayoutVerified};
+
+use super::types::*;
+
+const SIZE_MAC: usize = 16;
+const SIZE_TAG: usize = 16; // poly1305 tag
+const SIZE_XNONCE: usize = 24; // xchacha20 nonce
+const SIZE_COOKIE: usize = 16; //
+const SIZE_X25519_POINT: usize = 32; // x25519 public key
+const SIZE_TIMESTAMP: usize = 12;
+
+pub const TYPE_INITIATION: u32 = 1;
+pub const TYPE_RESPONSE: u32 = 2;
+pub const TYPE_COOKIE_REPLY: u32 = 3;
+
+const fn max(a: usize, b: usize) -> usize {
+ let m: usize = (a > b) as usize;
+ m * a + (1 - m) * b
+}
+
+pub const MAX_HANDSHAKE_MSG_SIZE: usize = max(
+ max(mem::size_of::<Response>(), mem::size_of::<Initiation>()),
+ mem::size_of::<CookieReply>(),
+);
+
+/* Handshake messsages */
+
+#[repr(packed)]
+#[derive(Copy, Clone, FromBytes, AsBytes)]
+pub struct Response {
+ pub noise: NoiseResponse, // inner message covered by macs
+ pub macs: MacsFooter,
+}
+
+#[repr(packed)]
+#[derive(Copy, Clone, FromBytes, AsBytes)]
+pub struct Initiation {
+ pub noise: NoiseInitiation, // inner message covered by macs
+ pub macs: MacsFooter,
+}
+
+#[repr(packed)]
+#[derive(Copy, Clone, FromBytes, AsBytes)]
+pub struct CookieReply {
+ pub f_type: U32<LittleEndian>,
+ pub f_receiver: U32<LittleEndian>,
+ pub f_nonce: [u8; SIZE_XNONCE],
+ pub f_cookie: [u8; SIZE_COOKIE + SIZE_TAG],
+}
+
+/* Inner sub-messages */
+
+#[repr(packed)]
+#[derive(Copy, Clone, FromBytes, AsBytes)]
+pub struct MacsFooter {
+ pub f_mac1: [u8; SIZE_MAC],
+ pub f_mac2: [u8; SIZE_MAC],
+}
+
+#[repr(packed)]
+#[derive(Copy, Clone, FromBytes, AsBytes)]
+pub struct NoiseInitiation {
+ pub f_type: U32<LittleEndian>,
+ pub f_sender: U32<LittleEndian>,
+ pub f_ephemeral: [u8; SIZE_X25519_POINT],
+ pub f_static: [u8; SIZE_X25519_POINT + SIZE_TAG],
+ pub f_timestamp: [u8; SIZE_TIMESTAMP + SIZE_TAG],
+}
+
+#[repr(packed)]
+#[derive(Copy, Clone, FromBytes, AsBytes)]
+pub struct NoiseResponse {
+ pub f_type: U32<LittleEndian>,
+ pub f_sender: U32<LittleEndian>,
+ pub f_receiver: U32<LittleEndian>,
+ pub f_ephemeral: [u8; SIZE_X25519_POINT],
+ pub f_empty: [u8; SIZE_TAG],
+}
+
+/* Zero copy parsing of handshake messages */
+
+impl Initiation {
+ pub fn parse<B: ByteSlice>(bytes: B) -> Result<LayoutVerified<B, Self>, HandshakeError> {
+ let msg: LayoutVerified<B, Self> =
+ LayoutVerified::new(bytes).ok_or(HandshakeError::InvalidMessageFormat)?;
+
+ if msg.noise.f_type.get() != (TYPE_INITIATION as u32) {
+ return Err(HandshakeError::InvalidMessageFormat);
+ }
+
+ Ok(msg)
+ }
+}
+
+impl Response {
+ pub fn parse<B: ByteSlice>(bytes: B) -> Result<LayoutVerified<B, Self>, HandshakeError> {
+ let msg: LayoutVerified<B, Self> =
+ LayoutVerified::new(bytes).ok_or(HandshakeError::InvalidMessageFormat)?;
+
+ if msg.noise.f_type.get() != (TYPE_RESPONSE as u32) {
+ return Err(HandshakeError::InvalidMessageFormat);
+ }
+
+ Ok(msg)
+ }
+}
+
+impl CookieReply {
+ pub fn parse<B: ByteSlice>(bytes: B) -> Result<LayoutVerified<B, Self>, HandshakeError> {
+ let msg: LayoutVerified<B, Self> =
+ LayoutVerified::new(bytes).ok_or(HandshakeError::InvalidMessageFormat)?;
+
+ if msg.f_type.get() != (TYPE_COOKIE_REPLY as u32) {
+ return Err(HandshakeError::InvalidMessageFormat);
+ }
+
+ Ok(msg)
+ }
+}
+
+/* Default values */
+
+impl Default for Response {
+ fn default() -> Self {
+ Self {
+ noise: Default::default(),
+ macs: Default::default(),
+ }
+ }
+}
+
+impl Default for Initiation {
+ fn default() -> Self {
+ Self {
+ noise: Default::default(),
+ macs: Default::default(),
+ }
+ }
+}
+
+impl Default for CookieReply {
+ fn default() -> Self {
+ Self {
+ f_type: <U32<LittleEndian>>::new(TYPE_COOKIE_REPLY as u32),
+ f_receiver: <U32<LittleEndian>>::ZERO,
+ f_nonce: [0u8; SIZE_XNONCE],
+ f_cookie: [0u8; SIZE_COOKIE + SIZE_TAG],
+ }
+ }
+}
+
+impl Default for MacsFooter {
+ fn default() -> Self {
+ Self {
+ f_mac1: [0u8; SIZE_MAC],
+ f_mac2: [0u8; SIZE_MAC],
+ }
+ }
+}
+
+impl Default for NoiseInitiation {
+ fn default() -> Self {
+ Self {
+ f_type: <U32<LittleEndian>>::new(TYPE_INITIATION as u32),
+ f_sender: <U32<LittleEndian>>::ZERO,
+ f_ephemeral: [0u8; SIZE_X25519_POINT],
+ f_static: [0u8; SIZE_X25519_POINT + SIZE_TAG],
+ f_timestamp: [0u8; SIZE_TIMESTAMP + SIZE_TAG],
+ }
+ }
+}
+
+impl Default for NoiseResponse {
+ fn default() -> Self {
+ Self {
+ f_type: <U32<LittleEndian>>::new(TYPE_RESPONSE as u32),
+ f_sender: <U32<LittleEndian>>::ZERO,
+ f_receiver: <U32<LittleEndian>>::ZERO,
+ f_ephemeral: [0u8; SIZE_X25519_POINT],
+ f_empty: [0u8; SIZE_TAG],
+ }
+ }
+}
+
+/* Debug formatting (for testing purposes) */
+
+#[cfg(test)]
+impl fmt::Debug for Initiation {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "Initiation {{ {:?} || {:?} }}", self.noise, self.macs)
+ }
+}
+
+#[cfg(test)]
+impl fmt::Debug for Response {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "Response {{ {:?} || {:?} }}", self.noise, self.macs)
+ }
+}
+
+#[cfg(test)]
+impl fmt::Debug for CookieReply {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(
+ f,
+ "CookieReply {{ type = {}, receiver = {}, nonce = {}, cookie = {} }}",
+ self.f_type,
+ self.f_receiver,
+ hex::encode(&self.f_nonce[..]),
+ hex::encode(&self.f_cookie[..]),
+ )
+ }
+}
+
+#[cfg(test)]
+impl fmt::Debug for NoiseInitiation {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f,
+ "NoiseInitiation {{ type = {}, sender = {}, ephemeral = {}, static = {}, timestamp = {} }}",
+ self.f_type.get(),
+ self.f_sender.get(),
+ hex::encode(&self.f_ephemeral[..]),
+ hex::encode(&self.f_static[..]),
+ hex::encode(&self.f_timestamp[..]),
+ )
+ }
+}
+
+#[cfg(test)]
+impl fmt::Debug for NoiseResponse {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f,
+ "NoiseResponse {{ type = {}, sender = {}, receiver = {}, ephemeral = {}, empty = |{} }}",
+ self.f_type,
+ self.f_sender,
+ self.f_receiver,
+ hex::encode(&self.f_ephemeral[..]),
+ hex::encode(&self.f_empty[..])
+ )
+ }
+}
+
+#[cfg(test)]
+impl fmt::Debug for MacsFooter {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(
+ f,
+ "Macs {{ mac1 = {}, mac2 = {} }}",
+ hex::encode(&self.f_mac1[..]),
+ hex::encode(&self.f_mac2[..])
+ )
+ }
+}
+
+/* Equality (for testing purposes) */
+
+#[cfg(test)]
+macro_rules! eq_as_bytes {
+ ($type:path) => {
+ impl PartialEq for $type {
+ fn eq(&self, other: &Self) -> bool {
+ self.as_bytes() == other.as_bytes()
+ }
+ }
+ impl Eq for $type {}
+ };
+}
+
+#[cfg(test)]
+eq_as_bytes!(Initiation);
+
+#[cfg(test)]
+eq_as_bytes!(Response);
+
+#[cfg(test)]
+eq_as_bytes!(CookieReply);
+
+#[cfg(test)]
+eq_as_bytes!(MacsFooter);
+
+#[cfg(test)]
+eq_as_bytes!(NoiseInitiation);
+
+#[cfg(test)]
+eq_as_bytes!(NoiseResponse);
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn message_response_identity() {
+ let mut msg: Response = Default::default();
+
+ msg.noise.f_sender.set(146252);
+ msg.noise.f_receiver.set(554442);
+ msg.noise.f_ephemeral = [
+ 0xc1, 0x66, 0x0a, 0x0c, 0xdc, 0x0f, 0x6c, 0x51, 0x0f, 0xc2, 0xcc, 0x51, 0x52, 0x0c,
+ 0xde, 0x1e, 0xf7, 0xf1, 0xca, 0x90, 0x86, 0x72, 0xad, 0x67, 0xea, 0x89, 0x45, 0x44,
+ 0x13, 0x56, 0x52, 0x1f,
+ ];
+ msg.noise.f_empty = [
+ 0x60, 0x0e, 0x1e, 0x95, 0x41, 0x6b, 0x52, 0x05, 0xa2, 0x09, 0xe1, 0xbf, 0x40, 0x05,
+ 0x2f, 0xde,
+ ];
+ msg.macs.f_mac1 = [
+ 0xf2, 0xad, 0x40, 0xb5, 0xf7, 0xde, 0x77, 0x35, 0x89, 0x19, 0xb7, 0x5c, 0xf9, 0x54,
+ 0x69, 0x29,
+ ];
+ msg.macs.f_mac2 = [
+ 0x4f, 0xd2, 0x1b, 0xfe, 0x77, 0xe6, 0x2e, 0xc9, 0x07, 0xe2, 0x87, 0x17, 0xbb, 0xe5,
+ 0xdf, 0xbb,
+ ];
+
+ let buf: Vec<u8> = msg.as_bytes().to_vec();
+ let msg_p = Response::parse(&buf[..]).unwrap();
+ assert_eq!(msg, *msg_p.into_ref());
+ }
+
+ #[test]
+ fn message_initiate_identity() {
+ let mut msg: Initiation = Default::default();
+
+ msg.noise.f_sender.set(575757);
+ msg.noise.f_ephemeral = [
+ 0xc1, 0x66, 0x0a, 0x0c, 0xdc, 0x0f, 0x6c, 0x51, 0x0f, 0xc2, 0xcc, 0x51, 0x52, 0x0c,
+ 0xde, 0x1e, 0xf7, 0xf1, 0xca, 0x90, 0x86, 0x72, 0xad, 0x67, 0xea, 0x89, 0x45, 0x44,
+ 0x13, 0x56, 0x52, 0x1f,
+ ];
+ msg.noise.f_static = [
+ 0xdc, 0x33, 0x90, 0x15, 0x8f, 0x82, 0x3e, 0x06, 0x44, 0xa0, 0xde, 0x4c, 0x15, 0x6c,
+ 0x5d, 0xa4, 0x65, 0x99, 0xf6, 0x6c, 0xa1, 0x14, 0x77, 0xf9, 0xeb, 0x6a, 0xec, 0xc3,
+ 0x3c, 0xda, 0x47, 0xe1, 0x45, 0xac, 0x8d, 0x43, 0xea, 0x1b, 0x2f, 0x02, 0x45, 0x5d,
+ 0x86, 0x37, 0xee, 0x83, 0x6b, 0x42,
+ ];
+ msg.noise.f_timestamp = [
+ 0x4f, 0x1c, 0x60, 0xec, 0x0e, 0xf6, 0x36, 0xf0, 0x78, 0x28, 0x57, 0x42, 0x60, 0x0e,
+ 0x1e, 0x95, 0x41, 0x6b, 0x52, 0x05, 0xa2, 0x09, 0xe1, 0xbf, 0x40, 0x05, 0x2f, 0xde,
+ ];
+ msg.macs.f_mac1 = [
+ 0xf2, 0xad, 0x40, 0xb5, 0xf7, 0xde, 0x77, 0x35, 0x89, 0x19, 0xb7, 0x5c, 0xf9, 0x54,
+ 0x69, 0x29,
+ ];
+ msg.macs.f_mac2 = [
+ 0x4f, 0xd2, 0x1b, 0xfe, 0x77, 0xe6, 0x2e, 0xc9, 0x07, 0xe2, 0x87, 0x17, 0xbb, 0xe5,
+ 0xdf, 0xbb,
+ ];
+
+ let buf: Vec<u8> = msg.as_bytes().to_vec();
+ let msg_p = Initiation::parse(&buf[..]).unwrap();
+ assert_eq!(msg, *msg_p.into_ref());
+ }
+}
diff --git a/src/wireguard/handshake/mod.rs b/src/wireguard/handshake/mod.rs
new file mode 100644
index 0000000..071a41f
--- /dev/null
+++ b/src/wireguard/handshake/mod.rs
@@ -0,0 +1,21 @@
+/* Implementation of the:
+ *
+ * Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s
+ *
+ * Protocol pattern, see: http://www.noiseprotocol.org/noise.html.
+ * For documentation.
+ */
+
+mod device;
+mod macs;
+mod messages;
+mod noise;
+mod peer;
+mod ratelimiter;
+mod timestamp;
+mod types;
+
+// publicly exposed interface
+
+pub use device::Device;
+pub use messages::{MAX_HANDSHAKE_MSG_SIZE, TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE};
diff --git a/src/wireguard/handshake/noise.rs b/src/wireguard/handshake/noise.rs
new file mode 100644
index 0000000..a2a84b0
--- /dev/null
+++ b/src/wireguard/handshake/noise.rs
@@ -0,0 +1,549 @@
+// DH
+use x25519_dalek::PublicKey;
+use x25519_dalek::StaticSecret;
+
+// HASH & MAC
+use blake2::Blake2s;
+use hmac::Hmac;
+
+// AEAD
+use aead::{Aead, NewAead, Payload};
+use chacha20poly1305::ChaCha20Poly1305;
+
+use rand::{CryptoRng, RngCore};
+
+use generic_array::typenum::*;
+use generic_array::*;
+
+use clear_on_drop::clear::Clear;
+use clear_on_drop::clear_stack_on_return;
+
+use subtle::ConstantTimeEq;
+
+use super::device::Device;
+use super::messages::{NoiseInitiation, NoiseResponse};
+use super::messages::{TYPE_INITIATION, TYPE_RESPONSE};
+use super::peer::{Peer, State};
+use super::timestamp;
+use super::types::*;
+
+use super::super::types::{KeyPair, Key};
+
+use std::time::Instant;
+
+// HMAC hasher (generic construction)
+
+type HMACBlake2s = Hmac<Blake2s>;
+
+// convenient alias to pass state temporarily into device.rs and back
+
+type TemporaryState = (u32, PublicKey, GenericArray<u8, U32>, GenericArray<u8, U32>);
+
+const SIZE_CK: usize = 32;
+const SIZE_HS: usize = 32;
+const SIZE_NONCE: usize = 8;
+const SIZE_TAG: usize = 16;
+
+// number of pages to clear after sensitive call
+const CLEAR_PAGES: usize = 1;
+
+// C := Hash(Construction)
+const INITIAL_CK: [u8; SIZE_CK] = [
+ 0x60, 0xe2, 0x6d, 0xae, 0xf3, 0x27, 0xef, 0xc0, 0x2e, 0xc3, 0x35, 0xe2, 0xa0, 0x25, 0xd2, 0xd0,
+ 0x16, 0xeb, 0x42, 0x06, 0xf8, 0x72, 0x77, 0xf5, 0x2d, 0x38, 0xd1, 0x98, 0x8b, 0x78, 0xcd, 0x36,
+];
+
+// H := Hash(C || Identifier)
+const INITIAL_HS: [u8; SIZE_HS] = [
+ 0x22, 0x11, 0xb3, 0x61, 0x08, 0x1a, 0xc5, 0x66, 0x69, 0x12, 0x43, 0xdb, 0x45, 0x8a, 0xd5, 0x32,
+ 0x2d, 0x9c, 0x6c, 0x66, 0x22, 0x93, 0xe8, 0xb7, 0x0e, 0xe1, 0x9c, 0x65, 0xba, 0x07, 0x9e, 0xf3,
+];
+
+const ZERO_NONCE: [u8; 12] = [0u8; 12];
+
+macro_rules! HASH {
+ ( $($input:expr),* ) => {{
+ use blake2::Digest;
+ let mut hsh = Blake2s::new();
+ $(
+ hsh.input($input);
+ )*
+ hsh.result()
+ }};
+}
+
+macro_rules! HMAC {
+ ($key:expr, $($input:expr),*) => {{
+ use hmac::Mac;
+ let mut mac = HMACBlake2s::new_varkey($key).unwrap();
+ $(
+ mac.input($input);
+ )*
+ mac.result().code()
+ }};
+}
+
+macro_rules! KDF1 {
+ ($ck:expr, $input:expr) => {{
+ let mut t0 = HMAC!($ck, $input);
+ let t1 = HMAC!(&t0, &[0x1]);
+ t0.clear();
+ t1
+ }};
+}
+
+macro_rules! KDF2 {
+ ($ck:expr, $input:expr) => {{
+ let mut t0 = HMAC!($ck, $input);
+ let t1 = HMAC!(&t0, &[0x1]);
+ let t2 = HMAC!(&t0, &t1, &[0x2]);
+ t0.clear();
+ (t1, t2)
+ }};
+}
+
+macro_rules! KDF3 {
+ ($ck:expr, $input:expr) => {{
+ let mut t0 = HMAC!($ck, $input);
+ let t1 = HMAC!(&t0, &[0x1]);
+ let t2 = HMAC!(&t0, &t1, &[0x2]);
+ let t3 = HMAC!(&t0, &t2, &[0x3]);
+ t0.clear();
+ (t1, t2, t3)
+ }};
+}
+
+macro_rules! SEAL {
+ ($key:expr, $ad:expr, $pt:expr, $ct:expr) => {
+ ChaCha20Poly1305::new(*GenericArray::from_slice($key))
+ .encrypt(&ZERO_NONCE.into(), Payload { msg: $pt, aad: $ad })
+ .map(|ct| $ct.copy_from_slice(&ct))
+ .unwrap()
+ };
+}
+
+macro_rules! OPEN {
+ ($key:expr, $ad:expr, $pt:expr, $ct:expr) => {
+ ChaCha20Poly1305::new(*GenericArray::from_slice($key))
+ .decrypt(&ZERO_NONCE.into(), Payload { msg: $ct, aad: $ad })
+ .map_err(|_| HandshakeError::DecryptionFailure)
+ .map(|pt| $pt.copy_from_slice(&pt))
+ };
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ const IDENTIFIER: &[u8] = b"WireGuard v1 zx2c4 Jason@zx2c4.com";
+ const CONSTRUCTION: &[u8] = b"Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
+
+ /* Sanity check precomputed initial chain key
+ */
+ #[test]
+ fn precomputed_chain_key() {
+ assert_eq!(INITIAL_CK[..], HASH!(CONSTRUCTION)[..]);
+ }
+
+ /* Sanity check precomputed initial hash transcript
+ */
+ #[test]
+ fn precomputed_hash() {
+ assert_eq!(INITIAL_HS[..], HASH!(INITIAL_CK, IDENTIFIER)[..]);
+ }
+
+ /* Sanity check the HKDF macro
+ *
+ * Test vectors generated using WireGuard-Go
+ */
+ #[test]
+ fn hkdf() {
+ let tests: Vec<(Vec<u8>, Vec<u8>, [u8; 32], [u8; 32], [u8; 32])> = vec![
+ (
+ vec![],
+ vec![],
+ [
+ 0x83, 0x87, 0xb4, 0x6b, 0xf4, 0x3e, 0xcc, 0xfc, 0xf3, 0x49, 0x55, 0x2a, 0x09,
+ 0x5d, 0x83, 0x15, 0xc4, 0x05, 0x5b, 0xeb, 0x90, 0x20, 0x8f, 0xb1, 0xbe, 0x23,
+ 0xb8, 0x94, 0xbc, 0x2e, 0xd5, 0xd0,
+ ],
+ [
+ 0x58, 0xa0, 0xe5, 0xf6, 0xfa, 0xef, 0xcc, 0xf4, 0x80, 0x7b, 0xff, 0x1f, 0x05,
+ 0xfa, 0x8a, 0x92, 0x17, 0x94, 0x57, 0x62, 0x04, 0x0b, 0xce, 0xc2, 0xf4, 0xb4,
+ 0xa6, 0x2b, 0xdf, 0xe0, 0xe8, 0x6e,
+ ],
+ [
+ 0x0c, 0xe6, 0xea, 0x98, 0xec, 0x54, 0x8f, 0x8e, 0x28, 0x1e, 0x93, 0xe3, 0x2d,
+ 0xb6, 0x56, 0x21, 0xc4, 0x5e, 0xb1, 0x8d, 0xc6, 0xf0, 0xa7, 0xad, 0x94, 0x17,
+ 0x86, 0x10, 0xa2, 0xf7, 0x33, 0x8e,
+ ],
+ ),
+ (
+ vec![0xde, 0xad, 0xbe, 0xef],
+ vec![],
+ [
+ 0x55, 0x32, 0x9d, 0xc8, 0x0e, 0x69, 0x0f, 0xd8, 0x6b, 0xd9, 0x66, 0x1f, 0x08,
+ 0x51, 0xc9, 0xb3, 0x68, 0x6d, 0xf2, 0xb1, 0xfd, 0xa0, 0x34, 0x7b, 0xc3, 0xd2,
+ 0x79, 0x58, 0x25, 0x4b, 0x32, 0xc6,
+ ],
+ [
+ 0x8d, 0xfc, 0x6d, 0x33, 0xa8, 0x11, 0x8f, 0xfe, 0x40, 0x8b, 0x31, 0xdd, 0xac,
+ 0x25, 0xf7, 0x2a, 0xee, 0x91, 0x15, 0xa4, 0x5b, 0x69, 0xba, 0x17, 0x6a, 0xd0,
+ 0x12, 0xb2, 0x43, 0x83, 0x4f, 0xee,
+ ],
+ [
+ 0xd6, 0x9e, 0x85, 0x2a, 0x28, 0x96, 0x56, 0x9e, 0xa5, 0x4a, 0x67, 0x96, 0x9a,
+ 0xa1, 0x80, 0x02, 0x87, 0x92, 0x1d, 0xac, 0x53, 0xce, 0x6d, 0xb4, 0xb4, 0xe1,
+ 0x21, 0x92, 0xf2, 0x63, 0xc4, 0xc4,
+ ],
+ ),
+ ];
+
+ for (key, input, t0, t1, t2) in &tests {
+ let tt0 = KDF1!(key, input);
+ debug_assert_eq!(tt0[..], t0[..]);
+
+ let (tt0, tt1) = KDF2!(key, input);
+ debug_assert_eq!(tt0[..], t0[..]);
+ debug_assert_eq!(tt1[..], t1[..]);
+
+ let (tt0, tt1, tt2) = KDF3!(key, input);
+ debug_assert_eq!(tt0[..], t0[..]);
+ debug_assert_eq!(tt1[..], t1[..]);
+ debug_assert_eq!(tt2[..], t2[..]);
+ }
+ }
+}
+
+pub fn create_initiation<R: RngCore + CryptoRng>(
+ rng: &mut R,
+ device: &Device,
+ peer: &Peer,
+ sender: u32,
+ msg: &mut NoiseInitiation,
+) -> Result<(), HandshakeError> {
+ clear_stack_on_return(CLEAR_PAGES, || {
+ // initialize state
+
+ let ck = INITIAL_CK;
+ let hs = INITIAL_HS;
+ let hs = HASH!(&hs, peer.pk.as_bytes());
+
+ msg.f_type.set(TYPE_INITIATION as u32);
+ msg.f_sender.set(sender);
+
+ // (E_priv, E_pub) := DH-Generate()
+
+ let eph_sk = StaticSecret::new(rng);
+ let eph_pk = PublicKey::from(&eph_sk);
+
+ // C := Kdf(C, E_pub)
+
+ let ck = KDF1!(&ck, eph_pk.as_bytes());
+
+ // msg.ephemeral := E_pub
+
+ msg.f_ephemeral = *eph_pk.as_bytes();
+
+ // H := HASH(H, msg.ephemeral)
+
+ let hs = HASH!(&hs, msg.f_ephemeral);
+
+ // (C, k) := Kdf2(C, DH(E_priv, S_pub))
+
+ let (ck, key) = KDF2!(&ck, eph_sk.diffie_hellman(&peer.pk).as_bytes());
+
+ // msg.static := Aead(k, 0, S_pub, H)
+
+ SEAL!(
+ &key,
+ &hs, // ad
+ device.pk.as_bytes(), // pt
+ &mut msg.f_static // ct || tag
+ );
+
+ // H := Hash(H || msg.static)
+
+ let hs = HASH!(&hs, &msg.f_static[..]);
+
+ // (C, k) := Kdf2(C, DH(S_priv, S_pub))
+
+ let (ck, key) = KDF2!(&ck, peer.ss.as_bytes());
+
+ // msg.timestamp := Aead(k, 0, Timestamp(), H)
+
+ SEAL!(
+ &key,
+ &hs, // ad
+ &timestamp::now(), // pt
+ &mut msg.f_timestamp // ct || tag
+ );
+
+ // H := Hash(H || msg.timestamp)
+
+ let hs = HASH!(&hs, &msg.f_timestamp);
+
+ // update state of peer
+
+ *peer.state.lock() = State::InitiationSent {
+ hs,
+ ck,
+ eph_sk,
+ sender,
+ };
+
+ Ok(())
+ })
+}
+
+pub fn consume_initiation<'a>(
+ device: &'a Device,
+ msg: &NoiseInitiation,
+) -> Result<(&'a Peer, TemporaryState), HandshakeError> {
+ clear_stack_on_return(CLEAR_PAGES, || {
+ // initialize new state
+
+ let ck = INITIAL_CK;
+ let hs = INITIAL_HS;
+ let hs = HASH!(&hs, device.pk.as_bytes());
+
+ // C := Kdf(C, E_pub)
+
+ let ck = KDF1!(&ck, &msg.f_ephemeral);
+
+ // H := HASH(H, msg.ephemeral)
+
+ let hs = HASH!(&hs, &msg.f_ephemeral);
+
+ // (C, k) := Kdf2(C, DH(E_priv, S_pub))
+
+ let eph_r_pk = PublicKey::from(msg.f_ephemeral);
+ let (ck, key) = KDF2!(&ck, device.sk.diffie_hellman(&eph_r_pk).as_bytes());
+
+ // msg.static := Aead(k, 0, S_pub, H)
+
+ let mut pk = [0u8; 32];
+
+ OPEN!(
+ &key,
+ &hs, // ad
+ &mut pk, // pt
+ &msg.f_static // ct || tag
+ )?;
+
+ let peer = device.lookup_pk(&PublicKey::from(pk))?;
+
+ // reset initiation state
+
+ *peer.state.lock() = State::Reset;
+
+ // H := Hash(H || msg.static)
+
+ let hs = HASH!(&hs, &msg.f_static[..]);
+
+ // (C, k) := Kdf2(C, DH(S_priv, S_pub))
+
+ let (ck, key) = KDF2!(&ck, peer.ss.as_bytes());
+
+ // msg.timestamp := Aead(k, 0, Timestamp(), H)
+
+ let mut ts = timestamp::ZERO;
+
+ OPEN!(
+ &key,
+ &hs, // ad
+ &mut ts, // pt
+ &msg.f_timestamp // ct || tag
+ )?;
+
+ // check and update timestamp
+
+ peer.check_replay_flood(device, &ts)?;
+
+ // H := Hash(H || msg.timestamp)
+
+ let hs = HASH!(&hs, &msg.f_timestamp);
+
+ // return state (to create response)
+
+ Ok((peer, (msg.f_sender.get(), eph_r_pk, hs, ck)))
+ })
+}
+
+pub fn create_response<R: RngCore + CryptoRng>(
+ rng: &mut R,
+ peer: &Peer,
+ sender: u32, // sending identifier
+ state: TemporaryState, // state from "consume_initiation"
+ msg: &mut NoiseResponse, // resulting response
+) -> Result<KeyPair, HandshakeError> {
+ clear_stack_on_return(CLEAR_PAGES, || {
+ // unpack state
+
+ let (receiver, eph_r_pk, hs, ck) = state;
+
+ msg.f_type.set(TYPE_RESPONSE as u32);
+ msg.f_sender.set(sender);
+ msg.f_receiver.set(receiver);
+
+ // (E_priv, E_pub) := DH-Generate()
+
+ let eph_sk = StaticSecret::new(rng);
+ let eph_pk = PublicKey::from(&eph_sk);
+
+ // C := Kdf1(C, E_pub)
+
+ let ck = KDF1!(&ck, eph_pk.as_bytes());
+
+ // msg.ephemeral := E_pub
+
+ msg.f_ephemeral = *eph_pk.as_bytes();
+
+ // H := Hash(H || msg.ephemeral)
+
+ let hs = HASH!(&hs, &msg.f_ephemeral);
+
+ // C := Kdf1(C, DH(E_priv, E_pub))
+
+ let ck = KDF1!(&ck, eph_sk.diffie_hellman(&eph_r_pk).as_bytes());
+
+ // C := Kdf1(C, DH(E_priv, S_pub))
+
+ let ck = KDF1!(&ck, eph_sk.diffie_hellman(&peer.pk).as_bytes());
+
+ // (C, tau, k) := Kdf3(C, Q)
+
+ let (ck, tau, key) = KDF3!(&ck, &peer.psk);
+
+ // H := Hash(H || tau)
+
+ let hs = HASH!(&hs, tau);
+
+ // msg.empty := Aead(k, 0, [], H)
+
+ SEAL!(
+ &key,
+ &hs, // ad
+ &[], // pt
+ &mut msg.f_empty // \epsilon || tag
+ );
+
+ // Not strictly needed
+ // let hs = HASH!(&hs, &msg.f_empty_tag);
+
+ // derive key-pair
+
+ let (key_recv, key_send) = KDF2!(&ck, &[]);
+
+ // return unconfirmed key-pair
+
+ Ok(KeyPair {
+ birth: Instant::now(),
+ initiator: false,
+ send: Key {
+ id: sender,
+ key: key_send.into(),
+ },
+ recv: Key {
+ id: receiver,
+ key: key_recv.into(),
+ },
+ })
+ })
+}
+
+/* The state lock is released while processing the message to
+ * allow concurrent processing of potential responses to the initiation,
+ * in order to better mitigate DoS from malformed response messages.
+ */
+pub fn consume_response(device: &Device, msg: &NoiseResponse) -> Result<Output, HandshakeError> {
+ clear_stack_on_return(CLEAR_PAGES, || {
+ // retrieve peer and copy initiation state
+ let peer = device.lookup_id(msg.f_receiver.get())?;
+
+ let (hs, ck, sender, eph_sk) = match *peer.state.lock() {
+ State::InitiationSent {
+ hs,
+ ck,
+ sender,
+ ref eph_sk,
+ } => Ok((hs, ck, sender, StaticSecret::from(eph_sk.to_bytes()))),
+ _ => Err(HandshakeError::InvalidState),
+ }?;
+
+ // C := Kdf1(C, E_pub)
+
+ let ck = KDF1!(&ck, &msg.f_ephemeral);
+
+ // H := Hash(H || msg.ephemeral)
+
+ let hs = HASH!(&hs, &msg.f_ephemeral);
+
+ // C := Kdf1(C, DH(E_priv, E_pub))
+
+ let eph_r_pk = PublicKey::from(msg.f_ephemeral);
+ let ck = KDF1!(&ck, eph_sk.diffie_hellman(&eph_r_pk).as_bytes());
+
+ // C := Kdf1(C, DH(E_priv, S_pub))
+
+ let ck = KDF1!(&ck, device.sk.diffie_hellman(&eph_r_pk).as_bytes());
+
+ // (C, tau, k) := Kdf3(C, Q)
+
+ let (ck, tau, key) = KDF3!(&ck, &peer.psk);
+
+ // H := Hash(H || tau)
+
+ let hs = HASH!(&hs, tau);
+
+ // msg.empty := Aead(k, 0, [], H)
+
+ OPEN!(
+ &key,
+ &hs, // ad
+ &mut [], // pt
+ &msg.f_empty // \epsilon || tag
+ )?;
+
+ // derive key-pair
+
+ let birth = Instant::now();
+ let (key_send, key_recv) = KDF2!(&ck, &[]);
+
+ // check for new initiation sent while lock released
+
+ let mut state = peer.state.lock();
+ let update = match *state {
+ State::InitiationSent {
+ eph_sk: ref old, ..
+ } => old.to_bytes().ct_eq(&eph_sk.to_bytes()).into(),
+ _ => false,
+ };
+
+ if update {
+ // null the initiation state
+ // (to avoid replay of this response message)
+ *state = State::Reset;
+
+ // return confirmed key-pair
+ Ok((
+ Some(peer.pk),
+ None,
+ Some(KeyPair {
+ birth,
+ initiator: true,
+ send: Key {
+ id: sender,
+ key: key_send.into(),
+ },
+ recv: Key {
+ id: msg.f_sender.get(),
+ key: key_recv.into(),
+ },
+ }),
+ ))
+ } else {
+ Err(HandshakeError::InvalidState)
+ }
+ })
+}
diff --git a/src/wireguard/handshake/peer.rs b/src/wireguard/handshake/peer.rs
new file mode 100644
index 0000000..c9e1c40
--- /dev/null
+++ b/src/wireguard/handshake/peer.rs
@@ -0,0 +1,142 @@
+use spin::Mutex;
+
+use std::mem;
+use std::time::{Duration, Instant};
+
+use generic_array::typenum::U32;
+use generic_array::GenericArray;
+
+use x25519_dalek::PublicKey;
+use x25519_dalek::SharedSecret;
+use x25519_dalek::StaticSecret;
+
+use clear_on_drop::clear::Clear;
+
+use super::device::Device;
+use super::macs;
+use super::timestamp;
+use super::types::*;
+
+const TIME_BETWEEN_INITIATIONS: Duration = Duration::from_millis(20);
+
+/* Represents the recomputation and state of a peer.
+ *
+ * This type is only for internal use and not exposed.
+ */
+pub struct Peer {
+ // mutable state
+ pub(crate) state: Mutex<State>,
+ pub(crate) timestamp: Mutex<Option<timestamp::TAI64N>>,
+ pub(crate) last_initiation_consumption: Mutex<Option<Instant>>,
+
+ // state related to DoS mitigation fields
+ pub(crate) macs: Mutex<macs::Generator>,
+
+ // constant state
+ pub(crate) pk: PublicKey, // public key of peer
+ pub(crate) ss: SharedSecret, // precomputed DH(static, static)
+ pub(crate) psk: Psk, // psk of peer
+}
+
+pub enum State {
+ Reset,
+ InitiationSent {
+ sender: u32, // assigned sender id
+ eph_sk: StaticSecret,
+ hs: GenericArray<u8, U32>,
+ ck: GenericArray<u8, U32>,
+ },
+}
+
+impl Drop for State {
+ fn drop(&mut self) {
+ match self {
+ State::InitiationSent { hs, ck, .. } => {
+ // eph_sk already cleared by dalek-x25519
+ hs.clear();
+ ck.clear();
+ }
+ _ => (),
+ }
+ }
+}
+
+impl Peer {
+ pub fn new(
+ pk: PublicKey, // public key of peer
+ ss: SharedSecret, // precomputed DH(static, static)
+ ) -> Self {
+ Self {
+ macs: Mutex::new(macs::Generator::new(pk)),
+ state: Mutex::new(State::Reset),
+ timestamp: Mutex::new(None),
+ last_initiation_consumption: Mutex::new(None),
+ pk: pk,
+ ss: ss,
+ psk: [0u8; 32],
+ }
+ }
+
+ /// Set the state of the peer unconditionally
+ ///
+ /// # Arguments
+ ///
+ pub fn set_state(&self, state_new: State) {
+ *self.state.lock() = state_new;
+ }
+
+ pub fn reset_state(&self) -> Option<u32> {
+ match mem::replace(&mut *self.state.lock(), State::Reset) {
+ State::InitiationSent { sender, .. } => Some(sender),
+ _ => None,
+ }
+ }
+
+ /// Set the mutable state of the peer conditioned on the timestamp being newer
+ ///
+ /// # Arguments
+ ///
+ /// * st_new - The updated state of the peer
+ /// * ts_new - The associated timestamp
+ pub fn check_replay_flood(
+ &self,
+ device: &Device,
+ timestamp_new: &timestamp::TAI64N,
+ ) -> Result<(), HandshakeError> {
+ let mut state = self.state.lock();
+ let mut timestamp = self.timestamp.lock();
+ let mut last_initiation_consumption = self.last_initiation_consumption.lock();
+
+ // check replay attack
+ match *timestamp {
+ Some(timestamp_old) => {
+ if !timestamp::compare(&timestamp_old, &timestamp_new) {
+ return Err(HandshakeError::OldTimestamp);
+ }
+ }
+ _ => (),
+ };
+
+ // check flood attack
+ match *last_initiation_consumption {
+ Some(last) => {
+ if last.elapsed() < TIME_BETWEEN_INITIATIONS {
+ return Err(HandshakeError::InitiationFlood);
+ }
+ }
+ _ => (),
+ }
+
+ // reset state
+ match *state {
+ State::InitiationSent { sender, .. } => device.release(sender),
+ _ => (),
+ }
+
+ // update replay & flood protection
+ *state = State::Reset;
+ *timestamp = Some(*timestamp_new);
+ *last_initiation_consumption = Some(Instant::now());
+ Ok(())
+ }
+}
diff --git a/src/wireguard/handshake/ratelimiter.rs b/src/wireguard/handshake/ratelimiter.rs
new file mode 100644
index 0000000..63d728c
--- /dev/null
+++ b/src/wireguard/handshake/ratelimiter.rs
@@ -0,0 +1,199 @@
+use spin;
+use std::collections::HashMap;
+use std::net::IpAddr;
+use std::sync::atomic::{AtomicBool, Ordering};
+use std::sync::{Arc, Condvar, Mutex};
+use std::thread;
+use std::time::{Duration, Instant};
+
+const PACKETS_PER_SECOND: u64 = 20;
+const PACKETS_BURSTABLE: u64 = 5;
+const PACKET_COST: u64 = 1_000_000_000 / PACKETS_PER_SECOND;
+const MAX_TOKENS: u64 = PACKET_COST * PACKETS_BURSTABLE;
+
+const GC_INTERVAL: Duration = Duration::from_secs(1);
+
+struct Entry {
+ pub last_time: Instant,
+ pub tokens: u64,
+}
+
+pub struct RateLimiter(Arc<RateLimiterInner>);
+
+struct RateLimiterInner {
+ gc_running: AtomicBool,
+ gc_dropped: (Mutex<bool>, Condvar),
+ table: spin::RwLock<HashMap<IpAddr, spin::Mutex<Entry>>>,
+}
+
+impl Drop for RateLimiter {
+ fn drop(&mut self) {
+ // wake up & terminate any lingering GC thread
+ let &(ref lock, ref cvar) = &self.0.gc_dropped;
+ let mut dropped = lock.lock().unwrap();
+ *dropped = true;
+ cvar.notify_all();
+ }
+}
+
+impl RateLimiter {
+ pub fn new() -> Self {
+ RateLimiter(Arc::new(RateLimiterInner {
+ gc_dropped: (Mutex::new(false), Condvar::new()),
+ gc_running: AtomicBool::from(false),
+ table: spin::RwLock::new(HashMap::new()),
+ }))
+ }
+
+ pub fn allow(&self, addr: &IpAddr) -> bool {
+ // check if allowed
+ let allowed = {
+ // check for existing entry (only requires read lock)
+ if let Some(entry) = self.0.table.read().get(addr) {
+ // update existing entry
+ let mut entry = entry.lock();
+
+ // add tokens earned since last time
+ entry.tokens = MAX_TOKENS
+ .min(entry.tokens + u64::from(entry.last_time.elapsed().subsec_nanos()));
+ entry.last_time = Instant::now();
+
+ // subtract cost of packet
+ if entry.tokens > PACKET_COST {
+ entry.tokens -= PACKET_COST;
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ // add new entry (write lock)
+ self.0.table.write().insert(
+ *addr,
+ spin::Mutex::new(Entry {
+ last_time: Instant::now(),
+ tokens: MAX_TOKENS - PACKET_COST,
+ }),
+ );
+ true
+ };
+
+ // check that GC thread is scheduled
+ if !self.0.gc_running.swap(true, Ordering::Relaxed) {
+ let limiter = self.0.clone();
+ thread::spawn(move || {
+ let &(ref lock, ref cvar) = &limiter.gc_dropped;
+ let mut dropped = lock.lock().unwrap();
+ while !*dropped {
+ // garbage collect
+ {
+ let mut tw = limiter.table.write();
+ tw.retain(|_, ref mut entry| {
+ entry.lock().last_time.elapsed() <= GC_INTERVAL
+ });
+ if tw.len() == 0 {
+ limiter.gc_running.store(false, Ordering::Relaxed);
+ return;
+ }
+ }
+
+ // wait until stopped or new GC (~1 every sec)
+ let res = cvar.wait_timeout(dropped, GC_INTERVAL).unwrap();
+ dropped = res.0;
+ }
+ });
+ }
+
+ allowed
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std;
+
+ struct Result {
+ allowed: bool,
+ text: &'static str,
+ wait: Duration,
+ }
+
+ #[test]
+ fn test_ratelimiter() {
+ let ratelimiter = RateLimiter::new();
+ let mut expected = vec![];
+ let ips = vec![
+ "127.0.0.1".parse().unwrap(),
+ "192.168.1.1".parse().unwrap(),
+ "172.167.2.3".parse().unwrap(),
+ "97.231.252.215".parse().unwrap(),
+ "248.97.91.167".parse().unwrap(),
+ "188.208.233.47".parse().unwrap(),
+ "104.2.183.179".parse().unwrap(),
+ "72.129.46.120".parse().unwrap(),
+ "2001:0db8:0a0b:12f0:0000:0000:0000:0001".parse().unwrap(),
+ "f5c2:818f:c052:655a:9860:b136:6894:25f0".parse().unwrap(),
+ "b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc".parse().unwrap(),
+ "a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918".parse().unwrap(),
+ "ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445".parse().unwrap(),
+ "3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4".parse().unwrap(),
+ ];
+
+ for _ in 0..PACKETS_BURSTABLE {
+ expected.push(Result {
+ allowed: true,
+ wait: Duration::new(0, 0),
+ text: "inital burst",
+ });
+ }
+
+ expected.push(Result {
+ allowed: false,
+ wait: Duration::new(0, 0),
+ text: "after burst",
+ });
+
+ expected.push(Result {
+ allowed: true,
+ wait: Duration::new(0, PACKET_COST as u32),
+ text: "filling tokens for single packet",
+ });
+
+ expected.push(Result {
+ allowed: false,
+ wait: Duration::new(0, 0),
+ text: "not having refilled enough",
+ });
+
+ expected.push(Result {
+ allowed: true,
+ wait: Duration::new(0, 2 * PACKET_COST as u32),
+ text: "filling tokens for 2 * packet burst",
+ });
+
+ expected.push(Result {
+ allowed: true,
+ wait: Duration::new(0, 0),
+ text: "second packet in 2 packet burst",
+ });
+
+ expected.push(Result {
+ allowed: false,
+ wait: Duration::new(0, 0),
+ text: "packet following 2 packet burst",
+ });
+
+ for item in expected {
+ std::thread::sleep(item.wait);
+ for ip in ips.iter() {
+ if ratelimiter.allow(&ip) != item.allowed {
+ panic!(
+ "test failed for {} on {}. expected: {}, got: {}",
+ ip, item.text, item.allowed, !item.allowed
+ )
+ }
+ }
+ }
+ }
+}
diff --git a/src/wireguard/handshake/timestamp.rs b/src/wireguard/handshake/timestamp.rs
new file mode 100644
index 0000000..b5bd9f0
--- /dev/null
+++ b/src/wireguard/handshake/timestamp.rs
@@ -0,0 +1,32 @@
+use std::time::{SystemTime, UNIX_EPOCH};
+
+pub type TAI64N = [u8; 12];
+
+const TAI64_EPOCH: u64 = 0x400000000000000a;
+
+pub const ZERO: TAI64N = [0u8; 12];
+
+pub fn now() -> TAI64N {
+ // get system time as duration
+ let sysnow = SystemTime::now();
+ let delta = sysnow.duration_since(UNIX_EPOCH).unwrap();
+
+ // convert to tai64n
+ let tai64_secs = delta.as_secs() + TAI64_EPOCH;
+ let tai64_nano = delta.subsec_nanos();
+
+ // serialize
+ let mut res = [0u8; 12];
+ res[..8].copy_from_slice(&tai64_secs.to_be_bytes()[..]);
+ res[8..].copy_from_slice(&tai64_nano.to_be_bytes()[..]);
+ res
+}
+
+pub fn compare(old: &TAI64N, new: &TAI64N) -> bool {
+ for i in 0..12 {
+ if new[i] > old[i] {
+ return true;
+ }
+ }
+ return false;
+}
diff --git a/src/wireguard/handshake/types.rs b/src/wireguard/handshake/types.rs
new file mode 100644
index 0000000..5f984cc
--- /dev/null
+++ b/src/wireguard/handshake/types.rs
@@ -0,0 +1,90 @@
+use std::error::Error;
+use std::fmt;
+
+use x25519_dalek::PublicKey;
+
+use super::super::types::KeyPair;
+
+/* Internal types for the noise IKpsk2 implementation */
+
+// config error
+
+#[derive(Debug)]
+pub struct ConfigError(String);
+
+impl ConfigError {
+ pub fn new(s: &str) -> Self {
+ ConfigError(s.to_string())
+ }
+}
+
+impl fmt::Display for ConfigError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "ConfigError({})", self.0)
+ }
+}
+
+impl Error for ConfigError {
+ fn description(&self) -> &str {
+ &self.0
+ }
+
+ fn source(&self) -> Option<&(dyn Error + 'static)> {
+ None
+ }
+}
+
+// handshake error
+
+#[derive(Debug)]
+pub enum HandshakeError {
+ DecryptionFailure,
+ UnknownPublicKey,
+ UnknownReceiverId,
+ InvalidMessageFormat,
+ OldTimestamp,
+ InvalidState,
+ InvalidMac1,
+ RateLimited,
+ InitiationFlood,
+}
+
+impl fmt::Display for HandshakeError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ HandshakeError::DecryptionFailure => write!(f, "Failed to AEAD:OPEN"),
+ HandshakeError::UnknownPublicKey => write!(f, "Unknown public key"),
+ HandshakeError::UnknownReceiverId => {
+ write!(f, "Receiver id not allocated to any handshake")
+ }
+ HandshakeError::InvalidMessageFormat => write!(f, "Invalid handshake message format"),
+ HandshakeError::OldTimestamp => write!(f, "Timestamp is less/equal to the newest"),
+ HandshakeError::InvalidState => write!(f, "Message does not apply to handshake state"),
+ HandshakeError::InvalidMac1 => write!(f, "Message has invalid mac1 field"),
+ HandshakeError::RateLimited => write!(f, "Message was dropped by rate limiter"),
+ HandshakeError::InitiationFlood => {
+ write!(f, "Message was dropped because of initiation flood")
+ }
+ }
+ }
+}
+
+impl Error for HandshakeError {
+ fn description(&self) -> &str {
+ "Generic Handshake Error"
+ }
+
+ fn source(&self) -> Option<&(dyn Error + 'static)> {
+ None
+ }
+}
+
+pub type Output = (
+ Option<PublicKey>, // external identifier associated with peer
+ Option<Vec<u8>>, // message to send
+ Option<KeyPair>, // resulting key-pair of successful handshake
+);
+
+// preshared key
+
+pub type Psk = [u8; 32];