diff options
Diffstat (limited to 'src/wireguard/router')
-rw-r--r-- | src/wireguard/router/anti_replay.rs | 157 | ||||
-rw-r--r-- | src/wireguard/router/constants.rs | 7 | ||||
-rw-r--r-- | src/wireguard/router/device.rs | 243 | ||||
-rw-r--r-- | src/wireguard/router/ip.rs | 26 | ||||
-rw-r--r-- | src/wireguard/router/messages.rs | 13 | ||||
-rw-r--r-- | src/wireguard/router/mod.rs | 22 | ||||
-rw-r--r-- | src/wireguard/router/peer.rs | 611 | ||||
-rw-r--r-- | src/wireguard/router/tests.rs | 432 | ||||
-rw-r--r-- | src/wireguard/router/types.rs | 65 | ||||
-rw-r--r-- | src/wireguard/router/workers.rs | 305 |
10 files changed, 1881 insertions, 0 deletions
diff --git a/src/wireguard/router/anti_replay.rs b/src/wireguard/router/anti_replay.rs new file mode 100644 index 0000000..b0838bd --- /dev/null +++ b/src/wireguard/router/anti_replay.rs @@ -0,0 +1,157 @@ +use std::mem; + +// Implementation of RFC 6479. +// https://tools.ietf.org/html/rfc6479 + +#[cfg(target_pointer_width = "64")] +type Word = u64; + +#[cfg(target_pointer_width = "64")] +const REDUNDANT_BIT_SHIFTS: usize = 6; + +#[cfg(target_pointer_width = "32")] +type Word = u32; + +#[cfg(target_pointer_width = "32")] +const REDUNDANT_BIT_SHIFTS: usize = 5; + +const SIZE_OF_WORD: usize = mem::size_of::<Word>() * 8; + +const BITMAP_BITLEN: usize = 2048; +const BITMAP_LEN: usize = (BITMAP_BITLEN / SIZE_OF_WORD); +const BITMAP_INDEX_MASK: u64 = BITMAP_LEN as u64 - 1; +const BITMAP_LOC_MASK: u64 = (SIZE_OF_WORD - 1) as u64; +const WINDOW_SIZE: u64 = (BITMAP_BITLEN - SIZE_OF_WORD) as u64; + +pub struct AntiReplay { + bitmap: [Word; BITMAP_LEN], + last: u64, +} + +impl Default for AntiReplay { + fn default() -> Self { + AntiReplay::new() + } +} + +impl AntiReplay { + pub fn new() -> Self { + debug_assert_eq!(1 << REDUNDANT_BIT_SHIFTS, SIZE_OF_WORD); + debug_assert_eq!(BITMAP_BITLEN % SIZE_OF_WORD, 0); + AntiReplay { + last: 0, + bitmap: [0; BITMAP_LEN], + } + } + + // Returns true if check is passed, i.e., not a replay or too old. + // + // Unlike RFC 6479, zero is allowed. + fn check(&self, seq: u64) -> bool { + // Larger is always good. + if seq > self.last { + return true; + } + + if self.last - seq > WINDOW_SIZE { + return false; + } + + let bit_location = seq & BITMAP_LOC_MASK; + let index = (seq >> REDUNDANT_BIT_SHIFTS) & BITMAP_INDEX_MASK; + + self.bitmap[index as usize] & (1 << bit_location) == 0 + } + + // Should only be called if check returns true. + fn update_store(&mut self, seq: u64) { + debug_assert!(self.check(seq)); + + let index = seq >> REDUNDANT_BIT_SHIFTS; + + if seq > self.last { + let index_cur = self.last >> REDUNDANT_BIT_SHIFTS; + let diff = index - index_cur; + + if diff >= BITMAP_LEN as u64 { + self.bitmap = [0; BITMAP_LEN]; + } else { + for i in 0..diff { + let real_index = (index_cur + i + 1) & BITMAP_INDEX_MASK; + self.bitmap[real_index as usize] = 0; + } + } + + self.last = seq; + } + + let index = index & BITMAP_INDEX_MASK; + let bit_location = seq & BITMAP_LOC_MASK; + self.bitmap[index as usize] |= 1 << bit_location; + } + + /// Checks and marks a sequence number in the replay filter + /// + /// # Arguments + /// + /// - seq: Sequence number check for replay and add to filter + /// + /// # Returns + /// + /// Ok(()) if sequence number is valid (not marked and not behind the moving window). + /// Err if the sequence number is invalid (already marked or "too old"). + pub fn update(&mut self, seq: u64) -> bool { + if self.check(seq) { + self.update_store(seq); + true + } else { + false + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn anti_replay() { + let mut ar = AntiReplay::new(); + + for i in 0..20000 { + assert!(ar.update(i)); + } + + for i in (0..20000).rev() { + assert!(!ar.check(i)); + } + + assert!(ar.update(65536)); + for i in (65536 - WINDOW_SIZE)..65535 { + assert!(ar.update(i)); + } + + for i in (65536 - 10 * WINDOW_SIZE)..65535 { + assert!(!ar.check(i)); + } + + assert!(ar.update(66000)); + for i in 65537..66000 { + assert!(ar.update(i)); + } + for i in 65537..66000 { + assert_eq!(ar.update(i), false); + } + + // Test max u64. + let next = u64::max_value(); + assert!(ar.update(next)); + assert!(!ar.check(next)); + for i in (next - WINDOW_SIZE)..next { + assert!(ar.update(i)); + } + for i in (next - 20 * WINDOW_SIZE)..next { + assert!(!ar.check(i)); + } + } +} diff --git a/src/wireguard/router/constants.rs b/src/wireguard/router/constants.rs new file mode 100644 index 0000000..0ca824a --- /dev/null +++ b/src/wireguard/router/constants.rs @@ -0,0 +1,7 @@ +// WireGuard semantics constants + +pub const MAX_STAGED_PACKETS: usize = 128; + +// performance constants + +pub const WORKER_QUEUE_SIZE: usize = MAX_STAGED_PACKETS; diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs new file mode 100644 index 0000000..455020c --- /dev/null +++ b/src/wireguard/router/device.rs @@ -0,0 +1,243 @@ +use std::collections::HashMap; +use std::net::{Ipv4Addr, Ipv6Addr}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::mpsc::sync_channel; +use std::sync::mpsc::SyncSender; +use std::sync::Arc; +use std::thread; +use std::time::Instant; + +use log::debug; +use spin::{Mutex, RwLock}; +use treebitmap::IpLookupTable; +use zerocopy::LayoutVerified; + +use super::anti_replay::AntiReplay; +use super::constants::*; +use super::ip::*; +use super::messages::{TransportHeader, TYPE_TRANSPORT}; +use super::peer::{new_peer, Peer, PeerInner}; +use super::types::{Callbacks, RouterError}; +use super::workers::{worker_parallel, JobParallel, Operation}; +use super::SIZE_MESSAGE_PREFIX; + +use super::super::types::{bind, tun, Endpoint, KeyPair}; + +pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { + // inbound writer (TUN) + pub inbound: T, + + // outbound writer (Bind) + pub outbound: RwLock<Option<B>>, + + // routing + pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state + pub ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv4 cryptkey routing + pub ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv6 cryptkey routing + + // work queues + pub queue_next: AtomicUsize, // next round-robin index + pub queues: Mutex<Vec<SyncSender<JobParallel>>>, // work queues (1 per thread) +} + +pub struct EncryptionState { + pub key: [u8; 32], // encryption key + pub id: u32, // receiver id + pub nonce: u64, // next available nonce + pub death: Instant, // (birth + reject-after-time - keepalive-timeout - rekey-timeout) +} + +pub struct DecryptionState<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { + pub keypair: Arc<KeyPair>, + pub confirmed: AtomicBool, + pub protector: Mutex<AntiReplay>, + pub peer: Arc<PeerInner<E, C, T, B>>, + pub death: Instant, // time when the key can no longer be used for decryption +} + +pub struct Device<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { + state: Arc<DeviceInner<E, C, T, B>>, // reference to device state + handles: Vec<thread::JoinHandle<()>>, // join handles for workers +} + +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Device<E, C, T, B> { + fn drop(&mut self) { + debug!("router: dropping device"); + + // drop all queues + { + let mut queues = self.state.queues.lock(); + while queues.pop().is_some() {} + } + + // join all worker threads + while match self.handles.pop() { + Some(handle) => { + handle.thread().unpark(); + handle.join().unwrap(); + true + } + _ => false, + } {} + + debug!("router: device dropped"); + } +} + +#[inline(always)] +fn get_route<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( + device: &Arc<DeviceInner<E, C, T, B>>, + packet: &[u8], +) -> Option<Arc<PeerInner<E, C, T, B>>> { + // ensure version access within bounds + if packet.len() < 1 { + return None; + }; + + // cast to correct IP header + match packet[0] >> 4 { + VERSION_IP4 => { + // check length and cast to IPv4 header + let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = + LayoutVerified::new_from_prefix(packet)?; + + // lookup destination address + device + .ipv4 + .read() + .longest_match(Ipv4Addr::from(header.f_destination)) + .and_then(|(_, _, p)| Some(p.clone())) + } + VERSION_IP6 => { + // check length and cast to IPv6 header + let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) = + LayoutVerified::new_from_prefix(packet)?; + + // lookup destination address + device + .ipv6 + .read() + .longest_match(Ipv6Addr::from(header.f_destination)) + .and_then(|(_, _, p)| Some(p.clone())) + } + _ => None, + } +} + +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C, T, B> { + pub fn new(num_workers: usize, tun: T) -> Device<E, C, T, B> { + // allocate shared device state + let inner = DeviceInner { + inbound: tun, + outbound: RwLock::new(None), + queues: Mutex::new(Vec::with_capacity(num_workers)), + queue_next: AtomicUsize::new(0), + recv: RwLock::new(HashMap::new()), + ipv4: RwLock::new(IpLookupTable::new()), + ipv6: RwLock::new(IpLookupTable::new()), + }; + + // start worker threads + let mut threads = Vec::with_capacity(num_workers); + for _ in 0..num_workers { + let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE); + inner.queues.lock().push(tx); + threads.push(thread::spawn(move || worker_parallel(rx))); + } + + // return exported device handle + Device { + state: Arc::new(inner), + handles: threads, + } + } + + /// A new secret key has been set for the device. + /// According to WireGuard semantics, this should cause all "sending" keys to be discarded. + pub fn new_sk(&self) {} + + /// Adds a new peer to the device + /// + /// # Returns + /// + /// A atomic ref. counted peer (with liftime matching the device) + pub fn new_peer(&self, opaque: C::Opaque) -> Peer<E, C, T, B> { + new_peer(self.state.clone(), opaque) + } + + /// Cryptkey routes and sends a plaintext message (IP packet) + /// + /// # Arguments + /// + /// - msg: IP packet to crypt-key route + /// + pub fn send(&self, msg: Vec<u8>) -> Result<(), RouterError> { + // ignore header prefix (for in-place transport message construction) + let packet = &msg[SIZE_MESSAGE_PREFIX..]; + + // lookup peer based on IP packet destination address + let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptKeyRoute)?; + + // schedule for encryption and transmission to peer + if let Some(job) = peer.send_job(msg, true) { + debug_assert_eq!(job.1.op, Operation::Encryption); + + // add job to worker queue + let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst); + let queues = self.state.queues.lock(); + queues[idx % queues.len()].send(job).unwrap(); + } + + Ok(()) + } + + /// Receive an encrypted transport message + /// + /// # Arguments + /// + /// - src: Source address of the packet + /// - msg: Encrypted transport message + /// + /// # Returns + /// + /// + pub fn recv(&self, src: E, msg: Vec<u8>) -> Result<(), RouterError> { + // parse / cast + let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) { + Some(v) => v, + None => { + return Err(RouterError::MalformedTransportMessage); + } + }; + let header: LayoutVerified<&[u8], TransportHeader> = header; + debug_assert!( + header.f_type.get() == TYPE_TRANSPORT as u32, + "this should be checked by the message type multiplexer" + ); + + // lookup peer based on receiver id + let dec = self.state.recv.read(); + let dec = dec + .get(&header.f_receiver.get()) + .ok_or(RouterError::UnknownReceiverId)?; + + // schedule for decryption and TUN write + if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) { + debug_assert_eq!(job.1.op, Operation::Decryption); + + // add job to worker queue + let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst); + let queues = self.state.queues.lock(); + queues[idx % queues.len()].send(job).unwrap(); + } + + Ok(()) + } + + /// Set outbound writer + /// + /// + pub fn set_outbound_writer(&self, new: B) { + *self.state.outbound.write() = Some(new); + } +} diff --git a/src/wireguard/router/ip.rs b/src/wireguard/router/ip.rs new file mode 100644 index 0000000..e66144f --- /dev/null +++ b/src/wireguard/router/ip.rs @@ -0,0 +1,26 @@ +use byteorder::BigEndian; +use zerocopy::byteorder::U16; +use zerocopy::{AsBytes, FromBytes}; + +pub const VERSION_IP4: u8 = 4; +pub const VERSION_IP6: u8 = 6; + +#[repr(packed)] +#[derive(Copy, Clone, FromBytes, AsBytes)] +pub struct IPv4Header { + _f_space1: [u8; 2], + pub f_total_len: U16<BigEndian>, + _f_space2: [u8; 8], + pub f_source: [u8; 4], + pub f_destination: [u8; 4], +} + +#[repr(packed)] +#[derive(Copy, Clone, FromBytes, AsBytes)] +pub struct IPv6Header { + _f_space1: [u8; 4], + pub f_len: U16<BigEndian>, + _f_space2: [u8; 2], + pub f_source: [u8; 16], + pub f_destination: [u8; 16], +} diff --git a/src/wireguard/router/messages.rs b/src/wireguard/router/messages.rs new file mode 100644 index 0000000..bf4d13b --- /dev/null +++ b/src/wireguard/router/messages.rs @@ -0,0 +1,13 @@ +use byteorder::LittleEndian; +use zerocopy::byteorder::{U32, U64}; +use zerocopy::{AsBytes, FromBytes}; + +pub const TYPE_TRANSPORT: u32 = 4; + +#[repr(packed)] +#[derive(Copy, Clone, FromBytes, AsBytes)] +pub struct TransportHeader { + pub f_type: U32<LittleEndian>, + pub f_receiver: U32<LittleEndian>, + pub f_counter: U64<LittleEndian>, +} diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs new file mode 100644 index 0000000..7a29cd9 --- /dev/null +++ b/src/wireguard/router/mod.rs @@ -0,0 +1,22 @@ +mod anti_replay; +mod constants; +mod device; +mod ip; +mod messages; +mod peer; +mod types; +mod workers; + +#[cfg(test)] +mod tests; + +use messages::TransportHeader; +use std::mem; + +pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::<TransportHeader>(); +pub const CAPACITY_MESSAGE_POSTFIX: usize = 16; + +pub use messages::TYPE_TRANSPORT; +pub use device::Device; +pub use peer::Peer; +pub use types::Callbacks; diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs new file mode 100644 index 0000000..4f47604 --- /dev/null +++ b/src/wireguard/router/peer.rs @@ -0,0 +1,611 @@ +use std::mem; +use std::net::{IpAddr, SocketAddr}; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::sync::mpsc::{sync_channel, SyncSender}; +use std::sync::Arc; +use std::thread; + +use arraydeque::{ArrayDeque, Wrapping}; +use log::debug; +use spin::Mutex; +use treebitmap::address::Address; +use treebitmap::IpLookupTable; +use zerocopy::LayoutVerified; + +use super::super::constants::*; +use super::super::types::{bind, tun, Endpoint, KeyPair}; + +use super::anti_replay::AntiReplay; +use super::device::DecryptionState; +use super::device::DeviceInner; +use super::device::EncryptionState; +use super::messages::TransportHeader; + +use futures::*; + +use super::workers::Operation; +use super::workers::{worker_inbound, worker_outbound}; +use super::workers::{JobBuffer, JobInbound, JobOutbound, JobParallel}; +use super::SIZE_MESSAGE_PREFIX; + +use super::constants::*; +use super::types::{Callbacks, RouterError}; + +pub struct KeyWheel { + next: Option<Arc<KeyPair>>, // next key state (unconfirmed) + current: Option<Arc<KeyPair>>, // current key state (used for encryption) + previous: Option<Arc<KeyPair>>, // old key state (used for decryption) + retired: Vec<u32>, // retired ids +} + +pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { + pub device: Arc<DeviceInner<E, C, T, B>>, + pub opaque: C::Opaque, + pub outbound: Mutex<SyncSender<JobOutbound>>, + pub inbound: Mutex<SyncSender<JobInbound<E, C, T, B>>>, + pub staged_packets: Mutex<ArrayDeque<[Vec<u8>; MAX_STAGED_PACKETS], Wrapping>>, + pub keys: Mutex<KeyWheel>, + pub ekey: Mutex<Option<EncryptionState>>, + pub endpoint: Mutex<Option<E>>, +} + +pub struct Peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { + state: Arc<PeerInner<E, C, T, B>>, + thread_outbound: Option<thread::JoinHandle<()>>, + thread_inbound: Option<thread::JoinHandle<()>>, +} + +fn treebit_list<A, R, E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( + peer: &Arc<PeerInner<E, C, T, B>>, + table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<E, C, T, B>>>>, + callback: Box<dyn Fn(A, u32) -> R>, +) -> Vec<R> +where + A: Address, +{ + let mut res = Vec::new(); + for subnet in table.read().iter() { + let (ip, masklen, p) = subnet; + if Arc::ptr_eq(&p, &peer) { + res.push(callback(ip, masklen)) + } + } + res +} + +fn treebit_remove<E: Endpoint, A: Address, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( + peer: &Peer<E, C, T, B>, + table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<E, C, T, B>>>>, +) { + let mut m = table.write(); + + // collect keys for value + let mut subnets = vec![]; + for subnet in m.iter() { + let (ip, masklen, p) = subnet; + if Arc::ptr_eq(&p, &peer.state) { + subnets.push((ip, masklen)) + } + } + + // remove all key mappings + for (ip, masklen) in subnets { + let r = m.remove(ip, masklen); + debug_assert!(r.is_some()); + } +} + +impl EncryptionState { + fn new(keypair: &Arc<KeyPair>) -> EncryptionState { + EncryptionState { + id: keypair.send.id, + key: keypair.send.key, + nonce: 0, + death: keypair.birth + REJECT_AFTER_TIME, + } + } +} + +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> DecryptionState<E, C, T, B> { + fn new( + peer: &Arc<PeerInner<E, C, T, B>>, + keypair: &Arc<KeyPair>, + ) -> DecryptionState<E, C, T, B> { + DecryptionState { + confirmed: AtomicBool::new(keypair.initiator), + keypair: keypair.clone(), + protector: spin::Mutex::new(AntiReplay::new()), + peer: peer.clone(), + death: keypair.birth + REJECT_AFTER_TIME, + } + } +} + +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Peer<E, C, T, B> { + fn drop(&mut self) { + let peer = &self.state; + + // remove from cryptkey router + + treebit_remove(self, &peer.device.ipv4); + treebit_remove(self, &peer.device.ipv6); + + // drop channels + + mem::replace(&mut *peer.inbound.lock(), sync_channel(0).0); + mem::replace(&mut *peer.outbound.lock(), sync_channel(0).0); + + // join with workers + + mem::replace(&mut self.thread_inbound, None).map(|v| v.join()); + mem::replace(&mut self.thread_outbound, None).map(|v| v.join()); + + // release ids from the receiver map + + let mut keys = peer.keys.lock(); + let mut release = Vec::with_capacity(3); + + keys.next.as_ref().map(|k| release.push(k.recv.id)); + keys.current.as_ref().map(|k| release.push(k.recv.id)); + keys.previous.as_ref().map(|k| release.push(k.recv.id)); + + if release.len() > 0 { + let mut recv = peer.device.recv.write(); + for id in &release { + recv.remove(id); + } + } + + // null key-material + + keys.next = None; + keys.current = None; + keys.previous = None; + + *peer.ekey.lock() = None; + *peer.endpoint.lock() = None; + + debug!("peer dropped & removed from device"); + } +} + +pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( + device: Arc<DeviceInner<E, C, T, B>>, + opaque: C::Opaque, +) -> Peer<E, C, T, B> { + let (out_tx, out_rx) = sync_channel(128); + let (in_tx, in_rx) = sync_channel(128); + + // allocate peer object + let peer = { + let device = device.clone(); + Arc::new(PeerInner { + opaque, + device, + inbound: Mutex::new(in_tx), + outbound: Mutex::new(out_tx), + ekey: spin::Mutex::new(None), + endpoint: spin::Mutex::new(None), + keys: spin::Mutex::new(KeyWheel { + next: None, + current: None, + previous: None, + retired: vec![], + }), + staged_packets: spin::Mutex::new(ArrayDeque::new()), + }) + }; + + // spawn outbound thread + let thread_inbound = { + let peer = peer.clone(); + let device = device.clone(); + thread::spawn(move || worker_outbound(device, peer, out_rx)) + }; + + // spawn inbound thread + let thread_outbound = { + let peer = peer.clone(); + let device = device.clone(); + thread::spawn(move || worker_inbound(device, peer, in_rx)) + }; + + Peer { + state: peer, + thread_inbound: Some(thread_inbound), + thread_outbound: Some(thread_outbound), + } +} + +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> PeerInner<E, C, T, B> { + fn send_staged(&self) -> bool { + debug!("peer.send_staged"); + let mut sent = false; + let mut staged = self.staged_packets.lock(); + loop { + match staged.pop_front() { + Some(msg) => { + sent = true; + self.send_raw(msg); + } + None => break sent, + } + } + } + + // Treat the msg as the payload of a transport message + // Unlike device.send, peer.send_raw does not buffer messages when a key is not available. + fn send_raw(&self, msg: Vec<u8>) -> bool { + debug!("peer.send_raw"); + match self.send_job(msg, false) { + Some(job) => { + debug!("send_raw: got obtained send_job"); + let index = self.device.queue_next.fetch_add(1, Ordering::SeqCst); + let queues = self.device.queues.lock(); + match queues[index % queues.len()].send(job) { + Ok(_) => true, + Err(_) => false, + } + } + None => false, + } + } + + pub fn confirm_key(&self, keypair: &Arc<KeyPair>) { + debug!("peer.confirm_key"); + { + // take lock and check keypair = keys.next + let mut keys = self.keys.lock(); + let next = match keys.next.as_ref() { + Some(next) => next, + None => { + return; + } + }; + if !Arc::ptr_eq(&next, keypair) { + return; + } + + // allocate new encryption state + let ekey = Some(EncryptionState::new(&next)); + + // rotate key-wheel + let mut swap = None; + mem::swap(&mut keys.next, &mut swap); + mem::swap(&mut keys.current, &mut swap); + mem::swap(&mut keys.previous, &mut swap); + + // tell the world outside the router that a key was confirmed + C::key_confirmed(&self.opaque); + + // set new key for encryption + *self.ekey.lock() = ekey; + } + + // start transmission of staged packets + self.send_staged(); + } + + pub fn recv_job( + &self, + src: E, + dec: Arc<DecryptionState<E, C, T, B>>, + msg: Vec<u8>, + ) -> Option<JobParallel> { + let (tx, rx) = oneshot(); + let key = dec.keypair.recv.key; + match self.inbound.lock().try_send((dec, src, rx)) { + Ok(_) => Some(( + tx, + JobBuffer { + msg, + key: key, + okay: false, + op: Operation::Decryption, + }, + )), + Err(_) => None, + } + } + + pub fn send_job(&self, mut msg: Vec<u8>, stage: bool) -> Option<JobParallel> { + debug!("peer.send_job"); + debug_assert!( + msg.len() >= mem::size_of::<TransportHeader>(), + "received message with size: {:}", + msg.len() + ); + + // parse / cast + let (header, _) = LayoutVerified::new_from_prefix(&mut msg[..]).unwrap(); + let mut header: LayoutVerified<&mut [u8], TransportHeader> = header; + + // check if has key + let key = { + let mut ekey = self.ekey.lock(); + let key = match ekey.as_mut() { + None => None, + Some(mut state) => { + // avoid integer overflow in nonce + if state.nonce >= REJECT_AFTER_MESSAGES - 1 { + *ekey = None; + None + } else { + // there should be no stacked packets lingering around + debug!("encryption state available, nonce = {}", state.nonce); + + // set transport message fields + header.f_counter.set(state.nonce); + header.f_receiver.set(state.id); + state.nonce += 1; + Some(state.key) + } + } + }; + + // If not suitable key was found: + // 1. Stage packet for later transmission + // 2. Request new key + if key.is_none() && stage { + self.staged_packets.lock().push_back(msg); + C::need_key(&self.opaque); + return None; + }; + + key + }?; + + // add job to in-order queue and return sendeer to device for inclusion in worker pool + let (tx, rx) = oneshot(); + match self.outbound.lock().try_send(rx) { + Ok(_) => Some(( + tx, + JobBuffer { + msg, + key, + okay: false, + op: Operation::Encryption, + }, + )), + Err(_) => None, + } + } +} + +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T, B> { + /// Set the endpoint of the peer + /// + /// # Arguments + /// + /// - `endpoint`, socket address converted to bind endpoint + /// + /// # Note + /// + /// This API still permits support for the "sticky socket" behavior, + /// as sockets should be "unsticked" when manually updating the endpoint + pub fn set_endpoint(&self, endpoint: E) { + debug!("peer.set_endpoint"); + *self.state.endpoint.lock() = Some(endpoint); + } + + /// Returns the current endpoint of the peer (for configuration) + /// + /// # Note + /// + /// Does not convey potential "sticky socket" information + pub fn get_endpoint(&self) -> Option<SocketAddr> { + debug!("peer.get_endpoint"); + self.state + .endpoint + .lock() + .as_ref() + .map(|e| e.into_address()) + } + + /// Zero all key-material related to the peer + pub fn zero_keys(&self) { + debug!("peer.zero_keys"); + + let mut release: Vec<u32> = Vec::with_capacity(3); + let mut keys = self.state.keys.lock(); + + // update key-wheel + + mem::replace(&mut keys.next, None).map(|k| release.push(k.local_id())); + mem::replace(&mut keys.current, None).map(|k| release.push(k.local_id())); + mem::replace(&mut keys.previous, None).map(|k| release.push(k.local_id())); + keys.retired.extend(&release[..]); + + // update inbound "recv" map + { + let mut recv = self.state.device.recv.write(); + for id in release { + recv.remove(&id); + } + } + + // clear encryption state + *self.state.ekey.lock() = None; + } + + /// Add a new keypair + /// + /// # Arguments + /// + /// - new: The new confirmed/unconfirmed key pair + /// + /// # Returns + /// + /// A vector of ids which has been released. + /// These should be released in the handshake module. + /// + /// # Note + /// + /// The number of ids to be released can be at most 3, + /// since the only way to add additional keys to the peer is by using this method + /// and a peer can have at most 3 keys allocated in the router at any time. + pub fn add_keypair(&self, new: KeyPair) -> Vec<u32> { + debug!("peer.add_keypair"); + + let initiator = new.initiator; + let release = { + let new = Arc::new(new); + let mut keys = self.state.keys.lock(); + let mut release = mem::replace(&mut keys.retired, vec![]); + + // update key-wheel + if new.initiator { + // start using key for encryption + *self.state.ekey.lock() = Some(EncryptionState::new(&new)); + + // move current into previous + keys.previous = keys.current.as_ref().map(|v| v.clone()); + keys.current = Some(new.clone()); + } else { + // store the key and await confirmation + keys.previous = keys.next.as_ref().map(|v| v.clone()); + keys.next = Some(new.clone()); + }; + + // update incoming packet id map + { + debug!("peer.add_keypair: updating inbound id map"); + let mut recv = self.state.device.recv.write(); + + // purge recv map of previous id + keys.previous.as_ref().map(|k| { + recv.remove(&k.local_id()); + release.push(k.local_id()); + }); + + // map new id to decryption state + debug_assert!(!recv.contains_key(&new.recv.id)); + recv.insert( + new.recv.id, + Arc::new(DecryptionState::new(&self.state, &new)), + ); + } + release + }; + + // schedule confirmation + if initiator { + debug_assert!(self.state.ekey.lock().is_some()); + debug!("peer.add_keypair: is initiator, must confirm the key"); + // attempt to confirm using staged packets + if !self.state.send_staged() { + // fall back to keepalive packet + let ok = self.send_keepalive(); + debug!( + "peer.add_keypair: keepalive for confirmation, sent = {}", + ok + ); + } + debug!("peer.add_keypair: key attempted confirmed"); + } + + debug_assert!( + release.len() <= 3, + "since the key-wheel contains at most 3 keys" + ); + release + } + + pub fn send_keepalive(&self) -> bool { + debug!("peer.send_keepalive"); + self.state.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX]) + } + + /// Map a subnet to the peer + /// + /// # Arguments + /// + /// - `ip`, the mask of the subnet + /// - `masklen`, the length of the mask + /// + /// # Note + /// + /// The `ip` must not have any bits set right of `masklen`. + /// e.g. `192.168.1.0/24` is valid, while `192.168.1.128/24` is not. + /// + /// If an identical value already exists as part of a prior peer, + /// the allowed IP entry will be removed from that peer and added to this peer. + pub fn add_subnet(&self, ip: IpAddr, masklen: u32) { + debug!("peer.add_subnet"); + match ip { + IpAddr::V4(v4) => { + self.state + .device + .ipv4 + .write() + .insert(v4, masklen, self.state.clone()) + } + IpAddr::V6(v6) => { + self.state + .device + .ipv6 + .write() + .insert(v6, masklen, self.state.clone()) + } + }; + } + + /// List subnets mapped to the peer + /// + /// # Returns + /// + /// A vector of subnets, represented by as mask/size + pub fn list_subnets(&self) -> Vec<(IpAddr, u32)> { + debug!("peer.list_subnets"); + let mut res = Vec::new(); + res.append(&mut treebit_list( + &self.state, + &self.state.device.ipv4, + Box::new(|ip, masklen| (IpAddr::V4(ip), masklen)), + )); + res.append(&mut treebit_list( + &self.state, + &self.state.device.ipv6, + Box::new(|ip, masklen| (IpAddr::V6(ip), masklen)), + )); + res + } + + /// Clear subnets mapped to the peer. + /// After the call, no subnets will be cryptkey routed to the peer. + /// Used for the UAPI command "replace_allowed_ips=true" + pub fn remove_subnets(&self) { + debug!("peer.remove_subnets"); + treebit_remove(self, &self.state.device.ipv4); + treebit_remove(self, &self.state.device.ipv6); + } + + /// Send a raw message to the peer (used for handshake messages) + /// + /// # Arguments + /// + /// - `msg`, message body to send to peer + /// + /// # Returns + /// + /// Unit if packet was sent, or an error indicating why sending failed + pub fn send(&self, msg: &[u8]) -> Result<(), RouterError> { + debug!("peer.send"); + let inner = &self.state; + match inner.endpoint.lock().as_ref() { + Some(endpoint) => inner + .device + .outbound + .read() + .as_ref() + .ok_or(RouterError::SendError) + .and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError)), + None => Err(RouterError::NoEndpoint), + } + } + + pub fn purge_staged_packets(&self) { + self.state.staged_packets.lock().clear(); + } +} diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs new file mode 100644 index 0000000..fbee39e --- /dev/null +++ b/src/wireguard/router/tests.rs @@ -0,0 +1,432 @@ +use std::net::IpAddr; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use std::sync::Mutex; +use std::thread; +use std::time::Duration; + +use num_cpus; +use pnet::packet::ipv4::MutableIpv4Packet; +use pnet::packet::ipv6::MutableIpv6Packet; + +use super::super::types::bind::*; +use super::super::types::*; + +use super::{Callbacks, Device, SIZE_MESSAGE_PREFIX}; + +extern crate test; + +const SIZE_KEEPALIVE: usize = 32; + +#[cfg(test)] +mod tests { + use super::*; + use env_logger; + use log::debug; + use std::sync::atomic::AtomicUsize; + use test::Bencher; + + // type for tracking events inside the router module + struct Flags { + send: Mutex<Vec<(usize, bool, bool)>>, + recv: Mutex<Vec<(usize, bool, bool)>>, + need_key: Mutex<Vec<()>>, + key_confirmed: Mutex<Vec<()>>, + } + + #[derive(Clone)] + struct Opaque(Arc<Flags>); + + struct TestCallbacks(); + + impl Opaque { + fn new() -> Opaque { + Opaque(Arc::new(Flags { + send: Mutex::new(vec![]), + recv: Mutex::new(vec![]), + need_key: Mutex::new(vec![]), + key_confirmed: Mutex::new(vec![]), + })) + } + + fn reset(&self) { + self.0.send.lock().unwrap().clear(); + self.0.recv.lock().unwrap().clear(); + self.0.need_key.lock().unwrap().clear(); + self.0.key_confirmed.lock().unwrap().clear(); + } + + fn send(&self) -> Option<(usize, bool, bool)> { + self.0.send.lock().unwrap().pop() + } + + fn recv(&self) -> Option<(usize, bool, bool)> { + self.0.recv.lock().unwrap().pop() + } + + fn need_key(&self) -> Option<()> { + self.0.need_key.lock().unwrap().pop() + } + + fn key_confirmed(&self) -> Option<()> { + self.0.key_confirmed.lock().unwrap().pop() + } + + // has all events been accounted for by assertions? + fn is_empty(&self) -> bool { + let send = self.0.send.lock().unwrap(); + let recv = self.0.recv.lock().unwrap(); + let need_key = self.0.need_key.lock().unwrap(); + let key_confirmed = self.0.key_confirmed.lock().unwrap(); + send.is_empty() && recv.is_empty() && need_key.is_empty() & key_confirmed.is_empty() + } + } + + impl Callbacks for TestCallbacks { + type Opaque = Opaque; + + fn send(t: &Self::Opaque, size: usize, data: bool, sent: bool) { + t.0.send.lock().unwrap().push((size, data, sent)) + } + + fn recv(t: &Self::Opaque, size: usize, data: bool, sent: bool) { + t.0.recv.lock().unwrap().push((size, data, sent)) + } + + fn need_key(t: &Self::Opaque) { + t.0.need_key.lock().unwrap().push(()); + } + + fn key_confirmed(t: &Self::Opaque) { + t.0.key_confirmed.lock().unwrap().push(()); + } + } + + // wait for scheduling + fn wait() { + thread::sleep(Duration::from_millis(50)); + } + + fn init() { + let _ = env_logger::builder().is_test(true).try_init(); + } + + fn make_packet(size: usize, ip: IpAddr) -> Vec<u8> { + // create "IP packet" + let mut msg = Vec::with_capacity(SIZE_MESSAGE_PREFIX + size + 16); + msg.resize(SIZE_MESSAGE_PREFIX + size, 0); + match ip { + IpAddr::V4(ip) => { + let mut packet = MutableIpv4Packet::new(&mut msg[SIZE_MESSAGE_PREFIX..]).unwrap(); + packet.set_destination(ip); + packet.set_version(4); + } + IpAddr::V6(ip) => { + let mut packet = MutableIpv6Packet::new(&mut msg[SIZE_MESSAGE_PREFIX..]).unwrap(); + packet.set_destination(ip); + packet.set_version(6); + } + } + msg + } + + #[bench] + fn bench_outbound(b: &mut Bencher) { + struct BencherCallbacks {} + impl Callbacks for BencherCallbacks { + type Opaque = Arc<AtomicUsize>; + fn send(t: &Self::Opaque, size: usize, _data: bool, _sent: bool) { + t.fetch_add(size, Ordering::SeqCst); + } + fn recv(_: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {} + fn need_key(_: &Self::Opaque) {} + fn key_confirmed(_: &Self::Opaque) {} + } + + // create device + let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false); + let router: Device<_, BencherCallbacks, dummy::TunWriter, dummy::VoidBind> = + Device::new(num_cpus::get(), tun_writer); + + // add new peer + let opaque = Arc::new(AtomicUsize::new(0)); + let peer = router.new_peer(opaque.clone()); + peer.add_keypair(dummy::keypair(true)); + + // add subnet to peer + let (mask, len, ip) = ("192.168.1.0", 24, "192.168.1.20"); + let mask: IpAddr = mask.parse().unwrap(); + let ip1: IpAddr = ip.parse().unwrap(); + peer.add_subnet(mask, len); + + // every iteration sends 10 GB + b.iter(|| { + opaque.store(0, Ordering::SeqCst); + let msg = make_packet(1024, ip1); + while opaque.load(Ordering::Acquire) < 10 * 1024 * 1024 { + router.send(msg.to_vec()).unwrap(); + } + }); + } + + #[test] + fn test_outbound() { + init(); + + // create device + let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false); + let router: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer); + router.set_outbound_writer(dummy::VoidBind::new()); + + let tests = vec![ + ("192.168.1.0", 24, "192.168.1.20", true), + ("172.133.133.133", 32, "172.133.133.133", true), + ("172.133.133.133", 32, "172.133.133.132", false), + ( + "2001:db8::ff00:42:0000", + 112, + "2001:db8::ff00:42:3242", + true, + ), + ( + "2001:db8::ff00:42:8000", + 113, + "2001:db8::ff00:42:0660", + false, + ), + ( + "2001:db8::ff00:42:8000", + 113, + "2001:db8::ff00:42:ffff", + true, + ), + ]; + + for (num, (mask, len, ip, okay)) in tests.iter().enumerate() { + for set_key in vec![true, false] { + debug!("index = {}, set_key = {}", num, set_key); + + // add new peer + let opaque = Opaque::new(); + let peer = router.new_peer(opaque.clone()); + let mask: IpAddr = mask.parse().unwrap(); + if set_key { + peer.add_keypair(dummy::keypair(true)); + } + + // map subnet to peer + peer.add_subnet(mask, *len); + + // create "IP packet" + let msg = make_packet(1024, ip.parse().unwrap()); + + // cryptkey route the IP packet + let res = router.send(msg); + + // allow some scheduling + wait(); + + if *okay { + // cryptkey routing succeeded + assert!(res.is_ok(), "crypt-key routing should succeed"); + assert_eq!( + opaque.need_key().is_some(), + !set_key, + "should have requested a new key, if no encryption state was set" + ); + assert_eq!( + opaque.send().is_some(), + set_key, + "transmission should have been attempted" + ); + assert!( + opaque.recv().is_none(), + "no messages should have been marked as received" + ); + } else { + // no such cryptkey route + assert!(res.is_err(), "crypt-key routing should fail"); + assert!( + opaque.need_key().is_none(), + "should not request a new-key if crypt-key routing failed" + ); + assert_eq!( + opaque.send(), + if set_key { + Some((SIZE_KEEPALIVE, false, false)) + } else { + None + }, + "transmission should only happen if key was set (keepalive)", + ); + assert!( + opaque.recv().is_none(), + "no messages should have been marked as received", + ); + } + } + } + } + + #[test] + fn test_bidirectional() { + init(); + + let tests = [ + ( + false, // confirm with keepalive + ("192.168.1.0", 24, "192.168.1.20", true), + ("172.133.133.133", 32, "172.133.133.133", true), + ), + ( + true, // confirm with staged packet + ("192.168.1.0", 24, "192.168.1.20", true), + ("172.133.133.133", 32, "172.133.133.133", true), + ), + ( + false, // confirm with keepalive + ( + "2001:db8::ff00:42:8000", + 113, + "2001:db8::ff00:42:ffff", + true, + ), + ( + "2001:db8::ff40:42:8000", + 113, + "2001:db8::ff40:42:ffff", + true, + ), + ), + ( + false, // confirm with staged packet + ( + "2001:db8::ff00:42:8000", + 113, + "2001:db8::ff00:42:ffff", + true, + ), + ( + "2001:db8::ff40:42:8000", + 113, + "2001:db8::ff40:42:ffff", + true, + ), + ), + ]; + + for (stage, p1, p2) in tests.iter() { + let ((bind_reader1, bind_writer1), (bind_reader2, bind_writer2)) = + dummy::PairBind::pair(); + + // create matching device + let (_fake, _, tun_writer1, _) = dummy::TunTest::create(1500, false); + let (_fake, _, tun_writer2, _) = dummy::TunTest::create(1500, false); + + let router1: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer1); + router1.set_outbound_writer(bind_writer1); + + let router2: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer2); + router2.set_outbound_writer(bind_writer2); + + // prepare opaque values for tracing callbacks + + let opaq1 = Opaque::new(); + let opaq2 = Opaque::new(); + + // create peers with matching keypairs and assign subnets + + let (mask, len, _ip, _okay) = p1; + let peer1 = router1.new_peer(opaq1.clone()); + let mask: IpAddr = mask.parse().unwrap(); + peer1.add_subnet(mask, *len); + peer1.add_keypair(dummy::keypair(false)); + + let (mask, len, _ip, _okay) = p2; + let peer2 = router2.new_peer(opaq2.clone()); + let mask: IpAddr = mask.parse().unwrap(); + peer2.add_subnet(mask, *len); + peer2.set_endpoint(dummy::UnitEndpoint::new()); + + if *stage { + // stage a packet which can be used for confirmation (in place of a keepalive) + let (_mask, _len, ip, _okay) = p2; + let msg = make_packet(1024, ip.parse().unwrap()); + router2.send(msg).expect("failed to sent staged packet"); + + wait(); + assert!(opaq2.recv().is_none()); + assert!( + opaq2.send().is_none(), + "sending should fail as not key is set" + ); + assert!( + opaq2.need_key().is_some(), + "a new key should be requested since a packet was attempted transmitted" + ); + assert!(opaq2.is_empty(), "callbacks should only run once"); + } + + // this should cause a key-confirmation packet (keepalive or staged packet) + // this also causes peer1 to learn the "endpoint" for peer2 + assert!(peer1.get_endpoint().is_none()); + peer2.add_keypair(dummy::keypair(true)); + + wait(); + assert!(opaq2.send().is_some()); + assert!(opaq2.is_empty(), "events on peer2 should be 'send'"); + assert!(opaq1.is_empty(), "nothing should happened on peer1"); + + // read confirming message received by the other end ("across the internet") + let mut buf = vec![0u8; 2048]; + let (len, from) = bind_reader1.read(&mut buf).unwrap(); + buf.truncate(len); + router1.recv(from, buf).unwrap(); + + wait(); + assert!(opaq1.recv().is_some()); + assert!(opaq1.key_confirmed().is_some()); + assert!( + opaq1.is_empty(), + "events on peer1 should be 'recv' and 'key_confirmed'" + ); + assert!(peer1.get_endpoint().is_some()); + assert!(opaq2.is_empty(), "nothing should happened on peer2"); + + // now that peer1 has an endpoint + // route packets : peer1 -> peer2 + + for _ in 0..10 { + assert!( + opaq1.is_empty(), + "we should have asserted a value for every callback on peer1" + ); + assert!( + opaq2.is_empty(), + "we should have asserted a value for every callback on peer2" + ); + + // pass IP packet to router + let (_mask, _len, ip, _okay) = p1; + let msg = make_packet(1024, ip.parse().unwrap()); + router1.send(msg).unwrap(); + + wait(); + assert!(opaq1.send().is_some()); + assert!(opaq1.recv().is_none()); + assert!(opaq1.need_key().is_none()); + + // receive ("across the internet") on the other end + let mut buf = vec![0u8; 2048]; + let (len, from) = bind_reader2.read(&mut buf).unwrap(); + buf.truncate(len); + router2.recv(from, buf).unwrap(); + + wait(); + assert!(opaq2.send().is_none()); + assert!(opaq2.recv().is_some()); + assert!(opaq2.need_key().is_none()); + } + } + } +} diff --git a/src/wireguard/router/types.rs b/src/wireguard/router/types.rs new file mode 100644 index 0000000..b7c3ae0 --- /dev/null +++ b/src/wireguard/router/types.rs @@ -0,0 +1,65 @@ +use std::error::Error; +use std::fmt; + +pub trait Opaque: Send + Sync + 'static {} + +impl<T> Opaque for T where T: Send + Sync + 'static {} + +/// A send/recv callback takes 3 arguments: +/// +/// * `0`, a reference to the opaque value assigned to the peer +/// * `1`, a bool indicating whether the message contained data (not just keepalive) +/// * `2`, a bool indicating whether the message was transmitted (i.e. did the peer have an associated endpoint?) +pub trait Callback<T>: Fn(&T, usize, bool, bool) -> () + Sync + Send + 'static {} + +impl<T, F> Callback<T> for F where F: Fn(&T, usize, bool, bool) -> () + Sync + Send + 'static {} + +/// A key callback takes 1 argument +/// +/// * `0`, a reference to the opaque value assigned to the peer +pub trait KeyCallback<T>: Fn(&T) -> () + Sync + Send + 'static {} + +impl<T, F> KeyCallback<T> for F where F: Fn(&T) -> () + Sync + Send + 'static {} + +pub trait Callbacks: Send + Sync + 'static { + type Opaque: Opaque; + fn send(opaque: &Self::Opaque, size: usize, data: bool, sent: bool); + fn recv(opaque: &Self::Opaque, size: usize, data: bool, sent: bool); + fn need_key(opaque: &Self::Opaque); + fn key_confirmed(opaque: &Self::Opaque); +} + +#[derive(Debug)] +pub enum RouterError { + NoCryptKeyRoute, + MalformedIPHeader, + MalformedTransportMessage, + UnknownReceiverId, + NoEndpoint, + SendError, +} + +impl fmt::Display for RouterError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + RouterError::NoCryptKeyRoute => write!(f, "No cryptkey route configured for subnet"), + RouterError::MalformedIPHeader => write!(f, "IP header is malformed"), + RouterError::MalformedTransportMessage => write!(f, "IP header is malformed"), + RouterError::UnknownReceiverId => { + write!(f, "No decryption state associated with receiver id") + } + RouterError::NoEndpoint => write!(f, "No endpoint for peer"), + RouterError::SendError => write!(f, "Failed to send packet on bind"), + } + } +} + +impl Error for RouterError { + fn description(&self) -> &str { + "Generic Handshake Error" + } + + fn source(&self) -> Option<&(dyn Error + 'static)> { + None + } +} diff --git a/src/wireguard/router/workers.rs b/src/wireguard/router/workers.rs new file mode 100644 index 0000000..2e89bb0 --- /dev/null +++ b/src/wireguard/router/workers.rs @@ -0,0 +1,305 @@ +use std::mem; +use std::sync::mpsc::Receiver; +use std::sync::Arc; + +use futures::sync::oneshot; +use futures::*; + +use log::debug; + +use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; +use std::net::{Ipv4Addr, Ipv6Addr}; +use std::sync::atomic::Ordering; +use zerocopy::{AsBytes, LayoutVerified}; + +use super::device::{DecryptionState, DeviceInner}; +use super::messages::{TransportHeader, TYPE_TRANSPORT}; +use super::peer::PeerInner; +use super::types::Callbacks; + +use super::super::types::{Endpoint, tun, bind}; +use super::ip::*; + +const SIZE_TAG: usize = 16; + +#[derive(PartialEq, Debug)] +pub enum Operation { + Encryption, + Decryption, +} + +pub struct JobBuffer { + pub msg: Vec<u8>, // message buffer (nonce and receiver id set) + pub key: [u8; 32], // chacha20poly1305 key + pub okay: bool, // state of the job + pub op: Operation, // should be buffer be encrypted / decrypted? +} + +pub type JobParallel = (oneshot::Sender<JobBuffer>, JobBuffer); + +#[allow(type_alias_bounds)] +pub type JobInbound<E, C, T, B: bind::Writer<E>> = ( + Arc<DecryptionState<E, C, T, B>>, + E, + oneshot::Receiver<JobBuffer>, +); + +pub type JobOutbound = oneshot::Receiver<JobBuffer>; + +#[inline(always)] +fn check_route<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( + device: &Arc<DeviceInner<E, C, T, B>>, + peer: &Arc<PeerInner<E, C, T, B>>, + packet: &[u8], +) -> Option<usize> { + match packet[0] >> 4 { + VERSION_IP4 => { + // check length and cast to IPv4 header + let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = + LayoutVerified::new_from_prefix(packet)?; + + // check IPv4 source address + device + .ipv4 + .read() + .longest_match(Ipv4Addr::from(header.f_source)) + .and_then(|(_, _, p)| { + if Arc::ptr_eq(p, &peer) { + Some(header.f_total_len.get() as usize) + } else { + None + } + }) + } + VERSION_IP6 => { + // check length and cast to IPv6 header + let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) = + LayoutVerified::new_from_prefix(packet)?; + + // check IPv6 source address + device + .ipv6 + .read() + .longest_match(Ipv6Addr::from(header.f_source)) + .and_then(|(_, _, p)| { + if Arc::ptr_eq(p, &peer) { + Some(header.f_len.get() as usize + mem::size_of::<IPv6Header>()) + } else { + None + } + }) + } + _ => None, + } +} + +pub fn worker_inbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( + device: Arc<DeviceInner<E, C, T, B>>, // related device + peer: Arc<PeerInner<E, C, T, B>>, // related peer + receiver: Receiver<JobInbound<E, C, T, B>>, +) { + loop { + // fetch job + let (state, endpoint, rx) = match receiver.recv() { + Ok(v) => v, + _ => { + return; + } + }; + debug!("inbound worker: obtained job"); + + // wait for job to complete + let _ = rx + .map(|buf| { + debug!("inbound worker: job complete"); + if buf.okay { + // cast transport header + let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) = + match LayoutVerified::new_from_prefix(&buf.msg[..]) { + Some(v) => v, + None => { + debug!("inbound worker: failed to parse message"); + return; + } + }; + + debug_assert!( + packet.len() >= CHACHA20_POLY1305.tag_len(), + "this should be checked earlier in the pipeline (decryption should fail)" + ); + + // check for replay + if !state.protector.lock().update(header.f_counter.get()) { + debug!("inbound worker: replay detected"); + return; + } + + // check for confirms key + if !state.confirmed.swap(true, Ordering::SeqCst) { + debug!("inbound worker: message confirms key"); + peer.confirm_key(&state.keypair); + } + + // update endpoint + *peer.endpoint.lock() = Some(endpoint); + + // calculate length of IP packet + padding + let length = packet.len() - SIZE_TAG; + debug!("inbound worker: plaintext length = {}", length); + + // check if should be written to TUN + let mut sent = false; + if length > 0 { + if let Some(inner_len) = check_route(&device, &peer, &packet[..length]) { + debug_assert!(inner_len <= length, "should be validated"); + if inner_len <= length { + sent = match device.inbound.write(&packet[..inner_len]) { + Err(e) => { + debug!("failed to write inbound packet to TUN: {:?}", e); + false + } + Ok(_) => true, + } + } + } + } else { + debug!("inbound worker: received keepalive") + } + + // trigger callback + C::recv(&peer.opaque, buf.msg.len(), length == 0, sent); + } else { + debug!("inbound worker: authentication failure") + } + }) + .wait(); + } +} + +pub fn worker_outbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( + device: Arc<DeviceInner<E, C, T, B>>, // related device + peer: Arc<PeerInner<E, C, T, B>>, // related peer + receiver: Receiver<JobOutbound>, +) { + loop { + // fetch job + let rx = match receiver.recv() { + Ok(v) => v, + _ => { + return; + } + }; + debug!("outbound worker: obtained job"); + + // wait for job to complete + let _ = rx + .map(|buf| { + debug!("outbound worker: job complete"); + if buf.okay { + // write to UDP bind + let xmit = if let Some(dst) = peer.endpoint.lock().as_ref() { + let send : &Option<B> = &*device.outbound.read(); + if let Some(writer) = send.as_ref() { + match writer.write(&buf.msg[..], dst) { + Err(e) => { + debug!("failed to send outbound packet: {:?}", e); + false + } + Ok(_) => true, + } + } else { + false + } + } else { + false + }; + + // trigger callback + C::send( + &peer.opaque, + buf.msg.len(), + buf.msg.len() > SIZE_TAG + mem::size_of::<TransportHeader>(), + xmit, + ); + } + }) + .wait(); + } +} + +pub fn worker_parallel(receiver: Receiver<JobParallel>) { + loop { + // fetch next job + let (tx, mut buf) = match receiver.recv() { + Err(_) => { + return; + } + Ok(val) => val, + }; + debug!("parallel worker: obtained job"); + + // make space for tag (TODO: consider moving this out) + if buf.op == Operation::Encryption { + buf.msg.extend([0u8; SIZE_TAG].iter()); + } + + // cast and check size of packet + let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = + match LayoutVerified::new_from_prefix(&mut buf.msg[..]) { + Some(v) => v, + None => { + debug_assert!( + false, + "parallel worker: failed to parse message (insufficient size)" + ); + continue; + } + }; + debug_assert!(packet.len() >= CHACHA20_POLY1305.tag_len()); + + // do the weird ring AEAD dance + let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &buf.key[..]).unwrap()); + + // create a nonce object + let mut nonce = [0u8; 12]; + debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); + nonce[4..].copy_from_slice(header.f_counter.as_bytes()); + let nonce = Nonce::assume_unique_for_key(nonce); + + match buf.op { + Operation::Encryption => { + debug!("parallel worker: process encryption"); + + // set the type field + header.f_type.set(TYPE_TRANSPORT); + + // encrypt content of transport message in-place + let end = packet.len() - SIZE_TAG; + let tag = key + .seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end]) + .unwrap(); + + // append tag + packet[end..].copy_from_slice(tag.as_ref()); + + buf.okay = true; + } + Operation::Decryption => { + debug!("parallel worker: process decryption"); + + // opening failure is signaled by fault state + buf.okay = match key.open_in_place(nonce, Aad::empty(), packet) { + Ok(_) => true, + Err(_) => false, + }; + } + } + + // pass ownership to consumer + let okay = tx.send(buf); + debug!( + "parallel worker: passing ownership to sequential worker: {}", + okay.is_ok() + ); + } +} |