summaryrefslogtreecommitdiffstats
path: root/src/wireguard/router
diff options
context:
space:
mode:
Diffstat (limited to 'src/wireguard/router')
-rw-r--r--src/wireguard/router/anti_replay.rs157
-rw-r--r--src/wireguard/router/constants.rs7
-rw-r--r--src/wireguard/router/device.rs243
-rw-r--r--src/wireguard/router/ip.rs26
-rw-r--r--src/wireguard/router/messages.rs13
-rw-r--r--src/wireguard/router/mod.rs22
-rw-r--r--src/wireguard/router/peer.rs611
-rw-r--r--src/wireguard/router/tests.rs432
-rw-r--r--src/wireguard/router/types.rs65
-rw-r--r--src/wireguard/router/workers.rs305
10 files changed, 1881 insertions, 0 deletions
diff --git a/src/wireguard/router/anti_replay.rs b/src/wireguard/router/anti_replay.rs
new file mode 100644
index 0000000..b0838bd
--- /dev/null
+++ b/src/wireguard/router/anti_replay.rs
@@ -0,0 +1,157 @@
+use std::mem;
+
+// Implementation of RFC 6479.
+// https://tools.ietf.org/html/rfc6479
+
+#[cfg(target_pointer_width = "64")]
+type Word = u64;
+
+#[cfg(target_pointer_width = "64")]
+const REDUNDANT_BIT_SHIFTS: usize = 6;
+
+#[cfg(target_pointer_width = "32")]
+type Word = u32;
+
+#[cfg(target_pointer_width = "32")]
+const REDUNDANT_BIT_SHIFTS: usize = 5;
+
+const SIZE_OF_WORD: usize = mem::size_of::<Word>() * 8;
+
+const BITMAP_BITLEN: usize = 2048;
+const BITMAP_LEN: usize = (BITMAP_BITLEN / SIZE_OF_WORD);
+const BITMAP_INDEX_MASK: u64 = BITMAP_LEN as u64 - 1;
+const BITMAP_LOC_MASK: u64 = (SIZE_OF_WORD - 1) as u64;
+const WINDOW_SIZE: u64 = (BITMAP_BITLEN - SIZE_OF_WORD) as u64;
+
+pub struct AntiReplay {
+ bitmap: [Word; BITMAP_LEN],
+ last: u64,
+}
+
+impl Default for AntiReplay {
+ fn default() -> Self {
+ AntiReplay::new()
+ }
+}
+
+impl AntiReplay {
+ pub fn new() -> Self {
+ debug_assert_eq!(1 << REDUNDANT_BIT_SHIFTS, SIZE_OF_WORD);
+ debug_assert_eq!(BITMAP_BITLEN % SIZE_OF_WORD, 0);
+ AntiReplay {
+ last: 0,
+ bitmap: [0; BITMAP_LEN],
+ }
+ }
+
+ // Returns true if check is passed, i.e., not a replay or too old.
+ //
+ // Unlike RFC 6479, zero is allowed.
+ fn check(&self, seq: u64) -> bool {
+ // Larger is always good.
+ if seq > self.last {
+ return true;
+ }
+
+ if self.last - seq > WINDOW_SIZE {
+ return false;
+ }
+
+ let bit_location = seq & BITMAP_LOC_MASK;
+ let index = (seq >> REDUNDANT_BIT_SHIFTS) & BITMAP_INDEX_MASK;
+
+ self.bitmap[index as usize] & (1 << bit_location) == 0
+ }
+
+ // Should only be called if check returns true.
+ fn update_store(&mut self, seq: u64) {
+ debug_assert!(self.check(seq));
+
+ let index = seq >> REDUNDANT_BIT_SHIFTS;
+
+ if seq > self.last {
+ let index_cur = self.last >> REDUNDANT_BIT_SHIFTS;
+ let diff = index - index_cur;
+
+ if diff >= BITMAP_LEN as u64 {
+ self.bitmap = [0; BITMAP_LEN];
+ } else {
+ for i in 0..diff {
+ let real_index = (index_cur + i + 1) & BITMAP_INDEX_MASK;
+ self.bitmap[real_index as usize] = 0;
+ }
+ }
+
+ self.last = seq;
+ }
+
+ let index = index & BITMAP_INDEX_MASK;
+ let bit_location = seq & BITMAP_LOC_MASK;
+ self.bitmap[index as usize] |= 1 << bit_location;
+ }
+
+ /// Checks and marks a sequence number in the replay filter
+ ///
+ /// # Arguments
+ ///
+ /// - seq: Sequence number check for replay and add to filter
+ ///
+ /// # Returns
+ ///
+ /// Ok(()) if sequence number is valid (not marked and not behind the moving window).
+ /// Err if the sequence number is invalid (already marked or "too old").
+ pub fn update(&mut self, seq: u64) -> bool {
+ if self.check(seq) {
+ self.update_store(seq);
+ true
+ } else {
+ false
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn anti_replay() {
+ let mut ar = AntiReplay::new();
+
+ for i in 0..20000 {
+ assert!(ar.update(i));
+ }
+
+ for i in (0..20000).rev() {
+ assert!(!ar.check(i));
+ }
+
+ assert!(ar.update(65536));
+ for i in (65536 - WINDOW_SIZE)..65535 {
+ assert!(ar.update(i));
+ }
+
+ for i in (65536 - 10 * WINDOW_SIZE)..65535 {
+ assert!(!ar.check(i));
+ }
+
+ assert!(ar.update(66000));
+ for i in 65537..66000 {
+ assert!(ar.update(i));
+ }
+ for i in 65537..66000 {
+ assert_eq!(ar.update(i), false);
+ }
+
+ // Test max u64.
+ let next = u64::max_value();
+ assert!(ar.update(next));
+ assert!(!ar.check(next));
+ for i in (next - WINDOW_SIZE)..next {
+ assert!(ar.update(i));
+ }
+ for i in (next - 20 * WINDOW_SIZE)..next {
+ assert!(!ar.check(i));
+ }
+ }
+}
diff --git a/src/wireguard/router/constants.rs b/src/wireguard/router/constants.rs
new file mode 100644
index 0000000..0ca824a
--- /dev/null
+++ b/src/wireguard/router/constants.rs
@@ -0,0 +1,7 @@
+// WireGuard semantics constants
+
+pub const MAX_STAGED_PACKETS: usize = 128;
+
+// performance constants
+
+pub const WORKER_QUEUE_SIZE: usize = MAX_STAGED_PACKETS;
diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs
new file mode 100644
index 0000000..455020c
--- /dev/null
+++ b/src/wireguard/router/device.rs
@@ -0,0 +1,243 @@
+use std::collections::HashMap;
+use std::net::{Ipv4Addr, Ipv6Addr};
+use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
+use std::sync::mpsc::sync_channel;
+use std::sync::mpsc::SyncSender;
+use std::sync::Arc;
+use std::thread;
+use std::time::Instant;
+
+use log::debug;
+use spin::{Mutex, RwLock};
+use treebitmap::IpLookupTable;
+use zerocopy::LayoutVerified;
+
+use super::anti_replay::AntiReplay;
+use super::constants::*;
+use super::ip::*;
+use super::messages::{TransportHeader, TYPE_TRANSPORT};
+use super::peer::{new_peer, Peer, PeerInner};
+use super::types::{Callbacks, RouterError};
+use super::workers::{worker_parallel, JobParallel, Operation};
+use super::SIZE_MESSAGE_PREFIX;
+
+use super::super::types::{bind, tun, Endpoint, KeyPair};
+
+pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
+ // inbound writer (TUN)
+ pub inbound: T,
+
+ // outbound writer (Bind)
+ pub outbound: RwLock<Option<B>>,
+
+ // routing
+ pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state
+ pub ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv4 cryptkey routing
+ pub ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv6 cryptkey routing
+
+ // work queues
+ pub queue_next: AtomicUsize, // next round-robin index
+ pub queues: Mutex<Vec<SyncSender<JobParallel>>>, // work queues (1 per thread)
+}
+
+pub struct EncryptionState {
+ pub key: [u8; 32], // encryption key
+ pub id: u32, // receiver id
+ pub nonce: u64, // next available nonce
+ pub death: Instant, // (birth + reject-after-time - keepalive-timeout - rekey-timeout)
+}
+
+pub struct DecryptionState<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
+ pub keypair: Arc<KeyPair>,
+ pub confirmed: AtomicBool,
+ pub protector: Mutex<AntiReplay>,
+ pub peer: Arc<PeerInner<E, C, T, B>>,
+ pub death: Instant, // time when the key can no longer be used for decryption
+}
+
+pub struct Device<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
+ state: Arc<DeviceInner<E, C, T, B>>, // reference to device state
+ handles: Vec<thread::JoinHandle<()>>, // join handles for workers
+}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Device<E, C, T, B> {
+ fn drop(&mut self) {
+ debug!("router: dropping device");
+
+ // drop all queues
+ {
+ let mut queues = self.state.queues.lock();
+ while queues.pop().is_some() {}
+ }
+
+ // join all worker threads
+ while match self.handles.pop() {
+ Some(handle) => {
+ handle.thread().unpark();
+ handle.join().unwrap();
+ true
+ }
+ _ => false,
+ } {}
+
+ debug!("router: device dropped");
+ }
+}
+
+#[inline(always)]
+fn get_route<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+ device: &Arc<DeviceInner<E, C, T, B>>,
+ packet: &[u8],
+) -> Option<Arc<PeerInner<E, C, T, B>>> {
+ // ensure version access within bounds
+ if packet.len() < 1 {
+ return None;
+ };
+
+ // cast to correct IP header
+ match packet[0] >> 4 {
+ VERSION_IP4 => {
+ // check length and cast to IPv4 header
+ let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) =
+ LayoutVerified::new_from_prefix(packet)?;
+
+ // lookup destination address
+ device
+ .ipv4
+ .read()
+ .longest_match(Ipv4Addr::from(header.f_destination))
+ .and_then(|(_, _, p)| Some(p.clone()))
+ }
+ VERSION_IP6 => {
+ // check length and cast to IPv6 header
+ let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) =
+ LayoutVerified::new_from_prefix(packet)?;
+
+ // lookup destination address
+ device
+ .ipv6
+ .read()
+ .longest_match(Ipv6Addr::from(header.f_destination))
+ .and_then(|(_, _, p)| Some(p.clone()))
+ }
+ _ => None,
+ }
+}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C, T, B> {
+ pub fn new(num_workers: usize, tun: T) -> Device<E, C, T, B> {
+ // allocate shared device state
+ let inner = DeviceInner {
+ inbound: tun,
+ outbound: RwLock::new(None),
+ queues: Mutex::new(Vec::with_capacity(num_workers)),
+ queue_next: AtomicUsize::new(0),
+ recv: RwLock::new(HashMap::new()),
+ ipv4: RwLock::new(IpLookupTable::new()),
+ ipv6: RwLock::new(IpLookupTable::new()),
+ };
+
+ // start worker threads
+ let mut threads = Vec::with_capacity(num_workers);
+ for _ in 0..num_workers {
+ let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE);
+ inner.queues.lock().push(tx);
+ threads.push(thread::spawn(move || worker_parallel(rx)));
+ }
+
+ // return exported device handle
+ Device {
+ state: Arc::new(inner),
+ handles: threads,
+ }
+ }
+
+ /// A new secret key has been set for the device.
+ /// According to WireGuard semantics, this should cause all "sending" keys to be discarded.
+ pub fn new_sk(&self) {}
+
+ /// Adds a new peer to the device
+ ///
+ /// # Returns
+ ///
+ /// A atomic ref. counted peer (with liftime matching the device)
+ pub fn new_peer(&self, opaque: C::Opaque) -> Peer<E, C, T, B> {
+ new_peer(self.state.clone(), opaque)
+ }
+
+ /// Cryptkey routes and sends a plaintext message (IP packet)
+ ///
+ /// # Arguments
+ ///
+ /// - msg: IP packet to crypt-key route
+ ///
+ pub fn send(&self, msg: Vec<u8>) -> Result<(), RouterError> {
+ // ignore header prefix (for in-place transport message construction)
+ let packet = &msg[SIZE_MESSAGE_PREFIX..];
+
+ // lookup peer based on IP packet destination address
+ let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptKeyRoute)?;
+
+ // schedule for encryption and transmission to peer
+ if let Some(job) = peer.send_job(msg, true) {
+ debug_assert_eq!(job.1.op, Operation::Encryption);
+
+ // add job to worker queue
+ let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst);
+ let queues = self.state.queues.lock();
+ queues[idx % queues.len()].send(job).unwrap();
+ }
+
+ Ok(())
+ }
+
+ /// Receive an encrypted transport message
+ ///
+ /// # Arguments
+ ///
+ /// - src: Source address of the packet
+ /// - msg: Encrypted transport message
+ ///
+ /// # Returns
+ ///
+ ///
+ pub fn recv(&self, src: E, msg: Vec<u8>) -> Result<(), RouterError> {
+ // parse / cast
+ let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) {
+ Some(v) => v,
+ None => {
+ return Err(RouterError::MalformedTransportMessage);
+ }
+ };
+ let header: LayoutVerified<&[u8], TransportHeader> = header;
+ debug_assert!(
+ header.f_type.get() == TYPE_TRANSPORT as u32,
+ "this should be checked by the message type multiplexer"
+ );
+
+ // lookup peer based on receiver id
+ let dec = self.state.recv.read();
+ let dec = dec
+ .get(&header.f_receiver.get())
+ .ok_or(RouterError::UnknownReceiverId)?;
+
+ // schedule for decryption and TUN write
+ if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) {
+ debug_assert_eq!(job.1.op, Operation::Decryption);
+
+ // add job to worker queue
+ let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst);
+ let queues = self.state.queues.lock();
+ queues[idx % queues.len()].send(job).unwrap();
+ }
+
+ Ok(())
+ }
+
+ /// Set outbound writer
+ ///
+ ///
+ pub fn set_outbound_writer(&self, new: B) {
+ *self.state.outbound.write() = Some(new);
+ }
+}
diff --git a/src/wireguard/router/ip.rs b/src/wireguard/router/ip.rs
new file mode 100644
index 0000000..e66144f
--- /dev/null
+++ b/src/wireguard/router/ip.rs
@@ -0,0 +1,26 @@
+use byteorder::BigEndian;
+use zerocopy::byteorder::U16;
+use zerocopy::{AsBytes, FromBytes};
+
+pub const VERSION_IP4: u8 = 4;
+pub const VERSION_IP6: u8 = 6;
+
+#[repr(packed)]
+#[derive(Copy, Clone, FromBytes, AsBytes)]
+pub struct IPv4Header {
+ _f_space1: [u8; 2],
+ pub f_total_len: U16<BigEndian>,
+ _f_space2: [u8; 8],
+ pub f_source: [u8; 4],
+ pub f_destination: [u8; 4],
+}
+
+#[repr(packed)]
+#[derive(Copy, Clone, FromBytes, AsBytes)]
+pub struct IPv6Header {
+ _f_space1: [u8; 4],
+ pub f_len: U16<BigEndian>,
+ _f_space2: [u8; 2],
+ pub f_source: [u8; 16],
+ pub f_destination: [u8; 16],
+}
diff --git a/src/wireguard/router/messages.rs b/src/wireguard/router/messages.rs
new file mode 100644
index 0000000..bf4d13b
--- /dev/null
+++ b/src/wireguard/router/messages.rs
@@ -0,0 +1,13 @@
+use byteorder::LittleEndian;
+use zerocopy::byteorder::{U32, U64};
+use zerocopy::{AsBytes, FromBytes};
+
+pub const TYPE_TRANSPORT: u32 = 4;
+
+#[repr(packed)]
+#[derive(Copy, Clone, FromBytes, AsBytes)]
+pub struct TransportHeader {
+ pub f_type: U32<LittleEndian>,
+ pub f_receiver: U32<LittleEndian>,
+ pub f_counter: U64<LittleEndian>,
+}
diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs
new file mode 100644
index 0000000..7a29cd9
--- /dev/null
+++ b/src/wireguard/router/mod.rs
@@ -0,0 +1,22 @@
+mod anti_replay;
+mod constants;
+mod device;
+mod ip;
+mod messages;
+mod peer;
+mod types;
+mod workers;
+
+#[cfg(test)]
+mod tests;
+
+use messages::TransportHeader;
+use std::mem;
+
+pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::<TransportHeader>();
+pub const CAPACITY_MESSAGE_POSTFIX: usize = 16;
+
+pub use messages::TYPE_TRANSPORT;
+pub use device::Device;
+pub use peer::Peer;
+pub use types::Callbacks;
diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs
new file mode 100644
index 0000000..4f47604
--- /dev/null
+++ b/src/wireguard/router/peer.rs
@@ -0,0 +1,611 @@
+use std::mem;
+use std::net::{IpAddr, SocketAddr};
+use std::sync::atomic::AtomicBool;
+use std::sync::atomic::Ordering;
+use std::sync::mpsc::{sync_channel, SyncSender};
+use std::sync::Arc;
+use std::thread;
+
+use arraydeque::{ArrayDeque, Wrapping};
+use log::debug;
+use spin::Mutex;
+use treebitmap::address::Address;
+use treebitmap::IpLookupTable;
+use zerocopy::LayoutVerified;
+
+use super::super::constants::*;
+use super::super::types::{bind, tun, Endpoint, KeyPair};
+
+use super::anti_replay::AntiReplay;
+use super::device::DecryptionState;
+use super::device::DeviceInner;
+use super::device::EncryptionState;
+use super::messages::TransportHeader;
+
+use futures::*;
+
+use super::workers::Operation;
+use super::workers::{worker_inbound, worker_outbound};
+use super::workers::{JobBuffer, JobInbound, JobOutbound, JobParallel};
+use super::SIZE_MESSAGE_PREFIX;
+
+use super::constants::*;
+use super::types::{Callbacks, RouterError};
+
+pub struct KeyWheel {
+ next: Option<Arc<KeyPair>>, // next key state (unconfirmed)
+ current: Option<Arc<KeyPair>>, // current key state (used for encryption)
+ previous: Option<Arc<KeyPair>>, // old key state (used for decryption)
+ retired: Vec<u32>, // retired ids
+}
+
+pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
+ pub device: Arc<DeviceInner<E, C, T, B>>,
+ pub opaque: C::Opaque,
+ pub outbound: Mutex<SyncSender<JobOutbound>>,
+ pub inbound: Mutex<SyncSender<JobInbound<E, C, T, B>>>,
+ pub staged_packets: Mutex<ArrayDeque<[Vec<u8>; MAX_STAGED_PACKETS], Wrapping>>,
+ pub keys: Mutex<KeyWheel>,
+ pub ekey: Mutex<Option<EncryptionState>>,
+ pub endpoint: Mutex<Option<E>>,
+}
+
+pub struct Peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
+ state: Arc<PeerInner<E, C, T, B>>,
+ thread_outbound: Option<thread::JoinHandle<()>>,
+ thread_inbound: Option<thread::JoinHandle<()>>,
+}
+
+fn treebit_list<A, R, E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+ peer: &Arc<PeerInner<E, C, T, B>>,
+ table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<E, C, T, B>>>>,
+ callback: Box<dyn Fn(A, u32) -> R>,
+) -> Vec<R>
+where
+ A: Address,
+{
+ let mut res = Vec::new();
+ for subnet in table.read().iter() {
+ let (ip, masklen, p) = subnet;
+ if Arc::ptr_eq(&p, &peer) {
+ res.push(callback(ip, masklen))
+ }
+ }
+ res
+}
+
+fn treebit_remove<E: Endpoint, A: Address, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+ peer: &Peer<E, C, T, B>,
+ table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<E, C, T, B>>>>,
+) {
+ let mut m = table.write();
+
+ // collect keys for value
+ let mut subnets = vec![];
+ for subnet in m.iter() {
+ let (ip, masklen, p) = subnet;
+ if Arc::ptr_eq(&p, &peer.state) {
+ subnets.push((ip, masklen))
+ }
+ }
+
+ // remove all key mappings
+ for (ip, masklen) in subnets {
+ let r = m.remove(ip, masklen);
+ debug_assert!(r.is_some());
+ }
+}
+
+impl EncryptionState {
+ fn new(keypair: &Arc<KeyPair>) -> EncryptionState {
+ EncryptionState {
+ id: keypair.send.id,
+ key: keypair.send.key,
+ nonce: 0,
+ death: keypair.birth + REJECT_AFTER_TIME,
+ }
+ }
+}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> DecryptionState<E, C, T, B> {
+ fn new(
+ peer: &Arc<PeerInner<E, C, T, B>>,
+ keypair: &Arc<KeyPair>,
+ ) -> DecryptionState<E, C, T, B> {
+ DecryptionState {
+ confirmed: AtomicBool::new(keypair.initiator),
+ keypair: keypair.clone(),
+ protector: spin::Mutex::new(AntiReplay::new()),
+ peer: peer.clone(),
+ death: keypair.birth + REJECT_AFTER_TIME,
+ }
+ }
+}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Peer<E, C, T, B> {
+ fn drop(&mut self) {
+ let peer = &self.state;
+
+ // remove from cryptkey router
+
+ treebit_remove(self, &peer.device.ipv4);
+ treebit_remove(self, &peer.device.ipv6);
+
+ // drop channels
+
+ mem::replace(&mut *peer.inbound.lock(), sync_channel(0).0);
+ mem::replace(&mut *peer.outbound.lock(), sync_channel(0).0);
+
+ // join with workers
+
+ mem::replace(&mut self.thread_inbound, None).map(|v| v.join());
+ mem::replace(&mut self.thread_outbound, None).map(|v| v.join());
+
+ // release ids from the receiver map
+
+ let mut keys = peer.keys.lock();
+ let mut release = Vec::with_capacity(3);
+
+ keys.next.as_ref().map(|k| release.push(k.recv.id));
+ keys.current.as_ref().map(|k| release.push(k.recv.id));
+ keys.previous.as_ref().map(|k| release.push(k.recv.id));
+
+ if release.len() > 0 {
+ let mut recv = peer.device.recv.write();
+ for id in &release {
+ recv.remove(id);
+ }
+ }
+
+ // null key-material
+
+ keys.next = None;
+ keys.current = None;
+ keys.previous = None;
+
+ *peer.ekey.lock() = None;
+ *peer.endpoint.lock() = None;
+
+ debug!("peer dropped & removed from device");
+ }
+}
+
+pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+ device: Arc<DeviceInner<E, C, T, B>>,
+ opaque: C::Opaque,
+) -> Peer<E, C, T, B> {
+ let (out_tx, out_rx) = sync_channel(128);
+ let (in_tx, in_rx) = sync_channel(128);
+
+ // allocate peer object
+ let peer = {
+ let device = device.clone();
+ Arc::new(PeerInner {
+ opaque,
+ device,
+ inbound: Mutex::new(in_tx),
+ outbound: Mutex::new(out_tx),
+ ekey: spin::Mutex::new(None),
+ endpoint: spin::Mutex::new(None),
+ keys: spin::Mutex::new(KeyWheel {
+ next: None,
+ current: None,
+ previous: None,
+ retired: vec![],
+ }),
+ staged_packets: spin::Mutex::new(ArrayDeque::new()),
+ })
+ };
+
+ // spawn outbound thread
+ let thread_inbound = {
+ let peer = peer.clone();
+ let device = device.clone();
+ thread::spawn(move || worker_outbound(device, peer, out_rx))
+ };
+
+ // spawn inbound thread
+ let thread_outbound = {
+ let peer = peer.clone();
+ let device = device.clone();
+ thread::spawn(move || worker_inbound(device, peer, in_rx))
+ };
+
+ Peer {
+ state: peer,
+ thread_inbound: Some(thread_inbound),
+ thread_outbound: Some(thread_outbound),
+ }
+}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> PeerInner<E, C, T, B> {
+ fn send_staged(&self) -> bool {
+ debug!("peer.send_staged");
+ let mut sent = false;
+ let mut staged = self.staged_packets.lock();
+ loop {
+ match staged.pop_front() {
+ Some(msg) => {
+ sent = true;
+ self.send_raw(msg);
+ }
+ None => break sent,
+ }
+ }
+ }
+
+ // Treat the msg as the payload of a transport message
+ // Unlike device.send, peer.send_raw does not buffer messages when a key is not available.
+ fn send_raw(&self, msg: Vec<u8>) -> bool {
+ debug!("peer.send_raw");
+ match self.send_job(msg, false) {
+ Some(job) => {
+ debug!("send_raw: got obtained send_job");
+ let index = self.device.queue_next.fetch_add(1, Ordering::SeqCst);
+ let queues = self.device.queues.lock();
+ match queues[index % queues.len()].send(job) {
+ Ok(_) => true,
+ Err(_) => false,
+ }
+ }
+ None => false,
+ }
+ }
+
+ pub fn confirm_key(&self, keypair: &Arc<KeyPair>) {
+ debug!("peer.confirm_key");
+ {
+ // take lock and check keypair = keys.next
+ let mut keys = self.keys.lock();
+ let next = match keys.next.as_ref() {
+ Some(next) => next,
+ None => {
+ return;
+ }
+ };
+ if !Arc::ptr_eq(&next, keypair) {
+ return;
+ }
+
+ // allocate new encryption state
+ let ekey = Some(EncryptionState::new(&next));
+
+ // rotate key-wheel
+ let mut swap = None;
+ mem::swap(&mut keys.next, &mut swap);
+ mem::swap(&mut keys.current, &mut swap);
+ mem::swap(&mut keys.previous, &mut swap);
+
+ // tell the world outside the router that a key was confirmed
+ C::key_confirmed(&self.opaque);
+
+ // set new key for encryption
+ *self.ekey.lock() = ekey;
+ }
+
+ // start transmission of staged packets
+ self.send_staged();
+ }
+
+ pub fn recv_job(
+ &self,
+ src: E,
+ dec: Arc<DecryptionState<E, C, T, B>>,
+ msg: Vec<u8>,
+ ) -> Option<JobParallel> {
+ let (tx, rx) = oneshot();
+ let key = dec.keypair.recv.key;
+ match self.inbound.lock().try_send((dec, src, rx)) {
+ Ok(_) => Some((
+ tx,
+ JobBuffer {
+ msg,
+ key: key,
+ okay: false,
+ op: Operation::Decryption,
+ },
+ )),
+ Err(_) => None,
+ }
+ }
+
+ pub fn send_job(&self, mut msg: Vec<u8>, stage: bool) -> Option<JobParallel> {
+ debug!("peer.send_job");
+ debug_assert!(
+ msg.len() >= mem::size_of::<TransportHeader>(),
+ "received message with size: {:}",
+ msg.len()
+ );
+
+ // parse / cast
+ let (header, _) = LayoutVerified::new_from_prefix(&mut msg[..]).unwrap();
+ let mut header: LayoutVerified<&mut [u8], TransportHeader> = header;
+
+ // check if has key
+ let key = {
+ let mut ekey = self.ekey.lock();
+ let key = match ekey.as_mut() {
+ None => None,
+ Some(mut state) => {
+ // avoid integer overflow in nonce
+ if state.nonce >= REJECT_AFTER_MESSAGES - 1 {
+ *ekey = None;
+ None
+ } else {
+ // there should be no stacked packets lingering around
+ debug!("encryption state available, nonce = {}", state.nonce);
+
+ // set transport message fields
+ header.f_counter.set(state.nonce);
+ header.f_receiver.set(state.id);
+ state.nonce += 1;
+ Some(state.key)
+ }
+ }
+ };
+
+ // If not suitable key was found:
+ // 1. Stage packet for later transmission
+ // 2. Request new key
+ if key.is_none() && stage {
+ self.staged_packets.lock().push_back(msg);
+ C::need_key(&self.opaque);
+ return None;
+ };
+
+ key
+ }?;
+
+ // add job to in-order queue and return sendeer to device for inclusion in worker pool
+ let (tx, rx) = oneshot();
+ match self.outbound.lock().try_send(rx) {
+ Ok(_) => Some((
+ tx,
+ JobBuffer {
+ msg,
+ key,
+ okay: false,
+ op: Operation::Encryption,
+ },
+ )),
+ Err(_) => None,
+ }
+ }
+}
+
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T, B> {
+ /// Set the endpoint of the peer
+ ///
+ /// # Arguments
+ ///
+ /// - `endpoint`, socket address converted to bind endpoint
+ ///
+ /// # Note
+ ///
+ /// This API still permits support for the "sticky socket" behavior,
+ /// as sockets should be "unsticked" when manually updating the endpoint
+ pub fn set_endpoint(&self, endpoint: E) {
+ debug!("peer.set_endpoint");
+ *self.state.endpoint.lock() = Some(endpoint);
+ }
+
+ /// Returns the current endpoint of the peer (for configuration)
+ ///
+ /// # Note
+ ///
+ /// Does not convey potential "sticky socket" information
+ pub fn get_endpoint(&self) -> Option<SocketAddr> {
+ debug!("peer.get_endpoint");
+ self.state
+ .endpoint
+ .lock()
+ .as_ref()
+ .map(|e| e.into_address())
+ }
+
+ /// Zero all key-material related to the peer
+ pub fn zero_keys(&self) {
+ debug!("peer.zero_keys");
+
+ let mut release: Vec<u32> = Vec::with_capacity(3);
+ let mut keys = self.state.keys.lock();
+
+ // update key-wheel
+
+ mem::replace(&mut keys.next, None).map(|k| release.push(k.local_id()));
+ mem::replace(&mut keys.current, None).map(|k| release.push(k.local_id()));
+ mem::replace(&mut keys.previous, None).map(|k| release.push(k.local_id()));
+ keys.retired.extend(&release[..]);
+
+ // update inbound "recv" map
+ {
+ let mut recv = self.state.device.recv.write();
+ for id in release {
+ recv.remove(&id);
+ }
+ }
+
+ // clear encryption state
+ *self.state.ekey.lock() = None;
+ }
+
+ /// Add a new keypair
+ ///
+ /// # Arguments
+ ///
+ /// - new: The new confirmed/unconfirmed key pair
+ ///
+ /// # Returns
+ ///
+ /// A vector of ids which has been released.
+ /// These should be released in the handshake module.
+ ///
+ /// # Note
+ ///
+ /// The number of ids to be released can be at most 3,
+ /// since the only way to add additional keys to the peer is by using this method
+ /// and a peer can have at most 3 keys allocated in the router at any time.
+ pub fn add_keypair(&self, new: KeyPair) -> Vec<u32> {
+ debug!("peer.add_keypair");
+
+ let initiator = new.initiator;
+ let release = {
+ let new = Arc::new(new);
+ let mut keys = self.state.keys.lock();
+ let mut release = mem::replace(&mut keys.retired, vec![]);
+
+ // update key-wheel
+ if new.initiator {
+ // start using key for encryption
+ *self.state.ekey.lock() = Some(EncryptionState::new(&new));
+
+ // move current into previous
+ keys.previous = keys.current.as_ref().map(|v| v.clone());
+ keys.current = Some(new.clone());
+ } else {
+ // store the key and await confirmation
+ keys.previous = keys.next.as_ref().map(|v| v.clone());
+ keys.next = Some(new.clone());
+ };
+
+ // update incoming packet id map
+ {
+ debug!("peer.add_keypair: updating inbound id map");
+ let mut recv = self.state.device.recv.write();
+
+ // purge recv map of previous id
+ keys.previous.as_ref().map(|k| {
+ recv.remove(&k.local_id());
+ release.push(k.local_id());
+ });
+
+ // map new id to decryption state
+ debug_assert!(!recv.contains_key(&new.recv.id));
+ recv.insert(
+ new.recv.id,
+ Arc::new(DecryptionState::new(&self.state, &new)),
+ );
+ }
+ release
+ };
+
+ // schedule confirmation
+ if initiator {
+ debug_assert!(self.state.ekey.lock().is_some());
+ debug!("peer.add_keypair: is initiator, must confirm the key");
+ // attempt to confirm using staged packets
+ if !self.state.send_staged() {
+ // fall back to keepalive packet
+ let ok = self.send_keepalive();
+ debug!(
+ "peer.add_keypair: keepalive for confirmation, sent = {}",
+ ok
+ );
+ }
+ debug!("peer.add_keypair: key attempted confirmed");
+ }
+
+ debug_assert!(
+ release.len() <= 3,
+ "since the key-wheel contains at most 3 keys"
+ );
+ release
+ }
+
+ pub fn send_keepalive(&self) -> bool {
+ debug!("peer.send_keepalive");
+ self.state.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX])
+ }
+
+ /// Map a subnet to the peer
+ ///
+ /// # Arguments
+ ///
+ /// - `ip`, the mask of the subnet
+ /// - `masklen`, the length of the mask
+ ///
+ /// # Note
+ ///
+ /// The `ip` must not have any bits set right of `masklen`.
+ /// e.g. `192.168.1.0/24` is valid, while `192.168.1.128/24` is not.
+ ///
+ /// If an identical value already exists as part of a prior peer,
+ /// the allowed IP entry will be removed from that peer and added to this peer.
+ pub fn add_subnet(&self, ip: IpAddr, masklen: u32) {
+ debug!("peer.add_subnet");
+ match ip {
+ IpAddr::V4(v4) => {
+ self.state
+ .device
+ .ipv4
+ .write()
+ .insert(v4, masklen, self.state.clone())
+ }
+ IpAddr::V6(v6) => {
+ self.state
+ .device
+ .ipv6
+ .write()
+ .insert(v6, masklen, self.state.clone())
+ }
+ };
+ }
+
+ /// List subnets mapped to the peer
+ ///
+ /// # Returns
+ ///
+ /// A vector of subnets, represented by as mask/size
+ pub fn list_subnets(&self) -> Vec<(IpAddr, u32)> {
+ debug!("peer.list_subnets");
+ let mut res = Vec::new();
+ res.append(&mut treebit_list(
+ &self.state,
+ &self.state.device.ipv4,
+ Box::new(|ip, masklen| (IpAddr::V4(ip), masklen)),
+ ));
+ res.append(&mut treebit_list(
+ &self.state,
+ &self.state.device.ipv6,
+ Box::new(|ip, masklen| (IpAddr::V6(ip), masklen)),
+ ));
+ res
+ }
+
+ /// Clear subnets mapped to the peer.
+ /// After the call, no subnets will be cryptkey routed to the peer.
+ /// Used for the UAPI command "replace_allowed_ips=true"
+ pub fn remove_subnets(&self) {
+ debug!("peer.remove_subnets");
+ treebit_remove(self, &self.state.device.ipv4);
+ treebit_remove(self, &self.state.device.ipv6);
+ }
+
+ /// Send a raw message to the peer (used for handshake messages)
+ ///
+ /// # Arguments
+ ///
+ /// - `msg`, message body to send to peer
+ ///
+ /// # Returns
+ ///
+ /// Unit if packet was sent, or an error indicating why sending failed
+ pub fn send(&self, msg: &[u8]) -> Result<(), RouterError> {
+ debug!("peer.send");
+ let inner = &self.state;
+ match inner.endpoint.lock().as_ref() {
+ Some(endpoint) => inner
+ .device
+ .outbound
+ .read()
+ .as_ref()
+ .ok_or(RouterError::SendError)
+ .and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError)),
+ None => Err(RouterError::NoEndpoint),
+ }
+ }
+
+ pub fn purge_staged_packets(&self) {
+ self.state.staged_packets.lock().clear();
+ }
+}
diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs
new file mode 100644
index 0000000..fbee39e
--- /dev/null
+++ b/src/wireguard/router/tests.rs
@@ -0,0 +1,432 @@
+use std::net::IpAddr;
+use std::sync::atomic::Ordering;
+use std::sync::Arc;
+use std::sync::Mutex;
+use std::thread;
+use std::time::Duration;
+
+use num_cpus;
+use pnet::packet::ipv4::MutableIpv4Packet;
+use pnet::packet::ipv6::MutableIpv6Packet;
+
+use super::super::types::bind::*;
+use super::super::types::*;
+
+use super::{Callbacks, Device, SIZE_MESSAGE_PREFIX};
+
+extern crate test;
+
+const SIZE_KEEPALIVE: usize = 32;
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use env_logger;
+ use log::debug;
+ use std::sync::atomic::AtomicUsize;
+ use test::Bencher;
+
+ // type for tracking events inside the router module
+ struct Flags {
+ send: Mutex<Vec<(usize, bool, bool)>>,
+ recv: Mutex<Vec<(usize, bool, bool)>>,
+ need_key: Mutex<Vec<()>>,
+ key_confirmed: Mutex<Vec<()>>,
+ }
+
+ #[derive(Clone)]
+ struct Opaque(Arc<Flags>);
+
+ struct TestCallbacks();
+
+ impl Opaque {
+ fn new() -> Opaque {
+ Opaque(Arc::new(Flags {
+ send: Mutex::new(vec![]),
+ recv: Mutex::new(vec![]),
+ need_key: Mutex::new(vec![]),
+ key_confirmed: Mutex::new(vec![]),
+ }))
+ }
+
+ fn reset(&self) {
+ self.0.send.lock().unwrap().clear();
+ self.0.recv.lock().unwrap().clear();
+ self.0.need_key.lock().unwrap().clear();
+ self.0.key_confirmed.lock().unwrap().clear();
+ }
+
+ fn send(&self) -> Option<(usize, bool, bool)> {
+ self.0.send.lock().unwrap().pop()
+ }
+
+ fn recv(&self) -> Option<(usize, bool, bool)> {
+ self.0.recv.lock().unwrap().pop()
+ }
+
+ fn need_key(&self) -> Option<()> {
+ self.0.need_key.lock().unwrap().pop()
+ }
+
+ fn key_confirmed(&self) -> Option<()> {
+ self.0.key_confirmed.lock().unwrap().pop()
+ }
+
+ // has all events been accounted for by assertions?
+ fn is_empty(&self) -> bool {
+ let send = self.0.send.lock().unwrap();
+ let recv = self.0.recv.lock().unwrap();
+ let need_key = self.0.need_key.lock().unwrap();
+ let key_confirmed = self.0.key_confirmed.lock().unwrap();
+ send.is_empty() && recv.is_empty() && need_key.is_empty() & key_confirmed.is_empty()
+ }
+ }
+
+ impl Callbacks for TestCallbacks {
+ type Opaque = Opaque;
+
+ fn send(t: &Self::Opaque, size: usize, data: bool, sent: bool) {
+ t.0.send.lock().unwrap().push((size, data, sent))
+ }
+
+ fn recv(t: &Self::Opaque, size: usize, data: bool, sent: bool) {
+ t.0.recv.lock().unwrap().push((size, data, sent))
+ }
+
+ fn need_key(t: &Self::Opaque) {
+ t.0.need_key.lock().unwrap().push(());
+ }
+
+ fn key_confirmed(t: &Self::Opaque) {
+ t.0.key_confirmed.lock().unwrap().push(());
+ }
+ }
+
+ // wait for scheduling
+ fn wait() {
+ thread::sleep(Duration::from_millis(50));
+ }
+
+ fn init() {
+ let _ = env_logger::builder().is_test(true).try_init();
+ }
+
+ fn make_packet(size: usize, ip: IpAddr) -> Vec<u8> {
+ // create "IP packet"
+ let mut msg = Vec::with_capacity(SIZE_MESSAGE_PREFIX + size + 16);
+ msg.resize(SIZE_MESSAGE_PREFIX + size, 0);
+ match ip {
+ IpAddr::V4(ip) => {
+ let mut packet = MutableIpv4Packet::new(&mut msg[SIZE_MESSAGE_PREFIX..]).unwrap();
+ packet.set_destination(ip);
+ packet.set_version(4);
+ }
+ IpAddr::V6(ip) => {
+ let mut packet = MutableIpv6Packet::new(&mut msg[SIZE_MESSAGE_PREFIX..]).unwrap();
+ packet.set_destination(ip);
+ packet.set_version(6);
+ }
+ }
+ msg
+ }
+
+ #[bench]
+ fn bench_outbound(b: &mut Bencher) {
+ struct BencherCallbacks {}
+ impl Callbacks for BencherCallbacks {
+ type Opaque = Arc<AtomicUsize>;
+ fn send(t: &Self::Opaque, size: usize, _data: bool, _sent: bool) {
+ t.fetch_add(size, Ordering::SeqCst);
+ }
+ fn recv(_: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {}
+ fn need_key(_: &Self::Opaque) {}
+ fn key_confirmed(_: &Self::Opaque) {}
+ }
+
+ // create device
+ let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false);
+ let router: Device<_, BencherCallbacks, dummy::TunWriter, dummy::VoidBind> =
+ Device::new(num_cpus::get(), tun_writer);
+
+ // add new peer
+ let opaque = Arc::new(AtomicUsize::new(0));
+ let peer = router.new_peer(opaque.clone());
+ peer.add_keypair(dummy::keypair(true));
+
+ // add subnet to peer
+ let (mask, len, ip) = ("192.168.1.0", 24, "192.168.1.20");
+ let mask: IpAddr = mask.parse().unwrap();
+ let ip1: IpAddr = ip.parse().unwrap();
+ peer.add_subnet(mask, len);
+
+ // every iteration sends 10 GB
+ b.iter(|| {
+ opaque.store(0, Ordering::SeqCst);
+ let msg = make_packet(1024, ip1);
+ while opaque.load(Ordering::Acquire) < 10 * 1024 * 1024 {
+ router.send(msg.to_vec()).unwrap();
+ }
+ });
+ }
+
+ #[test]
+ fn test_outbound() {
+ init();
+
+ // create device
+ let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false);
+ let router: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer);
+ router.set_outbound_writer(dummy::VoidBind::new());
+
+ let tests = vec![
+ ("192.168.1.0", 24, "192.168.1.20", true),
+ ("172.133.133.133", 32, "172.133.133.133", true),
+ ("172.133.133.133", 32, "172.133.133.132", false),
+ (
+ "2001:db8::ff00:42:0000",
+ 112,
+ "2001:db8::ff00:42:3242",
+ true,
+ ),
+ (
+ "2001:db8::ff00:42:8000",
+ 113,
+ "2001:db8::ff00:42:0660",
+ false,
+ ),
+ (
+ "2001:db8::ff00:42:8000",
+ 113,
+ "2001:db8::ff00:42:ffff",
+ true,
+ ),
+ ];
+
+ for (num, (mask, len, ip, okay)) in tests.iter().enumerate() {
+ for set_key in vec![true, false] {
+ debug!("index = {}, set_key = {}", num, set_key);
+
+ // add new peer
+ let opaque = Opaque::new();
+ let peer = router.new_peer(opaque.clone());
+ let mask: IpAddr = mask.parse().unwrap();
+ if set_key {
+ peer.add_keypair(dummy::keypair(true));
+ }
+
+ // map subnet to peer
+ peer.add_subnet(mask, *len);
+
+ // create "IP packet"
+ let msg = make_packet(1024, ip.parse().unwrap());
+
+ // cryptkey route the IP packet
+ let res = router.send(msg);
+
+ // allow some scheduling
+ wait();
+
+ if *okay {
+ // cryptkey routing succeeded
+ assert!(res.is_ok(), "crypt-key routing should succeed");
+ assert_eq!(
+ opaque.need_key().is_some(),
+ !set_key,
+ "should have requested a new key, if no encryption state was set"
+ );
+ assert_eq!(
+ opaque.send().is_some(),
+ set_key,
+ "transmission should have been attempted"
+ );
+ assert!(
+ opaque.recv().is_none(),
+ "no messages should have been marked as received"
+ );
+ } else {
+ // no such cryptkey route
+ assert!(res.is_err(), "crypt-key routing should fail");
+ assert!(
+ opaque.need_key().is_none(),
+ "should not request a new-key if crypt-key routing failed"
+ );
+ assert_eq!(
+ opaque.send(),
+ if set_key {
+ Some((SIZE_KEEPALIVE, false, false))
+ } else {
+ None
+ },
+ "transmission should only happen if key was set (keepalive)",
+ );
+ assert!(
+ opaque.recv().is_none(),
+ "no messages should have been marked as received",
+ );
+ }
+ }
+ }
+ }
+
+ #[test]
+ fn test_bidirectional() {
+ init();
+
+ let tests = [
+ (
+ false, // confirm with keepalive
+ ("192.168.1.0", 24, "192.168.1.20", true),
+ ("172.133.133.133", 32, "172.133.133.133", true),
+ ),
+ (
+ true, // confirm with staged packet
+ ("192.168.1.0", 24, "192.168.1.20", true),
+ ("172.133.133.133", 32, "172.133.133.133", true),
+ ),
+ (
+ false, // confirm with keepalive
+ (
+ "2001:db8::ff00:42:8000",
+ 113,
+ "2001:db8::ff00:42:ffff",
+ true,
+ ),
+ (
+ "2001:db8::ff40:42:8000",
+ 113,
+ "2001:db8::ff40:42:ffff",
+ true,
+ ),
+ ),
+ (
+ false, // confirm with staged packet
+ (
+ "2001:db8::ff00:42:8000",
+ 113,
+ "2001:db8::ff00:42:ffff",
+ true,
+ ),
+ (
+ "2001:db8::ff40:42:8000",
+ 113,
+ "2001:db8::ff40:42:ffff",
+ true,
+ ),
+ ),
+ ];
+
+ for (stage, p1, p2) in tests.iter() {
+ let ((bind_reader1, bind_writer1), (bind_reader2, bind_writer2)) =
+ dummy::PairBind::pair();
+
+ // create matching device
+ let (_fake, _, tun_writer1, _) = dummy::TunTest::create(1500, false);
+ let (_fake, _, tun_writer2, _) = dummy::TunTest::create(1500, false);
+
+ let router1: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer1);
+ router1.set_outbound_writer(bind_writer1);
+
+ let router2: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer2);
+ router2.set_outbound_writer(bind_writer2);
+
+ // prepare opaque values for tracing callbacks
+
+ let opaq1 = Opaque::new();
+ let opaq2 = Opaque::new();
+
+ // create peers with matching keypairs and assign subnets
+
+ let (mask, len, _ip, _okay) = p1;
+ let peer1 = router1.new_peer(opaq1.clone());
+ let mask: IpAddr = mask.parse().unwrap();
+ peer1.add_subnet(mask, *len);
+ peer1.add_keypair(dummy::keypair(false));
+
+ let (mask, len, _ip, _okay) = p2;
+ let peer2 = router2.new_peer(opaq2.clone());
+ let mask: IpAddr = mask.parse().unwrap();
+ peer2.add_subnet(mask, *len);
+ peer2.set_endpoint(dummy::UnitEndpoint::new());
+
+ if *stage {
+ // stage a packet which can be used for confirmation (in place of a keepalive)
+ let (_mask, _len, ip, _okay) = p2;
+ let msg = make_packet(1024, ip.parse().unwrap());
+ router2.send(msg).expect("failed to sent staged packet");
+
+ wait();
+ assert!(opaq2.recv().is_none());
+ assert!(
+ opaq2.send().is_none(),
+ "sending should fail as not key is set"
+ );
+ assert!(
+ opaq2.need_key().is_some(),
+ "a new key should be requested since a packet was attempted transmitted"
+ );
+ assert!(opaq2.is_empty(), "callbacks should only run once");
+ }
+
+ // this should cause a key-confirmation packet (keepalive or staged packet)
+ // this also causes peer1 to learn the "endpoint" for peer2
+ assert!(peer1.get_endpoint().is_none());
+ peer2.add_keypair(dummy::keypair(true));
+
+ wait();
+ assert!(opaq2.send().is_some());
+ assert!(opaq2.is_empty(), "events on peer2 should be 'send'");
+ assert!(opaq1.is_empty(), "nothing should happened on peer1");
+
+ // read confirming message received by the other end ("across the internet")
+ let mut buf = vec![0u8; 2048];
+ let (len, from) = bind_reader1.read(&mut buf).unwrap();
+ buf.truncate(len);
+ router1.recv(from, buf).unwrap();
+
+ wait();
+ assert!(opaq1.recv().is_some());
+ assert!(opaq1.key_confirmed().is_some());
+ assert!(
+ opaq1.is_empty(),
+ "events on peer1 should be 'recv' and 'key_confirmed'"
+ );
+ assert!(peer1.get_endpoint().is_some());
+ assert!(opaq2.is_empty(), "nothing should happened on peer2");
+
+ // now that peer1 has an endpoint
+ // route packets : peer1 -> peer2
+
+ for _ in 0..10 {
+ assert!(
+ opaq1.is_empty(),
+ "we should have asserted a value for every callback on peer1"
+ );
+ assert!(
+ opaq2.is_empty(),
+ "we should have asserted a value for every callback on peer2"
+ );
+
+ // pass IP packet to router
+ let (_mask, _len, ip, _okay) = p1;
+ let msg = make_packet(1024, ip.parse().unwrap());
+ router1.send(msg).unwrap();
+
+ wait();
+ assert!(opaq1.send().is_some());
+ assert!(opaq1.recv().is_none());
+ assert!(opaq1.need_key().is_none());
+
+ // receive ("across the internet") on the other end
+ let mut buf = vec![0u8; 2048];
+ let (len, from) = bind_reader2.read(&mut buf).unwrap();
+ buf.truncate(len);
+ router2.recv(from, buf).unwrap();
+
+ wait();
+ assert!(opaq2.send().is_none());
+ assert!(opaq2.recv().is_some());
+ assert!(opaq2.need_key().is_none());
+ }
+ }
+ }
+}
diff --git a/src/wireguard/router/types.rs b/src/wireguard/router/types.rs
new file mode 100644
index 0000000..b7c3ae0
--- /dev/null
+++ b/src/wireguard/router/types.rs
@@ -0,0 +1,65 @@
+use std::error::Error;
+use std::fmt;
+
+pub trait Opaque: Send + Sync + 'static {}
+
+impl<T> Opaque for T where T: Send + Sync + 'static {}
+
+/// A send/recv callback takes 3 arguments:
+///
+/// * `0`, a reference to the opaque value assigned to the peer
+/// * `1`, a bool indicating whether the message contained data (not just keepalive)
+/// * `2`, a bool indicating whether the message was transmitted (i.e. did the peer have an associated endpoint?)
+pub trait Callback<T>: Fn(&T, usize, bool, bool) -> () + Sync + Send + 'static {}
+
+impl<T, F> Callback<T> for F where F: Fn(&T, usize, bool, bool) -> () + Sync + Send + 'static {}
+
+/// A key callback takes 1 argument
+///
+/// * `0`, a reference to the opaque value assigned to the peer
+pub trait KeyCallback<T>: Fn(&T) -> () + Sync + Send + 'static {}
+
+impl<T, F> KeyCallback<T> for F where F: Fn(&T) -> () + Sync + Send + 'static {}
+
+pub trait Callbacks: Send + Sync + 'static {
+ type Opaque: Opaque;
+ fn send(opaque: &Self::Opaque, size: usize, data: bool, sent: bool);
+ fn recv(opaque: &Self::Opaque, size: usize, data: bool, sent: bool);
+ fn need_key(opaque: &Self::Opaque);
+ fn key_confirmed(opaque: &Self::Opaque);
+}
+
+#[derive(Debug)]
+pub enum RouterError {
+ NoCryptKeyRoute,
+ MalformedIPHeader,
+ MalformedTransportMessage,
+ UnknownReceiverId,
+ NoEndpoint,
+ SendError,
+}
+
+impl fmt::Display for RouterError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ RouterError::NoCryptKeyRoute => write!(f, "No cryptkey route configured for subnet"),
+ RouterError::MalformedIPHeader => write!(f, "IP header is malformed"),
+ RouterError::MalformedTransportMessage => write!(f, "IP header is malformed"),
+ RouterError::UnknownReceiverId => {
+ write!(f, "No decryption state associated with receiver id")
+ }
+ RouterError::NoEndpoint => write!(f, "No endpoint for peer"),
+ RouterError::SendError => write!(f, "Failed to send packet on bind"),
+ }
+ }
+}
+
+impl Error for RouterError {
+ fn description(&self) -> &str {
+ "Generic Handshake Error"
+ }
+
+ fn source(&self) -> Option<&(dyn Error + 'static)> {
+ None
+ }
+}
diff --git a/src/wireguard/router/workers.rs b/src/wireguard/router/workers.rs
new file mode 100644
index 0000000..2e89bb0
--- /dev/null
+++ b/src/wireguard/router/workers.rs
@@ -0,0 +1,305 @@
+use std::mem;
+use std::sync::mpsc::Receiver;
+use std::sync::Arc;
+
+use futures::sync::oneshot;
+use futures::*;
+
+use log::debug;
+
+use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
+use std::net::{Ipv4Addr, Ipv6Addr};
+use std::sync::atomic::Ordering;
+use zerocopy::{AsBytes, LayoutVerified};
+
+use super::device::{DecryptionState, DeviceInner};
+use super::messages::{TransportHeader, TYPE_TRANSPORT};
+use super::peer::PeerInner;
+use super::types::Callbacks;
+
+use super::super::types::{Endpoint, tun, bind};
+use super::ip::*;
+
+const SIZE_TAG: usize = 16;
+
+#[derive(PartialEq, Debug)]
+pub enum Operation {
+ Encryption,
+ Decryption,
+}
+
+pub struct JobBuffer {
+ pub msg: Vec<u8>, // message buffer (nonce and receiver id set)
+ pub key: [u8; 32], // chacha20poly1305 key
+ pub okay: bool, // state of the job
+ pub op: Operation, // should be buffer be encrypted / decrypted?
+}
+
+pub type JobParallel = (oneshot::Sender<JobBuffer>, JobBuffer);
+
+#[allow(type_alias_bounds)]
+pub type JobInbound<E, C, T, B: bind::Writer<E>> = (
+ Arc<DecryptionState<E, C, T, B>>,
+ E,
+ oneshot::Receiver<JobBuffer>,
+);
+
+pub type JobOutbound = oneshot::Receiver<JobBuffer>;
+
+#[inline(always)]
+fn check_route<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+ device: &Arc<DeviceInner<E, C, T, B>>,
+ peer: &Arc<PeerInner<E, C, T, B>>,
+ packet: &[u8],
+) -> Option<usize> {
+ match packet[0] >> 4 {
+ VERSION_IP4 => {
+ // check length and cast to IPv4 header
+ let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) =
+ LayoutVerified::new_from_prefix(packet)?;
+
+ // check IPv4 source address
+ device
+ .ipv4
+ .read()
+ .longest_match(Ipv4Addr::from(header.f_source))
+ .and_then(|(_, _, p)| {
+ if Arc::ptr_eq(p, &peer) {
+ Some(header.f_total_len.get() as usize)
+ } else {
+ None
+ }
+ })
+ }
+ VERSION_IP6 => {
+ // check length and cast to IPv6 header
+ let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) =
+ LayoutVerified::new_from_prefix(packet)?;
+
+ // check IPv6 source address
+ device
+ .ipv6
+ .read()
+ .longest_match(Ipv6Addr::from(header.f_source))
+ .and_then(|(_, _, p)| {
+ if Arc::ptr_eq(p, &peer) {
+ Some(header.f_len.get() as usize + mem::size_of::<IPv6Header>())
+ } else {
+ None
+ }
+ })
+ }
+ _ => None,
+ }
+}
+
+pub fn worker_inbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+ device: Arc<DeviceInner<E, C, T, B>>, // related device
+ peer: Arc<PeerInner<E, C, T, B>>, // related peer
+ receiver: Receiver<JobInbound<E, C, T, B>>,
+) {
+ loop {
+ // fetch job
+ let (state, endpoint, rx) = match receiver.recv() {
+ Ok(v) => v,
+ _ => {
+ return;
+ }
+ };
+ debug!("inbound worker: obtained job");
+
+ // wait for job to complete
+ let _ = rx
+ .map(|buf| {
+ debug!("inbound worker: job complete");
+ if buf.okay {
+ // cast transport header
+ let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) =
+ match LayoutVerified::new_from_prefix(&buf.msg[..]) {
+ Some(v) => v,
+ None => {
+ debug!("inbound worker: failed to parse message");
+ return;
+ }
+ };
+
+ debug_assert!(
+ packet.len() >= CHACHA20_POLY1305.tag_len(),
+ "this should be checked earlier in the pipeline (decryption should fail)"
+ );
+
+ // check for replay
+ if !state.protector.lock().update(header.f_counter.get()) {
+ debug!("inbound worker: replay detected");
+ return;
+ }
+
+ // check for confirms key
+ if !state.confirmed.swap(true, Ordering::SeqCst) {
+ debug!("inbound worker: message confirms key");
+ peer.confirm_key(&state.keypair);
+ }
+
+ // update endpoint
+ *peer.endpoint.lock() = Some(endpoint);
+
+ // calculate length of IP packet + padding
+ let length = packet.len() - SIZE_TAG;
+ debug!("inbound worker: plaintext length = {}", length);
+
+ // check if should be written to TUN
+ let mut sent = false;
+ if length > 0 {
+ if let Some(inner_len) = check_route(&device, &peer, &packet[..length]) {
+ debug_assert!(inner_len <= length, "should be validated");
+ if inner_len <= length {
+ sent = match device.inbound.write(&packet[..inner_len]) {
+ Err(e) => {
+ debug!("failed to write inbound packet to TUN: {:?}", e);
+ false
+ }
+ Ok(_) => true,
+ }
+ }
+ }
+ } else {
+ debug!("inbound worker: received keepalive")
+ }
+
+ // trigger callback
+ C::recv(&peer.opaque, buf.msg.len(), length == 0, sent);
+ } else {
+ debug!("inbound worker: authentication failure")
+ }
+ })
+ .wait();
+ }
+}
+
+pub fn worker_outbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+ device: Arc<DeviceInner<E, C, T, B>>, // related device
+ peer: Arc<PeerInner<E, C, T, B>>, // related peer
+ receiver: Receiver<JobOutbound>,
+) {
+ loop {
+ // fetch job
+ let rx = match receiver.recv() {
+ Ok(v) => v,
+ _ => {
+ return;
+ }
+ };
+ debug!("outbound worker: obtained job");
+
+ // wait for job to complete
+ let _ = rx
+ .map(|buf| {
+ debug!("outbound worker: job complete");
+ if buf.okay {
+ // write to UDP bind
+ let xmit = if let Some(dst) = peer.endpoint.lock().as_ref() {
+ let send : &Option<B> = &*device.outbound.read();
+ if let Some(writer) = send.as_ref() {
+ match writer.write(&buf.msg[..], dst) {
+ Err(e) => {
+ debug!("failed to send outbound packet: {:?}", e);
+ false
+ }
+ Ok(_) => true,
+ }
+ } else {
+ false
+ }
+ } else {
+ false
+ };
+
+ // trigger callback
+ C::send(
+ &peer.opaque,
+ buf.msg.len(),
+ buf.msg.len() > SIZE_TAG + mem::size_of::<TransportHeader>(),
+ xmit,
+ );
+ }
+ })
+ .wait();
+ }
+}
+
+pub fn worker_parallel(receiver: Receiver<JobParallel>) {
+ loop {
+ // fetch next job
+ let (tx, mut buf) = match receiver.recv() {
+ Err(_) => {
+ return;
+ }
+ Ok(val) => val,
+ };
+ debug!("parallel worker: obtained job");
+
+ // make space for tag (TODO: consider moving this out)
+ if buf.op == Operation::Encryption {
+ buf.msg.extend([0u8; SIZE_TAG].iter());
+ }
+
+ // cast and check size of packet
+ let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
+ match LayoutVerified::new_from_prefix(&mut buf.msg[..]) {
+ Some(v) => v,
+ None => {
+ debug_assert!(
+ false,
+ "parallel worker: failed to parse message (insufficient size)"
+ );
+ continue;
+ }
+ };
+ debug_assert!(packet.len() >= CHACHA20_POLY1305.tag_len());
+
+ // do the weird ring AEAD dance
+ let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &buf.key[..]).unwrap());
+
+ // create a nonce object
+ let mut nonce = [0u8; 12];
+ debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
+ nonce[4..].copy_from_slice(header.f_counter.as_bytes());
+ let nonce = Nonce::assume_unique_for_key(nonce);
+
+ match buf.op {
+ Operation::Encryption => {
+ debug!("parallel worker: process encryption");
+
+ // set the type field
+ header.f_type.set(TYPE_TRANSPORT);
+
+ // encrypt content of transport message in-place
+ let end = packet.len() - SIZE_TAG;
+ let tag = key
+ .seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end])
+ .unwrap();
+
+ // append tag
+ packet[end..].copy_from_slice(tag.as_ref());
+
+ buf.okay = true;
+ }
+ Operation::Decryption => {
+ debug!("parallel worker: process decryption");
+
+ // opening failure is signaled by fault state
+ buf.okay = match key.open_in_place(nonce, Aad::empty(), packet) {
+ Ok(_) => true,
+ Err(_) => false,
+ };
+ }
+ }
+
+ // pass ownership to consumer
+ let okay = tx.send(buf);
+ debug!(
+ "parallel worker: passing ownership to sequential worker: {}",
+ okay.is_ok()
+ );
+ }
+}