diff options
Diffstat (limited to '')
-rw-r--r-- | Cargo.lock | 2 | ||||
-rw-r--r-- | Cargo.toml | 2 | ||||
-rw-r--r-- | src/platform/dummy/bind.rs | 29 | ||||
-rw-r--r-- | src/platform/dummy/tun.rs | 27 | ||||
-rw-r--r-- | src/wireguard/handshake/noise.rs | 8 | ||||
-rw-r--r-- | src/wireguard/router/device.rs | 10 | ||||
-rw-r--r-- | src/wireguard/router/peer.rs | 12 | ||||
-rw-r--r-- | src/wireguard/router/tests.rs | 8 | ||||
-rw-r--r-- | src/wireguard/router/types.rs | 4 | ||||
-rw-r--r-- | src/wireguard/tests.rs | 116 | ||||
-rw-r--r-- | src/wireguard/timers.rs | 39 | ||||
-rw-r--r-- | src/wireguard/wireguard.rs | 47 |
12 files changed, 242 insertions, 62 deletions
@@ -1608,6 +1608,8 @@ dependencies = [ "pnet 0.22.0 (registry+https://github.com/rust-lang/crates.io-index)", "proptest 0.9.4 (registry+https://github.com/rust-lang/crates.io-index)", "rand 0.6.5 (registry+https://github.com/rust-lang/crates.io-index)", + "rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", + "rand_core 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)", "ring 0.16.7 (registry+https://github.com/rust-lang/crates.io-index)", "spin 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", "subtle 2.1.1 (registry+https://github.com/rust-lang/crates.io-index)", @@ -46,3 +46,5 @@ features = ["nightly"] [dev-dependencies] proptest = "0.9.4" pnet = "^0.22" +rand_chacha = "0.2.1" +rand_core = "0.5" diff --git a/src/platform/dummy/bind.rs b/src/platform/dummy/bind.rs index 984b886..3497656 100644 --- a/src/platform/dummy/bind.rs +++ b/src/platform/dummy/bind.rs @@ -1,7 +1,12 @@ +use hex; use std::error::Error; use std::fmt; use std::marker; +use log::debug; +use rand::rngs::OsRng; +use rand::Rng; + use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; use std::sync::Arc; use std::sync::Mutex; @@ -95,6 +100,7 @@ impl VoidBind { #[derive(Clone)] pub struct PairReader<E> { + id: u32, recv: Arc<Mutex<Receiver<Vec<u8>>>>, _marker: marker::PhantomData<E>, } @@ -110,13 +116,25 @@ impl Reader<UnitEndpoint> for PairReader<UnitEndpoint> { .map_err(|_| BindError::Disconnected)?; let len = vec.len(); buf[..len].copy_from_slice(&vec[..]); - Ok((vec.len(), UnitEndpoint {})) + debug!( + "dummy({}): read ({}, {})", + self.id, + len, + hex::encode(&buf[..len]) + ); + Ok((len, UnitEndpoint {})) } } impl Writer<UnitEndpoint> for PairWriter<UnitEndpoint> { type Error = BindError; fn write(&self, buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> { + debug!( + "dummy({}): write ({}, {})", + self.id, + buf.len(), + hex::encode(buf) + ); let owned = buf.to_owned(); match self.send.lock().unwrap().send(owned) { Err(_) => Err(BindError::Disconnected), @@ -127,6 +145,7 @@ impl Writer<UnitEndpoint> for PairWriter<UnitEndpoint> { #[derive(Clone)] pub struct PairWriter<E> { + id: u32, send: Arc<Mutex<SyncSender<Vec<u8>>>>, _marker: marker::PhantomData<E>, } @@ -139,25 +158,33 @@ impl PairBind { (PairReader<E>, PairWriter<E>), (PairReader<E>, PairWriter<E>), ) { + let mut rng = OsRng::new().unwrap(); + let id1: u32 = rng.gen(); + let id2: u32 = rng.gen(); + let (tx1, rx1) = sync_channel(128); let (tx2, rx2) = sync_channel(128); ( ( PairReader { + id: id1, recv: Arc::new(Mutex::new(rx1)), _marker: marker::PhantomData, }, PairWriter { + id: id1, send: Arc::new(Mutex::new(tx2)), _marker: marker::PhantomData, }, ), ( PairReader { + id: id2, recv: Arc::new(Mutex::new(rx2)), _marker: marker::PhantomData, }, PairWriter { + id: id2, send: Arc::new(Mutex::new(tx1)), _marker: marker::PhantomData, }, diff --git a/src/platform/dummy/tun.rs b/src/platform/dummy/tun.rs index fb87d2f..185b328 100644 --- a/src/platform/dummy/tun.rs +++ b/src/platform/dummy/tun.rs @@ -1,3 +1,8 @@ +use hex; +use log::debug; +use rand::rngs::OsRng; +use rand::Rng; + use std::cmp::min; use std::error::Error; use std::fmt; @@ -61,16 +66,19 @@ impl fmt::Display for TunError { pub struct TunTest {} pub struct TunFakeIO { + id: u32, store: bool, tx: SyncSender<Vec<u8>>, rx: Receiver<Vec<u8>>, } pub struct TunReader { + id: u32, rx: Receiver<Vec<u8>>, } pub struct TunWriter { + id: u32, store: bool, tx: Mutex<SyncSender<Vec<u8>>>, } @@ -88,6 +96,12 @@ impl Reader for TunReader { Ok(msg) => { let n = min(buf.len() - offset, msg.len()); buf[offset..offset + n].copy_from_slice(&msg[..n]); + debug!( + "dummy::TUN({}) : read ({}, {})", + self.id, + n, + hex::encode(&buf[offset..offset + n]) + ); Ok(n) } Err(_) => Err(TunError::Disconnected), @@ -99,6 +113,12 @@ impl Writer for TunWriter { type Error = TunError; fn write(&self, src: &[u8]) -> Result<(), Self::Error> { + debug!( + "dummy::TUN({}) : write ({}, {})", + self.id, + src.len(), + hex::encode(src) + ); if self.store { let m = src.to_owned(); match self.tx.lock().unwrap().send(m) { @@ -149,13 +169,18 @@ impl TunTest { sync_channel(1) }; + let mut rng = OsRng::new().unwrap(); + let id: u32 = rng.gen(); + let fake = TunFakeIO { + id, tx: tx1, rx: rx2, store, }; - let reader = TunReader { rx: rx1 }; + let reader = TunReader { id, rx: rx1 }; let writer = TunWriter { + id, tx: Mutex::new(tx2), store, }; diff --git a/src/wireguard/handshake/noise.rs b/src/wireguard/handshake/noise.rs index a2a84b0..68e738d 100644 --- a/src/wireguard/handshake/noise.rs +++ b/src/wireguard/handshake/noise.rs @@ -12,6 +12,8 @@ use chacha20poly1305::ChaCha20Poly1305; use rand::{CryptoRng, RngCore}; +use log::debug; + use generic_array::typenum::*; use generic_array::*; @@ -27,7 +29,7 @@ use super::peer::{Peer, State}; use super::timestamp; use super::types::*; -use super::super::types::{KeyPair, Key}; +use super::super::types::{Key, KeyPair}; use std::time::Instant; @@ -222,6 +224,7 @@ pub fn create_initiation<R: RngCore + CryptoRng>( sender: u32, msg: &mut NoiseInitiation, ) -> Result<(), HandshakeError> { + debug!("create initation"); clear_stack_on_return(CLEAR_PAGES, || { // initialize state @@ -300,6 +303,7 @@ pub fn consume_initiation<'a>( device: &'a Device, msg: &NoiseInitiation, ) -> Result<(&'a Peer, TemporaryState), HandshakeError> { + debug!("consume initation"); clear_stack_on_return(CLEAR_PAGES, || { // initialize new state @@ -377,6 +381,7 @@ pub fn create_response<R: RngCore + CryptoRng>( state: TemporaryState, // state from "consume_initiation" msg: &mut NoiseResponse, // resulting response ) -> Result<KeyPair, HandshakeError> { + debug!("create response"); clear_stack_on_return(CLEAR_PAGES, || { // unpack state @@ -457,6 +462,7 @@ pub fn create_response<R: RngCore + CryptoRng>( * in order to better mitigate DoS from malformed response messages. */ pub fn consume_response(device: &Device, msg: &NoiseResponse) -> Result<Output, HandshakeError> { + debug!("consume response"); clear_stack_on_return(CLEAR_PAGES, || { // retrieve peer and copy initiation state let peer = device.lookup_id(msg.f_receiver.get())?; diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs index b122bf4..254b3de 100644 --- a/src/wireguard/router/device.rs +++ b/src/wireguard/router/device.rs @@ -89,13 +89,7 @@ fn get_route<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( device: &Arc<DeviceInner<E, C, T, B>>, packet: &[u8], ) -> Option<Arc<PeerInner<E, C, T, B>>> { - // ensure version access within bounds - if packet.len() < 1 { - return None; - }; - - // cast to correct IP header - match packet[0] >> 4 { + match packet.get(0)? >> 4 { VERSION_IP4 => { // check length and cast to IPv4 header let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = @@ -176,7 +170,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C, let packet = &msg[SIZE_MESSAGE_PREFIX..]; // lookup peer based on IP packet destination address - let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptKeyRoute)?; + let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptoKeyRoute)?; // schedule for encryption and transmission to peer if let Some(job) = peer.send_job(msg, true) { diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs index 0b193a4..66a6e9f 100644 --- a/src/wireguard/router/peer.rs +++ b/src/wireguard/router/peer.rs @@ -531,8 +531,8 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T /// /// If an identical value already exists as part of a prior peer, /// the allowed IP entry will be removed from that peer and added to this peer. - pub fn add_subnet(&self, ip: IpAddr, masklen: u32) { - debug!("peer.add_subnet"); + pub fn add_allowed_ips(&self, ip: IpAddr, masklen: u32) { + debug!("peer.add_allowed_ips"); match ip { IpAddr::V4(v4) => { self.state @@ -556,8 +556,8 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T /// # Returns /// /// A vector of subnets, represented by as mask/size - pub fn list_subnets(&self) -> Vec<(IpAddr, u32)> { - debug!("peer.list_subnets"); + pub fn list_allowed_ips(&self) -> Vec<(IpAddr, u32)> { + debug!("peer.list_allowed_ips"); let mut res = Vec::new(); res.append(&mut treebit_list( &self.state, @@ -575,8 +575,8 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T /// Clear subnets mapped to the peer. /// After the call, no subnets will be cryptkey routed to the peer. /// Used for the UAPI command "replace_allowed_ips=true" - pub fn remove_subnets(&self) { - debug!("peer.remove_subnets"); + pub fn remove_allowed_ips(&self) { + debug!("peer.remove_allowed_ips"); treebit_remove(self, &self.state.device.ipv4); treebit_remove(self, &self.state.device.ipv6); } diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs index d44a612..6184993 100644 --- a/src/wireguard/router/tests.rs +++ b/src/wireguard/router/tests.rs @@ -157,7 +157,7 @@ mod tests { let (mask, len, ip) = ("192.168.1.0", 24, "192.168.1.20"); let mask: IpAddr = mask.parse().unwrap(); let ip1: IpAddr = ip.parse().unwrap(); - peer.add_subnet(mask, len); + peer.add_allowed_ips(mask, len); // every iteration sends 10 GB b.iter(|| { @@ -215,7 +215,7 @@ mod tests { } // map subnet to peer - peer.add_subnet(mask, *len); + peer.add_allowed_ips(mask, *len); // create "IP packet" let msg = make_packet(1024, ip.parse().unwrap()); @@ -339,13 +339,13 @@ mod tests { let (mask, len, _ip, _okay) = p1; let peer1 = router1.new_peer(opaq1.clone()); let mask: IpAddr = mask.parse().unwrap(); - peer1.add_subnet(mask, *len); + peer1.add_allowed_ips(mask, *len); peer1.add_keypair(dummy_keypair(false)); let (mask, len, _ip, _okay) = p2; let peer2 = router2.new_peer(opaq2.clone()); let mask: IpAddr = mask.parse().unwrap(); - peer2.add_subnet(mask, *len); + peer2.add_allowed_ips(mask, *len); peer2.set_endpoint(dummy::UnitEndpoint::new()); if *stage { diff --git a/src/wireguard/router/types.rs b/src/wireguard/router/types.rs index 52ee4f1..9f769fe 100644 --- a/src/wireguard/router/types.rs +++ b/src/wireguard/router/types.rs @@ -31,7 +31,7 @@ pub trait Callbacks: Send + Sync + 'static { #[derive(Debug)] pub enum RouterError { - NoCryptKeyRoute, + NoCryptoKeyRoute, MalformedIPHeader, MalformedTransportMessage, UnknownReceiverId, @@ -42,7 +42,7 @@ pub enum RouterError { impl fmt::Display for RouterError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - RouterError::NoCryptKeyRoute => write!(f, "No cryptkey route configured for subnet"), + RouterError::NoCryptoKeyRoute => write!(f, "No cryptokey route configured for subnet"), RouterError::MalformedIPHeader => write!(f, "IP header is malformed"), RouterError::MalformedTransportMessage => write!(f, "IP header is malformed"), RouterError::UnknownReceiverId => { diff --git a/src/wireguard/tests.rs b/src/wireguard/tests.rs index 7c87056..28dedec 100644 --- a/src/wireguard/tests.rs +++ b/src/wireguard/tests.rs @@ -5,13 +5,23 @@ use std::net::IpAddr; use std::thread; use std::time::Duration; -use rand::rngs::OsRng; +use hex; + +use rand_chacha::ChaCha8Rng; +use rand_core::{RngCore, SeedableRng}; use x25519_dalek::{PublicKey, StaticSecret}; use pnet::packet::ipv4::MutableIpv4Packet; use pnet::packet::ipv6::MutableIpv6Packet; -fn make_packet(size: usize, src: IpAddr, dst: IpAddr) -> Vec<u8> { +fn make_packet(size: usize, src: IpAddr, dst: IpAddr, id: u64) -> Vec<u8> { + // expand pseudo random payload + let mut rng: _ = ChaCha8Rng::seed_from_u64(id); + let mut p: Vec<u8> = vec![]; + for _ in 0..size { + p.push(rng.next_u32() as u8); + } + // create "IP packet" let mut msg = Vec::with_capacity(size); msg.resize(size, 0); @@ -19,21 +29,25 @@ fn make_packet(size: usize, src: IpAddr, dst: IpAddr) -> Vec<u8> { IpAddr::V4(dst) => { let mut packet = MutableIpv4Packet::new(&mut msg[..]).unwrap(); packet.set_destination(dst); + packet.set_total_length(size as u16); packet.set_source(if let IpAddr::V4(src) = src { src } else { panic!("src.version != dst.version") }); + packet.set_payload(&p[..]); packet.set_version(4); } IpAddr::V6(dst) => { let mut packet = MutableIpv6Packet::new(&mut msg[..]).unwrap(); packet.set_destination(dst); + packet.set_payload_length((size - MutableIpv6Packet::minimum_packet_size()) as u16); packet.set_source(if let IpAddr::V6(src) = src { src } else { panic!("src.version != dst.version") }); + packet.set_payload(&p[..]); packet.set_version(6); } } @@ -55,7 +69,7 @@ fn wait() { fn test_pure_wireguard() { init(); - // create WG instances for fake TUN devices + // create WG instances for dummy TUN devices let (fake1, tun_reader1, tun_writer1, mtu1) = dummy::TunTest::create(1500, true); let wg1: Wireguard<dummy::TunTest, dummy::PairBind> = @@ -77,10 +91,20 @@ fn test_pure_wireguard() { // generate (public, pivate) key pairs - let mut rng = OsRng::new().unwrap(); - let sk1 = StaticSecret::new(&mut rng); - let sk2 = StaticSecret::new(&mut rng); + let sk1 = StaticSecret::from([ + 0x3f, 0x69, 0x86, 0xd1, 0xc0, 0xec, 0x25, 0xa0, 0x9c, 0x8e, 0x56, 0xb5, 0x1d, 0xb7, 0x3c, + 0xed, 0x56, 0x8e, 0x59, 0x9d, 0xd9, 0xc3, 0x98, 0x67, 0x74, 0x69, 0x90, 0xc3, 0x43, 0x36, + 0x78, 0x89, + ]); + + let sk2 = StaticSecret::from([ + 0xfb, 0xd1, 0xd6, 0xe4, 0x65, 0x06, 0xd2, 0xe5, 0xc5, 0xdf, 0x6e, 0xab, 0x51, 0x71, 0xd8, + 0x70, 0xb5, 0xb7, 0x77, 0x51, 0xb4, 0xbe, 0xfb, 0xbc, 0x88, 0x62, 0x40, 0xca, 0x2c, 0xc2, + 0x66, 0xe2, + ]); + let pk1 = PublicKey::from(&sk1); + let pk2 = PublicKey::from(&sk2); wg1.new_peer(pk2); @@ -94,21 +118,79 @@ fn test_pure_wireguard() { let peer2 = wg1.lookup_peer(&pk2).unwrap(); let peer1 = wg2.lookup_peer(&pk1).unwrap(); - peer1.router.add_subnet("192.168.2.0".parse().unwrap(), 24); - peer2.router.add_subnet("192.168.1.0".parse().unwrap(), 24); + peer1 + .router + .add_allowed_ips("192.168.1.0".parse().unwrap(), 24); + + peer2 + .router + .add_allowed_ips("192.168.2.0".parse().unwrap(), 24); - // set endpoints + // set endpoint (the other should be learned dynamically) - peer1.router.set_endpoint(dummy::UnitEndpoint::new()); peer2.router.set_endpoint(dummy::UnitEndpoint::new()); - // create IP packets (causing a new handshake) + let num_packets = 20; + + // send IP packets (causing a new handshake) + + { + let mut packets: Vec<Vec<u8>> = Vec::with_capacity(num_packets); + + for id in 0..num_packets { + packets.push(make_packet( + 50 + 50 * id as usize, // size + "192.168.1.20".parse().unwrap(), // src + "192.168.2.10".parse().unwrap(), // dst + id as u64, // prng seed + )); + } + + let mut backup = packets.clone(); + + while let Some(p) = packets.pop() { + fake1.write(p); + } - let packet_p1_to_p2 = make_packet( - 1000, - "192.168.2.20".parse().unwrap(), // src - "192.168.1.10".parse().unwrap(), // dst - ); + wait(); - fake1.write(packet_p1_to_p2); + while let Some(p) = backup.pop() { + assert_eq!( + hex::encode(fake2.read()), + hex::encode(p), + "Failed to receive valid IPv4 packet unmodified and in-order" + ); + } + } + + // send IP packets (other direction) + + { + let mut packets: Vec<Vec<u8>> = Vec::with_capacity(num_packets); + + for id in 0..num_packets { + packets.push(make_packet( + 50 + 50 * id as usize, // size + "192.168.2.10".parse().unwrap(), // src + "192.168.1.20".parse().unwrap(), // dst + (id + 100) as u64, // prng seed + )); + } + + let mut backup = packets.clone(); + + while let Some(p) = packets.pop() { + fake2.write(p); + } + + wait(); + + while let Some(p) = backup.pop() { + assert_eq!( + hex::encode(fake1.read()), + hex::encode(p), + "Failed to receive valid IPv4 packet unmodified and in-order" + ); + } + } } diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs index 5ebc746..3b16bf6 100644 --- a/src/wireguard/timers.rs +++ b/src/wireguard/timers.rs @@ -7,10 +7,10 @@ use log::info; use hjul::{Runner, Timer}; -use super::{bind, tun}; use super::constants::*; -use super::router::{Callbacks, message_data_len}; +use super::router::{message_data_len, Callbacks}; use super::wireguard::{Peer, PeerInner}; +use super::{bind, tun}; pub struct Timers { handshake_pending: AtomicBool, @@ -32,16 +32,20 @@ impl Timers { } } -impl <B: bind::Bind>PeerInner<B> { +impl<B: bind::Bind> PeerInner<B> { /* should be called after an authenticated data packet is sent */ pub fn timers_data_sent(&self) { - self.timers().new_handshake.start(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT); + self.timers() + .new_handshake + .start(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT); } /* should be called after an authenticated data packet is received */ pub fn timers_data_received(&self) { if !self.timers().send_keepalive.start(KEEPALIVE_TIMEOUT) { - self.timers().need_another_keepalive.store(true, Ordering::SeqCst) + self.timers() + .need_another_keepalive + .store(true, Ordering::SeqCst) } } @@ -74,7 +78,9 @@ impl <B: bind::Bind>PeerInner<B> { */ pub fn timers_handshake_complete(&self) { self.timers().handshake_attempts.store(0, Ordering::SeqCst); - self.timers().sent_lastminute_handshake.store(false, Ordering::SeqCst); + self.timers() + .sent_lastminute_handshake + .store(false, Ordering::SeqCst); // TODO: Store time in peer for config // self.walltime_last_handshake } @@ -92,7 +98,9 @@ impl <B: bind::Bind>PeerInner<B> { pub fn timers_any_authenticated_packet_traversal(&self) { let keepalive = self.keepalive.load(Ordering::Acquire); if keepalive > 0 { - self.timers().send_persistent_keepalive.reset(Duration::from_secs(keepalive as u64)); + self.timers() + .send_persistent_keepalive + .reset(Duration::from_secs(keepalive as u64)); } } @@ -149,11 +157,7 @@ impl Timers { new_handshake: { let peer = peer.clone(); runner.timer(move || { - info!( - "Retrying handshake with {}, because we stopped hearing back after {} seconds", - peer, - (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT).as_secs() - ); + info!("Initiate new handshake with {}", peer); peer.new_handshake(); peer.timers.read().handshake_begun(); }) @@ -171,10 +175,12 @@ impl Timers { if keepalive > 0 { peer.router.send_keepalive(); peer.timers().send_keepalive.stop(); - peer.timers().send_persistent_keepalive.start(Duration::from_secs(keepalive as u64)); + peer.timers() + .send_persistent_keepalive + .start(Duration::from_secs(keepalive as u64)); } }) - } + }, } } @@ -196,7 +202,8 @@ impl Timers { pub fn updated_persistent_keepalive(&self, keepalive: usize) { if keepalive > 0 { - self.send_persistent_keepalive.reset(Duration::from_secs(keepalive as u64)); + self.send_persistent_keepalive + .reset(Duration::from_secs(keepalive as u64)); } } @@ -210,7 +217,7 @@ impl Timers { new_handshake: runner.timer(|| {}), send_keepalive: runner.timer(|| {}), send_persistent_keepalive: runner.timer(|| {}), - zero_key_material: runner.timer(|| {}) + zero_key_material: runner.timer(|| {}), } } diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index 25544d9..233559e 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -21,6 +21,7 @@ use std::collections::HashMap; use log::debug; use rand::rngs::OsRng; +use rand::Rng; use spin::{Mutex, RwLock, RwLockReadGuard}; use byteorder::{ByteOrder, LittleEndian}; @@ -37,6 +38,8 @@ pub struct Peer<T: Tun, B: Bind> { } pub struct PeerInner<B: Bind> { + pub id: u64, + pub keepalive: AtomicUsize, // keepalive interval pub rx_bytes: AtomicU64, pub tx_bytes: AtomicU64, @@ -50,6 +53,9 @@ pub struct PeerInner<B: Bind> { } pub struct WireguardInner<T: Tun, B: Bind> { + // identifier (for logging) + id: u32, + // provides access to the MTU value of the tun device // (otherwise owned solely by the router and a dedicated read IO thread) mtu: T::MTU, @@ -96,7 +102,13 @@ impl<B: Bind> PeerInner<B> { impl<T: Tun, B: Bind> fmt::Display for Peer<T, B> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "peer()") + write!(f, "peer(id = {})", self.id) + } +} + +impl<T: Tun, B: Bind> fmt::Display for WireguardInner<T, B> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "wireguard({:x})", self.id) } } @@ -209,7 +221,9 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { } pub fn new_peer(&self, pk: PublicKey) { + let mut rng = OsRng::new().unwrap(); let state = Arc::new(PeerInner { + id: rng.gen(), pk, last_handshake: Mutex::new(SystemTime::UNIX_EPOCH), handshake_queued: AtomicBool::new(false), @@ -277,11 +291,17 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { handshake::TYPE_COOKIE_REPLY | handshake::TYPE_INITIATION | handshake::TYPE_RESPONSE => { + debug!("{} : reader, received handshake message", wg); + + let pending = wg.pending.fetch_add(1, Ordering::SeqCst); + // update under_load flag - if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD { + if pending > THRESHOLD_UNDER_LOAD { + debug!("{} : reader, set under load (pending = {})", wg, pending); last_under_load = Instant::now(); wg.under_load.store(true, Ordering::SeqCst); } else if last_under_load.elapsed() > DURATION_UNDER_LOAD { + debug!("{} : reader, clear under load", wg); wg.under_load.store(false, Ordering::SeqCst); } @@ -291,6 +311,8 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { .unwrap(); } router::TYPE_TRANSPORT => { + debug!("{} : reader, received transport message", wg); + // transport message let _ = wg.router.recv(src, msg).map_err(|e| { debug!("Failed to handle incoming transport message: {}", e); @@ -313,6 +335,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { let mut rng = OsRng::new().unwrap(); let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE); let wg = Arc::new(WireguardInner { + id: rng.gen(), mtu: mtu.clone(), peers: RwLock::new(HashMap::new()), send: RwLock::new(None), @@ -331,12 +354,13 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { let wg = wg.clone(); let rx = rx.clone(); thread::spawn(move || { + debug!("{} : handshake worker, started", wg); + // prepare OsRng instance for this thread let mut rng = OsRng::new().unwrap(); // process elements from the handshake queue for job in rx { - wg.pending.fetch_sub(1, Ordering::SeqCst); let state = wg.handshake.read(); if !state.active { continue; @@ -344,6 +368,8 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { match job { HandshakeJob::Message(msg, src) => { + wg.pending.fetch_sub(1, Ordering::SeqCst); + // feed message to handshake device let src_validate = (&src).into_address(); // TODO avoid @@ -352,6 +378,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { &mut rng, &msg[..], if wg.under_load.load(Ordering::Relaxed) { + debug!("{} : handshake worker, under load", wg); Some(&src_validate) } else { None @@ -364,9 +391,14 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { resp_len = msg.len() as u64; let send: &Option<B::Writer> = &*wg.send.read(); if let Some(writer) = send.as_ref() { + debug!( + "{} : handshake worker, send response ({} bytes)", + wg, resp_len + ); let _ = writer.write(&msg[..], &src).map_err(|e| { debug!( - "handshake worker, failed to send response, error = {}", + "{} : handshake worker, failed to send response, error = {}", + wg, e ) }); @@ -387,11 +419,13 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { // update timers after sending handshake response if resp_len > 0 { + debug!("{} : handshake worker, handshake response sent", wg); peer.state.sent_handshake_response(); } // add resulting keypair to peer keypair.map(|kp| { + debug!("{} : handshake worker, new keypair", wg); // free any unused ids for id in peer.router.add_keypair(kp) { state.device.release(id); @@ -400,14 +434,15 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { } } } - Err(e) => debug!("handshake worker, error = {:?}", e), + Err(e) => debug!("{} : handshake worker, error = {:?}", wg, e), } } HandshakeJob::New(pk) => { + debug!("{} : handshake worker, new handshake requested", wg); let _ = state.device.begin(&mut rng, &pk).map(|msg| { if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { let _ = peer.router.send(&msg[..]).map_err(|e| { - debug!("handshake worker, failed to send handshake initiation, error = {}", e) + debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e) }); peer.state.sent_handshake_initiation(); } |