aboutsummaryrefslogtreecommitdiffstats
path: root/src/wireguard
diff options
context:
space:
mode:
Diffstat (limited to 'src/wireguard')
-rw-r--r--src/wireguard/config.rs186
-rw-r--r--src/wireguard/constants.rs20
-rw-r--r--src/wireguard/handshake/device.rs574
-rw-r--r--src/wireguard/handshake/macs.rs327
-rw-r--r--src/wireguard/handshake/messages.rs363
-rw-r--r--src/wireguard/handshake/mod.rs21
-rw-r--r--src/wireguard/handshake/noise.rs549
-rw-r--r--src/wireguard/handshake/peer.rs142
-rw-r--r--src/wireguard/handshake/ratelimiter.rs199
-rw-r--r--src/wireguard/handshake/timestamp.rs32
-rw-r--r--src/wireguard/handshake/types.rs90
-rw-r--r--src/wireguard/mod.rs23
-rw-r--r--src/wireguard/router/anti_replay.rs157
-rw-r--r--src/wireguard/router/constants.rs7
-rw-r--r--src/wireguard/router/device.rs243
-rw-r--r--src/wireguard/router/ip.rs26
-rw-r--r--src/wireguard/router/messages.rs13
-rw-r--r--src/wireguard/router/mod.rs22
-rw-r--r--src/wireguard/router/peer.rs611
-rw-r--r--src/wireguard/router/tests.rs432
-rw-r--r--src/wireguard/router/types.rs65
-rw-r--r--src/wireguard/router/workers.rs305
-rw-r--r--src/wireguard/tests.rs46
-rw-r--r--src/wireguard/timers.rs234
-rw-r--r--src/wireguard/types/bind.rs23
-rw-r--r--src/wireguard/types/dummy.rs323
-rw-r--r--src/wireguard/types/endpoint.rs7
-rw-r--r--src/wireguard/types/keys.rs36
-rw-r--r--src/wireguard/types/mod.rs10
-rw-r--r--src/wireguard/types/tun.rs56
-rw-r--r--src/wireguard/wireguard.rs407
31 files changed, 5549 insertions, 0 deletions
diff --git a/src/wireguard/config.rs b/src/wireguard/config.rs
new file mode 100644
index 0000000..0f2953d
--- /dev/null
+++ b/src/wireguard/config.rs
@@ -0,0 +1,186 @@
+use std::net::{IpAddr, SocketAddr};
+use x25519_dalek::{PublicKey, StaticSecret};
+
+use super::wireguard::Wireguard;
+use super::types::bind::Bind;
+use super::types::tun::Tun;
+
+/// The goal of the configuration interface is, among others,
+/// to hide the IO implementations (over which the WG device is generic),
+/// from the configuration and UAPI code.
+
+/// Describes a snapshot of the state of a peer
+pub struct PeerState {
+ rx_bytes: u64,
+ tx_bytes: u64,
+ last_handshake_time_sec: u64,
+ last_handshake_time_nsec: u64,
+ public_key: PublicKey,
+ allowed_ips: Vec<(IpAddr, u32)>,
+}
+
+pub enum ConfigError {
+ NoSuchPeer
+}
+
+impl ConfigError {
+
+ fn errno(&self) -> i32 {
+ match self {
+ NoSuchPeer => 1,
+ }
+ }
+}
+
+/// Exposed configuration interface
+pub trait Configuration {
+ /// Updates the private key of the device
+ ///
+ /// # Arguments
+ ///
+ /// - `sk`: The new private key (or None, if the private key should be cleared)
+ fn set_private_key(&self, sk: Option<StaticSecret>);
+
+ /// Returns the private key of the device
+ ///
+ /// # Returns
+ ///
+ /// The private if set, otherwise None.
+ fn get_private_key(&self) -> Option<StaticSecret>;
+
+ /// Returns the protocol version of the device
+ ///
+ /// # Returns
+ ///
+ /// An integer indicating the protocol version
+ fn get_protocol_version(&self) -> usize;
+
+ fn set_listen_port(&self, port: u16) -> Option<ConfigError>;
+
+ /// Set the firewall mark (or similar, depending on platform)
+ ///
+ /// # Arguments
+ ///
+ /// - `mark`: The fwmark value
+ ///
+ /// # Returns
+ ///
+ /// An error if this operation is not supported by the underlying
+ /// "bind" implementation.
+ fn set_fwmark(&self, mark: Option<u32>) -> Option<ConfigError>;
+
+ /// Removes all peers from the device
+ fn replace_peers(&self);
+
+ /// Remove the peer from the
+ ///
+ /// # Arguments
+ ///
+ /// - `peer`: The public key of the peer to remove
+ ///
+ /// # Returns
+ ///
+ /// If the peer does not exists this operation is a noop
+ fn remove_peer(&self, peer: PublicKey);
+
+ /// Adds a new peer to the device
+ ///
+ /// # Arguments
+ ///
+ /// - `peer`: The public key of the peer to add
+ ///
+ /// # Returns
+ ///
+ /// A bool indicating if the peer was added.
+ ///
+ /// If the peer already exists this operation is a noop
+ fn add_peer(&self, peer: PublicKey) -> bool;
+
+ /// Update the psk of a peer
+ ///
+ /// # Arguments
+ ///
+ /// - `peer`: The public key of the peer
+ /// - `psk`: The new psk or None if the psk should be unset
+ ///
+ /// # Returns
+ ///
+ /// An error if no such peer exists
+ fn set_preshared_key(&self, peer: PublicKey, psk: Option<[u8; 32]>) -> Option<ConfigError>;
+
+ /// Update the endpoint of the
+ ///
+ /// # Arguments
+ ///
+ /// - `peer': The public key of the peer
+ /// - `psk`
+ fn set_endpoint(&self, peer: PublicKey, addr: SocketAddr) -> Option<ConfigError>;
+
+ /// Update the endpoint of the
+ ///
+ /// # Arguments
+ ///
+ /// - `peer': The public key of the peer
+ /// - `psk`
+ fn set_persistent_keepalive_interval(&self, peer: PublicKey) -> Option<ConfigError>;
+
+ /// Remove all allowed IPs from the peer
+ ///
+ /// # Arguments
+ ///
+ /// - `peer': The public key of the peer
+ ///
+ /// # Returns
+ ///
+ /// An error if no such peer exists
+ fn replace_allowed_ips(&self, peer: PublicKey) -> Option<ConfigError>;
+
+ /// Add a new allowed subnet to the peer
+ ///
+ /// # Arguments
+ ///
+ /// - `peer`: The public key of the peer
+ /// - `ip`: Subnet mask
+ /// - `masklen`:
+ ///
+ /// # Returns
+ ///
+ /// An error if the peer does not exist
+ ///
+ /// # Note:
+ ///
+ /// The API must itself sanitize the (ip, masklen) set:
+ /// The ip should be masked to remove any set bits right of the first "masklen" bits.
+ fn add_allowed_ip(&self, peer: PublicKey, ip: IpAddr, masklen: u32) -> Option<ConfigError>;
+
+ /// Returns the state of all peers
+ ///
+ /// # Returns
+ ///
+ /// A list of structures describing the state of each peer
+ fn get_peers(&self) -> Vec<PeerState>;
+}
+
+impl <T : Tun, B : Bind>Configuration for Wireguard<T, B> {
+
+ fn set_private_key(&self, sk : Option<StaticSecret>) {
+ self.set_key(sk)
+ }
+
+ fn get_private_key(&self) -> Option<StaticSecret> {
+ self.get_sk()
+ }
+
+ fn get_protocol_version(&self) -> usize {
+ 1
+ }
+
+ fn set_listen_port(&self, port : u16) -> Option<ConfigError> {
+ None
+ }
+
+ fn set_fwmark(&self, mark: Option<u32>) -> Option<ConfigError> {
+ None
+ }
+
+} \ No newline at end of file
diff --git a/src/wireguard/constants.rs b/src/wireguard/constants.rs
new file mode 100644
index 0000000..72de8d9
--- /dev/null
+++ b/src/wireguard/constants.rs
@@ -0,0 +1,20 @@
+use std::time::Duration;
+use std::u64;
+
+pub const REKEY_AFTER_MESSAGES: u64 = u64::MAX - (1 << 16);
+pub const REJECT_AFTER_MESSAGES: u64 = u64::MAX - (1 << 4);
+
+pub const REKEY_AFTER_TIME: Duration = Duration::from_secs(120);
+pub const REJECT_AFTER_TIME: Duration = Duration::from_secs(180);
+pub const REKEY_ATTEMPT_TIME: Duration = Duration::from_secs(90);
+pub const REKEY_TIMEOUT: Duration = Duration::from_secs(5);
+pub const KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10);
+
+pub const MAX_TIMER_HANDSHAKES: usize = 18;
+
+pub const TIMER_MAX_DURATION: Duration = Duration::from_secs(200);
+pub const TIMERS_TICK: Duration = Duration::from_millis(100);
+pub const TIMERS_SLOTS: usize = (TIMER_MAX_DURATION.as_micros() / TIMERS_TICK.as_micros()) as usize;
+pub const TIMERS_CAPACITY: usize = 1024;
+
+pub const MESSAGE_PADDING_MULTIPLE: usize = 16;
diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs
new file mode 100644
index 0000000..6a55f6e
--- /dev/null
+++ b/src/wireguard/handshake/device.rs
@@ -0,0 +1,574 @@
+use spin::RwLock;
+use std::collections::HashMap;
+use std::net::SocketAddr;
+use std::sync::Mutex;
+use zerocopy::AsBytes;
+
+use byteorder::{ByteOrder, LittleEndian};
+
+use rand::prelude::*;
+
+use x25519_dalek::PublicKey;
+use x25519_dalek::StaticSecret;
+
+use super::macs;
+use super::messages::{CookieReply, Initiation, Response};
+use super::messages::{TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE};
+use super::noise;
+use super::peer::Peer;
+use super::ratelimiter::RateLimiter;
+use super::types::*;
+
+const MAX_PEER_PER_DEVICE: usize = 1 << 20;
+
+pub struct Device {
+ pub sk: StaticSecret, // static secret key
+ pub pk: PublicKey, // static public key
+ macs: macs::Validator, // validator for the mac fields
+ pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state
+ id_map: RwLock<HashMap<u32, [u8; 32]>>, // receiver ids -> public key
+ limiter: Mutex<RateLimiter>,
+}
+
+/* A mutable reference to the device needs to be held during configuration.
+ * Wrapping the device in a RwLock enables peer config after "configuration time"
+ */
+impl Device {
+ /// Initialize a new handshake state machine
+ ///
+ /// # Arguments
+ ///
+ /// * `sk` - x25519 scalar representing the local private key
+ pub fn new(sk: StaticSecret) -> Device {
+ let pk = PublicKey::from(&sk);
+ Device {
+ pk,
+ sk,
+ macs: macs::Validator::new(pk),
+ pk_map: HashMap::new(),
+ id_map: RwLock::new(HashMap::new()),
+ limiter: Mutex::new(RateLimiter::new()),
+ }
+ }
+
+ /// Update the secret key of the device
+ ///
+ /// # Arguments
+ ///
+ /// * `sk` - x25519 scalar representing the local private key
+ pub fn set_sk(&mut self, sk: StaticSecret) {
+ // update secret and public key
+ let pk = PublicKey::from(&sk);
+ self.sk = sk;
+ self.pk = pk;
+ self.macs = macs::Validator::new(pk);
+
+ // recalculate the shared secrets for every peer
+ let mut ids = vec![];
+ for mut peer in self.pk_map.values_mut() {
+ peer.reset_state().map(|id| ids.push(id));
+ peer.ss = self.sk.diffie_hellman(&peer.pk)
+ }
+
+ // release ids from aborted handshakes
+ for id in ids {
+ self.release(id)
+ }
+ }
+
+ /// Return the secret key of the device
+ ///
+ /// # Returns
+ ///
+ /// A secret key (x25519 scalar)
+ pub fn get_sk(&self) -> StaticSecret {
+ StaticSecret::from(self.sk.to_bytes())
+ }
+
+ /// Add a new public key to the state machine
+ /// To remove public keys, you must create a new machine instance
+ ///
+ /// # Arguments
+ ///
+ /// * `pk` - The public key to add
+ /// * `identifier` - Associated identifier which can be used to distinguish the peers
+ pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> {
+ // check that the pk is not added twice
+ if let Some(_) = self.pk_map.get(pk.as_bytes()) {
+ return Err(ConfigError::new("Duplicate public key"));
+ };
+
+ // check that the pk is not that of the device
+ if *self.pk.as_bytes() == *pk.as_bytes() {
+ return Err(ConfigError::new(
+ "Public key corresponds to secret key of interface",
+ ));
+ }
+
+ // ensure less than 2^20 peers
+ if self.pk_map.len() > MAX_PEER_PER_DEVICE {
+ return Err(ConfigError::new("Too many peers for device"));
+ }
+
+ // map the public key to the peer state
+ self.pk_map
+ .insert(*pk.as_bytes(), Peer::new(pk, self.sk.diffie_hellman(&pk)));
+
+ Ok(())
+ }
+
+ /// Remove a peer by public key
+ /// To remove public keys, you must create a new machine instance
+ ///
+ /// # Arguments
+ ///
+ /// * `pk` - The public key of the peer to remove
+ ///
+ /// # Returns
+ ///
+ /// The call might fail if the public key is not found
+ pub fn remove(&mut self, pk: PublicKey) -> Result<(), ConfigError> {
+ // take write-lock on receive id table
+ let mut id_map = self.id_map.write();
+
+ // remove the peer
+ self.pk_map
+ .remove(pk.as_bytes())
+ .ok_or(ConfigError::new("Public key not in device"))?;
+
+ // pruge the id map (linear scan)
+ id_map.retain(|_, v| v != pk.as_bytes());
+ Ok(())
+ }
+
+ /// Add a psk to the peer
+ ///
+ /// # Arguments
+ ///
+ /// * `pk` - The public key of the peer
+ /// * `psk` - The psk to set / unset
+ ///
+ /// # Returns
+ ///
+ /// The call might fail if the public key is not found
+ pub fn set_psk(&mut self, pk: PublicKey, psk: Option<Psk>) -> Result<(), ConfigError> {
+ match self.pk_map.get_mut(pk.as_bytes()) {
+ Some(mut peer) => {
+ peer.psk = match psk {
+ Some(v) => v,
+ None => [0u8; 32],
+ };
+ Ok(())
+ }
+ _ => Err(ConfigError::new("No such public key")),
+ }
+ }
+
+ /// Return the psk for the peer
+ ///
+ /// # Arguments
+ ///
+ /// * `pk` - The public key of the peer
+ ///
+ /// # Returns
+ ///
+ /// A 32 byte array holding the PSK
+ ///
+ /// The call might fail if the public key is not found
+ pub fn get_psk(&self, pk: PublicKey) -> Result<Psk, ConfigError> {
+ match self.pk_map.get(pk.as_bytes()) {
+ Some(peer) => Ok(peer.psk),
+ _ => Err(ConfigError::new("No such public key")),
+ }
+ }
+
+ /// Release an id back to the pool
+ ///
+ /// # Arguments
+ ///
+ /// * `id` - The (sender) id to release
+ pub fn release(&self, id: u32) {
+ let mut m = self.id_map.write();
+ debug_assert!(m.contains_key(&id), "Releasing id not allocated");
+ m.remove(&id);
+ }
+
+ /// Begin a new handshake
+ ///
+ /// # Arguments
+ ///
+ /// * `pk` - Public key of peer to initiate handshake for
+ pub fn begin<R: RngCore + CryptoRng>(
+ &self,
+ rng: &mut R,
+ pk: &PublicKey,
+ ) -> Result<Vec<u8>, HandshakeError> {
+ match self.pk_map.get(pk.as_bytes()) {
+ None => Err(HandshakeError::UnknownPublicKey),
+ Some(peer) => {
+ let sender = self.allocate(rng, peer);
+
+ let mut msg = Initiation::default();
+
+ noise::create_initiation(rng, self, peer, sender, &mut msg.noise)?;
+
+ // add macs to initation
+
+ peer.macs
+ .lock()
+ .generate(msg.noise.as_bytes(), &mut msg.macs);
+
+ Ok(msg.as_bytes().to_owned())
+ }
+ }
+ }
+
+ /// Process a handshake message.
+ ///
+ /// # Arguments
+ ///
+ /// * `msg` - Byte slice containing the message (untrusted input)
+ pub fn process<'a, R: RngCore + CryptoRng, S>(
+ &self,
+ rng: &mut R, // rng instance to sample randomness from
+ msg: &[u8], // message buffer
+ src: Option<&'a S>, // optional source endpoint, set when "under load"
+ ) -> Result<Output, HandshakeError>
+ where
+ &'a S: Into<&'a SocketAddr>,
+ {
+ // ensure type read in-range
+ if msg.len() < 4 {
+ return Err(HandshakeError::InvalidMessageFormat);
+ }
+
+ // de-multiplex the message type field
+ match LittleEndian::read_u32(msg) {
+ TYPE_INITIATION => {
+ // parse message
+ let msg = Initiation::parse(msg)?;
+
+ // check mac1 field
+ self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
+
+ // address validation & DoS mitigation
+ if let Some(src) = src {
+ // obtain ref to socket addr
+ let src = src.into();
+
+ // check mac2 field
+ if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
+ let mut reply = Default::default();
+ self.macs.create_cookie_reply(
+ rng,
+ msg.noise.f_sender.get(),
+ src,
+ &msg.macs,
+ &mut reply,
+ );
+ return Ok((None, Some(reply.as_bytes().to_owned()), None));
+ }
+
+ // check ratelimiter
+ if !self.limiter.lock().unwrap().allow(&src.ip()) {
+ return Err(HandshakeError::RateLimited);
+ }
+ }
+
+ // consume the initiation
+ let (peer, st) = noise::consume_initiation(self, &msg.noise)?;
+
+ // allocate new index for response
+ let sender = self.allocate(rng, peer);
+
+ // prepare memory for response, TODO: take slice for zero allocation
+ let mut resp = Response::default();
+
+ // create response (release id on error)
+ let keys = noise::create_response(rng, peer, sender, st, &mut resp.noise).map_err(
+ |e| {
+ self.release(sender);
+ e
+ },
+ )?;
+
+ // add macs to response
+ peer.macs
+ .lock()
+ .generate(resp.noise.as_bytes(), &mut resp.macs);
+
+ // return unconfirmed keypair and the response as vector
+ Ok((Some(peer.pk), Some(resp.as_bytes().to_owned()), Some(keys)))
+ }
+ TYPE_RESPONSE => {
+ let msg = Response::parse(msg)?;
+
+ // check mac1 field
+ self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?;
+
+ // address validation & DoS mitigation
+ if let Some(src) = src {
+ // obtain ref to socket addr
+ let src = src.into();
+
+ // check mac2 field
+ if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) {
+ let mut reply = Default::default();
+ self.macs.create_cookie_reply(
+ rng,
+ msg.noise.f_sender.get(),
+ src,
+ &msg.macs,
+ &mut reply,
+ );
+ return Ok((None, Some(reply.as_bytes().to_owned()), None));
+ }
+
+ // check ratelimiter
+ if !self.limiter.lock().unwrap().allow(&src.ip()) {
+ return Err(HandshakeError::RateLimited);
+ }
+ }
+
+ // consume inner playload
+ noise::consume_response(self, &msg.noise)
+ }
+ TYPE_COOKIE_REPLY => {
+ let msg = CookieReply::parse(msg)?;
+
+ // lookup peer
+ let peer = self.lookup_id(msg.f_receiver.get())?;
+
+ // validate cookie reply
+ peer.macs.lock().process(&msg)?;
+
+ // this prompts no new message and
+ // DOES NOT cryptographically verify the peer
+ Ok((None, None, None))
+ }
+ _ => Err(HandshakeError::InvalidMessageFormat),
+ }
+ }
+
+ // Internal function
+ //
+ // Return the peer associated with the public key
+ pub(crate) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer, HandshakeError> {
+ self.pk_map
+ .get(pk.as_bytes())
+ .ok_or(HandshakeError::UnknownPublicKey)
+ }
+
+ // Internal function
+ //
+ // Return the peer currently associated with the receiver identifier
+ pub(crate) fn lookup_id(&self, id: u32) -> Result<&Peer, HandshakeError> {
+ let im = self.id_map.read();
+ let pk = im.get(&id).ok_or(HandshakeError::UnknownReceiverId)?;
+ match self.pk_map.get(pk) {
+ Some(peer) => Ok(peer),
+ _ => unreachable!(), // if the id-lookup succeeded, the peer should exist
+ }
+ }
+
+ // Internal function
+ //
+ // Allocated a new receiver identifier for the peer
+ fn allocate<R: RngCore + CryptoRng>(&self, rng: &mut R, peer: &Peer) -> u32 {
+ loop {
+ let id = rng.gen();
+
+ // check membership with read lock
+ if self.id_map.read().contains_key(&id) {
+ continue;
+ }
+
+ // take write lock and add index
+ let mut m = self.id_map.write();
+ if !m.contains_key(&id) {
+ m.insert(id, *peer.pk.as_bytes());
+ return id;
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::super::messages::*;
+ use super::*;
+ use hex;
+ use rand::rngs::OsRng;
+ use std::net::SocketAddr;
+ use std::thread;
+ use std::time::Duration;
+
+ fn setup_devices<R: RngCore + CryptoRng>(
+ rng: &mut R,
+ ) -> (PublicKey, Device, PublicKey, Device) {
+ // generate new keypairs
+
+ let sk1 = StaticSecret::new(rng);
+ let pk1 = PublicKey::from(&sk1);
+
+ let sk2 = StaticSecret::new(rng);
+ let pk2 = PublicKey::from(&sk2);
+
+ // pick random psk
+
+ let mut psk = [0u8; 32];
+ rng.fill_bytes(&mut psk[..]);
+
+ // intialize devices on both ends
+
+ let mut dev1 = Device::new(sk1);
+ let mut dev2 = Device::new(sk2);
+
+ dev1.add(pk2).unwrap();
+ dev2.add(pk1).unwrap();
+
+ dev1.set_psk(pk2, Some(psk)).unwrap();
+ dev2.set_psk(pk1, Some(psk)).unwrap();
+
+ (pk1, dev1, pk2, dev2)
+ }
+
+ /* Test longest possible handshake interaction (7 messages):
+ *
+ * 1. I -> R (initation)
+ * 2. I <- R (cookie reply)
+ * 3. I -> R (initation)
+ * 4. I <- R (response)
+ * 5. I -> R (cookie reply)
+ * 6. I -> R (initation)
+ * 7. I <- R (response)
+ */
+ #[test]
+ fn handshake_under_load() {
+ let mut rng = OsRng::new().unwrap();
+ let (_pk1, dev1, pk2, dev2) = setup_devices(&mut rng);
+
+ let src1: SocketAddr = "172.16.0.1:8080".parse().unwrap();
+ let src2: SocketAddr = "172.16.0.2:7070".parse().unwrap();
+
+ // 1. device-1 : create first initation
+ let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
+
+ // 2. device-2 : responds with CookieReply
+ let msg_cookie = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
+ (None, Some(msg), None) => msg,
+ _ => panic!("unexpected response"),
+ };
+
+ // device-1 : processes CookieReply (no response)
+ match dev1.process(&mut rng, &msg_cookie, Some(&src2)).unwrap() {
+ (None, None, None) => (),
+ _ => panic!("unexpected response"),
+ }
+
+ // avoid initation flood
+ thread::sleep(Duration::from_millis(20));
+
+ // 3. device-1 : create second initation
+ let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
+
+ // 4. device-2 : responds with noise response
+ let msg_response = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
+ (Some(_), Some(msg), Some(kp)) => {
+ assert_eq!(kp.initiator, false);
+ msg
+ }
+ _ => panic!("unexpected response"),
+ };
+
+ // 5. device-1 : responds with CookieReply
+ let msg_cookie = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() {
+ (None, Some(msg), None) => msg,
+ _ => panic!("unexpected response"),
+ };
+
+ // device-2 : processes CookieReply (no response)
+ match dev2.process(&mut rng, &msg_cookie, Some(&src1)).unwrap() {
+ (None, None, None) => (),
+ _ => panic!("unexpected response"),
+ }
+
+ // avoid initation flood
+ thread::sleep(Duration::from_millis(20));
+
+ // 6. device-1 : create third initation
+ let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
+
+ // 7. device-2 : responds with noise response
+ let (msg_response, kp1) = match dev2.process(&mut rng, &msg_init, Some(&src1)).unwrap() {
+ (Some(_), Some(msg), Some(kp)) => {
+ assert_eq!(kp.initiator, false);
+ (msg, kp)
+ }
+ _ => panic!("unexpected response"),
+ };
+
+ // device-1 : process noise response
+ let kp2 = match dev1.process(&mut rng, &msg_response, Some(&src2)).unwrap() {
+ (Some(_), None, Some(kp)) => {
+ assert_eq!(kp.initiator, true);
+ kp
+ }
+ _ => panic!("unexpected response"),
+ };
+
+ assert_eq!(kp1.send, kp2.recv);
+ assert_eq!(kp1.recv, kp2.send);
+ }
+
+ #[test]
+ fn handshake_no_load() {
+ let mut rng = OsRng::new().unwrap();
+ let (pk1, mut dev1, pk2, mut dev2) = setup_devices(&mut rng);
+
+ // do a few handshakes (every handshake should succeed)
+
+ for i in 0..10 {
+ println!("handshake : {}", i);
+
+ // create initiation
+
+ let msg1 = dev1.begin(&mut rng, &pk2).unwrap();
+
+ println!("msg1 = {} : {} bytes", hex::encode(&msg1[..]), msg1.len());
+ println!("msg1 = {:?}", Initiation::parse(&msg1[..]).unwrap());
+
+ // process initiation and create response
+
+ let (_, msg2, ks_r) = dev2.process(&mut rng, &msg1, None).unwrap();
+
+ let ks_r = ks_r.unwrap();
+ let msg2 = msg2.unwrap();
+
+ println!("msg2 = {} : {} bytes", hex::encode(&msg2[..]), msg2.len());
+ println!("msg2 = {:?}", Response::parse(&msg2[..]).unwrap());
+
+ assert!(!ks_r.initiator, "Responders key-pair is confirmed");
+
+ // process response and obtain confirmed key-pair
+
+ let (_, msg3, ks_i) = dev1.process(&mut rng, &msg2, None).unwrap();
+ let ks_i = ks_i.unwrap();
+
+ assert!(msg3.is_none(), "Returned message after response");
+ assert!(ks_i.initiator, "Initiators key-pair is not confirmed");
+
+ assert_eq!(ks_i.send, ks_r.recv, "KeyI.send != KeyR.recv");
+ assert_eq!(ks_i.recv, ks_r.send, "KeyI.recv != KeyR.send");
+
+ dev1.release(ks_i.send.id);
+ dev2.release(ks_r.send.id);
+
+ // to avoid flood detection
+ thread::sleep(Duration::from_millis(20));
+ }
+
+ dev1.remove(pk2).unwrap();
+ dev2.remove(pk1).unwrap();
+ }
+}
diff --git a/src/wireguard/handshake/macs.rs b/src/wireguard/handshake/macs.rs
new file mode 100644
index 0000000..689826b
--- /dev/null
+++ b/src/wireguard/handshake/macs.rs
@@ -0,0 +1,327 @@
+use generic_array::GenericArray;
+use rand::{CryptoRng, RngCore};
+use spin::RwLock;
+use std::time::{Duration, Instant};
+
+// types to coalesce into bytes
+use std::net::SocketAddr;
+use x25519_dalek::PublicKey;
+
+// AEAD
+use aead::{Aead, NewAead, Payload};
+use chacha20poly1305::XChaCha20Poly1305;
+
+// MAC
+use blake2::Blake2s;
+use subtle::ConstantTimeEq;
+
+use super::messages::{CookieReply, MacsFooter, TYPE_COOKIE_REPLY};
+use super::types::HandshakeError;
+
+const LABEL_MAC1: &[u8] = b"mac1----";
+const LABEL_COOKIE: &[u8] = b"cookie--";
+
+const SIZE_COOKIE: usize = 16;
+const SIZE_SECRET: usize = 32;
+const SIZE_MAC: usize = 16; // blake2s-mac128
+const SIZE_TAG: usize = 16; // xchacha20poly1305 tag
+
+const COOKIE_UPDATE_INTERVAL: Duration = Duration::from_secs(120);
+
+macro_rules! HASH {
+ ( $($input:expr),* ) => {{
+ use blake2::Digest;
+ let mut hsh = Blake2s::new();
+ $(
+ hsh.input($input);
+ )*
+ hsh.result()
+ }};
+}
+
+macro_rules! MAC {
+ ( $key:expr, $($input:expr),* ) => {{
+ use blake2::VarBlake2s;
+ use digest::Input;
+ use digest::VariableOutput;
+ let mut tag = [0u8; SIZE_MAC];
+ let mut mac = VarBlake2s::new_keyed($key, SIZE_MAC);
+ $(
+ mac.input($input);
+ )*
+ mac.variable_result(|buf| tag.copy_from_slice(buf));
+ tag
+ }};
+}
+
+macro_rules! XSEAL {
+ ($key:expr, $nonce:expr, $ad:expr, $pt:expr, $ct:expr) => {{
+ let ct = XChaCha20Poly1305::new(*GenericArray::from_slice($key))
+ .encrypt(
+ GenericArray::from_slice($nonce),
+ Payload { msg: $pt, aad: $ad },
+ )
+ .unwrap();
+ debug_assert_eq!(ct.len(), $pt.len() + SIZE_TAG);
+ $ct.copy_from_slice(&ct);
+ }};
+}
+
+macro_rules! XOPEN {
+ ($key:expr, $nonce:expr, $ad:expr, $pt:expr, $ct:expr) => {{
+ debug_assert_eq!($ct.len(), $pt.len() + SIZE_TAG);
+ XChaCha20Poly1305::new(*GenericArray::from_slice($key))
+ .decrypt(
+ GenericArray::from_slice($nonce),
+ Payload { msg: $ct, aad: $ad },
+ )
+ .map_err(|_| HandshakeError::DecryptionFailure)
+ .map(|pt| $pt.copy_from_slice(&pt))
+ }};
+}
+
+struct Cookie {
+ value: [u8; 16],
+ birth: Instant,
+}
+
+pub struct Generator {
+ mac1_key: [u8; 32],
+ cookie_key: [u8; 32], // xchacha20poly key for opening cookie response
+ last_mac1: Option<[u8; 16]>,
+ cookie: Option<Cookie>,
+}
+
+fn addr_to_mac_bytes(addr: &SocketAddr) -> Vec<u8> {
+ match addr {
+ SocketAddr::V4(addr) => {
+ let mut res = Vec::with_capacity(4 + 2);
+ res.extend(&addr.ip().octets());
+ res.extend(&addr.port().to_le_bytes());
+ res
+ }
+ SocketAddr::V6(addr) => {
+ let mut res = Vec::with_capacity(16 + 2);
+ res.extend(&addr.ip().octets());
+ res.extend(&addr.port().to_le_bytes());
+ res
+ }
+ }
+}
+
+impl Generator {
+ /// Initalize a new mac field generator
+ ///
+ /// # Arguments
+ ///
+ /// - pk: The public key of the peer to which the generator is associated
+ ///
+ /// # Returns
+ ///
+ /// A freshly initated generator
+ pub fn new(pk: PublicKey) -> Generator {
+ Generator {
+ mac1_key: HASH!(LABEL_MAC1, pk.as_bytes()).into(),
+ cookie_key: HASH!(LABEL_COOKIE, pk.as_bytes()).into(),
+ last_mac1: None,
+ cookie: None,
+ }
+ }
+
+ /// Process a CookieReply message
+ ///
+ /// # Arguments
+ ///
+ /// - reply: CookieReply to process
+ ///
+ /// # Returns
+ ///
+ /// Can fail if the cookie reply fails to validate
+ /// (either indicating that it is outdated or malformed)
+ pub fn process(&mut self, reply: &CookieReply) -> Result<(), HandshakeError> {
+ let mac1 = self.last_mac1.ok_or(HandshakeError::InvalidState)?;
+ let mut tau = [0u8; SIZE_COOKIE];
+ XOPEN!(
+ &self.cookie_key, // key
+ &reply.f_nonce, // nonce
+ &mac1, // ad
+ &mut tau, // pt
+ &reply.f_cookie // ct || tag
+ )?;
+ self.cookie = Some(Cookie {
+ birth: Instant::now(),
+ value: tau,
+ });
+ Ok(())
+ }
+
+ /// Generate both mac fields for an inner message
+ ///
+ /// # Arguments
+ ///
+ /// - inner: A byteslice representing the inner message to be covered
+ /// - macs: The destination mac footer for the resulting macs
+ pub fn generate(&mut self, inner: &[u8], macs: &mut MacsFooter) {
+ macs.f_mac1 = MAC!(&self.mac1_key, inner);
+ macs.f_mac2 = match &self.cookie {
+ Some(cookie) => {
+ if cookie.birth.elapsed() > COOKIE_UPDATE_INTERVAL {
+ self.cookie = None;
+ [0u8; SIZE_MAC]
+ } else {
+ MAC!(&cookie.value, inner, macs.f_mac1)
+ }
+ }
+ None => [0u8; SIZE_MAC],
+ };
+ self.last_mac1 = Some(macs.f_mac1);
+ }
+}
+
+struct Secret {
+ value: [u8; 32],
+ birth: Instant,
+}
+
+pub struct Validator {
+ mac1_key: [u8; 32], // mac1 key, derived from device public key
+ cookie_key: [u8; 32], // xchacha20poly key for sealing cookie response
+ secret: RwLock<Secret>,
+}
+
+impl Validator {
+ pub fn new(pk: PublicKey) -> Validator {
+ Validator {
+ mac1_key: HASH!(LABEL_MAC1, pk.as_bytes()).into(),
+ cookie_key: HASH!(LABEL_COOKIE, pk.as_bytes()).into(),
+ secret: RwLock::new(Secret {
+ value: [0u8; SIZE_SECRET],
+ birth: Instant::now() - Duration::new(86400, 0),
+ }),
+ }
+ }
+
+ fn get_tau(&self, src: &[u8]) -> Option<[u8; SIZE_COOKIE]> {
+ let secret = self.secret.read();
+ if secret.birth.elapsed() < COOKIE_UPDATE_INTERVAL {
+ Some(MAC!(&secret.value, src))
+ } else {
+ None
+ }
+ }
+
+ fn get_set_tau<R: RngCore + CryptoRng>(&self, rng: &mut R, src: &[u8]) -> [u8; SIZE_COOKIE] {
+ // check if current value is still valid
+ {
+ let secret = self.secret.read();
+ if secret.birth.elapsed() < COOKIE_UPDATE_INTERVAL {
+ return MAC!(&secret.value, src);
+ };
+ }
+
+ // take write lock, check again
+ {
+ let mut secret = self.secret.write();
+ if secret.birth.elapsed() < COOKIE_UPDATE_INTERVAL {
+ return MAC!(&secret.value, src);
+ };
+
+ // set new random cookie secret
+ rng.fill_bytes(&mut secret.value);
+ secret.birth = Instant::now();
+ MAC!(&secret.value, src)
+ }
+ }
+
+ pub fn create_cookie_reply<R: RngCore + CryptoRng>(
+ &self,
+ rng: &mut R,
+ receiver: u32, // receiver id of incoming message
+ src: &SocketAddr, // source address of incoming message
+ macs: &MacsFooter, // footer of incoming message
+ msg: &mut CookieReply, // resulting cookie reply
+ ) {
+ let src = addr_to_mac_bytes(src);
+ msg.f_type.set(TYPE_COOKIE_REPLY as u32);
+ msg.f_receiver.set(receiver);
+ rng.fill_bytes(&mut msg.f_nonce);
+ XSEAL!(
+ &self.cookie_key, // key
+ &msg.f_nonce, // nonce
+ &macs.f_mac1, // ad
+ &self.get_set_tau(rng, &src), // pt
+ &mut msg.f_cookie // ct || tag
+ );
+ }
+
+ /// Check the mac1 field against the inner message
+ ///
+ /// # Arguments
+ ///
+ /// - inner: The inner message covered by the mac1 field
+ /// - macs: The mac footer
+ pub fn check_mac1(&self, inner: &[u8], macs: &MacsFooter) -> Result<(), HandshakeError> {
+ let valid_mac1: bool = MAC!(&self.mac1_key, inner).ct_eq(&macs.f_mac1).into();
+ if !valid_mac1 {
+ Err(HandshakeError::InvalidMac1)
+ } else {
+ Ok(())
+ }
+ }
+
+ pub fn check_mac2(&self, inner: &[u8], src: &SocketAddr, macs: &MacsFooter) -> bool {
+ let src = addr_to_mac_bytes(src);
+ match self.get_tau(&src) {
+ Some(tau) => MAC!(&tau, inner, macs.f_mac1).ct_eq(&macs.f_mac2).into(),
+ None => false,
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use proptest::prelude::*;
+ use rand::rngs::OsRng;
+ use x25519_dalek::StaticSecret;
+
+ fn new_validator_generator() -> (Validator, Generator) {
+ let mut rng = OsRng::new().unwrap();
+ let sk = StaticSecret::new(&mut rng);
+ let pk = PublicKey::from(&sk);
+ (Validator::new(pk), Generator::new(pk))
+ }
+
+ proptest! {
+ #[test]
+ fn test_cookie_reply(inner1 : Vec<u8>, inner2 : Vec<u8>, receiver : u32) {
+ let mut msg = CookieReply::default();
+ let mut rng = OsRng::new().expect("failed to create rng");
+ let mut macs = MacsFooter::default();
+ let src = "192.0.2.16:8080".parse().unwrap();
+ let (validator, mut generator) = new_validator_generator();
+
+ // generate mac1 for first message
+ generator.generate(&inner1[..], &mut macs);
+ assert_ne!(macs.f_mac1, [0u8; SIZE_MAC], "mac1 should be set");
+ assert_eq!(macs.f_mac2, [0u8; SIZE_MAC], "mac2 should not be set");
+
+ // check validity of mac1
+ validator.check_mac1(&inner1[..], &macs).expect("mac1 of inner1 did not validate");
+ assert_eq!(validator.check_mac2(&inner1[..], &src, &macs), false, "mac2 of inner2 did not validate");
+ validator.create_cookie_reply(&mut rng, receiver, &src, &macs, &mut msg);
+
+ // consume cookie reply
+ generator.process(&msg).expect("failed to process CookieReply");
+
+ // generate mac2 & mac2 for second message
+ generator.generate(&inner2[..], &mut macs);
+ assert_ne!(macs.f_mac1, [0u8; SIZE_MAC], "mac1 should be set");
+ assert_ne!(macs.f_mac2, [0u8; SIZE_MAC], "mac2 should be set");
+
+ // check validity of mac1 and mac2
+ validator.check_mac1(&inner2[..], &macs).expect("mac1 of inner2 did not validate");
+ assert!(validator.check_mac2(&inner2[..], &src, &macs), "mac2 of inner2 did not validate");
+ }
+ }
+}
diff --git a/src/wireguard/handshake/messages.rs b/src/wireguard/handshake/messages.rs
new file mode 100644
index 0000000..29d80af
--- /dev/null
+++ b/src/wireguard/handshake/messages.rs
@@ -0,0 +1,363 @@
+#[cfg(test)]
+use hex;
+
+#[cfg(test)]
+use std::fmt;
+
+use std::mem;
+
+use byteorder::LittleEndian;
+use zerocopy::byteorder::U32;
+use zerocopy::{AsBytes, ByteSlice, FromBytes, LayoutVerified};
+
+use super::types::*;
+
+const SIZE_MAC: usize = 16;
+const SIZE_TAG: usize = 16; // poly1305 tag
+const SIZE_XNONCE: usize = 24; // xchacha20 nonce
+const SIZE_COOKIE: usize = 16; //
+const SIZE_X25519_POINT: usize = 32; // x25519 public key
+const SIZE_TIMESTAMP: usize = 12;
+
+pub const TYPE_INITIATION: u32 = 1;
+pub const TYPE_RESPONSE: u32 = 2;
+pub const TYPE_COOKIE_REPLY: u32 = 3;
+
+const fn max(a: usize, b: usize) -> usize {
+ let m: usize = (a > b) as usize;
+ m * a + (1 - m) * b
+}
+
+pub const MAX_HANDSHAKE_MSG_SIZE: usize = max(
+ max(mem::size_of::<Response>(), mem::size_of::<Initiation>()),
+ mem::size_of::<CookieReply>(),
+);
+
+/* Handshake messsages */
+
+#[repr(packed)]
+#[derive(Copy, Clone, FromBytes, AsBytes)]
+pub struct Response {
+ pub noise: NoiseResponse, // inner message covered by macs
+ pub macs: MacsFooter,
+}
+
+#[repr(packed)]
+#[derive(Copy, Clone, FromBytes, AsBytes)]
+pub struct Initiation {
+ pub noise: NoiseInitiation, // inner message covered by macs
+ pub macs: MacsFooter,
+}
+
+#[repr(packed)]
+#[derive(Copy, Clone, FromBytes, AsBytes)]
+pub struct CookieReply {
+ pub f_type: U32<LittleEndian>,
+ pub f_receiver: U32<LittleEndian>,
+ pub f_nonce: [u8; SIZE_XNONCE],
+ pub f_cookie: [u8; SIZE_COOKIE + SIZE_TAG],
+}
+
+/* Inner sub-messages */
+
+#[repr(packed)]
+#[derive(Copy, Clone, FromBytes, AsBytes)]
+pub struct MacsFooter {
+ pub f_mac1: [u8; SIZE_MAC],
+ pub f_mac2: [u8; SIZE_MAC],
+}
+
+#[repr(packed)]
+#[derive(Copy, Clone, FromBytes, AsBytes)]
+pub struct NoiseInitiation {
+ pub f_type: U32<LittleEndian>,
+ pub f_sender: U32<LittleEndian>,
+ pub f_ephemeral: [u8; SIZE_X25519_POINT],
+ pub f_static: [u8; SIZE_X25519_POINT + SIZE_TAG],
+ pub f_timestamp: [u8; SIZE_TIMESTAMP + SIZE_TAG],
+}
+
+#[repr(packed)]
+#[derive(Copy, Clone, FromBytes, AsBytes)]
+pub struct NoiseResponse {
+ pub f_type: U32<LittleEndian>,
+ pub f_sender: U32<LittleEndian>,
+ pub f_receiver: U32<LittleEndian>,
+ pub f_ephemeral: [u8; SIZE_X25519_POINT],
+ pub f_empty: [u8; SIZE_TAG],
+}
+
+/* Zero copy parsing of handshake messages */
+
+impl Initiation {
+ pub fn parse<B: ByteSlice>(bytes: B) -> Result<LayoutVerified<B, Self>, HandshakeError> {
+ let msg: LayoutVerified<B, Self> =
+ LayoutVerified::new(bytes).ok_or(HandshakeError::InvalidMessageFormat)?;
+
+ if msg.noise.f_type.get() != (TYPE_INITIATION as u32) {
+ return Err(HandshakeError::InvalidMessageFormat);
+ }
+
+ Ok(msg)
+ }
+}
+
+impl Response {
+ pub fn parse<B: ByteSlice>(bytes: B) -> Result<LayoutVerified<B, Self>, HandshakeError> {
+ let msg: LayoutVerified<B, Self> =
+ LayoutVerified::new(bytes).ok_or(HandshakeError::InvalidMessageFormat)?;
+
+ if msg.noise.f_type.get() != (TYPE_RESPONSE as u32) {
+ return Err(HandshakeError::InvalidMessageFormat);
+ }
+
+ Ok(msg)
+ }
+}
+
+impl CookieReply {
+ pub fn parse<B: ByteSlice>(bytes: B) -> Result<LayoutVerified<B, Self>, HandshakeError> {
+ let msg: LayoutVerified<B, Self> =
+ LayoutVerified::new(bytes).ok_or(HandshakeError::InvalidMessageFormat)?;
+
+ if msg.f_type.get() != (TYPE_COOKIE_REPLY as u32) {
+ return Err(HandshakeError::InvalidMessageFormat);
+ }
+
+ Ok(msg)
+ }
+}
+
+/* Default values */
+
+impl Default for Response {
+ fn default() -> Self {
+ Self {
+ noise: Default::default(),
+ macs: Default::default(),
+ }
+ }
+}
+
+impl Default for Initiation {
+ fn default() -> Self {
+ Self {
+ noise: Default::default(),
+ macs: Default::default(),
+ }
+ }
+}
+
+impl Default for CookieReply {
+ fn default() -> Self {
+ Self {
+ f_type: <U32<LittleEndian>>::new(TYPE_COOKIE_REPLY as u32),
+ f_receiver: <U32<LittleEndian>>::ZERO,
+ f_nonce: [0u8; SIZE_XNONCE],
+ f_cookie: [0u8; SIZE_COOKIE + SIZE_TAG],
+ }
+ }
+}
+
+impl Default for MacsFooter {
+ fn default() -> Self {
+ Self {
+ f_mac1: [0u8; SIZE_MAC],
+ f_mac2: [0u8; SIZE_MAC],
+ }
+ }
+}
+
+impl Default for NoiseInitiation {
+ fn default() -> Self {
+ Self {
+ f_type: <U32<LittleEndian>>::new(TYPE_INITIATION as u32),
+ f_sender: <U32<LittleEndian>>::ZERO,
+ f_ephemeral: [0u8; SIZE_X25519_POINT],
+ f_static: [0u8; SIZE_X25519_POINT + SIZE_TAG],
+ f_timestamp: [0u8; SIZE_TIMESTAMP + SIZE_TAG],
+ }
+ }
+}
+
+impl Default for NoiseResponse {
+ fn default() -> Self {
+ Self {
+ f_type: <U32<LittleEndian>>::new(TYPE_RESPONSE as u32),
+ f_sender: <U32<LittleEndian>>::ZERO,
+ f_receiver: <U32<LittleEndian>>::ZERO,
+ f_ephemeral: [0u8; SIZE_X25519_POINT],
+ f_empty: [0u8; SIZE_TAG],
+ }
+ }
+}
+
+/* Debug formatting (for testing purposes) */
+
+#[cfg(test)]
+impl fmt::Debug for Initiation {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "Initiation {{ {:?} || {:?} }}", self.noise, self.macs)
+ }
+}
+
+#[cfg(test)]
+impl fmt::Debug for Response {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "Response {{ {:?} || {:?} }}", self.noise, self.macs)
+ }
+}
+
+#[cfg(test)]
+impl fmt::Debug for CookieReply {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(
+ f,
+ "CookieReply {{ type = {}, receiver = {}, nonce = {}, cookie = {} }}",
+ self.f_type,
+ self.f_receiver,
+ hex::encode(&self.f_nonce[..]),
+ hex::encode(&self.f_cookie[..]),
+ )
+ }
+}
+
+#[cfg(test)]
+impl fmt::Debug for NoiseInitiation {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f,
+ "NoiseInitiation {{ type = {}, sender = {}, ephemeral = {}, static = {}, timestamp = {} }}",
+ self.f_type.get(),
+ self.f_sender.get(),
+ hex::encode(&self.f_ephemeral[..]),
+ hex::encode(&self.f_static[..]),
+ hex::encode(&self.f_timestamp[..]),
+ )
+ }
+}
+
+#[cfg(test)]
+impl fmt::Debug for NoiseResponse {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f,
+ "NoiseResponse {{ type = {}, sender = {}, receiver = {}, ephemeral = {}, empty = |{} }}",
+ self.f_type,
+ self.f_sender,
+ self.f_receiver,
+ hex::encode(&self.f_ephemeral[..]),
+ hex::encode(&self.f_empty[..])
+ )
+ }
+}
+
+#[cfg(test)]
+impl fmt::Debug for MacsFooter {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(
+ f,
+ "Macs {{ mac1 = {}, mac2 = {} }}",
+ hex::encode(&self.f_mac1[..]),
+ hex::encode(&self.f_mac2[..])
+ )
+ }
+}
+
+/* Equality (for testing purposes) */
+
+#[cfg(test)]
+macro_rules! eq_as_bytes {
+ ($type:path) => {
+ impl PartialEq for $type {
+ fn eq(&self, other: &Self) -> bool {
+ self.as_bytes() == other.as_bytes()
+ }
+ }
+ impl Eq for $type {}
+ };
+}
+
+#[cfg(test)]
+eq_as_bytes!(Initiation);
+
+#[cfg(test)]
+eq_as_bytes!(Response);
+
+#[cfg(test)]
+eq_as_bytes!(CookieReply);
+
+#[cfg(test)]
+eq_as_bytes!(MacsFooter);
+
+#[cfg(test)]
+eq_as_bytes!(NoiseInitiation);
+
+#[cfg(test)]
+eq_as_bytes!(NoiseResponse);
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn message_response_identity() {
+ let mut msg: Response = Default::default();
+
+ msg.noise.f_sender.set(146252);
+ msg.noise.f_receiver.set(554442);
+ msg.noise.f_ephemeral = [
+ 0xc1, 0x66, 0x0a, 0x0c, 0xdc, 0x0f, 0x6c, 0x51, 0x0f, 0xc2, 0xcc, 0x51, 0x52, 0x0c,
+ 0xde, 0x1e, 0xf7, 0xf1, 0xca, 0x90, 0x86, 0x72, 0xad, 0x67, 0xea, 0x89, 0x45, 0x44,
+ 0x13, 0x56, 0x52, 0x1f,
+ ];
+ msg.noise.f_empty = [
+ 0x60, 0x0e, 0x1e, 0x95, 0x41, 0x6b, 0x52, 0x05, 0xa2, 0x09, 0xe1, 0xbf, 0x40, 0x05,
+ 0x2f, 0xde,
+ ];
+ msg.macs.f_mac1 = [
+ 0xf2, 0xad, 0x40, 0xb5, 0xf7, 0xde, 0x77, 0x35, 0x89, 0x19, 0xb7, 0x5c, 0xf9, 0x54,
+ 0x69, 0x29,
+ ];
+ msg.macs.f_mac2 = [
+ 0x4f, 0xd2, 0x1b, 0xfe, 0x77, 0xe6, 0x2e, 0xc9, 0x07, 0xe2, 0x87, 0x17, 0xbb, 0xe5,
+ 0xdf, 0xbb,
+ ];
+
+ let buf: Vec<u8> = msg.as_bytes().to_vec();
+ let msg_p = Response::parse(&buf[..]).unwrap();
+ assert_eq!(msg, *msg_p.into_ref());
+ }
+
+ #[test]
+ fn message_initiate_identity() {
+ let mut msg: Initiation = Default::default();
+
+ msg.noise.f_sender.set(575757);
+ msg.noise.f_ephemeral = [
+ 0xc1, 0x66, 0x0a, 0x0c, 0xdc, 0x0f, 0x6c, 0x51, 0x0f, 0xc2, 0xcc, 0x51, 0x52, 0x0c,
+ 0xde, 0x1e, 0xf7, 0xf1, 0xca, 0x90, 0x86, 0x72, 0xad, 0x67, 0xea, 0x89, 0x45, 0x44,
+ 0x13, 0x56, 0x52, 0x1f,
+ ];
+ msg.noise.f_static = [
+ 0xdc, 0x33, 0x90, 0x15, 0x8f, 0x82, 0x3e, 0x06, 0x44, 0xa0, 0xde, 0x4c, 0x15, 0x6c,
+ 0x5d, 0xa4, 0x65, 0x99, 0xf6, 0x6c, 0xa1, 0x14, 0x77, 0xf9, 0xeb, 0x6a, 0xec, 0xc3,
+ 0x3c, 0xda, 0x47, 0xe1, 0x45, 0xac, 0x8d, 0x43, 0xea, 0x1b, 0x2f, 0x02, 0x45, 0x5d,
+ 0x86, 0x37, 0xee, 0x83, 0x6b, 0x42,
+ ];
+ msg.noise.f_timestamp = [
+ 0x4f, 0x1c, 0x60, 0xec, 0x0e, 0xf6, 0x36, 0xf0, 0x78, 0x28, 0x57, 0x42, 0x60, 0x0e,
+ 0x1e, 0x95, 0x41, 0x6b, 0x52, 0x05, 0xa2, 0x09, 0xe1, 0xbf, 0x40, 0x05, 0x2f, 0xde,
+ ];
+ msg.macs.f_mac1 = [
+ 0xf2, 0xad, 0x40, 0xb5, 0xf7, 0xde, 0x77, 0x35, 0x89, 0x19, 0xb7, 0x5c, 0xf9, 0x54,
+ 0x69, 0x29,
+ ];
+ msg.macs.f_mac2 = [
+ 0x4f, 0xd2, 0x1b, 0xfe, 0x77, 0xe6, 0x2e, 0xc9, 0x07, 0xe2, 0x87, 0x17, 0xbb, 0xe5,
+ 0xdf, 0xbb,
+ ];
+
+ let buf: Vec<u8> = msg.as_bytes().to_vec();
+ let msg_p = Initiation::parse(&buf[..]).unwrap();
+ assert_eq!(msg, *msg_p.into_ref());
+ }
+}
diff --git a/src/wireguard/handshake/mod.rs b/src/wireguard/handshake/mod.rs
new file mode 100644
index 0000000..071a41f
--- /dev/null
+++ b/src/wireguard/handshake/mod.rs
@@ -0,0 +1,21 @@
+/* Implementation of the:
+ *
+ * Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s
+ *
+ * Protocol pattern, see: http://www.noiseprotocol.org/noise.html.
+ * For documentation.
+ */
+
+mod device;
+mod macs;
+mod messages;
+mod noise;
+mod peer;
+mod ratelimiter;
+mod timestamp;
+mod types;
+
+// publicly exposed interface
+
+pub use device::Device;
+pub use messages::{MAX_HANDSHAKE_MSG_SIZE, TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE};
diff --git a/src/wireguard/handshake/noise.rs b/src/wireguard/handshake/noise.rs
new file mode 100644
index 0000000..a2a84b0
--- /dev/null
+++ b/src/wireguard/handshake/noise.rs
@@ -0,0 +1,549 @@
+// DH
+use x25519_dalek::PublicKey;
+use x25519_dalek::StaticSecret;
+
+// HASH & MAC
+use blake2::Blake2s;
+use hmac::Hmac;
+
+// AEAD
+use aead::{Aead, NewAead, Payload};
+use chacha20poly1305::ChaCha20Poly1305;
+
+use rand::{CryptoRng, RngCore};
+
+use generic_array::typenum::*;
+use generic_array::*;
+
+use clear_on_drop::clear::Clear;
+use clear_on_drop::clear_stack_on_return;
+
+use subtle::ConstantTimeEq;
+
+use super::device::Device;
+use super::messages::{NoiseInitiation, NoiseResponse};
+use super::messages::{TYPE_INITIATION, TYPE_RESPONSE};
+use super::peer::{Peer, State};
+use super::timestamp;
+use super::types::*;
+
+use super::super::types::{KeyPair, Key};
+
+use std::time::Instant;
+
+// HMAC hasher (generic construction)
+
+type HMACBlake2s = Hmac<Blake2s>;
+
+// convenient alias to pass state temporarily into device.rs and back
+
+type TemporaryState = (u32, PublicKey, GenericArray<u8, U32>, GenericArray<u8, U32>);
+
+const SIZE_CK: usize = 32;
+const SIZE_HS: usize = 32;
+const SIZE_NONCE: usize = 8;
+const SIZE_TAG: usize = 16;
+
+// number of pages to clear after sensitive call
+const CLEAR_PAGES: usize = 1;
+
+// C := Hash(Construction)
+const INITIAL_CK: [u8; SIZE_CK] = [
+ 0x60, 0xe2, 0x6d, 0xae, 0xf3, 0x27, 0xef, 0xc0, 0x2e, 0xc3, 0x35, 0xe2, 0xa0, 0x25, 0xd2, 0xd0,
+ 0x16, 0xeb, 0x42, 0x06, 0xf8, 0x72, 0x77, 0xf5, 0x2d, 0x38, 0xd1, 0x98, 0x8b, 0x78, 0xcd, 0x36,
+];
+
+// H := Hash(C || Identifier)
+const INITIAL_HS: [u8; SIZE_HS] = [
+ 0x22, 0x11, 0xb3, 0x61, 0x08, 0x1a, 0xc5, 0x66, 0x69, 0x12, 0x43, 0xdb, 0x45, 0x8a, 0xd5, 0x32,
+ 0x2d, 0x9c, 0x6c, 0x66, 0x22, 0x93, 0xe8, 0xb7, 0x0e, 0xe1, 0x9c, 0x65, 0xba, 0x07, 0x9e, 0xf3,
+];
+
+const ZERO_NONCE: [u8; 12] = [0u8; 12];
+
+macro_rules! HASH {
+ ( $($input:expr),* ) => {{
+ use blake2::Digest;
+ let mut hsh = Blake2s::new();
+ $(
+ hsh.input($input);
+ )*
+ hsh.result()
+ }};
+}
+
+macro_rules! HMAC {
+ ($key:expr, $($input:expr),*) => {{
+ use hmac::Mac;
+ let mut mac = HMACBlake2s::new_varkey($key).unwrap();
+ $(
+ mac.input($input);
+ )*
+ mac.result().code()
+ }};
+}
+
+macro_rules! KDF1 {
+ ($ck:expr, $input:expr) => {{
+ let mut t0 = HMAC!($ck, $input);
+ let t1 = HMAC!(&t0, &[0x1]);
+ t0.clear();
+ t1
+ }};
+}
+
+macro_rules! KDF2 {
+ ($ck:expr, $input:expr) => {{
+ let mut t0 = HMAC!($ck, $input);
+ let t1 = HMAC!(&t0, &[0x1]);
+ let t2 = HMAC!(&t0, &t1, &[0x2]);
+ t0.clear();
+ (t1, t2)
+ }};
+}
+
+macro_rules! KDF3 {
+ ($ck:expr, $input:expr) => {{
+ let mut t0 = HMAC!($ck, $input);
+ let t1 = HMAC!(&t0, &[0x1]);
+ let t2 = HMAC!(&t0, &t1, &[0x2]);
+ let t3 = HMAC!(&t0, &t2, &[0x3]);
+ t0.clear();
+ (t1, t2, t3)
+ }};
+}
+
+macro_rules! SEAL {
+ ($key:expr, $ad:expr, $pt:expr, $ct:expr) => {
+ ChaCha20Poly1305::new(*GenericArray::from_slice($key))
+ .encrypt(&ZERO_NONCE.into(), Payload { msg: $pt, aad: $ad })
+ .map(|ct| $ct.copy_from_slice(&ct))
+ .unwrap()
+ };
+}
+
+macro_rules! OPEN {
+ ($key:expr, $ad:expr, $pt:expr, $ct:expr) => {
+ ChaCha20Poly1305::new(*GenericArray::from_slice($key))
+ .decrypt(&ZERO_NONCE.into(), Payload { msg: $ct, aad: $ad })
+ .map_err(|_| HandshakeError::DecryptionFailure)
+ .map(|pt| $pt.copy_from_slice(&pt))
+ };
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ const IDENTIFIER: &[u8] = b"WireGuard v1 zx2c4 Jason@zx2c4.com";
+ const CONSTRUCTION: &[u8] = b"Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
+
+ /* Sanity check precomputed initial chain key
+ */
+ #[test]
+ fn precomputed_chain_key() {
+ assert_eq!(INITIAL_CK[..], HASH!(CONSTRUCTION)[..]);
+ }
+
+ /* Sanity check precomputed initial hash transcript
+ */
+ #[test]
+ fn precomputed_hash() {
+ assert_eq!(INITIAL_HS[..], HASH!(INITIAL_CK, IDENTIFIER)[..]);
+ }
+
+ /* Sanity check the HKDF macro
+ *
+ * Test vectors generated using WireGuard-Go
+ */
+ #[test]
+ fn hkdf() {
+ let tests: Vec<(Vec<u8>, Vec<u8>, [u8; 32], [u8; 32], [u8; 32])> = vec![
+ (
+ vec![],
+ vec![],
+ [
+ 0x83, 0x87, 0xb4, 0x6b, 0xf4, 0x3e, 0xcc, 0xfc, 0xf3, 0x49, 0x55, 0x2a, 0x09,
+ 0x5d, 0x83, 0x15, 0xc4, 0x05, 0x5b, 0xeb, 0x90, 0x20, 0x8f, 0xb1, 0xbe, 0x23,
+ 0xb8, 0x94, 0xbc, 0x2e, 0xd5, 0xd0,
+ ],
+ [
+ 0x58, 0xa0, 0xe5, 0xf6, 0xfa, 0xef, 0xcc, 0xf4, 0x80, 0x7b, 0xff, 0x1f, 0x05,
+ 0xfa, 0x8a, 0x92, 0x17, 0x94, 0x57, 0x62, 0x04, 0x0b, 0xce, 0xc2, 0xf4, 0xb4,
+ 0xa6, 0x2b, 0xdf, 0xe0, 0xe8, 0x6e,
+ ],
+ [
+ 0x0c, 0xe6, 0xea, 0x98, 0xec, 0x54, 0x8f, 0x8e, 0x28, 0x1e, 0x93, 0xe3, 0x2d,
+ 0xb6, 0x56, 0x21, 0xc4, 0x5e, 0xb1, 0x8d, 0xc6, 0xf0, 0xa7, 0xad, 0x94, 0x17,
+ 0x86, 0x10, 0xa2, 0xf7, 0x33, 0x8e,
+ ],
+ ),
+ (
+ vec![0xde, 0xad, 0xbe, 0xef],
+ vec![],
+ [
+ 0x55, 0x32, 0x9d, 0xc8, 0x0e, 0x69, 0x0f, 0xd8, 0x6b, 0xd9, 0x66, 0x1f, 0x08,
+ 0x51, 0xc9, 0xb3, 0x68, 0x6d, 0xf2, 0xb1, 0xfd, 0xa0, 0x34, 0x7b, 0xc3, 0xd2,
+ 0x79, 0x58, 0x25, 0x4b, 0x32, 0xc6,
+ ],
+ [
+ 0x8d, 0xfc, 0x6d, 0x33, 0xa8, 0x11, 0x8f, 0xfe, 0x40, 0x8b, 0x31, 0xdd, 0xac,
+ 0x25, 0xf7, 0x2a, 0xee, 0x91, 0x15, 0xa4, 0x5b, 0x69, 0xba, 0x17, 0x6a, 0xd0,
+ 0x12, 0xb2, 0x43, 0x83, 0x4f, 0xee,
+ ],
+ [
+ 0xd6, 0x9e, 0x85, 0x2a, 0x28, 0x96, 0x56, 0x9e, 0xa5, 0x4a, 0x67, 0x96, 0x9a,
+ 0xa1, 0x80, 0x02, 0x87, 0x92, 0x1d, 0xac, 0x53, 0xce, 0x6d, 0xb4, 0xb4, 0xe1,
+ 0x21, 0x92, 0xf2, 0x63, 0xc4, 0xc4,
+ ],
+ ),
+ ];
+
+ for (key, input, t0, t1, t2) in &tests {
+ let tt0 = KDF1!(key, input);
+ debug_assert_eq!(tt0[..], t0[..]);
+
+ let (tt0, tt1) = KDF2!(key, input);
+ debug_assert_eq!(tt0[..], t0[..]);
+ debug_assert_eq!(tt1[..], t1[..]);
+
+ let (tt0, tt1, tt2) = KDF3!(key, input);
+ debug_assert_eq!(tt0[..], t0[..]);
+ debug_assert_eq!(tt1[..], t1[..]);
+ debug_assert_eq!(tt2[..], t2[..]);
+ }
+ }
+}
+
+pub fn create_initiation<R: RngCore + CryptoRng>(
+ rng: &mut R,
+ device: &Device,
+ peer: &Peer,
+ sender: u32,
+ msg: &mut NoiseInitiation,
+) -> Result<(), HandshakeError> {
+ clear_stack_on_return(CLEAR_PAGES, || {
+ // initialize state
+
+ let ck = INITIAL_CK;
+ let hs = INITIAL_HS;
+ let hs = HASH!(&hs, peer.pk.as_bytes());
+
+ msg.f_type.set(TYPE_INITIATION as u32);
+ msg.f_sender.set(sender);
+
+ // (E_priv, E_pub) := DH-Generate()
+
+ let eph_sk = StaticSecret::new(rng);
+ let eph_pk = PublicKey::from(&eph_sk);
+
+ // C := Kdf(C, E_pub)
+
+ let ck = KDF1!(&ck, eph_pk.as_bytes());
+
+ // msg.ephemeral := E_pub
+
+ msg.f_ephemeral = *eph_pk.as_bytes();
+
+ // H := HASH(H, msg.ephemeral)
+
+ let hs = HASH!(&hs, msg.f_ephemeral);
+
+ // (C, k) := Kdf2(C, DH(E_priv, S_pub))
+
+ let (ck, key) = KDF2!(&ck, eph_sk.diffie_hellman(&peer.pk).as_bytes());
+
+ // msg.static := Aead(k, 0, S_pub, H)
+
+ SEAL!(
+ &key,
+ &hs, // ad
+ device.pk.as_bytes(), // pt
+ &mut msg.f_static // ct || tag
+ );
+
+ // H := Hash(H || msg.static)
+
+ let hs = HASH!(&hs, &msg.f_static[..]);
+
+ // (C, k) := Kdf2(C, DH(S_priv, S_pub))
+
+ let (ck, key) = KDF2!(&ck, peer.ss.as_bytes());
+
+ // msg.timestamp := Aead(k, 0, Timestamp(), H)
+
+ SEAL!(
+ &key,
+ &hs, // ad
+ &timestamp::now(), // pt
+ &mut msg.f_timestamp // ct || tag
+ );
+
+ // H := Hash(H || msg.timestamp)
+
+ let hs = HASH!(&hs, &msg.f_timestamp);
+
+ // update state of peer
+
+ *peer.state.lock() = State::InitiationSent {
+ hs,
+ ck,
+ eph_sk,
+ sender,
+ };
+
+ Ok(())
+ })
+}
+
+pub fn consume_initiation<'a>(
+ device: &'a Device,
+ msg: &NoiseInitiation,
+) -> Result<(&'a Peer, TemporaryState), HandshakeError> {
+ clear_stack_on_return(CLEAR_PAGES, || {
+ // initialize new state
+
+ let ck = INITIAL_CK;
+ let hs = INITIAL_HS;
+ let hs = HASH!(&hs, device.pk.as_bytes());
+
+ // C := Kdf(C, E_pub)
+
+ let ck = KDF1!(&ck, &msg.f_ephemeral);
+
+ // H := HASH(H, msg.ephemeral)
+
+ let hs = HASH!(&hs, &msg.f_ephemeral);
+
+ // (C, k) := Kdf2(C, DH(E_priv, S_pub))
+
+ let eph_r_pk = PublicKey::from(msg.f_ephemeral);
+ let (ck, key) = KDF2!(&ck, device.sk.diffie_hellman(&eph_r_pk).as_bytes());
+
+ // msg.static := Aead(k, 0, S_pub, H)
+
+ let mut pk = [0u8; 32];
+
+ OPEN!(
+ &key,
+ &hs, // ad
+ &mut pk, // pt
+ &msg.f_static // ct || tag
+ )?;
+
+ let peer = device.lookup_pk(&PublicKey::from(pk))?;
+
+ // reset initiation state
+
+ *peer.state.lock() = State::Reset;
+
+ // H := Hash(H || msg.static)
+
+ let hs = HASH!(&hs, &msg.f_static[..]);
+
+ // (C, k) := Kdf2(C, DH(S_priv, S_pub))
+
+ let (ck, key) = KDF2!(&ck, peer.ss.as_bytes());
+
+ // msg.timestamp := Aead(k, 0, Timestamp(), H)
+
+ let mut ts = timestamp::ZERO;
+
+ OPEN!(
+ &key,
+ &hs, // ad
+ &mut ts, // pt
+ &msg.f_timestamp // ct || tag
+ )?;
+
+ // check and update timestamp
+
+ peer.check_replay_flood(device, &ts)?;
+
+ // H := Hash(H || msg.timestamp)
+
+ let hs = HASH!(&hs, &msg.f_timestamp);
+
+ // return state (to create response)
+
+ Ok((peer, (msg.f_sender.get(), eph_r_pk, hs, ck)))
+ })
+}
+
+pub fn create_response<R: RngCore + CryptoRng>(
+ rng: &mut R,
+ peer: &Peer,
+ sender: u32, // sending identifier
+ state: TemporaryState, // state from "consume_initiation"
+ msg: &mut NoiseResponse, // resulting response
+) -> Result<KeyPair, HandshakeError> {
+ clear_stack_on_return(CLEAR_PAGES, || {
+ // unpack state
+
+ let (receiver, eph_r_pk, hs, ck) = state;
+
+ msg.f_type.set(TYPE_RESPONSE as u32);
+ msg.f_sender.set(sender);
+ msg.f_receiver.set(receiver);
+
+ // (E_priv, E_pub) := DH-Generate()
+
+ let eph_sk = StaticSecret::new(rng);
+ let eph_pk = PublicKey::from(&eph_sk);
+
+ // C := Kdf1(C, E_pub)
+
+ let ck = KDF1!(&ck, eph_pk.as_bytes());
+
+ // msg.ephemeral := E_pub
+
+ msg.f_ephemeral = *eph_pk.as_bytes();
+
+ // H := Hash(H || msg.ephemeral)
+
+ let hs = HASH!(&hs, &msg.f_ephemeral);
+
+ // C := Kdf1(C, DH(E_priv, E_pub))
+
+ let ck = KDF1!(&ck, eph_sk.diffie_hellman(&eph_r_pk).as_bytes());
+
+ // C := Kdf1(C, DH(E_priv, S_pub))
+
+ let ck = KDF1!(&ck, eph_sk.diffie_hellman(&peer.pk).as_bytes());
+
+ // (C, tau, k) := Kdf3(C, Q)
+
+ let (ck, tau, key) = KDF3!(&ck, &peer.psk);
+
+ // H := Hash(H || tau)
+
+ let hs = HASH!(&hs, tau);
+
+ // msg.empty := Aead(k, 0, [], H)
+
+ SEAL!(
+ &key,
+ &hs, // ad
+ &[], // pt
+ &mut msg.f_empty // \epsilon || tag
+ );
+
+ // Not strictly needed
+ // let hs = HASH!(&hs, &msg.f_empty_tag);
+
+ // derive key-pair
+
+ let (key_recv, key_send) = KDF2!(&ck, &[]);
+
+ // return unconfirmed key-pair
+
+ Ok(KeyPair {
+ birth: Instant::now(),
+ initiator: false,
+ send: Key {
+ id: sender,
+ key: key_send.into(),
+ },
+ recv: Key {
+ id: receiver,
+ key: key_recv.into(),
+ },
+ })
+ })
+}
+
+/* The state lock is released while processing the message to
+ * allow concurrent processing of potential responses to the initiation,
+ * in order to better mitigate DoS from malformed response messages.
+ */
+pub fn consume_response(device: &Device, msg: &NoiseResponse) -> Result<Output, HandshakeError> {
+ clear_stack_on_return(CLEAR_PAGES, || {
+ // retrieve peer and copy initiation state
+ let peer = device.lookup_id(msg.f_receiver.get())?;
+
+ let (hs, ck, sender, eph_sk) = match *peer.state.lock() {
+ State::InitiationSent {
+ hs,
+ ck,
+ sender,
+ ref eph_sk,
+ } => Ok((hs, ck, sender, StaticSecret::from(eph_sk.to_bytes()))),
+ _ => Err(HandshakeError::InvalidState),
+ }?;
+
+ // C := Kdf1(C, E_pub)
+
+ let ck = KDF1!(&ck, &msg.f_ephemeral);
+
+ // H := Hash(H || msg.ephemeral)
+
+ let hs = HASH!(&hs, &msg.f_ephemeral);
+
+ // C := Kdf1(C, DH(E_priv, E_pub))
+
+ let eph_r_pk = PublicKey::from(msg.f_ephemeral);
+ let ck = KDF1!(&ck, eph_sk.diffie_hellman(&eph_r_pk).as_bytes());
+
+ // C := Kdf1(C, DH(E_priv, S_pub))
+
+ let ck = KDF1!(&ck, device.sk.diffie_hellman(&eph_r_pk).as_bytes());
+
+ // (C, tau, k) := Kdf3(C, Q)
+
+ let (ck, tau, key) = KDF3!(&ck, &peer.psk);
+
+ // H := Hash(H || tau)
+
+ let hs = HASH!(&hs, tau);
+
+ // msg.empty := Aead(k, 0, [], H)
+
+ OPEN!(
+ &key,
+ &hs, // ad
+ &mut [], // pt
+ &msg.f_empty // \epsilon || tag
+ )?;
+
+ // derive key-pair
+
+ let birth = Instant::now();
+ let (key_send, key_recv) = KDF2!(&ck, &[]);
+
+ // check for new initiation sent while lock released
+
+ let mut state = peer.state.lock();
+ let update = match *state {
+ State::InitiationSent {
+ eph_sk: ref old, ..
+ } => old.to_bytes().ct_eq(&eph_sk.to_bytes()).into(),
+ _ => false,
+ };
+
+ if update {
+ // null the initiation state
+ // (to avoid replay of this response message)
+ *state = State::Reset;
+
+ // return confirmed key-pair
+ Ok((
+ Some(peer.pk),
+ None,
+ Some(KeyPair {
+ birth,
+ initiator: true,
+ send: Key {
+ id: sender,
+ key: key_send.into(),
+ },
+ recv: Key {
+ id: msg.f_sender.get(),
+ key: key_recv.into(),
+ },
+ }),
+ ))
+ } else {
+ Err(HandshakeError::InvalidState)
+ }
+ })
+}
diff --git a/src/wireguard/handshake/peer.rs b/src/wireguard/handshake/peer.rs
new file mode 100644
index 0000000..c9e1c40
--- /dev/null
+++ b/src/wireguard/handshake/peer.rs
@@ -0,0 +1,142 @@
+use spin::Mutex;
+
+use std::mem;
+use std::time::{Duration, Instant};
+
+use generic_array::typenum::U32;
+use generic_array::GenericArray;
+
+use x25519_dalek::PublicKey;
+use x25519_dalek::SharedSecret;
+use x25519_dalek::StaticSecret;
+
+use clear_on_drop::clear::Clear;
+
+use super::device::Device;
+use super::macs;
+use super::timestamp;
+use super::types::*;
+
+const TIME_BETWEEN_INITIATIONS: Duration = Duration::from_millis(20);
+
+/* Represents the recomputation and state of a peer.
+ *
+ * This type is only for internal use and not exposed.
+ */
+pub struct Peer {
+ // mutable state
+ pub(crate) state: Mutex<State>,
+ pub(crate) timestamp: Mutex<Option<timestamp::TAI64N>>,
+ pub(crate) last_initiation_consumption: Mutex<Option<Instant>>,
+
+ // state related to DoS mitigation fields
+ pub(crate) macs: Mutex<macs::Generator>,
+
+ // constant state
+ pub(crate) pk: PublicKey, // public key of peer
+ pub(crate) ss: SharedSecret, // precomputed DH(static, static)
+ pub(crate) psk: Psk, // psk of peer
+}
+
+pub enum State {
+ Reset,
+ InitiationSent {
+ sender: u32, // assigned sender id
+ eph_sk: StaticSecret,
+ hs: GenericArray<u8, U32>,
+ ck: GenericArray<u8, U32>,
+ },
+}
+
+impl Drop for State {
+ fn drop(&mut self) {
+ match self {
+ State::InitiationSent { hs, ck, .. } => {
+ // eph_sk already cleared by dalek-x25519
+ hs.clear();
+ ck.clear();
+ }
+ _ => (),
+ }
+ }
+}
+
+impl Peer {
+ pub fn new(
+ pk: PublicKey, // public key of peer
+ ss: SharedSecret, // precomputed DH(static, static)
+ ) -> Self {
+ Self {
+ macs: Mutex::new(macs::Generator::new(pk)),
+ state: Mutex::new(State::Reset),
+ timestamp: Mutex::new(None),
+ last_initiation_consumption: Mutex::new(None),
+ pk: pk,
+ ss: ss,
+ psk: [0u8; 32],
+ }
+ }
+
+ /// Set the state of the peer unconditionally
+ ///
+ /// # Arguments
+ ///
+ pub fn set_state(&self, state_new: State) {
+ *self.state.lock() = state_new;
+ }
+
+ pub fn reset_state(&self) -> Option<u32> {
+ match mem::replace(&mut *self.state.lock(), State::Reset) {
+ State::InitiationSent { sender, .. } => Some(sender),
+ _ => None,
+ }
+ }
+
+ /// Set the mutable state of the peer conditioned on the timestamp being newer
+ ///
+ /// # Arguments
+ ///
+ /// * st_new - The updated state of the peer
+ /// * ts_new - The associated timestamp
+ pub fn check_replay_flood(
+ &self,
+ device: &Device,
+ timestamp_new: &timestamp::TAI64N,
+ ) -> Result<(), HandshakeError> {
+ let mut state = self.state.lock();
+ let mut timestamp = self.timestamp.lock();
+ let mut last_initiation_consumption = self.last_initiation_consumption.lock();
+
+ // check replay attack
+ match *timestamp {
+ Some(timestamp_old) => {
+ if !timestamp::compare(&timestamp_old, &timestamp_new) {
+ return Err(HandshakeError::OldTimestamp);
+ }
+ }
+ _ => (),
+ };
+
+ // check flood attack
+ match *last_initiation_consumption {
+ Some(last) => {
+ if last.elapsed() < TIME_BETWEEN_INITIATIONS {
+ return Err(HandshakeError::InitiationFlood);
+ }
+ }
+ _ => (),
+ }
+
+ // reset state
+ match *state {
+ State::InitiationSent { sender, .. } => device.release(sender),
+ _ => (),
+ }
+
+ // update replay & flood protection
+ *state = State::Reset;
+ *timestamp = Some(*timestamp_new);
+ *last_initiation_consumption = Some(Instant::now());
+ Ok(())
+ }
+}
diff --git a/src/wireguard/handshake/ratelimiter.rs b/src/wireguard/handshake/ratelimiter.rs
new file mode 100644
index 0000000..63d728c
--- /dev/null
+++ b/src/wireguard/handshake/ratelimiter.rs
@@ -0,0 +1,199 @@
+use spin;
+use std::collections::HashMap;
+use std::net::IpAddr;
+use std::sync::atomic::{AtomicBool, Ordering};
+use std::sync::{Arc, Condvar, Mutex};
+use std::thread;
+use std::time::{Duration, Instant};
+
+const PACKETS_PER_SECOND: u64 = 20;
+const PACKETS_BURSTABLE: u64 = 5;
+const PACKET_COST: u64 = 1_000_000_000 / PACKETS_PER_SECOND;
+const MAX_TOKENS: u64 = PACKET_COST * PACKETS_BURSTABLE;
+
+const GC_INTERVAL: Duration = Duration::from_secs(1);
+
+struct Entry {
+ pub last_time: Instant,
+ pub tokens: u64,
+}
+
+pub struct RateLimiter(Arc<RateLimiterInner>);
+
+struct RateLimiterInner {
+ gc_running: AtomicBool,
+ gc_dropped: (Mutex<bool>, Condvar),
+ table: spin::RwLock<HashMap<IpAddr, spin::Mutex<Entry>>>,
+}
+
+impl Drop for RateLimiter {
+ fn drop(&mut self) {
+ // wake up & terminate any lingering GC thread
+ let &(ref lock, ref cvar) = &self.0.gc_dropped;
+ let mut dropped = lock.lock().unwrap();
+ *dropped = true;
+ cvar.notify_all();
+ }
+}
+
+impl RateLimiter {
+ pub fn new() -> Self {
+ RateLimiter(Arc::new(RateLimiterInner {
+ gc_dropped: (Mutex::new(false), Condvar::new()),
+ gc_running: AtomicBool::from(false),
+ table: spin::RwLock::new(HashMap::new()),
+ }))
+ }
+
+ pub fn allow(&self, addr: &IpAddr) -> bool {
+ // check if allowed
+ let allowed = {
+ // check for existing entry (only requires read lock)
+ if let Some(entry) = self.0.table.read().get(addr) {
+ // update existing entry
+ let mut entry = entry.lock();
+
+ // add tokens earned since last time
+ entry.tokens = MAX_TOKENS
+ .min(entry.tokens + u64::from(entry.last_time.elapsed().subsec_nanos()));
+ entry.last_time = Instant::now();
+
+ // subtract cost of packet
+ if entry.tokens > PACKET_COST {
+ entry.tokens -= PACKET_COST;
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ // add new entry (write lock)
+ self.0.table.write().insert(
+ *addr,
+ spin::Mutex::new(Entry {
+ last_time: Instant::now(),
+ tokens: MAX_TOKENS - PACKET_COST,
+ }),
+ );
+ true
+ };
+
+ // check that GC thread is scheduled
+ if !self.0.gc_running.swap(true, Ordering::Relaxed) {
+ let limiter = self.0.clone();
+ thread::spawn(move || {
+ let &(ref lock, ref cvar) = &limiter.gc_dropped;
+ let mut dropped = lock.lock().unwrap();
+ while !*dropped {
+ // garbage collect
+ {
+ let mut tw = limiter.table.write();
+ tw.retain(|_, ref mut entry| {
+ entry.lock().last_time.elapsed() <= GC_INTERVAL
+ });
+ if tw.len() == 0 {
+ limiter.gc_running.store(false, Ordering::Relaxed);
+ return;
+ }
+ }
+
+ // wait until stopped or new GC (~1 every sec)
+ let res = cvar.wait_timeout(dropped, GC_INTERVAL).unwrap();
+ dropped = res.0;
+ }
+ });
+ }
+
+ allowed
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std;
+
+ struct Result {
+ allowed: bool,
+ text: &'static str,
+ wait: Duration,
+ }
+
+ #[test]
+ fn test_ratelimiter() {
+ let ratelimiter = RateLimiter::new();
+ let mut expected = vec![];
+ let ips = vec![
+ "127.0.0.1".parse().unwrap(),
+ "192.168.1.1".parse().unwrap(),
+ "172.167.2.3".parse().unwrap(),
+ "97.231.252.215".parse().unwrap(),
+ "248.97.91.167".parse().unwrap(),
+ "188.208.233.47".parse().unwrap(),
+ "104.2.183.179".parse().unwrap(),
+ "72.129.46.120".parse().unwrap(),
+ "2001:0db8:0a0b:12f0:0000:0000:0000:0001".parse().unwrap(),
+ "f5c2:818f:c052:655a:9860:b136:6894:25f0".parse().unwrap(),
+ "b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc".parse().unwrap(),
+ "a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918".parse().unwrap(),
+ "ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445".parse().unwrap(),
+ "3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4".parse().unwrap(),
+ ];
+
+ for _ in 0..PACKETS_BURSTABLE {
+ expected.push(Result {
+ allowed: true,
+ wait: Duration::new(0, 0),
+ text: "inital burst",
+ });
+ }
+
+ expected.push(Result {
+ allowed: false,
+ wait: Duration::new(0, 0),
+ text: "after burst",
+ });
+
+ expected.push(Result {
+ allowed: true,
+ wait: Duration::new(0, PACKET_COST as u32),
+ text: "filling tokens for single packet",
+ });
+
+ expected.push(Result {
+ allowed: false,
+ wait: Duration::new(0, 0),
+ text: "not having refilled enough",
+ });
+
+ expected.push(Result {
+ allowed: true,
+ wait: Duration::new(0, 2 * PACKET_COST as u32),
+ text: "filling tokens for 2 * packet burst",
+ });
+
+ expected.push(Result {
+ allowed: true,
+ wait: Duration::new(0, 0),
+ text: "second packet in 2 packet burst",
+ });
+
+ expected.push(Result {
+ allowed: false,
+ wait: Duration::new(0, 0),
+ text: "packet following 2 packet burst",
+ });
+
+ for item in expected {
+ std::thread::sleep(item.wait);
+ for ip in ips.iter() {
+ if ratelimiter.allow(&ip) != item.allowed {
+ panic!(
+ "test failed for {} on {}. expected: {}, got: {}",
+ ip, item.text, item.allowed, !item.allowed
+ )
+ }
+ }
+ }
+ }
+}
diff --git a/src/wireguard/handshake/timestamp.rs b/src/wireguard/handshake/timestamp.rs
new file mode 100644
index 0000000..b5bd9f0
--- /dev/null
+++ b/src/wireguard/handshake/timestamp.rs
@@ -0,0 +1,32 @@
+use std::time::{SystemTime, UNIX_EPOCH};
+
+pub type TAI64N = [u8; 12];
+
+const TAI64_EPOCH: u64 = 0x400000000000000a;
+
+pub const ZERO: TAI64N = [0u8; 12];
+
+pub fn now() -> TAI64N {
+ // get system time as duration
+ let sysnow = SystemTime::now();
+ let delta = sysnow.duration_since(UNIX_EPOCH).unwrap();
+
+ // convert to tai64n
+ let tai64_secs = delta.as_secs() + TAI64_EPOCH;
+ let tai64_nano = delta.subsec_nanos();
+
+ // serialize
+ let mut res = [0u8; 12];
+ res[..8].copy_from_slice(&tai64_secs.to_be_bytes()[..]);
+ res[8..].copy_from_slice(&tai64_nano.to_be_bytes()[..]);
+ res
+}
+
+pub fn compare(old: &TAI64N, new: &TAI64N) -> bool {
+ for i in 0..12 {
+ if new[i] > old[i] {
+ return true;
+ }
+ }
+ return false;
+}
diff --git a/src/wireguard/handshake/types.rs b/src/wireguard/handshake/types.rs
new file mode 100644
index 0000000..5f984cc
--- /dev/null
+++ b/src/wireguard/handshake/types.rs
@@ -0,0 +1,90 @@
+use std::error::Error;
+use std::fmt;
+
+use x25519_dalek::PublicKey;
+
+use super::super::types::KeyPair;
+
+/* Internal types for the noise IKpsk2 implementation */
+
+// config error
+
+#[derive(Debug)]
+pub struct ConfigError(String);
+
+impl ConfigError {
+ pub fn new(s: &str) -> Self {
+ ConfigError(s.to_string())
+ }
+}
+
+impl fmt::Display for ConfigError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "ConfigError({})", self.0)
+ }
+}
+
+impl Error for ConfigError {
+ fn description(&self) -> &str {
+ &self.0
+ }
+
+ fn source(&self) -> Option<&(dyn Error + 'static)> {
+ None
+ }
+}
+
+// handshake error
+
+#[derive(Debug)]
+pub enum HandshakeError {
+ DecryptionFailure,
+ UnknownPublicKey,
+ UnknownReceiverId,
+ InvalidMessageFormat,
+ OldTimestamp,
+ InvalidState,
+ InvalidMac1,
+ RateLimited,
+ InitiationFlood,
+}
+
+impl fmt::Display for HandshakeError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ HandshakeError::DecryptionFailure => write!(f, "Failed to AEAD:OPEN"),
+ HandshakeError::UnknownPublicKey => write!(f, "Unknown public key"),
+ HandshakeError::UnknownReceiverId => {
+ write!(f, "Receiver id not allocated to any handshake")
+ }
+ HandshakeError::InvalidMessageFormat => write!(f, "Invalid handshake message format"),
+ HandshakeError::OldTimestamp => write!(f, "Timestamp is less/equal to the newest"),
+ HandshakeError::InvalidState => write!(f, "Message does not apply to handshake state"),
+ HandshakeError::InvalidMac1 => write!(f, "Message has invalid mac1 field"),
+ HandshakeError::RateLimited => write!(f, "Message was dropped by rate limiter"),
+ HandshakeError::InitiationFlood => {
+ write!(f, "Message was dropped because of initiation flood")
+ }
+ }
+ }
+}
+
+impl Error for HandshakeError {
+ fn description(&self) -> &str {
+ "Generic Handshake Error"
+ }
+
+ fn source(&self) -> Option<&(dyn Error + 'static)> {
+ None
+ }
+}
+
+pub type Output = (
+ Option<PublicKey>, // external identifier associated with peer
+ Option<Vec<u8>>, // message to send
+ Option<KeyPair>, // resulting key-pair of successful handshake
+);
+
+// preshared key
+
+pub type Psk = [u8; 32];
diff --git a/src/wireguard/mod.rs b/src/wireguard/mod.rs
new file mode 100644
index 0000000..9417e57
--- /dev/null
+++ b/src/wireguard/mod.rs
@@ -0,0 +1,23 @@
+mod wireguard;
+// mod config;
+mod constants;
+mod timers;
+
+mod handshake;
+mod router;
+mod types;
+
+#[cfg(test)]
+mod tests;
+
+/// The WireGuard sub-module contains a pure, configurable implementation of WireGuard.
+/// The implementation is generic over:
+///
+/// - TUN type, specifying how packets are received on the interface side: a reader/writer and MTU reporting interface.
+/// - Bind type, specifying how WireGuard messages are sent/received from the internet and what constitutes an "endpoint"
+
+pub use wireguard::{Wireguard, Peer};
+
+pub use types::bind;
+pub use types::tun;
+pub use types::Endpoint; \ No newline at end of file
diff --git a/src/wireguard/router/anti_replay.rs b/src/wireguard/router/anti_replay.rs
new file mode 100644
index 0000000..b0838bd
--- /dev/null
+++ b/src/wireguard/router/anti_replay.rs
@@ -0,0 +1,157 @@
+use std::mem;
+
+// Implementation of RFC 6479.
+// https://tools.ietf.org/html/rfc6479
+
+#[cfg(target_pointer_width = "64")]
+type Word = u64;
+
+#[cfg(target_pointer_width = "64")]
+const REDUNDANT_BIT_SHIFTS: usize = 6;
+
+#[cfg(target_pointer_width = "32")]
+type Word = u32;
+
+#[cfg(target_pointer_width = "32")]
+const REDUNDANT_BIT_SHIFTS: usize = 5;
+
+const SIZE_OF_WORD: usize = mem::size_of::<Word>() * 8;
+
+const BITMAP_BITLEN: usize = 2048;
+const BITMAP_LEN: usize = (BITMAP_BITLEN / SIZE_OF_WORD);
+const BITMAP_INDEX_MASK: u64 = BITMAP_LEN as u64 - 1;
+const BITMAP_LOC_MASK: u64 = (SIZE_OF_WORD - 1) as u64;
+const WINDOW_SIZE: u64 = (BITMAP_BITLEN - SIZE_OF_WORD) as u64;
+
+pub struct AntiReplay {
+ bitmap: [Word; BITMAP_LEN],
+ last: u64,
+}
+
+impl Default for AntiReplay {
+ fn default() -> Self {
+ AntiReplay::new()
+ }
+}
+
+impl AntiReplay {
+ pub fn new() -> Self {
+ debug_assert_eq!(1 << REDUNDANT_BIT_SHIFTS, SIZE_OF_WORD);
+ debug_assert_eq!(BITMAP_BITLEN % SIZE_OF_WORD, 0);
+ AntiReplay {
+ last: 0,
+ bitmap: [0; BITMAP_LEN],
+ }
+ }
+
+ // Returns true if check is passed, i.e., not a replay or too old.
+ //
+ // Unlike RFC 6479, zero is allowed.
+ fn check(&self, seq: u64) -> bool {
+ // Larger is always good.
+ if seq > self.last {
+ return true;
+ }
+
+ if self.last - seq > WINDOW_SIZE {
+ return false;
+ }
+
+ let bit_location = seq & BITMAP_LOC_MASK;
+ let index = (seq >> REDUNDANT_BIT_SHIFTS) & BITMAP_INDEX_MASK;
+
+ self.bitmap[index as usize] & (1 << bit_location) == 0
+ }
+
+ // Should only be called if check returns true.
+ fn update_store(&mut self, seq: u64) {
+ debug_assert!(self.check(seq));
+
+ let index = seq >> REDUNDANT_BIT_SHIFTS;
+
+ if seq > self.last {
+ let index_cur = self.last >> REDUNDANT_BIT_SHIFTS;
+ let diff = index - index_cur;
+
+ if diff >= BITMAP_LEN as u64 {
+ self.bitmap = [0; BITMAP_LEN];
+ } else {
+ for i in 0..diff {
+ let real_index = (index_cur + i + 1) & BITMAP_INDEX_MASK;
+ self.bitmap[real_index as usize] = 0;
+ }
+ }
+
+ self.last = seq;
+ }
+
+ let index = index & BITMAP_INDEX_MASK;
+ let bit_location = seq & BITMAP_LOC_MASK;
+ self.bitmap[index as usize] |= 1 << bit_location;
+ }
+
+ /// Checks and marks a sequence number in the replay filter
+ ///
+ /// # Arguments
+ ///
+ /// - seq: Sequence number check for replay and add to filter
+ ///
+ /// # Returns
+ ///
+ /// Ok(()) if sequence number is valid (not marked and not behind the moving window).
+ /// Err if the sequence number is invalid (already marked or "too old").
+ pub fn update(&mut self, seq: u64) -> bool {
+ if self.check(seq) {
+ self.update_store(seq);
+ true
+ } else {
+ false
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn anti_replay() {
+ let mut ar = AntiReplay::new();
+
+ for i in 0..20000 {
+ assert!(ar.update(i));
+ }
+
+ for i in (0..20000).rev() {
+ assert!(!ar.check(i));
+ }
+
+ assert!(ar.update(65536));
+ for i in (65536 - WINDOW_SIZE)..65535 {
+ assert!(ar.update(i));
+ }
+
+ for i in (65536 - 10 * WINDOW_SIZE)..65535 {
+ assert!(!ar.check(i));
+ }
+
+ assert!(ar.update(66000));
+ for i in 65537..66000 {
+ assert!(ar.update(i));
+ }
+ for i in 65537..66000 {
+ assert_eq!(ar.update(i), false);
+ }
+
+ // Test max u64.
+ let next = u64::max_value();
+ assert!(ar.update(next));
+ assert!(!ar.check(next));
+ for i in (next - WINDOW_SIZE)..next {
+ assert!(ar.update(i));
+ }
+ for i in (next - 20 * WINDOW_SIZE)..next {
+ assert!(!ar.check(i));
+ }
+ }
+}
diff --git a/src/wireguard/router/constants.rs b/src/wireguard/router/constants.rs
new file mode 100644
index 0000000..0ca824a
--- /dev/null
+++ b/src/wireguard/router/constants.rs
@@ -0,0 +1,7 @@
+// WireGuard semantics constants
+
+pub const MAX_STAGED_PACKETS: usize = 128;
+
+// performance constants
+
+pub const WORKER_QUEUE_SIZE: usize = MAX_STAGED_PACKETS;
diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs
new file mode 100644
index 0000000..455020c
--- /dev/null
+++ b/src/wireguard/router/device.rs
@@ -0,0 +1,243 @@
+use std::collections::HashMap;
+use std::net::{Ipv4Addr, Ipv6Addr};
+use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
+use std::sync::mpsc::sync_channel;
+use std::sync::mpsc::SyncSender;
+use std::sync::Arc;
+use std::thread;
+use std::time::Instant;
+
+use log::debug;
+use spin::{Mutex, RwLock};
+use treebitmap::IpLookupTable;
+use zerocopy::LayoutVerified;
+
+use super::anti_replay::AntiReplay;
+use super::constants::*;
+use super::ip::*;
+use super::messages::{TransportHeader, TYPE_TRANSPORT};
+use super::peer::{new_peer, Peer, PeerInner};
+use super::types::{Callbacks, RouterError};
+use super::workers::{worker_parallel, JobParallel, Operation};
+use super::SIZE_MESSAGE_PREFIX;
+
+use super::super::types::{bind, tun, Endpoint, KeyPair};
+
+pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
+ // inbound writer (TUN)
+ pub inbound: T,
+
+ // outbound writer (Bind)
+ pub outbound: RwLock<Option<B>>,
+
+ // routing
+ pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state
+ pub ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv4 cryptkey routing
+ pub ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv6 cryptkey routing
+
+ // work queues
+ pub queue_next: AtomicUsize, // next round-robin index
+ pub queues: Mutex<Vec<SyncSender<JobParallel>>>, // work queues (1 per thread)
+}
+
+pub struct EncryptionState {
+ pub key: [u8; 32], // encryption key
+ pub id: u32, // receiver id
+ pub nonce: u64, // next available nonce
+ pub death: Instant, // (birth + reject-after-time - keepalive-timeout - rekey-timeout)
+}
+
+pub struct DecryptionState<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
+ pub keypair: Arc<KeyPair>,
+ pub confirmed: AtomicBool,
+ pub protector: Mutex<AntiReplay>,
+ pub peer: Arc<PeerInner<E, C, T, B>>,
+ pub death: Instant, // time when the key can no longer be used for decryption
+}
+
+pub struct Device<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
+ state: Arc<DeviceInner<E, C, T, B>>, // reference to device state
+ handles: Vec<thread::JoinHandle<()>>, // join handles for workers
+}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Device<E, C, T, B> {
+ fn drop(&mut self) {
+ debug!("router: dropping device");
+
+ // drop all queues
+ {
+ let mut queues = self.state.queues.lock();
+ while queues.pop().is_some() {}
+ }
+
+ // join all worker threads
+ while match self.handles.pop() {
+ Some(handle) => {
+ handle.thread().unpark();
+ handle.join().unwrap();
+ true
+ }
+ _ => false,
+ } {}
+
+ debug!("router: device dropped");
+ }
+}
+
+#[inline(always)]
+fn get_route<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+ device: &Arc<DeviceInner<E, C, T, B>>,
+ packet: &[u8],
+) -> Option<Arc<PeerInner<E, C, T, B>>> {
+ // ensure version access within bounds
+ if packet.len() < 1 {
+ return None;
+ };
+
+ // cast to correct IP header
+ match packet[0] >> 4 {
+ VERSION_IP4 => {
+ // check length and cast to IPv4 header
+ let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) =
+ LayoutVerified::new_from_prefix(packet)?;
+
+ // lookup destination address
+ device
+ .ipv4
+ .read()
+ .longest_match(Ipv4Addr::from(header.f_destination))
+ .and_then(|(_, _, p)| Some(p.clone()))
+ }
+ VERSION_IP6 => {
+ // check length and cast to IPv6 header
+ let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) =
+ LayoutVerified::new_from_prefix(packet)?;
+
+ // lookup destination address
+ device
+ .ipv6
+ .read()
+ .longest_match(Ipv6Addr::from(header.f_destination))
+ .and_then(|(_, _, p)| Some(p.clone()))
+ }
+ _ => None,
+ }
+}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C, T, B> {
+ pub fn new(num_workers: usize, tun: T) -> Device<E, C, T, B> {
+ // allocate shared device state
+ let inner = DeviceInner {
+ inbound: tun,
+ outbound: RwLock::new(None),
+ queues: Mutex::new(Vec::with_capacity(num_workers)),
+ queue_next: AtomicUsize::new(0),
+ recv: RwLock::new(HashMap::new()),
+ ipv4: RwLock::new(IpLookupTable::new()),
+ ipv6: RwLock::new(IpLookupTable::new()),
+ };
+
+ // start worker threads
+ let mut threads = Vec::with_capacity(num_workers);
+ for _ in 0..num_workers {
+ let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE);
+ inner.queues.lock().push(tx);
+ threads.push(thread::spawn(move || worker_parallel(rx)));
+ }
+
+ // return exported device handle
+ Device {
+ state: Arc::new(inner),
+ handles: threads,
+ }
+ }
+
+ /// A new secret key has been set for the device.
+ /// According to WireGuard semantics, this should cause all "sending" keys to be discarded.
+ pub fn new_sk(&self) {}
+
+ /// Adds a new peer to the device
+ ///
+ /// # Returns
+ ///
+ /// A atomic ref. counted peer (with liftime matching the device)
+ pub fn new_peer(&self, opaque: C::Opaque) -> Peer<E, C, T, B> {
+ new_peer(self.state.clone(), opaque)
+ }
+
+ /// Cryptkey routes and sends a plaintext message (IP packet)
+ ///
+ /// # Arguments
+ ///
+ /// - msg: IP packet to crypt-key route
+ ///
+ pub fn send(&self, msg: Vec<u8>) -> Result<(), RouterError> {
+ // ignore header prefix (for in-place transport message construction)
+ let packet = &msg[SIZE_MESSAGE_PREFIX..];
+
+ // lookup peer based on IP packet destination address
+ let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptKeyRoute)?;
+
+ // schedule for encryption and transmission to peer
+ if let Some(job) = peer.send_job(msg, true) {
+ debug_assert_eq!(job.1.op, Operation::Encryption);
+
+ // add job to worker queue
+ let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst);
+ let queues = self.state.queues.lock();
+ queues[idx % queues.len()].send(job).unwrap();
+ }
+
+ Ok(())
+ }
+
+ /// Receive an encrypted transport message
+ ///
+ /// # Arguments
+ ///
+ /// - src: Source address of the packet
+ /// - msg: Encrypted transport message
+ ///
+ /// # Returns
+ ///
+ ///
+ pub fn recv(&self, src: E, msg: Vec<u8>) -> Result<(), RouterError> {
+ // parse / cast
+ let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) {
+ Some(v) => v,
+ None => {
+ return Err(RouterError::MalformedTransportMessage);
+ }
+ };
+ let header: LayoutVerified<&[u8], TransportHeader> = header;
+ debug_assert!(
+ header.f_type.get() == TYPE_TRANSPORT as u32,
+ "this should be checked by the message type multiplexer"
+ );
+
+ // lookup peer based on receiver id
+ let dec = self.state.recv.read();
+ let dec = dec
+ .get(&header.f_receiver.get())
+ .ok_or(RouterError::UnknownReceiverId)?;
+
+ // schedule for decryption and TUN write
+ if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) {
+ debug_assert_eq!(job.1.op, Operation::Decryption);
+
+ // add job to worker queue
+ let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst);
+ let queues = self.state.queues.lock();
+ queues[idx % queues.len()].send(job).unwrap();
+ }
+
+ Ok(())
+ }
+
+ /// Set outbound writer
+ ///
+ ///
+ pub fn set_outbound_writer(&self, new: B) {
+ *self.state.outbound.write() = Some(new);
+ }
+}
diff --git a/src/wireguard/router/ip.rs b/src/wireguard/router/ip.rs
new file mode 100644
index 0000000..e66144f
--- /dev/null
+++ b/src/wireguard/router/ip.rs
@@ -0,0 +1,26 @@
+use byteorder::BigEndian;
+use zerocopy::byteorder::U16;
+use zerocopy::{AsBytes, FromBytes};
+
+pub const VERSION_IP4: u8 = 4;
+pub const VERSION_IP6: u8 = 6;
+
+#[repr(packed)]
+#[derive(Copy, Clone, FromBytes, AsBytes)]
+pub struct IPv4Header {
+ _f_space1: [u8; 2],
+ pub f_total_len: U16<BigEndian>,
+ _f_space2: [u8; 8],
+ pub f_source: [u8; 4],
+ pub f_destination: [u8; 4],
+}
+
+#[repr(packed)]
+#[derive(Copy, Clone, FromBytes, AsBytes)]
+pub struct IPv6Header {
+ _f_space1: [u8; 4],
+ pub f_len: U16<BigEndian>,
+ _f_space2: [u8; 2],
+ pub f_source: [u8; 16],
+ pub f_destination: [u8; 16],
+}
diff --git a/src/wireguard/router/messages.rs b/src/wireguard/router/messages.rs
new file mode 100644
index 0000000..bf4d13b
--- /dev/null
+++ b/src/wireguard/router/messages.rs
@@ -0,0 +1,13 @@
+use byteorder::LittleEndian;
+use zerocopy::byteorder::{U32, U64};
+use zerocopy::{AsBytes, FromBytes};
+
+pub const TYPE_TRANSPORT: u32 = 4;
+
+#[repr(packed)]
+#[derive(Copy, Clone, FromBytes, AsBytes)]
+pub struct TransportHeader {
+ pub f_type: U32<LittleEndian>,
+ pub f_receiver: U32<LittleEndian>,
+ pub f_counter: U64<LittleEndian>,
+}
diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs
new file mode 100644
index 0000000..7a29cd9
--- /dev/null
+++ b/src/wireguard/router/mod.rs
@@ -0,0 +1,22 @@
+mod anti_replay;
+mod constants;
+mod device;
+mod ip;
+mod messages;
+mod peer;
+mod types;
+mod workers;
+
+#[cfg(test)]
+mod tests;
+
+use messages::TransportHeader;
+use std::mem;
+
+pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::<TransportHeader>();
+pub const CAPACITY_MESSAGE_POSTFIX: usize = 16;
+
+pub use messages::TYPE_TRANSPORT;
+pub use device::Device;
+pub use peer::Peer;
+pub use types::Callbacks;
diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs
new file mode 100644
index 0000000..4f47604
--- /dev/null
+++ b/src/wireguard/router/peer.rs
@@ -0,0 +1,611 @@
+use std::mem;
+use std::net::{IpAddr, SocketAddr};
+use std::sync::atomic::AtomicBool;
+use std::sync::atomic::Ordering;
+use std::sync::mpsc::{sync_channel, SyncSender};
+use std::sync::Arc;
+use std::thread;
+
+use arraydeque::{ArrayDeque, Wrapping};
+use log::debug;
+use spin::Mutex;
+use treebitmap::address::Address;
+use treebitmap::IpLookupTable;
+use zerocopy::LayoutVerified;
+
+use super::super::constants::*;
+use super::super::types::{bind, tun, Endpoint, KeyPair};
+
+use super::anti_replay::AntiReplay;
+use super::device::DecryptionState;
+use super::device::DeviceInner;
+use super::device::EncryptionState;
+use super::messages::TransportHeader;
+
+use futures::*;
+
+use super::workers::Operation;
+use super::workers::{worker_inbound, worker_outbound};
+use super::workers::{JobBuffer, JobInbound, JobOutbound, JobParallel};
+use super::SIZE_MESSAGE_PREFIX;
+
+use super::constants::*;
+use super::types::{Callbacks, RouterError};
+
+pub struct KeyWheel {
+ next: Option<Arc<KeyPair>>, // next key state (unconfirmed)
+ current: Option<Arc<KeyPair>>, // current key state (used for encryption)
+ previous: Option<Arc<KeyPair>>, // old key state (used for decryption)
+ retired: Vec<u32>, // retired ids
+}
+
+pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
+ pub device: Arc<DeviceInner<E, C, T, B>>,
+ pub opaque: C::Opaque,
+ pub outbound: Mutex<SyncSender<JobOutbound>>,
+ pub inbound: Mutex<SyncSender<JobInbound<E, C, T, B>>>,
+ pub staged_packets: Mutex<ArrayDeque<[Vec<u8>; MAX_STAGED_PACKETS], Wrapping>>,
+ pub keys: Mutex<KeyWheel>,
+ pub ekey: Mutex<Option<EncryptionState>>,
+ pub endpoint: Mutex<Option<E>>,
+}
+
+pub struct Peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
+ state: Arc<PeerInner<E, C, T, B>>,
+ thread_outbound: Option<thread::JoinHandle<()>>,
+ thread_inbound: Option<thread::JoinHandle<()>>,
+}
+
+fn treebit_list<A, R, E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+ peer: &Arc<PeerInner<E, C, T, B>>,
+ table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<E, C, T, B>>>>,
+ callback: Box<dyn Fn(A, u32) -> R>,
+) -> Vec<R>
+where
+ A: Address,
+{
+ let mut res = Vec::new();
+ for subnet in table.read().iter() {
+ let (ip, masklen, p) = subnet;
+ if Arc::ptr_eq(&p, &peer) {
+ res.push(callback(ip, masklen))
+ }
+ }
+ res
+}
+
+fn treebit_remove<E: Endpoint, A: Address, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+ peer: &Peer<E, C, T, B>,
+ table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<E, C, T, B>>>>,
+) {
+ let mut m = table.write();
+
+ // collect keys for value
+ let mut subnets = vec![];
+ for subnet in m.iter() {
+ let (ip, masklen, p) = subnet;
+ if Arc::ptr_eq(&p, &peer.state) {
+ subnets.push((ip, masklen))
+ }
+ }
+
+ // remove all key mappings
+ for (ip, masklen) in subnets {
+ let r = m.remove(ip, masklen);
+ debug_assert!(r.is_some());
+ }
+}
+
+impl EncryptionState {
+ fn new(keypair: &Arc<KeyPair>) -> EncryptionState {
+ EncryptionState {
+ id: keypair.send.id,
+ key: keypair.send.key,
+ nonce: 0,
+ death: keypair.birth + REJECT_AFTER_TIME,
+ }
+ }
+}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> DecryptionState<E, C, T, B> {
+ fn new(
+ peer: &Arc<PeerInner<E, C, T, B>>,
+ keypair: &Arc<KeyPair>,
+ ) -> DecryptionState<E, C, T, B> {
+ DecryptionState {
+ confirmed: AtomicBool::new(keypair.initiator),
+ keypair: keypair.clone(),
+ protector: spin::Mutex::new(AntiReplay::new()),
+ peer: peer.clone(),
+ death: keypair.birth + REJECT_AFTER_TIME,
+ }
+ }
+}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Peer<E, C, T, B> {
+ fn drop(&mut self) {
+ let peer = &self.state;
+
+ // remove from cryptkey router
+
+ treebit_remove(self, &peer.device.ipv4);
+ treebit_remove(self, &peer.device.ipv6);
+
+ // drop channels
+
+ mem::replace(&mut *peer.inbound.lock(), sync_channel(0).0);
+ mem::replace(&mut *peer.outbound.lock(), sync_channel(0).0);
+
+ // join with workers
+
+ mem::replace(&mut self.thread_inbound, None).map(|v| v.join());
+ mem::replace(&mut self.thread_outbound, None).map(|v| v.join());
+
+ // release ids from the receiver map
+
+ let mut keys = peer.keys.lock();
+ let mut release = Vec::with_capacity(3);
+
+ keys.next.as_ref().map(|k| release.push(k.recv.id));
+ keys.current.as_ref().map(|k| release.push(k.recv.id));
+ keys.previous.as_ref().map(|k| release.push(k.recv.id));
+
+ if release.len() > 0 {
+ let mut recv = peer.device.recv.write();
+ for id in &release {
+ recv.remove(id);
+ }
+ }
+
+ // null key-material
+
+ keys.next = None;
+ keys.current = None;
+ keys.previous = None;
+
+ *peer.ekey.lock() = None;
+ *peer.endpoint.lock() = None;
+
+ debug!("peer dropped & removed from device");
+ }
+}
+
+pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+ device: Arc<DeviceInner<E, C, T, B>>,
+ opaque: C::Opaque,
+) -> Peer<E, C, T, B> {
+ let (out_tx, out_rx) = sync_channel(128);
+ let (in_tx, in_rx) = sync_channel(128);
+
+ // allocate peer object
+ let peer = {
+ let device = device.clone();
+ Arc::new(PeerInner {
+ opaque,
+ device,
+ inbound: Mutex::new(in_tx),
+ outbound: Mutex::new(out_tx),
+ ekey: spin::Mutex::new(None),
+ endpoint: spin::Mutex::new(None),
+ keys: spin::Mutex::new(KeyWheel {
+ next: None,
+ current: None,
+ previous: None,
+ retired: vec![],
+ }),
+ staged_packets: spin::Mutex::new(ArrayDeque::new()),
+ })
+ };
+
+ // spawn outbound thread
+ let thread_inbound = {
+ let peer = peer.clone();
+ let device = device.clone();
+ thread::spawn(move || worker_outbound(device, peer, out_rx))
+ };
+
+ // spawn inbound thread
+ let thread_outbound = {
+ let peer = peer.clone();
+ let device = device.clone();
+ thread::spawn(move || worker_inbound(device, peer, in_rx))
+ };
+
+ Peer {
+ state: peer,
+ thread_inbound: Some(thread_inbound),
+ thread_outbound: Some(thread_outbound),
+ }
+}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> PeerInner<E, C, T, B> {
+ fn send_staged(&self) -> bool {
+ debug!("peer.send_staged");
+ let mut sent = false;
+ let mut staged = self.staged_packets.lock();
+ loop {
+ match staged.pop_front() {
+ Some(msg) => {
+ sent = true;
+ self.send_raw(msg);
+ }
+ None => break sent,
+ }
+ }
+ }
+
+ // Treat the msg as the payload of a transport message
+ // Unlike device.send, peer.send_raw does not buffer messages when a key is not available.
+ fn send_raw(&self, msg: Vec<u8>) -> bool {
+ debug!("peer.send_raw");
+ match self.send_job(msg, false) {
+ Some(job) => {
+ debug!("send_raw: got obtained send_job");
+ let index = self.device.queue_next.fetch_add(1, Ordering::SeqCst);
+ let queues = self.device.queues.lock();
+ match queues[index % queues.len()].send(job) {
+ Ok(_) => true,
+ Err(_) => false,
+ }
+ }
+ None => false,
+ }
+ }
+
+ pub fn confirm_key(&self, keypair: &Arc<KeyPair>) {
+ debug!("peer.confirm_key");
+ {
+ // take lock and check keypair = keys.next
+ let mut keys = self.keys.lock();
+ let next = match keys.next.as_ref() {
+ Some(next) => next,
+ None => {
+ return;
+ }
+ };
+ if !Arc::ptr_eq(&next, keypair) {
+ return;
+ }
+
+ // allocate new encryption state
+ let ekey = Some(EncryptionState::new(&next));
+
+ // rotate key-wheel
+ let mut swap = None;
+ mem::swap(&mut keys.next, &mut swap);
+ mem::swap(&mut keys.current, &mut swap);
+ mem::swap(&mut keys.previous, &mut swap);
+
+ // tell the world outside the router that a key was confirmed
+ C::key_confirmed(&self.opaque);
+
+ // set new key for encryption
+ *self.ekey.lock() = ekey;
+ }
+
+ // start transmission of staged packets
+ self.send_staged();
+ }
+
+ pub fn recv_job(
+ &self,
+ src: E,
+ dec: Arc<DecryptionState<E, C, T, B>>,
+ msg: Vec<u8>,
+ ) -> Option<JobParallel> {
+ let (tx, rx) = oneshot();
+ let key = dec.keypair.recv.key;
+ match self.inbound.lock().try_send((dec, src, rx)) {
+ Ok(_) => Some((
+ tx,
+ JobBuffer {
+ msg,
+ key: key,
+ okay: false,
+ op: Operation::Decryption,
+ },
+ )),
+ Err(_) => None,
+ }
+ }
+
+ pub fn send_job(&self, mut msg: Vec<u8>, stage: bool) -> Option<JobParallel> {
+ debug!("peer.send_job");
+ debug_assert!(
+ msg.len() >= mem::size_of::<TransportHeader>(),
+ "received message with size: {:}",
+ msg.len()
+ );
+
+ // parse / cast
+ let (header, _) = LayoutVerified::new_from_prefix(&mut msg[..]).unwrap();
+ let mut header: LayoutVerified<&mut [u8], TransportHeader> = header;
+
+ // check if has key
+ let key = {
+ let mut ekey = self.ekey.lock();
+ let key = match ekey.as_mut() {
+ None => None,
+ Some(mut state) => {
+ // avoid integer overflow in nonce
+ if state.nonce >= REJECT_AFTER_MESSAGES - 1 {
+ *ekey = None;
+ None
+ } else {
+ // there should be no stacked packets lingering around
+ debug!("encryption state available, nonce = {}", state.nonce);
+
+ // set transport message fields
+ header.f_counter.set(state.nonce);
+ header.f_receiver.set(state.id);
+ state.nonce += 1;
+ Some(state.key)
+ }
+ }
+ };
+
+ // If not suitable key was found:
+ // 1. Stage packet for later transmission
+ // 2. Request new key
+ if key.is_none() && stage {
+ self.staged_packets.lock().push_back(msg);
+ C::need_key(&self.opaque);
+ return None;
+ };
+
+ key
+ }?;
+
+ // add job to in-order queue and return sendeer to device for inclusion in worker pool
+ let (tx, rx) = oneshot();
+ match self.outbound.lock().try_send(rx) {
+ Ok(_) => Some((
+ tx,
+ JobBuffer {
+ msg,
+ key,
+ okay: false,
+ op: Operation::Encryption,
+ },
+ )),
+ Err(_) => None,
+ }
+ }
+}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T, B> {
+ /// Set the endpoint of the peer
+ ///
+ /// # Arguments
+ ///
+ /// - `endpoint`, socket address converted to bind endpoint
+ ///
+ /// # Note
+ ///
+ /// This API still permits support for the "sticky socket" behavior,
+ /// as sockets should be "unsticked" when manually updating the endpoint
+ pub fn set_endpoint(&self, endpoint: E) {
+ debug!("peer.set_endpoint");
+ *self.state.endpoint.lock() = Some(endpoint);
+ }
+
+ /// Returns the current endpoint of the peer (for configuration)
+ ///
+ /// # Note
+ ///
+ /// Does not convey potential "sticky socket" information
+ pub fn get_endpoint(&self) -> Option<SocketAddr> {
+ debug!("peer.get_endpoint");
+ self.state
+ .endpoint
+ .lock()
+ .as_ref()
+ .map(|e| e.into_address())
+ }
+
+ /// Zero all key-material related to the peer
+ pub fn zero_keys(&self) {
+ debug!("peer.zero_keys");
+
+ let mut release: Vec<u32> = Vec::with_capacity(3);
+ let mut keys = self.state.keys.lock();
+
+ // update key-wheel
+
+ mem::replace(&mut keys.next, None).map(|k| release.push(k.local_id()));
+ mem::replace(&mut keys.current, None).map(|k| release.push(k.local_id()));
+ mem::replace(&mut keys.previous, None).map(|k| release.push(k.local_id()));
+ keys.retired.extend(&release[..]);
+
+ // update inbound "recv" map
+ {
+ let mut recv = self.state.device.recv.write();
+ for id in release {
+ recv.remove(&id);
+ }
+ }
+
+ // clear encryption state
+ *self.state.ekey.lock() = None;
+ }
+
+ /// Add a new keypair
+ ///
+ /// # Arguments
+ ///
+ /// - new: The new confirmed/unconfirmed key pair
+ ///
+ /// # Returns
+ ///
+ /// A vector of ids which has been released.
+ /// These should be released in the handshake module.
+ ///
+ /// # Note
+ ///
+ /// The number of ids to be released can be at most 3,
+ /// since the only way to add additional keys to the peer is by using this method
+ /// and a peer can have at most 3 keys allocated in the router at any time.
+ pub fn add_keypair(&self, new: KeyPair) -> Vec<u32> {
+ debug!("peer.add_keypair");
+
+ let initiator = new.initiator;
+ let release = {
+ let new = Arc::new(new);
+ let mut keys = self.state.keys.lock();
+ let mut release = mem::replace(&mut keys.retired, vec![]);
+
+ // update key-wheel
+ if new.initiator {
+ // start using key for encryption
+ *self.state.ekey.lock() = Some(EncryptionState::new(&new));
+
+ // move current into previous
+ keys.previous = keys.current.as_ref().map(|v| v.clone());
+ keys.current = Some(new.clone());
+ } else {
+ // store the key and await confirmation
+ keys.previous = keys.next.as_ref().map(|v| v.clone());
+ keys.next = Some(new.clone());
+ };
+
+ // update incoming packet id map
+ {
+ debug!("peer.add_keypair: updating inbound id map");
+ let mut recv = self.state.device.recv.write();
+
+ // purge recv map of previous id
+ keys.previous.as_ref().map(|k| {
+ recv.remove(&k.local_id());
+ release.push(k.local_id());
+ });
+
+ // map new id to decryption state
+ debug_assert!(!recv.contains_key(&new.recv.id));
+ recv.insert(
+ new.recv.id,
+ Arc::new(DecryptionState::new(&self.state, &new)),
+ );
+ }
+ release
+ };
+
+ // schedule confirmation
+ if initiator {
+ debug_assert!(self.state.ekey.lock().is_some());
+ debug!("peer.add_keypair: is initiator, must confirm the key");
+ // attempt to confirm using staged packets
+ if !self.state.send_staged() {
+ // fall back to keepalive packet
+ let ok = self.send_keepalive();
+ debug!(
+ "peer.add_keypair: keepalive for confirmation, sent = {}",
+ ok
+ );
+ }
+ debug!("peer.add_keypair: key attempted confirmed");
+ }
+
+ debug_assert!(
+ release.len() <= 3,
+ "since the key-wheel contains at most 3 keys"
+ );
+ release
+ }
+
+ pub fn send_keepalive(&self) -> bool {
+ debug!("peer.send_keepalive");
+ self.state.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX])
+ }
+
+ /// Map a subnet to the peer
+ ///
+ /// # Arguments
+ ///
+ /// - `ip`, the mask of the subnet
+ /// - `masklen`, the length of the mask
+ ///
+ /// # Note
+ ///
+ /// The `ip` must not have any bits set right of `masklen`.
+ /// e.g. `192.168.1.0/24` is valid, while `192.168.1.128/24` is not.
+ ///
+ /// If an identical value already exists as part of a prior peer,
+ /// the allowed IP entry will be removed from that peer and added to this peer.
+ pub fn add_subnet(&self, ip: IpAddr, masklen: u32) {
+ debug!("peer.add_subnet");
+ match ip {
+ IpAddr::V4(v4) => {
+ self.state
+ .device
+ .ipv4
+ .write()
+ .insert(v4, masklen, self.state.clone())
+ }
+ IpAddr::V6(v6) => {
+ self.state
+ .device
+ .ipv6
+ .write()
+ .insert(v6, masklen, self.state.clone())
+ }
+ };
+ }
+
+ /// List subnets mapped to the peer
+ ///
+ /// # Returns
+ ///
+ /// A vector of subnets, represented by as mask/size
+ pub fn list_subnets(&self) -> Vec<(IpAddr, u32)> {
+ debug!("peer.list_subnets");
+ let mut res = Vec::new();
+ res.append(&mut treebit_list(
+ &self.state,
+ &self.state.device.ipv4,
+ Box::new(|ip, masklen| (IpAddr::V4(ip), masklen)),
+ ));
+ res.append(&mut treebit_list(
+ &self.state,
+ &self.state.device.ipv6,
+ Box::new(|ip, masklen| (IpAddr::V6(ip), masklen)),
+ ));
+ res
+ }
+
+ /// Clear subnets mapped to the peer.
+ /// After the call, no subnets will be cryptkey routed to the peer.
+ /// Used for the UAPI command "replace_allowed_ips=true"
+ pub fn remove_subnets(&self) {
+ debug!("peer.remove_subnets");
+ treebit_remove(self, &self.state.device.ipv4);
+ treebit_remove(self, &self.state.device.ipv6);
+ }
+
+ /// Send a raw message to the peer (used for handshake messages)
+ ///
+ /// # Arguments
+ ///
+ /// - `msg`, message body to send to peer
+ ///
+ /// # Returns
+ ///
+ /// Unit if packet was sent, or an error indicating why sending failed
+ pub fn send(&self, msg: &[u8]) -> Result<(), RouterError> {
+ debug!("peer.send");
+ let inner = &self.state;
+ match inner.endpoint.lock().as_ref() {
+ Some(endpoint) => inner
+ .device
+ .outbound
+ .read()
+ .as_ref()
+ .ok_or(RouterError::SendError)
+ .and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError)),
+ None => Err(RouterError::NoEndpoint),
+ }
+ }
+
+ pub fn purge_staged_packets(&self) {
+ self.state.staged_packets.lock().clear();
+ }
+}
diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs
new file mode 100644
index 0000000..fbee39e
--- /dev/null
+++ b/src/wireguard/router/tests.rs
@@ -0,0 +1,432 @@
+use std::net::IpAddr;
+use std::sync::atomic::Ordering;
+use std::sync::Arc;
+use std::sync::Mutex;
+use std::thread;
+use std::time::Duration;
+
+use num_cpus;
+use pnet::packet::ipv4::MutableIpv4Packet;
+use pnet::packet::ipv6::MutableIpv6Packet;
+
+use super::super::types::bind::*;
+use super::super::types::*;
+
+use super::{Callbacks, Device, SIZE_MESSAGE_PREFIX};
+
+extern crate test;
+
+const SIZE_KEEPALIVE: usize = 32;
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use env_logger;
+ use log::debug;
+ use std::sync::atomic::AtomicUsize;
+ use test::Bencher;
+
+ // type for tracking events inside the router module
+ struct Flags {
+ send: Mutex<Vec<(usize, bool, bool)>>,
+ recv: Mutex<Vec<(usize, bool, bool)>>,
+ need_key: Mutex<Vec<()>>,
+ key_confirmed: Mutex<Vec<()>>,
+ }
+
+ #[derive(Clone)]
+ struct Opaque(Arc<Flags>);
+
+ struct TestCallbacks();
+
+ impl Opaque {
+ fn new() -> Opaque {
+ Opaque(Arc::new(Flags {
+ send: Mutex::new(vec![]),
+ recv: Mutex::new(vec![]),
+ need_key: Mutex::new(vec![]),
+ key_confirmed: Mutex::new(vec![]),
+ }))
+ }
+
+ fn reset(&self) {
+ self.0.send.lock().unwrap().clear();
+ self.0.recv.lock().unwrap().clear();
+ self.0.need_key.lock().unwrap().clear();
+ self.0.key_confirmed.lock().unwrap().clear();
+ }
+
+ fn send(&self) -> Option<(usize, bool, bool)> {
+ self.0.send.lock().unwrap().pop()
+ }
+
+ fn recv(&self) -> Option<(usize, bool, bool)> {
+ self.0.recv.lock().unwrap().pop()
+ }
+
+ fn need_key(&self) -> Option<()> {
+ self.0.need_key.lock().unwrap().pop()
+ }
+
+ fn key_confirmed(&self) -> Option<()> {
+ self.0.key_confirmed.lock().unwrap().pop()
+ }
+
+ // has all events been accounted for by assertions?
+ fn is_empty(&self) -> bool {
+ let send = self.0.send.lock().unwrap();
+ let recv = self.0.recv.lock().unwrap();
+ let need_key = self.0.need_key.lock().unwrap();
+ let key_confirmed = self.0.key_confirmed.lock().unwrap();
+ send.is_empty() && recv.is_empty() && need_key.is_empty() & key_confirmed.is_empty()
+ }
+ }
+
+ impl Callbacks for TestCallbacks {
+ type Opaque = Opaque;
+
+ fn send(t: &Self::Opaque, size: usize, data: bool, sent: bool) {
+ t.0.send.lock().unwrap().push((size, data, sent))
+ }
+
+ fn recv(t: &Self::Opaque, size: usize, data: bool, sent: bool) {
+ t.0.recv.lock().unwrap().push((size, data, sent))
+ }
+
+ fn need_key(t: &Self::Opaque) {
+ t.0.need_key.lock().unwrap().push(());
+ }
+
+ fn key_confirmed(t: &Self::Opaque) {
+ t.0.key_confirmed.lock().unwrap().push(());
+ }
+ }
+
+ // wait for scheduling
+ fn wait() {
+ thread::sleep(Duration::from_millis(50));
+ }
+
+ fn init() {
+ let _ = env_logger::builder().is_test(true).try_init();
+ }
+
+ fn make_packet(size: usize, ip: IpAddr) -> Vec<u8> {
+ // create "IP packet"
+ let mut msg = Vec::with_capacity(SIZE_MESSAGE_PREFIX + size + 16);
+ msg.resize(SIZE_MESSAGE_PREFIX + size, 0);
+ match ip {
+ IpAddr::V4(ip) => {
+ let mut packet = MutableIpv4Packet::new(&mut msg[SIZE_MESSAGE_PREFIX..]).unwrap();
+ packet.set_destination(ip);
+ packet.set_version(4);
+ }
+ IpAddr::V6(ip) => {
+ let mut packet = MutableIpv6Packet::new(&mut msg[SIZE_MESSAGE_PREFIX..]).unwrap();
+ packet.set_destination(ip);
+ packet.set_version(6);
+ }
+ }
+ msg
+ }
+
+ #[bench]
+ fn bench_outbound(b: &mut Bencher) {
+ struct BencherCallbacks {}
+ impl Callbacks for BencherCallbacks {
+ type Opaque = Arc<AtomicUsize>;
+ fn send(t: &Self::Opaque, size: usize, _data: bool, _sent: bool) {
+ t.fetch_add(size, Ordering::SeqCst);
+ }
+ fn recv(_: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {}
+ fn need_key(_: &Self::Opaque) {}
+ fn key_confirmed(_: &Self::Opaque) {}
+ }
+
+ // create device
+ let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false);
+ let router: Device<_, BencherCallbacks, dummy::TunWriter, dummy::VoidBind> =
+ Device::new(num_cpus::get(), tun_writer);
+
+ // add new peer
+ let opaque = Arc::new(AtomicUsize::new(0));
+ let peer = router.new_peer(opaque.clone());
+ peer.add_keypair(dummy::keypair(true));
+
+ // add subnet to peer
+ let (mask, len, ip) = ("192.168.1.0", 24, "192.168.1.20");
+ let mask: IpAddr = mask.parse().unwrap();
+ let ip1: IpAddr = ip.parse().unwrap();
+ peer.add_subnet(mask, len);
+
+ // every iteration sends 10 GB
+ b.iter(|| {
+ opaque.store(0, Ordering::SeqCst);
+ let msg = make_packet(1024, ip1);
+ while opaque.load(Ordering::Acquire) < 10 * 1024 * 1024 {
+ router.send(msg.to_vec()).unwrap();
+ }
+ });
+ }
+
+ #[test]
+ fn test_outbound() {
+ init();
+
+ // create device
+ let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false);
+ let router: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer);
+ router.set_outbound_writer(dummy::VoidBind::new());
+
+ let tests = vec![
+ ("192.168.1.0", 24, "192.168.1.20", true),
+ ("172.133.133.133", 32, "172.133.133.133", true),
+ ("172.133.133.133", 32, "172.133.133.132", false),
+ (
+ "2001:db8::ff00:42:0000",
+ 112,
+ "2001:db8::ff00:42:3242",
+ true,
+ ),
+ (
+ "2001:db8::ff00:42:8000",
+ 113,
+ "2001:db8::ff00:42:0660",
+ false,
+ ),
+ (
+ "2001:db8::ff00:42:8000",
+ 113,
+ "2001:db8::ff00:42:ffff",
+ true,
+ ),
+ ];
+
+ for (num, (mask, len, ip, okay)) in tests.iter().enumerate() {
+ for set_key in vec![true, false] {
+ debug!("index = {}, set_key = {}", num, set_key);
+
+ // add new peer
+ let opaque = Opaque::new();
+ let peer = router.new_peer(opaque.clone());
+ let mask: IpAddr = mask.parse().unwrap();
+ if set_key {
+ peer.add_keypair(dummy::keypair(true));
+ }
+
+ // map subnet to peer
+ peer.add_subnet(mask, *len);
+
+ // create "IP packet"
+ let msg = make_packet(1024, ip.parse().unwrap());
+
+ // cryptkey route the IP packet
+ let res = router.send(msg);
+
+ // allow some scheduling
+ wait();
+
+ if *okay {
+ // cryptkey routing succeeded
+ assert!(res.is_ok(), "crypt-key routing should succeed");
+ assert_eq!(
+ opaque.need_key().is_some(),
+ !set_key,
+ "should have requested a new key, if no encryption state was set"
+ );
+ assert_eq!(
+ opaque.send().is_some(),
+ set_key,
+ "transmission should have been attempted"
+ );
+ assert!(
+ opaque.recv().is_none(),
+ "no messages should have been marked as received"
+ );
+ } else {
+ // no such cryptkey route
+ assert!(res.is_err(), "crypt-key routing should fail");
+ assert!(
+ opaque.need_key().is_none(),
+ "should not request a new-key if crypt-key routing failed"
+ );
+ assert_eq!(
+ opaque.send(),
+ if set_key {
+ Some((SIZE_KEEPALIVE, false, false))
+ } else {
+ None
+ },
+ "transmission should only happen if key was set (keepalive)",
+ );
+ assert!(
+ opaque.recv().is_none(),
+ "no messages should have been marked as received",
+ );
+ }
+ }
+ }
+ }
+
+ #[test]
+ fn test_bidirectional() {
+ init();
+
+ let tests = [
+ (
+ false, // confirm with keepalive
+ ("192.168.1.0", 24, "192.168.1.20", true),
+ ("172.133.133.133", 32, "172.133.133.133", true),
+ ),
+ (
+ true, // confirm with staged packet
+ ("192.168.1.0", 24, "192.168.1.20", true),
+ ("172.133.133.133", 32, "172.133.133.133", true),
+ ),
+ (
+ false, // confirm with keepalive
+ (
+ "2001:db8::ff00:42:8000",
+ 113,
+ "2001:db8::ff00:42:ffff",
+ true,
+ ),
+ (
+ "2001:db8::ff40:42:8000",
+ 113,
+ "2001:db8::ff40:42:ffff",
+ true,
+ ),
+ ),
+ (
+ false, // confirm with staged packet
+ (
+ "2001:db8::ff00:42:8000",
+ 113,
+ "2001:db8::ff00:42:ffff",
+ true,
+ ),
+ (
+ "2001:db8::ff40:42:8000",
+ 113,
+ "2001:db8::ff40:42:ffff",
+ true,
+ ),
+ ),
+ ];
+
+ for (stage, p1, p2) in tests.iter() {
+ let ((bind_reader1, bind_writer1), (bind_reader2, bind_writer2)) =
+ dummy::PairBind::pair();
+
+ // create matching device
+ let (_fake, _, tun_writer1, _) = dummy::TunTest::create(1500, false);
+ let (_fake, _, tun_writer2, _) = dummy::TunTest::create(1500, false);
+
+ let router1: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer1);
+ router1.set_outbound_writer(bind_writer1);
+
+ let router2: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer2);
+ router2.set_outbound_writer(bind_writer2);
+
+ // prepare opaque values for tracing callbacks
+
+ let opaq1 = Opaque::new();
+ let opaq2 = Opaque::new();
+
+ // create peers with matching keypairs and assign subnets
+
+ let (mask, len, _ip, _okay) = p1;
+ let peer1 = router1.new_peer(opaq1.clone());
+ let mask: IpAddr = mask.parse().unwrap();
+ peer1.add_subnet(mask, *len);
+ peer1.add_keypair(dummy::keypair(false));
+
+ let (mask, len, _ip, _okay) = p2;
+ let peer2 = router2.new_peer(opaq2.clone());
+ let mask: IpAddr = mask.parse().unwrap();
+ peer2.add_subnet(mask, *len);
+ peer2.set_endpoint(dummy::UnitEndpoint::new());
+
+ if *stage {
+ // stage a packet which can be used for confirmation (in place of a keepalive)
+ let (_mask, _len, ip, _okay) = p2;
+ let msg = make_packet(1024, ip.parse().unwrap());
+ router2.send(msg).expect("failed to sent staged packet");
+
+ wait();
+ assert!(opaq2.recv().is_none());
+ assert!(
+ opaq2.send().is_none(),
+ "sending should fail as not key is set"
+ );
+ assert!(
+ opaq2.need_key().is_some(),
+ "a new key should be requested since a packet was attempted transmitted"
+ );
+ assert!(opaq2.is_empty(), "callbacks should only run once");
+ }
+
+ // this should cause a key-confirmation packet (keepalive or staged packet)
+ // this also causes peer1 to learn the "endpoint" for peer2
+ assert!(peer1.get_endpoint().is_none());
+ peer2.add_keypair(dummy::keypair(true));
+
+ wait();
+ assert!(opaq2.send().is_some());
+ assert!(opaq2.is_empty(), "events on peer2 should be 'send'");
+ assert!(opaq1.is_empty(), "nothing should happened on peer1");
+
+ // read confirming message received by the other end ("across the internet")
+ let mut buf = vec![0u8; 2048];
+ let (len, from) = bind_reader1.read(&mut buf).unwrap();
+ buf.truncate(len);
+ router1.recv(from, buf).unwrap();
+
+ wait();
+ assert!(opaq1.recv().is_some());
+ assert!(opaq1.key_confirmed().is_some());
+ assert!(
+ opaq1.is_empty(),
+ "events on peer1 should be 'recv' and 'key_confirmed'"
+ );
+ assert!(peer1.get_endpoint().is_some());
+ assert!(opaq2.is_empty(), "nothing should happened on peer2");
+
+ // now that peer1 has an endpoint
+ // route packets : peer1 -> peer2
+
+ for _ in 0..10 {
+ assert!(
+ opaq1.is_empty(),
+ "we should have asserted a value for every callback on peer1"
+ );
+ assert!(
+ opaq2.is_empty(),
+ "we should have asserted a value for every callback on peer2"
+ );
+
+ // pass IP packet to router
+ let (_mask, _len, ip, _okay) = p1;
+ let msg = make_packet(1024, ip.parse().unwrap());
+ router1.send(msg).unwrap();
+
+ wait();
+ assert!(opaq1.send().is_some());
+ assert!(opaq1.recv().is_none());
+ assert!(opaq1.need_key().is_none());
+
+ // receive ("across the internet") on the other end
+ let mut buf = vec![0u8; 2048];
+ let (len, from) = bind_reader2.read(&mut buf).unwrap();
+ buf.truncate(len);
+ router2.recv(from, buf).unwrap();
+
+ wait();
+ assert!(opaq2.send().is_none());
+ assert!(opaq2.recv().is_some());
+ assert!(opaq2.need_key().is_none());
+ }
+ }
+ }
+}
diff --git a/src/wireguard/router/types.rs b/src/wireguard/router/types.rs
new file mode 100644
index 0000000..b7c3ae0
--- /dev/null
+++ b/src/wireguard/router/types.rs
@@ -0,0 +1,65 @@
+use std::error::Error;
+use std::fmt;
+
+pub trait Opaque: Send + Sync + 'static {}
+
+impl<T> Opaque for T where T: Send + Sync + 'static {}
+
+/// A send/recv callback takes 3 arguments:
+///
+/// * `0`, a reference to the opaque value assigned to the peer
+/// * `1`, a bool indicating whether the message contained data (not just keepalive)
+/// * `2`, a bool indicating whether the message was transmitted (i.e. did the peer have an associated endpoint?)
+pub trait Callback<T>: Fn(&T, usize, bool, bool) -> () + Sync + Send + 'static {}
+
+impl<T, F> Callback<T> for F where F: Fn(&T, usize, bool, bool) -> () + Sync + Send + 'static {}
+
+/// A key callback takes 1 argument
+///
+/// * `0`, a reference to the opaque value assigned to the peer
+pub trait KeyCallback<T>: Fn(&T) -> () + Sync + Send + 'static {}
+
+impl<T, F> KeyCallback<T> for F where F: Fn(&T) -> () + Sync + Send + 'static {}
+
+pub trait Callbacks: Send + Sync + 'static {
+ type Opaque: Opaque;
+ fn send(opaque: &Self::Opaque, size: usize, data: bool, sent: bool);
+ fn recv(opaque: &Self::Opaque, size: usize, data: bool, sent: bool);
+ fn need_key(opaque: &Self::Opaque);
+ fn key_confirmed(opaque: &Self::Opaque);
+}
+
+#[derive(Debug)]
+pub enum RouterError {
+ NoCryptKeyRoute,
+ MalformedIPHeader,
+ MalformedTransportMessage,
+ UnknownReceiverId,
+ NoEndpoint,
+ SendError,
+}
+
+impl fmt::Display for RouterError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ RouterError::NoCryptKeyRoute => write!(f, "No cryptkey route configured for subnet"),
+ RouterError::MalformedIPHeader => write!(f, "IP header is malformed"),
+ RouterError::MalformedTransportMessage => write!(f, "IP header is malformed"),
+ RouterError::UnknownReceiverId => {
+ write!(f, "No decryption state associated with receiver id")
+ }
+ RouterError::NoEndpoint => write!(f, "No endpoint for peer"),
+ RouterError::SendError => write!(f, "Failed to send packet on bind"),
+ }
+ }
+}
+
+impl Error for RouterError {
+ fn description(&self) -> &str {
+ "Generic Handshake Error"
+ }
+
+ fn source(&self) -> Option<&(dyn Error + 'static)> {
+ None
+ }
+}
diff --git a/src/wireguard/router/workers.rs b/src/wireguard/router/workers.rs
new file mode 100644
index 0000000..2e89bb0
--- /dev/null
+++ b/src/wireguard/router/workers.rs
@@ -0,0 +1,305 @@
+use std::mem;
+use std::sync::mpsc::Receiver;
+use std::sync::Arc;
+
+use futures::sync::oneshot;
+use futures::*;
+
+use log::debug;
+
+use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
+use std::net::{Ipv4Addr, Ipv6Addr};
+use std::sync::atomic::Ordering;
+use zerocopy::{AsBytes, LayoutVerified};
+
+use super::device::{DecryptionState, DeviceInner};
+use super::messages::{TransportHeader, TYPE_TRANSPORT};
+use super::peer::PeerInner;
+use super::types::Callbacks;
+
+use super::super::types::{Endpoint, tun, bind};
+use super::ip::*;
+
+const SIZE_TAG: usize = 16;
+
+#[derive(PartialEq, Debug)]
+pub enum Operation {
+ Encryption,
+ Decryption,
+}
+
+pub struct JobBuffer {
+ pub msg: Vec<u8>, // message buffer (nonce and receiver id set)
+ pub key: [u8; 32], // chacha20poly1305 key
+ pub okay: bool, // state of the job
+ pub op: Operation, // should be buffer be encrypted / decrypted?
+}
+
+pub type JobParallel = (oneshot::Sender<JobBuffer>, JobBuffer);
+
+#[allow(type_alias_bounds)]
+pub type JobInbound<E, C, T, B: bind::Writer<E>> = (
+ Arc<DecryptionState<E, C, T, B>>,
+ E,
+ oneshot::Receiver<JobBuffer>,
+);
+
+pub type JobOutbound = oneshot::Receiver<JobBuffer>;
+
+#[inline(always)]
+fn check_route<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+ device: &Arc<DeviceInner<E, C, T, B>>,
+ peer: &Arc<PeerInner<E, C, T, B>>,
+ packet: &[u8],
+) -> Option<usize> {
+ match packet[0] >> 4 {
+ VERSION_IP4 => {
+ // check length and cast to IPv4 header
+ let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) =
+ LayoutVerified::new_from_prefix(packet)?;
+
+ // check IPv4 source address
+ device
+ .ipv4
+ .read()
+ .longest_match(Ipv4Addr::from(header.f_source))
+ .and_then(|(_, _, p)| {
+ if Arc::ptr_eq(p, &peer) {
+ Some(header.f_total_len.get() as usize)
+ } else {
+ None
+ }
+ })
+ }
+ VERSION_IP6 => {
+ // check length and cast to IPv6 header
+ let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) =
+ LayoutVerified::new_from_prefix(packet)?;
+
+ // check IPv6 source address
+ device
+ .ipv6
+ .read()
+ .longest_match(Ipv6Addr::from(header.f_source))
+ .and_then(|(_, _, p)| {
+ if Arc::ptr_eq(p, &peer) {
+ Some(header.f_len.get() as usize + mem::size_of::<IPv6Header>())
+ } else {
+ None
+ }
+ })
+ }
+ _ => None,
+ }
+}
+
+pub fn worker_inbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+ device: Arc<DeviceInner<E, C, T, B>>, // related device
+ peer: Arc<PeerInner<E, C, T, B>>, // related peer
+ receiver: Receiver<JobInbound<E, C, T, B>>,
+) {
+ loop {
+ // fetch job
+ let (state, endpoint, rx) = match receiver.recv() {
+ Ok(v) => v,
+ _ => {
+ return;
+ }
+ };
+ debug!("inbound worker: obtained job");
+
+ // wait for job to complete
+ let _ = rx
+ .map(|buf| {
+ debug!("inbound worker: job complete");
+ if buf.okay {
+ // cast transport header
+ let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) =
+ match LayoutVerified::new_from_prefix(&buf.msg[..]) {
+ Some(v) => v,
+ None => {
+ debug!("inbound worker: failed to parse message");
+ return;
+ }
+ };
+
+ debug_assert!(
+ packet.len() >= CHACHA20_POLY1305.tag_len(),
+ "this should be checked earlier in the pipeline (decryption should fail)"
+ );
+
+ // check for replay
+ if !state.protector.lock().update(header.f_counter.get()) {
+ debug!("inbound worker: replay detected");
+ return;
+ }
+
+ // check for confirms key
+ if !state.confirmed.swap(true, Ordering::SeqCst) {
+ debug!("inbound worker: message confirms key");
+ peer.confirm_key(&state.keypair);
+ }
+
+ // update endpoint
+ *peer.endpoint.lock() = Some(endpoint);
+
+ // calculate length of IP packet + padding
+ let length = packet.len() - SIZE_TAG;
+ debug!("inbound worker: plaintext length = {}", length);
+
+ // check if should be written to TUN
+ let mut sent = false;
+ if length > 0 {
+ if let Some(inner_len) = check_route(&device, &peer, &packet[..length]) {
+ debug_assert!(inner_len <= length, "should be validated");
+ if inner_len <= length {
+ sent = match device.inbound.write(&packet[..inner_len]) {
+ Err(e) => {
+ debug!("failed to write inbound packet to TUN: {:?}", e);
+ false
+ }
+ Ok(_) => true,
+ }
+ }
+ }
+ } else {
+ debug!("inbound worker: received keepalive")
+ }
+
+ // trigger callback
+ C::recv(&peer.opaque, buf.msg.len(), length == 0, sent);
+ } else {
+ debug!("inbound worker: authentication failure")
+ }
+ })
+ .wait();
+ }
+}
+
+pub fn worker_outbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+ device: Arc<DeviceInner<E, C, T, B>>, // related device
+ peer: Arc<PeerInner<E, C, T, B>>, // related peer
+ receiver: Receiver<JobOutbound>,
+) {
+ loop {
+ // fetch job
+ let rx = match receiver.recv() {
+ Ok(v) => v,
+ _ => {
+ return;
+ }
+ };
+ debug!("outbound worker: obtained job");
+
+ // wait for job to complete
+ let _ = rx
+ .map(|buf| {
+ debug!("outbound worker: job complete");
+ if buf.okay {
+ // write to UDP bind
+ let xmit = if let Some(dst) = peer.endpoint.lock().as_ref() {
+ let send : &Option<B> = &*device.outbound.read();
+ if let Some(writer) = send.as_ref() {
+ match writer.write(&buf.msg[..], dst) {
+ Err(e) => {
+ debug!("failed to send outbound packet: {:?}", e);
+ false
+ }
+ Ok(_) => true,
+ }
+ } else {
+ false
+ }
+ } else {
+ false
+ };
+
+ // trigger callback
+ C::send(
+ &peer.opaque,
+ buf.msg.len(),
+ buf.msg.len() > SIZE_TAG + mem::size_of::<TransportHeader>(),
+ xmit,
+ );
+ }
+ })
+ .wait();
+ }
+}
+
+pub fn worker_parallel(receiver: Receiver<JobParallel>) {
+ loop {
+ // fetch next job
+ let (tx, mut buf) = match receiver.recv() {
+ Err(_) => {
+ return;
+ }
+ Ok(val) => val,
+ };
+ debug!("parallel worker: obtained job");
+
+ // make space for tag (TODO: consider moving this out)
+ if buf.op == Operation::Encryption {
+ buf.msg.extend([0u8; SIZE_TAG].iter());
+ }
+
+ // cast and check size of packet
+ let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
+ match LayoutVerified::new_from_prefix(&mut buf.msg[..]) {
+ Some(v) => v,
+ None => {
+ debug_assert!(
+ false,
+ "parallel worker: failed to parse message (insufficient size)"
+ );
+ continue;
+ }
+ };
+ debug_assert!(packet.len() >= CHACHA20_POLY1305.tag_len());
+
+ // do the weird ring AEAD dance
+ let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &buf.key[..]).unwrap());
+
+ // create a nonce object
+ let mut nonce = [0u8; 12];
+ debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
+ nonce[4..].copy_from_slice(header.f_counter.as_bytes());
+ let nonce = Nonce::assume_unique_for_key(nonce);
+
+ match buf.op {
+ Operation::Encryption => {
+ debug!("parallel worker: process encryption");
+
+ // set the type field
+ header.f_type.set(TYPE_TRANSPORT);
+
+ // encrypt content of transport message in-place
+ let end = packet.len() - SIZE_TAG;
+ let tag = key
+ .seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end])
+ .unwrap();
+
+ // append tag
+ packet[end..].copy_from_slice(tag.as_ref());
+
+ buf.okay = true;
+ }
+ Operation::Decryption => {
+ debug!("parallel worker: process decryption");
+
+ // opening failure is signaled by fault state
+ buf.okay = match key.open_in_place(nonce, Aad::empty(), packet) {
+ Ok(_) => true,
+ Err(_) => false,
+ };
+ }
+ }
+
+ // pass ownership to consumer
+ let okay = tx.send(buf);
+ debug!(
+ "parallel worker: passing ownership to sequential worker: {}",
+ okay.is_ok()
+ );
+ }
+}
diff --git a/src/wireguard/tests.rs b/src/wireguard/tests.rs
new file mode 100644
index 0000000..0148d5d
--- /dev/null
+++ b/src/wireguard/tests.rs
@@ -0,0 +1,46 @@
+use super::types::tun::Tun;
+use super::types::{bind, dummy, tun};
+use super::wireguard::Wireguard;
+
+use std::thread;
+use std::time::Duration;
+
+fn init() {
+ let _ = env_logger::builder().is_test(true).try_init();
+}
+
+/* Create and configure two matching pure instances of WireGuard
+ *
+ */
+#[test]
+fn test_pure_wireguard() {
+ init();
+
+ // create WG instances for fake TUN devices
+
+ let (fake1, tun_reader1, tun_writer1, mtu1) = dummy::TunTest::create(1500, true);
+ let wg1: Wireguard<dummy::TunTest, dummy::PairBind> =
+ Wireguard::new(vec![tun_reader1], tun_writer1, mtu1);
+
+ let (fake2, tun_reader2, tun_writer2, mtu2) = dummy::TunTest::create(1500, true);
+ let wg2: Wireguard<dummy::TunTest, dummy::PairBind> =
+ Wireguard::new(vec![tun_reader2], tun_writer2, mtu2);
+
+ // create pair bind to connect the interfaces "over the internet"
+
+ let ((bind_reader1, bind_writer1), (bind_reader2, bind_writer2)) = dummy::PairBind::pair();
+
+ wg1.set_writer(bind_writer1);
+ wg2.set_writer(bind_writer2);
+
+ wg1.add_reader(bind_reader1);
+ wg2.add_reader(bind_reader2);
+
+ // generate (public, pivate) key pairs
+
+ // configure cryptkey router
+
+ // create IP packets
+
+ thread::sleep(Duration::from_millis(500));
+}
diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs
new file mode 100644
index 0000000..2792c7b
--- /dev/null
+++ b/src/wireguard/timers.rs
@@ -0,0 +1,234 @@
+use std::marker::PhantomData;
+use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
+use std::sync::Arc;
+use std::time::Duration;
+
+use log::info;
+
+use hjul::{Runner, Timer};
+
+use super::constants::*;
+use super::router::Callbacks;
+use super::types::{bind, tun};
+use super::wireguard::{Peer, PeerInner};
+
+pub struct Timers {
+ handshake_pending: AtomicBool,
+ handshake_attempts: AtomicUsize,
+
+ retransmit_handshake: Timer,
+ send_keepalive: Timer,
+ send_persistent_keepalive: Timer,
+ sent_lastminute_handshake: AtomicBool,
+ zero_key_material: Timer,
+ new_handshake: Timer,
+ need_another_keepalive: AtomicBool,
+}
+
+impl Timers {
+ #[inline(always)]
+ fn need_another_keepalive(&self) -> bool {
+ self.need_another_keepalive.swap(false, Ordering::SeqCst)
+ }
+}
+
+impl <T: tun::Tun, B: bind::Bind>Peer<T, B> {
+ /* should be called after an authenticated data packet is sent */
+ pub fn timers_data_sent(&self) {
+ self.timers().new_handshake.start(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT);
+ }
+
+ /* should be called after an authenticated data packet is received */
+ pub fn timers_data_received(&self) {
+ if !self.timers().send_keepalive.start(KEEPALIVE_TIMEOUT) {
+ self.timers().need_another_keepalive.store(true, Ordering::SeqCst)
+ }
+ }
+
+ /* Should be called after any type of authenticated packet is sent, whether:
+ * - keepalive
+ * - data
+ * - handshake
+ */
+ pub fn timers_any_authenticated_packet_sent(&self) {
+ self.timers().send_keepalive.stop()
+ }
+
+ /* Should be called after any type of authenticated packet is received, whether:
+ * - keepalive
+ * - data
+ * - handshake
+ */
+ pub fn timers_any_authenticated_packet_received(&self) {
+ self.timers().new_handshake.stop();
+ }
+
+ /* Should be called after a handshake initiation message is sent. */
+ pub fn timers_handshake_initiated(&self) {
+ self.timers().send_keepalive.stop();
+ self.timers().retransmit_handshake.reset(REKEY_TIMEOUT);
+ }
+
+ /* Should be called after a handshake response message is received and processed
+ * or when getting key confirmation via the first data message.
+ */
+ pub fn timers_handshake_complete(&self) {
+ self.timers().handshake_attempts.store(0, Ordering::SeqCst);
+ self.timers().sent_lastminute_handshake.store(false, Ordering::SeqCst);
+ // TODO: Store time in peer for config
+ // self.walltime_last_handshake
+ }
+
+ /* Should be called after an ephemeral key is created, which is before sending a
+ * handshake response or after receiving a handshake response.
+ */
+ pub fn timers_session_derived(&self) {
+ self.timers().zero_key_material.reset(REJECT_AFTER_TIME * 3);
+ }
+
+ /* Should be called before a packet with authentication, whether
+ * keepalive, data, or handshake is sent, or after one is received.
+ */
+ pub fn timers_any_authenticated_packet_traversal(&self) {
+ let keepalive = self.state.keepalive.load(Ordering::Acquire);
+ if keepalive > 0 {
+ self.timers().send_persistent_keepalive.reset(Duration::from_secs(keepalive as u64));
+ }
+ }
+}
+
+impl Timers {
+ pub fn new<T, B>(runner: &Runner, peer: Peer<T, B>) -> Timers
+ where
+ T: tun::Tun,
+ B: bind::Bind,
+ {
+ // create a timer instance for the provided peer
+ Timers {
+ handshake_pending: AtomicBool::new(false),
+ need_another_keepalive: AtomicBool::new(false),
+ sent_lastminute_handshake: AtomicBool::new(false),
+ handshake_attempts: AtomicUsize::new(0),
+ retransmit_handshake: {
+ let peer = peer.clone();
+ runner.timer(move || {
+ if peer.timers().handshake_retry() {
+ info!("Retransmit handshake for {}", peer);
+ peer.new_handshake();
+ } else {
+ info!("Failed to complete handshake for {}", peer);
+ peer.router.purge_staged_packets();
+ peer.timers().send_keepalive.stop();
+ peer.timers().zero_key_material.start(REJECT_AFTER_TIME * 3);
+ }
+ })
+ },
+ send_keepalive: {
+ let peer = peer.clone();
+ runner.timer(move || {
+ peer.router.send_keepalive();
+ if peer.timers().need_another_keepalive() {
+ peer.timers().send_keepalive.start(KEEPALIVE_TIMEOUT);
+ }
+ })
+ },
+ new_handshake: {
+ let peer = peer.clone();
+ runner.timer(move || {
+ info!(
+ "Retrying handshake with {}, because we stopped hearing back after {} seconds",
+ peer,
+ (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT).as_secs()
+ );
+ peer.new_handshake();
+ peer.timers.read().handshake_begun();
+ })
+ },
+ zero_key_material: {
+ let peer = peer.clone();
+ runner.timer(move || {
+ peer.router.zero_keys();
+ })
+ },
+ send_persistent_keepalive: {
+ let peer = peer.clone();
+ runner.timer(move || {
+ let keepalive = peer.state.keepalive.load(Ordering::Acquire);
+ if keepalive > 0 {
+ peer.router.send_keepalive();
+ peer.timers().send_keepalive.stop();
+ peer.timers().send_persistent_keepalive.start(Duration::from_secs(keepalive as u64));
+ }
+ })
+ }
+ }
+ }
+
+ fn handshake_begun(&self) {
+ self.handshake_pending.store(true, Ordering::SeqCst);
+ self.handshake_attempts.store(0, Ordering::SeqCst);
+ self.retransmit_handshake.reset(REKEY_TIMEOUT);
+ }
+
+ fn handshake_retry(&self) -> bool {
+ if self.handshake_attempts.fetch_add(1, Ordering::SeqCst) <= MAX_TIMER_HANDSHAKES {
+ self.retransmit_handshake.reset(REKEY_TIMEOUT);
+ true
+ } else {
+ self.handshake_pending.store(false, Ordering::SeqCst);
+ false
+ }
+ }
+
+ pub fn updated_persistent_keepalive(&self, keepalive: usize) {
+ if keepalive > 0 {
+ self.send_persistent_keepalive.reset(Duration::from_secs(keepalive as u64));
+ }
+ }
+
+ pub fn dummy(runner: &Runner) -> Timers {
+ Timers {
+ handshake_pending: AtomicBool::new(false),
+ need_another_keepalive: AtomicBool::new(false),
+ sent_lastminute_handshake: AtomicBool::new(false),
+ handshake_attempts: AtomicUsize::new(0),
+ retransmit_handshake: runner.timer(|| {}),
+ new_handshake: runner.timer(|| {}),
+ send_keepalive: runner.timer(|| {}),
+ send_persistent_keepalive: runner.timer(|| {}),
+ zero_key_material: runner.timer(|| {})
+ }
+ }
+
+ pub fn handshake_sent(&self) {
+ self.send_keepalive.stop();
+ }
+}
+
+/* Instance of the router callbacks */
+
+pub struct Events<T, B>(PhantomData<(T, B)>);
+
+impl<T: tun::Tun, B: bind::Bind> Callbacks for Events<T, B> {
+ type Opaque = Arc<PeerInner<B>>;
+
+ fn send(peer: &Self::Opaque, size: usize, data: bool, sent: bool) {
+ peer.tx_bytes.fetch_add(size as u64, Ordering::Relaxed);
+ }
+
+ fn recv(peer: &Self::Opaque, size: usize, data: bool, sent: bool) {
+ peer.rx_bytes.fetch_add(size as u64, Ordering::Relaxed);
+ }
+
+ fn need_key(peer: &Self::Opaque) {
+ let timers = peer.timers();
+ if !timers.handshake_pending.swap(true, Ordering::SeqCst) {
+ timers.handshake_attempts.store(0, Ordering::SeqCst);
+ timers.new_handshake.fire();
+ }
+ }
+
+ fn key_confirmed(peer: &Self::Opaque) {
+ peer.timers().retransmit_handshake.stop();
+ }
+}
diff --git a/src/wireguard/types/bind.rs b/src/wireguard/types/bind.rs
new file mode 100644
index 0000000..3d3f187
--- /dev/null
+++ b/src/wireguard/types/bind.rs
@@ -0,0 +1,23 @@
+use super::Endpoint;
+use std::error::Error;
+
+pub trait Reader<E: Endpoint>: Send + Sync {
+ type Error: Error;
+
+ fn read(&self, buf: &mut [u8]) -> Result<(usize, E), Self::Error>;
+}
+
+pub trait Writer<E: Endpoint>: Send + Sync + Clone + 'static {
+ type Error: Error;
+
+ fn write(&self, buf: &[u8], dst: &E) -> Result<(), Self::Error>;
+}
+
+pub trait Bind: Send + Sync + 'static {
+ type Error: Error;
+ type Endpoint: Endpoint;
+
+ /* Until Rust gets type equality constraints these have to be generic */
+ type Writer: Writer<Self::Endpoint>;
+ type Reader: Reader<Self::Endpoint>;
+}
diff --git a/src/wireguard/types/dummy.rs b/src/wireguard/types/dummy.rs
new file mode 100644
index 0000000..2403c9b
--- /dev/null
+++ b/src/wireguard/types/dummy.rs
@@ -0,0 +1,323 @@
+use std::error::Error;
+use std::fmt;
+use std::marker;
+use std::net::SocketAddr;
+use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
+use std::sync::Arc;
+use std::sync::Mutex;
+use std::time::Instant;
+use std::sync::atomic::{Ordering, AtomicUsize};
+
+use super::*;
+
+/* This submodule provides pure/dummy implementations of the IO interfaces
+ * for use in unit tests thoughout the project.
+ */
+
+/* Error implementation */
+
+#[derive(Debug)]
+pub enum BindError {
+ Disconnected,
+}
+
+impl Error for BindError {
+ fn description(&self) -> &str {
+ "Generic Bind Error"
+ }
+
+ fn source(&self) -> Option<&(dyn Error + 'static)> {
+ None
+ }
+}
+
+impl fmt::Display for BindError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ BindError::Disconnected => write!(f, "PairBind disconnected"),
+ }
+ }
+}
+
+/* TUN implementation */
+
+#[derive(Debug)]
+pub enum TunError {
+ Disconnected
+}
+
+impl Error for TunError {
+ fn description(&self) -> &str {
+ "Generic Tun Error"
+ }
+
+ fn source(&self) -> Option<&(dyn Error + 'static)> {
+ None
+ }
+}
+
+impl fmt::Display for TunError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "Not Possible")
+ }
+}
+
+/* Endpoint implementation */
+
+#[derive(Clone, Copy)]
+pub struct UnitEndpoint {}
+
+impl Endpoint for UnitEndpoint {
+ fn from_address(_: SocketAddr) -> UnitEndpoint {
+ UnitEndpoint {}
+ }
+
+ fn into_address(&self) -> SocketAddr {
+ "127.0.0.1:8080".parse().unwrap()
+ }
+
+ fn clear_src(&self) {}
+}
+
+impl UnitEndpoint {
+ pub fn new() -> UnitEndpoint {
+ UnitEndpoint {}
+ }
+}
+
+/* */
+
+pub struct TunTest {}
+
+pub struct TunFakeIO {
+ store: bool,
+ tx: SyncSender<Vec<u8>>,
+ rx: Receiver<Vec<u8>>
+}
+
+pub struct TunReader {
+ rx: Receiver<Vec<u8>>
+}
+
+pub struct TunWriter {
+ store: bool,
+ tx: Mutex<SyncSender<Vec<u8>>>
+}
+
+#[derive(Clone)]
+pub struct TunMTU {
+ mtu: Arc<AtomicUsize>
+}
+
+impl tun::Reader for TunReader {
+ type Error = TunError;
+
+ fn read(&self, buf: &mut [u8], offset: usize) -> Result<usize, Self::Error> {
+ match self.rx.recv() {
+ Ok(m) => {
+ buf[offset..].copy_from_slice(&m[..]);
+ Ok(m.len())
+ }
+ Err(_) => Err(TunError::Disconnected)
+ }
+ }
+}
+
+impl tun::Writer for TunWriter {
+ type Error = TunError;
+
+ fn write(&self, src: &[u8]) -> Result<(), Self::Error> {
+ if self.store {
+ let m = src.to_owned();
+ match self.tx.lock().unwrap().send(m) {
+ Ok(_) => Ok(()),
+ Err(_) => Err(TunError::Disconnected)
+ }
+ } else {
+ Ok(())
+ }
+ }
+}
+
+impl tun::MTU for TunMTU {
+ fn mtu(&self) -> usize {
+ self.mtu.load(Ordering::Acquire)
+ }
+}
+
+impl tun::Tun for TunTest {
+ type Writer = TunWriter;
+ type Reader = TunReader;
+ type MTU = TunMTU;
+ type Error = TunError;
+}
+
+impl TunFakeIO {
+ pub fn write(&self, msg : Vec<u8>) {
+ if self.store {
+ self.tx.send(msg).unwrap();
+ }
+ }
+
+ pub fn read(&self) -> Vec<u8> {
+ self.rx.recv().unwrap()
+ }
+}
+
+impl TunTest {
+ pub fn create(mtu : usize, store: bool) -> (TunFakeIO, TunReader, TunWriter, TunMTU) {
+
+ let (tx1, rx1) = if store { sync_channel(32) } else { sync_channel(1) };
+ let (tx2, rx2) = if store { sync_channel(32) } else { sync_channel(1) };
+
+ let fake = TunFakeIO{tx: tx1, rx: rx2, store};
+ let reader = TunReader{rx : rx1};
+ let writer = TunWriter{tx : Mutex::new(tx2), store};
+ let mtu = TunMTU{mtu : Arc::new(AtomicUsize::new(mtu))};
+
+ (fake, reader, writer, mtu)
+ }
+}
+
+/* Void Bind */
+
+#[derive(Clone, Copy)]
+pub struct VoidBind {}
+
+impl bind::Reader<UnitEndpoint> for VoidBind {
+ type Error = BindError;
+
+ fn read(&self, _buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> {
+ Ok((0, UnitEndpoint {}))
+ }
+}
+
+impl bind::Writer<UnitEndpoint> for VoidBind {
+ type Error = BindError;
+
+ fn write(&self, _buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> {
+ Ok(())
+ }
+}
+
+impl bind::Bind for VoidBind {
+ type Error = BindError;
+ type Endpoint = UnitEndpoint;
+
+ type Reader = VoidBind;
+ type Writer = VoidBind;
+}
+
+impl VoidBind {
+ pub fn new() -> VoidBind {
+ VoidBind {}
+ }
+}
+
+/* Pair Bind */
+
+#[derive(Clone)]
+pub struct PairReader<E> {
+ recv: Arc<Mutex<Receiver<Vec<u8>>>>,
+ _marker: marker::PhantomData<E>,
+}
+
+impl bind::Reader<UnitEndpoint> for PairReader<UnitEndpoint> {
+ type Error = BindError;
+ fn read(&self, buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> {
+ let vec = self
+ .recv
+ .lock()
+ .unwrap()
+ .recv()
+ .map_err(|_| BindError::Disconnected)?;
+ let len = vec.len();
+ buf[..len].copy_from_slice(&vec[..]);
+ Ok((vec.len(), UnitEndpoint {}))
+ }
+}
+
+impl bind::Writer<UnitEndpoint> for PairWriter<UnitEndpoint> {
+ type Error = BindError;
+ fn write(&self, buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> {
+ let owned = buf.to_owned();
+ match self.send.lock().unwrap().send(owned) {
+ Err(_) => Err(BindError::Disconnected),
+ Ok(_) => Ok(()),
+ }
+ }
+}
+
+#[derive(Clone)]
+pub struct PairWriter<E> {
+ send: Arc<Mutex<SyncSender<Vec<u8>>>>,
+ _marker: marker::PhantomData<E>,
+}
+
+#[derive(Clone)]
+pub struct PairBind {}
+
+impl PairBind {
+ pub fn pair<E>() -> (
+ (PairReader<E>, PairWriter<E>),
+ (PairReader<E>, PairWriter<E>),
+ ) {
+ let (tx1, rx1) = sync_channel(128);
+ let (tx2, rx2) = sync_channel(128);
+ (
+ (
+ PairReader {
+ recv: Arc::new(Mutex::new(rx1)),
+ _marker: marker::PhantomData,
+ },
+ PairWriter {
+ send: Arc::new(Mutex::new(tx2)),
+ _marker: marker::PhantomData,
+ },
+ ),
+ (
+ PairReader {
+ recv: Arc::new(Mutex::new(rx2)),
+ _marker: marker::PhantomData,
+ },
+ PairWriter {
+ send: Arc::new(Mutex::new(tx1)),
+ _marker: marker::PhantomData,
+ },
+ ),
+ )
+ }
+}
+
+impl bind::Bind for PairBind {
+ type Error = BindError;
+ type Endpoint = UnitEndpoint;
+ type Reader = PairReader<Self::Endpoint>;
+ type Writer = PairWriter<Self::Endpoint>;
+}
+
+pub fn keypair(initiator: bool) -> KeyPair {
+ let k1 = Key {
+ key: [0x53u8; 32],
+ id: 0x646e6573,
+ };
+ let k2 = Key {
+ key: [0x52u8; 32],
+ id: 0x76636572,
+ };
+ if initiator {
+ KeyPair {
+ birth: Instant::now(),
+ initiator: true,
+ send: k1,
+ recv: k2,
+ }
+ } else {
+ KeyPair {
+ birth: Instant::now(),
+ initiator: false,
+ send: k2,
+ recv: k1,
+ }
+ }
+}
diff --git a/src/wireguard/types/endpoint.rs b/src/wireguard/types/endpoint.rs
new file mode 100644
index 0000000..f4f93da
--- /dev/null
+++ b/src/wireguard/types/endpoint.rs
@@ -0,0 +1,7 @@
+use std::net::SocketAddr;
+
+pub trait Endpoint: Send + 'static {
+ fn from_address(addr: SocketAddr) -> Self;
+ fn into_address(&self) -> SocketAddr;
+ fn clear_src(&self);
+}
diff --git a/src/wireguard/types/keys.rs b/src/wireguard/types/keys.rs
new file mode 100644
index 0000000..282c4ae
--- /dev/null
+++ b/src/wireguard/types/keys.rs
@@ -0,0 +1,36 @@
+use clear_on_drop::clear::Clear;
+use std::time::Instant;
+
+#[derive(Debug, Clone)]
+pub struct Key {
+ pub key: [u8; 32],
+ pub id: u32,
+}
+
+// zero key on drop
+impl Drop for Key {
+ fn drop(&mut self) {
+ self.key.clear()
+ }
+}
+
+#[cfg(test)]
+impl PartialEq for Key {
+ fn eq(&self, other: &Self) -> bool {
+ self.id == other.id && self.key[..] == other.key[..]
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct KeyPair {
+ pub birth: Instant, // when was the key-pair created
+ pub initiator: bool, // has the key-pair been confirmed?
+ pub send: Key, // key for outbound messages
+ pub recv: Key, // key for inbound messages
+}
+
+impl KeyPair {
+ pub fn local_id(&self) -> u32 {
+ self.recv.id
+ }
+}
diff --git a/src/wireguard/types/mod.rs b/src/wireguard/types/mod.rs
new file mode 100644
index 0000000..e0725f3
--- /dev/null
+++ b/src/wireguard/types/mod.rs
@@ -0,0 +1,10 @@
+mod endpoint;
+mod keys;
+pub mod tun;
+pub mod bind;
+
+#[cfg(test)]
+pub mod dummy;
+
+pub use endpoint::Endpoint;
+pub use keys::{Key, KeyPair}; \ No newline at end of file
diff --git a/src/wireguard/types/tun.rs b/src/wireguard/types/tun.rs
new file mode 100644
index 0000000..2ba16ff
--- /dev/null
+++ b/src/wireguard/types/tun.rs
@@ -0,0 +1,56 @@
+use std::error::Error;
+
+pub trait Writer: Send + Sync + 'static {
+ type Error: Error;
+
+ /// Receive a cryptkey routed IP packet
+ ///
+ /// # Arguments
+ ///
+ /// - src: Buffer containing the IP packet to be written
+ ///
+ /// # Returns
+ ///
+ /// Unit type or an error
+ fn write(&self, src: &[u8]) -> Result<(), Self::Error>;
+}
+
+pub trait Reader: Send + 'static {
+ type Error: Error;
+
+ /// Reads an IP packet into dst[offset:] from the tunnel device
+ ///
+ /// The reason for providing space for a prefix
+ /// is to efficiently accommodate platforms on which the packet is prefaced by a header.
+ /// This space is later used to construct the transport message inplace.
+ ///
+ /// # Arguments
+ ///
+ /// - buf: Destination buffer (enough space for MTU bytes + header)
+ /// - offset: Offset for the beginning of the IP packet
+ ///
+ /// # Returns
+ ///
+ /// The size of the IP packet (ignoring the header) or an std::error::Error instance:
+ fn read(&self, buf: &mut [u8], offset: usize) -> Result<usize, Self::Error>;
+}
+
+pub trait MTU: Send + Sync + Clone + 'static {
+ /// Returns the MTU of the device
+ ///
+ /// This function needs to be efficient (called for every read).
+ /// The goto implementation strategy is to .load an atomic variable,
+ /// then use e.g. netlink to update the variable in a separate thread.
+ ///
+ /// # Returns
+ ///
+ /// The MTU of the interface in bytes
+ fn mtu(&self) -> usize;
+}
+
+pub trait Tun: Send + Sync + 'static {
+ type Writer: Writer;
+ type Reader: Reader;
+ type MTU: MTU;
+ type Error: Error;
+}
diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs
new file mode 100644
index 0000000..7a22280
--- /dev/null
+++ b/src/wireguard/wireguard.rs
@@ -0,0 +1,407 @@
+use super::constants::*;
+use super::handshake;
+use super::router;
+use super::timers::{Events, Timers};
+
+use super::types::bind::Reader as BindReader;
+use super::types::bind::{Bind, Writer};
+use super::types::tun::{Reader, Tun, MTU};
+use super::types::Endpoint;
+
+use hjul::Runner;
+
+use std::fmt;
+use std::ops::Deref;
+use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
+use std::sync::Arc;
+use std::thread;
+use std::time::{Duration, Instant};
+
+use std::collections::HashMap;
+
+use log::debug;
+use rand::rngs::OsRng;
+use spin::{Mutex, RwLock, RwLockReadGuard};
+
+use byteorder::{ByteOrder, LittleEndian};
+use crossbeam_channel::{bounded, Sender};
+use x25519_dalek::{PublicKey, StaticSecret};
+
+const SIZE_HANDSHAKE_QUEUE: usize = 128;
+const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
+const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000);
+
+pub struct Peer<T: Tun, B: Bind> {
+ pub router: Arc<router::Peer<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>,
+ pub state: Arc<PeerInner<B>>,
+}
+
+impl<T: Tun, B: Bind> Clone for Peer<T, B> {
+ fn clone(&self) -> Peer<T, B> {
+ Peer {
+ router: self.router.clone(),
+ state: self.state.clone(),
+ }
+ }
+}
+
+pub struct PeerInner<B: Bind> {
+ pub keepalive: AtomicUsize, // keepalive interval
+ pub rx_bytes: AtomicU64,
+ pub tx_bytes: AtomicU64,
+ pub queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, // handshake queue
+ pub pk: PublicKey, // DISCUSS: Change layout in handshake module (adopt pattern of router), to avoid this.
+ pub timers: RwLock<Timers>, //
+}
+
+impl<B: Bind> PeerInner<B> {
+ #[inline(always)]
+ pub fn timers(&self) -> RwLockReadGuard<Timers> {
+ self.timers.read()
+ }
+}
+
+impl<T: Tun, B: Bind> fmt::Display for Peer<T, B> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "peer()")
+ }
+}
+
+impl<T: Tun, B: Bind> Deref for Peer<T, B> {
+ type Target = PeerInner<B>;
+ fn deref(&self) -> &Self::Target {
+ &self.state
+ }
+}
+
+impl<B: Bind> PeerInner<B> {
+ pub fn new_handshake(&self) {
+ // TODO: clear endpoint source address ("unsticky")
+ self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap();
+ }
+}
+
+struct Handshake {
+ device: handshake::Device,
+ active: bool,
+}
+
+pub enum HandshakeJob<E> {
+ Message(Vec<u8>, E),
+ New(PublicKey),
+}
+
+struct WireguardInner<T: Tun, B: Bind> {
+ // provides access to the MTU value of the tun device
+ // (otherwise owned solely by the router and a dedicated read IO thread)
+ mtu: T::MTU,
+ send: RwLock<Option<B::Writer>>,
+
+ // identify and configuration map
+ peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>,
+
+ // cryptkey router
+ router: router::Device<B::Endpoint, Events<T, B>, T::Writer, B::Writer>,
+
+ // handshake related state
+ handshake: RwLock<Handshake>,
+ under_load: AtomicBool,
+ pending: AtomicUsize, // num of pending handshake packets in queue
+ queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>,
+}
+
+pub struct Wireguard<T: Tun, B: Bind> {
+ runner: Runner,
+ state: Arc<WireguardInner<T, B>>,
+}
+
+/* Returns the padded length of a message:
+ *
+ * # Arguments
+ *
+ * - `size` : Size of unpadded message
+ * - `mtu` : Maximum transmission unit of the device
+ *
+ * # Returns
+ *
+ * The padded length (always less than or equal to the MTU)
+ */
+#[inline(always)]
+const fn padding(size: usize, mtu: usize) -> usize {
+ #[inline(always)]
+ const fn min(a: usize, b: usize) -> usize {
+ let m = (a > b) as usize;
+ a * m + (1 - m) * b
+ }
+ let pad = MESSAGE_PADDING_MULTIPLE;
+ min(mtu, size + (pad - size % pad) % pad)
+}
+
+impl<T: Tun, B: Bind> Wireguard<T, B> {
+ pub fn set_key(&self, sk: Option<StaticSecret>) {
+ let mut handshake = self.state.handshake.write();
+ match sk {
+ None => {
+ let mut rng = OsRng::new().unwrap();
+ handshake.device.set_sk(StaticSecret::new(&mut rng));
+ handshake.active = false;
+ }
+ Some(sk) => {
+ handshake.device.set_sk(sk);
+ handshake.active = true;
+ }
+ }
+ }
+
+ pub fn get_sk(&self) -> Option<StaticSecret> {
+ let handshake = self.state.handshake.read();
+ if handshake.active {
+ Some(handshake.device.get_sk())
+ } else {
+ None
+ }
+ }
+
+ pub fn new_peer(&self, pk: PublicKey) -> Peer<T, B> {
+ let state = Arc::new(PeerInner {
+ pk,
+ queue: Mutex::new(self.state.queue.lock().clone()),
+ keepalive: AtomicUsize::new(0),
+ rx_bytes: AtomicU64::new(0),
+ tx_bytes: AtomicU64::new(0),
+ timers: RwLock::new(Timers::dummy(&self.runner)),
+ });
+
+ let router = Arc::new(self.state.router.new_peer(state.clone()));
+
+ let peer = Peer { router, state };
+
+ /* The need for dummy timers arises from the chicken-egg
+ * problem of the timer callbacks being able to set timers themselves.
+ *
+ * This is in fact the only place where the write lock is ever taken.
+ */
+ *peer.timers.write() = Timers::new(&self.runner, peer.clone());
+ peer
+ }
+
+ /* Begin consuming messages from the reader.
+ *
+ * Any previous reader thread is stopped by closing the previous reader,
+ * which unblocks the thread and causes an error on reader.read
+ */
+ pub fn add_reader(&self, reader: B::Reader) {
+ let wg = self.state.clone();
+ thread::spawn(move || {
+ let mut last_under_load =
+ Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
+
+ loop {
+ // create vector big enough for any message given current MTU
+ let size = wg.mtu.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE;
+ let mut msg: Vec<u8> = Vec::with_capacity(size);
+ msg.resize(size, 0);
+
+ // read UDP packet into vector
+ let (size, src) = match reader.read(&mut msg) {
+ Err(e) => {
+ debug!("Bind reader closed with {}", e);
+ return;
+ }
+ Ok(v) => v,
+ };
+ msg.truncate(size);
+
+ // message type de-multiplexer
+ if msg.len() < std::mem::size_of::<u32>() {
+ continue;
+ }
+ match LittleEndian::read_u32(&msg[..]) {
+ handshake::TYPE_COOKIE_REPLY
+ | handshake::TYPE_INITIATION
+ | handshake::TYPE_RESPONSE => {
+ // update under_load flag
+ if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD {
+ last_under_load = Instant::now();
+ wg.under_load.store(true, Ordering::SeqCst);
+ } else if last_under_load.elapsed() > DURATION_UNDER_LOAD {
+ wg.under_load.store(false, Ordering::SeqCst);
+ }
+
+ wg.queue
+ .lock()
+ .send(HandshakeJob::Message(msg, src))
+ .unwrap();
+ }
+ router::TYPE_TRANSPORT => {
+ // transport message
+ let _ = wg.router.recv(src, msg).map_err(|e| {
+ debug!("Failed to handle incoming transport message: {}", e);
+ });
+ }
+ _ => (),
+ }
+ }
+ });
+ }
+
+ pub fn set_writer(&self, writer: B::Writer) {
+ // TODO: Consider unifying these and avoid Clone requirement on writer
+ *self.state.send.write() = Some(writer.clone());
+ self.state.router.set_outbound_writer(writer);
+ }
+
+ pub fn new(mut readers: Vec<T::Reader>, writer: T::Writer, mtu: T::MTU) -> Wireguard<T, B> {
+ // create device state
+ let mut rng = OsRng::new().unwrap();
+ let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
+ let wg = Arc::new(WireguardInner {
+ mtu: mtu.clone(),
+ peers: RwLock::new(HashMap::new()),
+ send: RwLock::new(None),
+ router: router::Device::new(num_cpus::get(), writer), // router owns the writing half
+ pending: AtomicUsize::new(0),
+ handshake: RwLock::new(Handshake {
+ device: handshake::Device::new(StaticSecret::new(&mut rng)),
+ active: false,
+ }),
+ under_load: AtomicBool::new(false),
+ queue: Mutex::new(tx),
+ });
+
+ // start handshake workers
+ for _ in 0..num_cpus::get() {
+ let wg = wg.clone();
+ let rx = rx.clone();
+ thread::spawn(move || {
+ // prepare OsRng instance for this thread
+ let mut rng = OsRng::new().unwrap();
+
+ // process elements from the handshake queue
+ for job in rx {
+ wg.pending.fetch_sub(1, Ordering::SeqCst);
+ let state = wg.handshake.read();
+ if !state.active {
+ continue;
+ }
+
+ match job {
+ HandshakeJob::Message(msg, src) => {
+ // feed message to handshake device
+ let src_validate = (&src).into_address(); // TODO avoid
+
+ // process message
+ match state.device.process(
+ &mut rng,
+ &msg[..],
+ if wg.under_load.load(Ordering::Relaxed) {
+ Some(&src_validate)
+ } else {
+ None
+ },
+ ) {
+ Ok((pk, resp, keypair)) => {
+ // send response
+ let mut resp_len: u64 = 0;
+ if let Some(msg) = resp {
+ resp_len = msg.len() as u64;
+ let send: &Option<B::Writer> = &*wg.send.read();
+ if let Some(writer) = send.as_ref() {
+ let _ = writer.write(&msg[..], &src).map_err(|e| {
+ debug!(
+ "handshake worker, failed to send response, error = {}",
+ e
+ )
+ });
+ }
+ }
+
+ // update timers
+ if let Some(pk) = pk {
+ // authenticated handshake packet received
+ if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
+ // add to rx_bytes and tx_bytes
+ let req_len = msg.len() as u64;
+ peer.rx_bytes.fetch_add(req_len, Ordering::Relaxed);
+ peer.tx_bytes.fetch_add(resp_len, Ordering::Relaxed);
+
+ // update endpoint
+ peer.router.set_endpoint(src);
+
+ // add keypair to peer
+ keypair.map(|kp| {
+ // free any unused ids
+ for id in peer.router.add_keypair(kp) {
+ state.device.release(id);
+ }
+ });
+ }
+ }
+ }
+ Err(e) => debug!("handshake worker, error = {:?}", e),
+ }
+ }
+ HandshakeJob::New(pk) => {
+ let _ = state.device.begin(&mut rng, &pk).map(|msg| {
+ if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
+ let _ = peer.router.send(&msg[..]).map_err(|e| {
+ debug!("handshake worker, failed to send handshake initiation, error = {}", e)
+ });
+ }
+ });
+ }
+ }
+ }
+ });
+ }
+
+ // start TUN read IO threads (multiple threads to support multi-queue interfaces)
+ debug_assert!(
+ readers.len() > 0,
+ "attempted to create WG device without TUN readers"
+ );
+ while let Some(reader) = readers.pop() {
+ let wg = wg.clone();
+ let mtu = mtu.clone();
+ thread::spawn(move || loop {
+ // create vector big enough for any transport message (based on MTU)
+ let mtu = mtu.mtu();
+ let size = mtu + router::SIZE_MESSAGE_PREFIX;
+ let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
+ msg.resize(size, 0);
+
+ // read a new IP packet
+ let payload = match reader.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX) {
+ Ok(payload) => payload,
+ Err(e) => {
+ debug!("TUN worker, failed to read from tun device: {}", e);
+ return;
+ }
+ };
+ debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu);
+
+ // truncate padding
+ let payload = padding(payload, mtu);
+ msg.truncate(router::SIZE_MESSAGE_PREFIX + payload);
+ debug_assert!(payload <= mtu);
+ debug_assert_eq!(
+ if payload < mtu {
+ (msg.len() - router::SIZE_MESSAGE_PREFIX) % MESSAGE_PADDING_MULTIPLE
+ } else {
+ 0
+ },
+ 0
+ );
+
+ // crypt-key route
+ let e = wg.router.send(msg);
+ debug!("TUN worker, router returned {:?}", e);
+ });
+ }
+
+ Wireguard {
+ state: wg,
+ runner: Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY),
+ }
+ }
+}