From 3ba0247634bbaa1da61532ca43e67fb2ad6c1106 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Mon, 18 Nov 2019 13:13:55 +0100 Subject: Better compartmentalization of cryptokey router --- src/wireguard/router/device copy.rs | 228 +++++++++++++++++++++++++++++++++ src/wireguard/router/device.rs | 14 +- src/wireguard/router/peer.rs | 84 +----------- src/wireguard/router/route.rs | 246 ++++++++++++++++++++++-------------- src/wireguard/router/tests.rs | 4 +- src/wireguard/router/workers.rs | 4 +- 6 files changed, 395 insertions(+), 185 deletions(-) create mode 100644 src/wireguard/router/device copy.rs (limited to 'src') diff --git a/src/wireguard/router/device copy.rs b/src/wireguard/router/device copy.rs new file mode 100644 index 0000000..04b2045 --- /dev/null +++ b/src/wireguard/router/device copy.rs @@ -0,0 +1,228 @@ +use std::collections::HashMap; + +use std::net::{Ipv4Addr, Ipv6Addr}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::mpsc::sync_channel; +use std::sync::mpsc::SyncSender; +use std::sync::Arc; +use std::thread; +use std::time::Instant; + +use log::debug; +use spin::{Mutex, RwLock}; +use treebitmap::IpLookupTable; +use zerocopy::LayoutVerified; + +use super::anti_replay::AntiReplay; +use super::constants::*; + +use super::messages::{TransportHeader, TYPE_TRANSPORT}; +use super::peer::{new_peer, Peer, PeerInner}; +use super::types::{Callbacks, RouterError}; +use super::workers::{worker_parallel, JobParallel}; +use super::SIZE_MESSAGE_PREFIX; + +use super::route::get_route; + +use super::super::{bind, tun, Endpoint, KeyPair}; + +pub struct DeviceInner> { + // inbound writer (TUN) + pub inbound: T, + + // outbound writer (Bind) + pub outbound: RwLock<(bool, Option)>, + + // routing + pub recv: RwLock>>>, // receiver id -> decryption state + pub ipv4: RwLock>>>, // ipv4 cryptkey routing + pub ipv6: RwLock>>>, // ipv6 cryptkey routing + + // work queues + pub queue_next: AtomicUsize, // next round-robin index + pub queues: Mutex>>, // work queues (1 per thread) +} + +pub struct EncryptionState { + pub keypair: Arc, // keypair + pub nonce: u64, // next available nonce + pub death: Instant, // (birth + reject-after-time - keepalive-timeout - rekey-timeout) +} + +pub struct DecryptionState> { + pub keypair: Arc, + pub confirmed: AtomicBool, + pub protector: Mutex, + pub peer: Arc>, + pub death: Instant, // time when the key can no longer be used for decryption +} + +pub struct Device> { + state: Arc>, // reference to device state + handles: Vec>, // join handles for workers +} + +impl> Drop for Device { + fn drop(&mut self) { + debug!("router: dropping device"); + + // drop all queues + { + let mut queues = self.state.queues.lock(); + while queues.pop().is_some() {} + } + + // join all worker threads + while match self.handles.pop() { + Some(handle) => { + handle.thread().unpark(); + handle.join().unwrap(); + true + } + _ => false, + } {} + + debug!("router: device dropped"); + } +} + +impl> Device { + pub fn new(num_workers: usize, tun: T) -> Device { + // allocate shared device state + let inner = DeviceInner { + inbound: tun, + outbound: RwLock::new((true, None)), + queues: Mutex::new(Vec::with_capacity(num_workers)), + queue_next: AtomicUsize::new(0), + recv: RwLock::new(HashMap::new()), + ipv4: RwLock::new(IpLookupTable::new()), + ipv6: RwLock::new(IpLookupTable::new()), + }; + + // start worker threads + let mut threads = Vec::with_capacity(num_workers); + for _ in 0..num_workers { + let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE); + inner.queues.lock().push(tx); + threads.push(thread::spawn(move || worker_parallel(rx))); + } + + // return exported device handle + Device { + state: Arc::new(inner), + handles: threads, + } + } + + /// Brings the router down. + /// When the router is brought down it: + /// - Prevents transmission of outbound messages. + pub fn down(&self) { + self.state.outbound.write().0 = false; + } + + /// Brints the router up + /// When the router is brought up it enables the transmission of outbound messages. + pub fn up(&self) { + self.state.outbound.write().0 = true; + } + + /// A new secret key has been set for the device. + /// According to WireGuard semantics, this should cause all "sending" keys to be discarded. + pub fn new_sk(&self) {} + + /// Adds a new peer to the device + /// + /// # Returns + /// + /// A atomic ref. counted peer (with liftime matching the device) + pub fn new_peer(&self, opaque: C::Opaque) -> Peer { + new_peer(self.state.clone(), opaque) + } + + /// Cryptkey routes and sends a plaintext message (IP packet) + /// + /// # Arguments + /// + /// - msg: IP packet to crypt-key route + /// + pub fn send(&self, msg: Vec) -> Result<(), RouterError> { + debug_assert!(msg.len() > SIZE_MESSAGE_PREFIX); + log::trace!( + "Router, outbound packet = {}", + hex::encode(&msg[SIZE_MESSAGE_PREFIX..]) + ); + + // ignore header prefix (for in-place transport message construction) + let packet = &msg[SIZE_MESSAGE_PREFIX..]; + + // lookup peer based on IP packet destination address + let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptoKeyRoute)?; + + // schedule for encryption and transmission to peer + if let Some(job) = peer.send_job(msg, true) { + // add job to worker queue + let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst); + let queues = self.state.queues.lock(); + queues[idx % queues.len()].send(job).unwrap(); + } + + Ok(()) + } + + /// Receive an encrypted transport message + /// + /// # Arguments + /// + /// - src: Source address of the packet + /// - msg: Encrypted transport message + /// + /// # Returns + /// + /// + pub fn recv(&self, src: E, msg: Vec) -> Result<(), RouterError> { + // parse / cast + let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) { + Some(v) => v, + None => { + return Err(RouterError::MalformedTransportMessage); + } + }; + + let header: LayoutVerified<&[u8], TransportHeader> = header; + + debug_assert!( + header.f_type.get() == TYPE_TRANSPORT as u32, + "this should be checked by the message type multiplexer" + ); + + log::trace!( + "Router, handle transport message: (receiver = {}, counter = {})", + header.f_receiver, + header.f_counter + ); + + // lookup peer based on receiver id + let dec = self.state.recv.read(); + let dec = dec + .get(&header.f_receiver.get()) + .ok_or(RouterError::UnknownReceiverId)?; + + // schedule for decryption and TUN write + if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) { + // add job to worker queue + let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst); + let queues = self.state.queues.lock(); + queues[idx % queues.len()].send(job).unwrap(); + } + + Ok(()) + } + + /// Set outbound writer + /// + /// + pub fn set_outbound_writer(&self, new: B) { + self.state.outbound.write().1 = Some(new); + } +} diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs index 04b2045..7adcf8a 100644 --- a/src/wireguard/router/device.rs +++ b/src/wireguard/router/device.rs @@ -22,7 +22,7 @@ use super::types::{Callbacks, RouterError}; use super::workers::{worker_parallel, JobParallel}; use super::SIZE_MESSAGE_PREFIX; -use super::route::get_route; +use super::route::RoutingTable; use super::super::{bind, tun, Endpoint, KeyPair}; @@ -35,8 +35,7 @@ pub struct DeviceInner>>>, // receiver id -> decryption state - pub ipv4: RwLock>>>, // ipv4 cryptkey routing - pub ipv6: RwLock>>>, // ipv6 cryptkey routing + pub table: RoutingTable>, // work queues pub queue_next: AtomicUsize, // next round-robin index @@ -95,8 +94,7 @@ impl> Device> Device> Deref for Pe } } -fn treebit_list>( - peer: &Arc>, - table: &spin::RwLock>>>, - callback: Box R>, -) -> Vec -where - A: Address, -{ - let mut res = Vec::new(); - for subnet in table.read().iter() { - let (ip, masklen, p) = subnet; - if Arc::ptr_eq(&p, &peer) { - res.push(callback(ip, masklen)) - } - } - res -} - -fn treebit_remove>( - peer: &Peer, - table: &spin::RwLock>>>, -) { - let mut m = table.write(); - - // collect keys for value - let mut subnets = vec![]; - for subnet in m.iter() { - let (ip, masklen, p) = subnet; - if Arc::ptr_eq(&p, &peer.state) { - subnets.push((ip, masklen)) - } - } - - // remove all key mappings - for (ip, masklen) in subnets { - let r = m.remove(ip, masklen); - debug_assert!(r.is_some()); - } -} - impl EncryptionState { fn new(keypair: &Arc) -> EncryptionState { EncryptionState { @@ -134,8 +92,7 @@ impl> Drop for Pee // remove from cryptkey router - treebit_remove(self, &peer.device.ipv4); - treebit_remove(self, &peer.device.ipv6); + self.state.device.table.remove(peer); // drop channels @@ -560,23 +517,10 @@ impl> Peer { - self.state - .device - .ipv4 - .write() - .insert(v4.mask(masklen), masklen, self.state.clone()) - } - IpAddr::V6(v6) => { - self.state - .device - .ipv6 - .write() - .insert(v6.mask(masklen), masklen, self.state.clone()) - } - }; + self.state + .device + .table + .insert(ip, masklen, self.state.clone()) } /// List subnets mapped to the peer @@ -585,28 +529,14 @@ impl> Peer Vec<(IpAddr, u32)> { - debug!("peer.list_allowed_ips"); - let mut res = Vec::new(); - res.append(&mut treebit_list( - &self.state, - &self.state.device.ipv4, - Box::new(|ip, masklen| (IpAddr::V4(ip), masklen)), - )); - res.append(&mut treebit_list( - &self.state, - &self.state.device.ipv6, - Box::new(|ip, masklen| (IpAddr::V6(ip), masklen)), - )); - res + self.state.device.table.list(&self.state) } /// Clear subnets mapped to the peer. /// After the call, no subnets will be cryptkey routed to the peer. /// Used for the UAPI command "replace_allowed_ips=true" pub fn remove_allowed_ips(&self) { - debug!("peer.remove_allowed_ips"); - treebit_remove(self, &self.state.device.ipv4); - treebit_remove(self, &self.state.device.ipv6); + self.state.device.table.remove(&self.state) } pub fn clear_src(&self) { diff --git a/src/wireguard/router/route.rs b/src/wireguard/router/route.rs index 29e7635..e5f5955 100644 --- a/src/wireguard/router/route.rs +++ b/src/wireguard/router/route.rs @@ -1,113 +1,163 @@ -use super::super::{bind, tun, Endpoint}; -use super::device::DeviceInner; use super::ip::*; -use super::peer::PeerInner; -use super::types::Callbacks; -use log::trace; use zerocopy::LayoutVerified; use std::mem; -use std::net::{Ipv4Addr, Ipv6Addr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::sync::Arc; -#[inline(always)] -pub fn get_route>( - device: &Arc>, - packet: &[u8], -) -> Option>> { - match packet.get(0)? >> 4 { - VERSION_IP4 => { - // check length and cast to IPv4 header - let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = - LayoutVerified::new_from_prefix(packet)?; - - log::trace!( - "Router, get route for IPv4 destination: {:?}", - Ipv4Addr::from(header.f_destination) - ); - - // check IPv4 source address - device - .ipv4 - .read() - .longest_match(Ipv4Addr::from(header.f_destination)) - .and_then(|(_, _, p)| Some(p.clone())) +use spin::RwLock; +use treebitmap::address::Address; +use treebitmap::IpLookupTable; + +/* Functions for obtaining and validating "cryptokey" routes */ + +pub struct RoutingTable { + ipv4: RwLock>>, + ipv6: RwLock>>, +} + +impl RoutingTable { + pub fn new() -> Self { + RoutingTable { + ipv4: RwLock::new(IpLookupTable::new()), + ipv6: RwLock::new(IpLookupTable::new()), } - VERSION_IP6 => { - // check length and cast to IPv6 header - let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) = - LayoutVerified::new_from_prefix(packet)?; - - log::trace!( - "Router, get route for IPv6 destination: {:?}", - Ipv6Addr::from(header.f_destination) - ); - - // check IPv6 source address - device - .ipv6 - .read() - .longest_match(Ipv6Addr::from(header.f_destination)) - .and_then(|(_, _, p)| Some(p.clone())) + } + + fn collect(table: &IpLookupTable>, value: &Arc) -> Vec<(A, u32)> + where + A: Address, + { + let mut res = Vec::new(); + for (ip, cidr, v) in table.iter() { + if Arc::ptr_eq(v, value) { + res.push((ip, cidr)) + } } - _ => None, + res } -} -#[inline(always)] -pub fn check_route>( - device: &Arc>, - peer: &Arc>, - packet: &[u8], -) -> Option { - match packet.get(0)? >> 4 { - VERSION_IP4 => { - // check length and cast to IPv4 header - let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = - LayoutVerified::new_from_prefix(packet)?; - - log::trace!( - "Router, check route for IPv4 source: {:?}", - Ipv4Addr::from(header.f_source) - ); - - // check IPv4 source address - device - .ipv4 - .read() - .longest_match(Ipv4Addr::from(header.f_source)) - .and_then(|(_, _, p)| { - if Arc::ptr_eq(p, peer) { - Some(header.f_total_len.get() as usize) - } else { - None - } - }) + pub fn list(&self, value: &Arc) -> Vec<(IpAddr, u32)> { + let mut res = vec![]; + res.extend( + Self::collect(&*self.ipv4.read(), value) + .into_iter() + .map(|(ip, cidr)| (IpAddr::V4(ip), cidr)), + ); + res.extend( + Self::collect(&*self.ipv6.read(), value) + .into_iter() + .map(|(ip, cidr)| (IpAddr::V6(ip), cidr)), + ); + res + } + + pub fn remove(&self, value: &Arc) { + let mut v4 = self.ipv4.write(); + let mut v6 = self.ipv6.write(); + for (ip, cidr) in Self::collect(&*v4, value) { + v4.remove(ip, cidr); } - VERSION_IP6 => { - // check length and cast to IPv6 header - let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) = - LayoutVerified::new_from_prefix(packet)?; - - log::trace!( - "Router, check route for IPv6 source: {:?}", - Ipv6Addr::from(header.f_source) - ); - - // check IPv6 source address - device - .ipv6 - .read() - .longest_match(Ipv6Addr::from(header.f_source)) - .and_then(|(_, _, p)| { - if Arc::ptr_eq(p, peer) { - Some(header.f_len.get() as usize + mem::size_of::()) - } else { - None - } - }) + for (ip, cidr) in Self::collect(&*v6, value) { + v6.remove(ip, cidr); } - _ => None, + } + + #[inline(always)] + pub fn get_route(&self, packet: &[u8]) -> Option> { + match packet.get(0)? >> 4 { + VERSION_IP4 => { + // check length and cast to IPv4 header + let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = + LayoutVerified::new_from_prefix(packet)?; + + log::trace!( + "Router, get route for IPv4 destination: {:?}", + Ipv4Addr::from(header.f_destination) + ); + + // check IPv4 source address + self.ipv4 + .read() + .longest_match(Ipv4Addr::from(header.f_destination)) + .and_then(|(_, _, p)| Some(p.clone())) + } + VERSION_IP6 => { + // check length and cast to IPv6 header + let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) = + LayoutVerified::new_from_prefix(packet)?; + + log::trace!( + "Router, get route for IPv6 destination: {:?}", + Ipv6Addr::from(header.f_destination) + ); + + // check IPv6 source address + self.ipv6 + .read() + .longest_match(Ipv6Addr::from(header.f_destination)) + .and_then(|(_, _, p)| Some(p.clone())) + } + _ => None, + } + } + + #[inline(always)] + pub fn check_route(&self, peer: &Arc, packet: &[u8]) -> Option { + match packet.get(0)? >> 4 { + VERSION_IP4 => { + // check length and cast to IPv4 header + let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = + LayoutVerified::new_from_prefix(packet)?; + + log::trace!( + "Router, check route for IPv4 source: {:?}", + Ipv4Addr::from(header.f_source) + ); + + // check IPv4 source address + self.ipv4 + .read() + .longest_match(Ipv4Addr::from(header.f_source)) + .and_then(|(_, _, p)| { + if Arc::ptr_eq(p, peer) { + Some(header.f_total_len.get() as usize) + } else { + None + } + }) + } + VERSION_IP6 => { + // check length and cast to IPv6 header + let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) = + LayoutVerified::new_from_prefix(packet)?; + + log::trace!( + "Router, check route for IPv6 source: {:?}", + Ipv6Addr::from(header.f_source) + ); + + // check IPv6 source address + self.ipv6 + .read() + .longest_match(Ipv6Addr::from(header.f_source)) + .and_then(|(_, _, p)| { + if Arc::ptr_eq(p, peer) { + Some(header.f_len.get() as usize + mem::size_of::()) + } else { + None + } + }) + } + _ => None, + } + } + + pub fn insert(&self, ip: IpAddr, cidr: u32, value: Arc) { + match ip { + IpAddr::V4(v4) => self.ipv4.write().insert(v4.mask(cidr), cidr, value), + IpAddr::V6(v6) => self.ipv6.write().insert(v6.mask(cidr), cidr, value), + }; } } diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs index d14b438..24c1b56 100644 --- a/src/wireguard/router/tests.rs +++ b/src/wireguard/router/tests.rs @@ -86,11 +86,11 @@ mod tests { impl Callbacks for TestCallbacks { type Opaque = Opaque; - fn send(t: &Self::Opaque, size: usize, sent: bool, keypair: &Arc, counter: u64) { + fn send(t: &Self::Opaque, size: usize, sent: bool, _keypair: &Arc, _counter: u64) { t.0.send.lock().unwrap().push((size, sent)) } - fn recv(t: &Self::Opaque, size: usize, sent: bool, keypair: &Arc) { + fn recv(t: &Self::Opaque, size: usize, sent: bool, _keypair: &Arc) { t.0.recv.lock().unwrap().push((size, sent)) } diff --git a/src/wireguard/router/workers.rs b/src/wireguard/router/workers.rs index d87174f..cd8015b 100644 --- a/src/wireguard/router/workers.rs +++ b/src/wireguard/router/workers.rs @@ -14,7 +14,6 @@ use zerocopy::{AsBytes, LayoutVerified}; use super::device::{DecryptionState, DeviceInner}; use super::messages::{TransportHeader, TYPE_TRANSPORT}; use super::peer::PeerInner; -use super::route::check_route; use super::types::Callbacks; use super::REJECT_AFTER_MESSAGES; @@ -108,7 +107,8 @@ pub fn worker_inbound 0 { - if let Some(inner_len) = check_route(&device, &peer, &packet[..length]) { + if let Some(inner_len) = device.table.check_route(&peer, &packet[..length]) + { // TODO: Consider moving the cryptkey route check to parallel decryption worker debug_assert!(inner_len <= length, "should be validated earlier"); if inner_len <= length { -- cgit v1.2.3-59-g8ed1b