From 4ff328b7da876fb3305fefd83865553af9c8ab2c Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Mon, 28 Oct 2019 14:48:24 +0100 Subject: First full test of pure WireGuard --- src/platform/dummy/bind.rs | 29 +++++++++- src/platform/dummy/tun.rs | 27 ++++++++- src/wireguard/handshake/noise.rs | 8 ++- src/wireguard/router/device.rs | 10 +--- src/wireguard/router/peer.rs | 12 ++-- src/wireguard/router/tests.rs | 8 +-- src/wireguard/router/types.rs | 4 +- src/wireguard/tests.rs | 116 +++++++++++++++++++++++++++++++++------ src/wireguard/timers.rs | 39 +++++++------ src/wireguard/wireguard.rs | 47 ++++++++++++++-- 10 files changed, 238 insertions(+), 62 deletions(-) (limited to 'src') diff --git a/src/platform/dummy/bind.rs b/src/platform/dummy/bind.rs index 984b886..3497656 100644 --- a/src/platform/dummy/bind.rs +++ b/src/platform/dummy/bind.rs @@ -1,7 +1,12 @@ +use hex; use std::error::Error; use std::fmt; use std::marker; +use log::debug; +use rand::rngs::OsRng; +use rand::Rng; + use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; use std::sync::Arc; use std::sync::Mutex; @@ -95,6 +100,7 @@ impl VoidBind { #[derive(Clone)] pub struct PairReader { + id: u32, recv: Arc>>>, _marker: marker::PhantomData, } @@ -110,13 +116,25 @@ impl Reader for PairReader { .map_err(|_| BindError::Disconnected)?; let len = vec.len(); buf[..len].copy_from_slice(&vec[..]); - Ok((vec.len(), UnitEndpoint {})) + debug!( + "dummy({}): read ({}, {})", + self.id, + len, + hex::encode(&buf[..len]) + ); + Ok((len, UnitEndpoint {})) } } impl Writer for PairWriter { type Error = BindError; fn write(&self, buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> { + debug!( + "dummy({}): write ({}, {})", + self.id, + buf.len(), + hex::encode(buf) + ); let owned = buf.to_owned(); match self.send.lock().unwrap().send(owned) { Err(_) => Err(BindError::Disconnected), @@ -127,6 +145,7 @@ impl Writer for PairWriter { #[derive(Clone)] pub struct PairWriter { + id: u32, send: Arc>>>, _marker: marker::PhantomData, } @@ -139,25 +158,33 @@ impl PairBind { (PairReader, PairWriter), (PairReader, PairWriter), ) { + let mut rng = OsRng::new().unwrap(); + let id1: u32 = rng.gen(); + let id2: u32 = rng.gen(); + let (tx1, rx1) = sync_channel(128); let (tx2, rx2) = sync_channel(128); ( ( PairReader { + id: id1, recv: Arc::new(Mutex::new(rx1)), _marker: marker::PhantomData, }, PairWriter { + id: id1, send: Arc::new(Mutex::new(tx2)), _marker: marker::PhantomData, }, ), ( PairReader { + id: id2, recv: Arc::new(Mutex::new(rx2)), _marker: marker::PhantomData, }, PairWriter { + id: id2, send: Arc::new(Mutex::new(tx1)), _marker: marker::PhantomData, }, diff --git a/src/platform/dummy/tun.rs b/src/platform/dummy/tun.rs index fb87d2f..185b328 100644 --- a/src/platform/dummy/tun.rs +++ b/src/platform/dummy/tun.rs @@ -1,3 +1,8 @@ +use hex; +use log::debug; +use rand::rngs::OsRng; +use rand::Rng; + use std::cmp::min; use std::error::Error; use std::fmt; @@ -61,16 +66,19 @@ impl fmt::Display for TunError { pub struct TunTest {} pub struct TunFakeIO { + id: u32, store: bool, tx: SyncSender>, rx: Receiver>, } pub struct TunReader { + id: u32, rx: Receiver>, } pub struct TunWriter { + id: u32, store: bool, tx: Mutex>>, } @@ -88,6 +96,12 @@ impl Reader for TunReader { Ok(msg) => { let n = min(buf.len() - offset, msg.len()); buf[offset..offset + n].copy_from_slice(&msg[..n]); + debug!( + "dummy::TUN({}) : read ({}, {})", + self.id, + n, + hex::encode(&buf[offset..offset + n]) + ); Ok(n) } Err(_) => Err(TunError::Disconnected), @@ -99,6 +113,12 @@ impl Writer for TunWriter { type Error = TunError; fn write(&self, src: &[u8]) -> Result<(), Self::Error> { + debug!( + "dummy::TUN({}) : write ({}, {})", + self.id, + src.len(), + hex::encode(src) + ); if self.store { let m = src.to_owned(); match self.tx.lock().unwrap().send(m) { @@ -149,13 +169,18 @@ impl TunTest { sync_channel(1) }; + let mut rng = OsRng::new().unwrap(); + let id: u32 = rng.gen(); + let fake = TunFakeIO { + id, tx: tx1, rx: rx2, store, }; - let reader = TunReader { rx: rx1 }; + let reader = TunReader { id, rx: rx1 }; let writer = TunWriter { + id, tx: Mutex::new(tx2), store, }; diff --git a/src/wireguard/handshake/noise.rs b/src/wireguard/handshake/noise.rs index a2a84b0..68e738d 100644 --- a/src/wireguard/handshake/noise.rs +++ b/src/wireguard/handshake/noise.rs @@ -12,6 +12,8 @@ use chacha20poly1305::ChaCha20Poly1305; use rand::{CryptoRng, RngCore}; +use log::debug; + use generic_array::typenum::*; use generic_array::*; @@ -27,7 +29,7 @@ use super::peer::{Peer, State}; use super::timestamp; use super::types::*; -use super::super::types::{KeyPair, Key}; +use super::super::types::{Key, KeyPair}; use std::time::Instant; @@ -222,6 +224,7 @@ pub fn create_initiation( sender: u32, msg: &mut NoiseInitiation, ) -> Result<(), HandshakeError> { + debug!("create initation"); clear_stack_on_return(CLEAR_PAGES, || { // initialize state @@ -300,6 +303,7 @@ pub fn consume_initiation<'a>( device: &'a Device, msg: &NoiseInitiation, ) -> Result<(&'a Peer, TemporaryState), HandshakeError> { + debug!("consume initation"); clear_stack_on_return(CLEAR_PAGES, || { // initialize new state @@ -377,6 +381,7 @@ pub fn create_response( state: TemporaryState, // state from "consume_initiation" msg: &mut NoiseResponse, // resulting response ) -> Result { + debug!("create response"); clear_stack_on_return(CLEAR_PAGES, || { // unpack state @@ -457,6 +462,7 @@ pub fn create_response( * in order to better mitigate DoS from malformed response messages. */ pub fn consume_response(device: &Device, msg: &NoiseResponse) -> Result { + debug!("consume response"); clear_stack_on_return(CLEAR_PAGES, || { // retrieve peer and copy initiation state let peer = device.lookup_id(msg.f_receiver.get())?; diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs index b122bf4..254b3de 100644 --- a/src/wireguard/router/device.rs +++ b/src/wireguard/router/device.rs @@ -89,13 +89,7 @@ fn get_route>( device: &Arc>, packet: &[u8], ) -> Option>> { - // ensure version access within bounds - if packet.len() < 1 { - return None; - }; - - // cast to correct IP header - match packet[0] >> 4 { + match packet.get(0)? >> 4 { VERSION_IP4 => { // check length and cast to IPv4 header let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = @@ -176,7 +170,7 @@ impl> Device> Peer { self.state @@ -556,8 +556,8 @@ impl> Peer Vec<(IpAddr, u32)> { - debug!("peer.list_subnets"); + pub fn list_allowed_ips(&self) -> Vec<(IpAddr, u32)> { + debug!("peer.list_allowed_ips"); let mut res = Vec::new(); res.append(&mut treebit_list( &self.state, @@ -575,8 +575,8 @@ impl> Peer) -> fmt::Result { match self { - RouterError::NoCryptKeyRoute => write!(f, "No cryptkey route configured for subnet"), + RouterError::NoCryptoKeyRoute => write!(f, "No cryptokey route configured for subnet"), RouterError::MalformedIPHeader => write!(f, "IP header is malformed"), RouterError::MalformedTransportMessage => write!(f, "IP header is malformed"), RouterError::UnknownReceiverId => { diff --git a/src/wireguard/tests.rs b/src/wireguard/tests.rs index 7c87056..28dedec 100644 --- a/src/wireguard/tests.rs +++ b/src/wireguard/tests.rs @@ -5,13 +5,23 @@ use std::net::IpAddr; use std::thread; use std::time::Duration; -use rand::rngs::OsRng; +use hex; + +use rand_chacha::ChaCha8Rng; +use rand_core::{RngCore, SeedableRng}; use x25519_dalek::{PublicKey, StaticSecret}; use pnet::packet::ipv4::MutableIpv4Packet; use pnet::packet::ipv6::MutableIpv6Packet; -fn make_packet(size: usize, src: IpAddr, dst: IpAddr) -> Vec { +fn make_packet(size: usize, src: IpAddr, dst: IpAddr, id: u64) -> Vec { + // expand pseudo random payload + let mut rng: _ = ChaCha8Rng::seed_from_u64(id); + let mut p: Vec = vec![]; + for _ in 0..size { + p.push(rng.next_u32() as u8); + } + // create "IP packet" let mut msg = Vec::with_capacity(size); msg.resize(size, 0); @@ -19,21 +29,25 @@ fn make_packet(size: usize, src: IpAddr, dst: IpAddr) -> Vec { IpAddr::V4(dst) => { let mut packet = MutableIpv4Packet::new(&mut msg[..]).unwrap(); packet.set_destination(dst); + packet.set_total_length(size as u16); packet.set_source(if let IpAddr::V4(src) = src { src } else { panic!("src.version != dst.version") }); + packet.set_payload(&p[..]); packet.set_version(4); } IpAddr::V6(dst) => { let mut packet = MutableIpv6Packet::new(&mut msg[..]).unwrap(); packet.set_destination(dst); + packet.set_payload_length((size - MutableIpv6Packet::minimum_packet_size()) as u16); packet.set_source(if let IpAddr::V6(src) = src { src } else { panic!("src.version != dst.version") }); + packet.set_payload(&p[..]); packet.set_version(6); } } @@ -55,7 +69,7 @@ fn wait() { fn test_pure_wireguard() { init(); - // create WG instances for fake TUN devices + // create WG instances for dummy TUN devices let (fake1, tun_reader1, tun_writer1, mtu1) = dummy::TunTest::create(1500, true); let wg1: Wireguard = @@ -77,10 +91,20 @@ fn test_pure_wireguard() { // generate (public, pivate) key pairs - let mut rng = OsRng::new().unwrap(); - let sk1 = StaticSecret::new(&mut rng); - let sk2 = StaticSecret::new(&mut rng); + let sk1 = StaticSecret::from([ + 0x3f, 0x69, 0x86, 0xd1, 0xc0, 0xec, 0x25, 0xa0, 0x9c, 0x8e, 0x56, 0xb5, 0x1d, 0xb7, 0x3c, + 0xed, 0x56, 0x8e, 0x59, 0x9d, 0xd9, 0xc3, 0x98, 0x67, 0x74, 0x69, 0x90, 0xc3, 0x43, 0x36, + 0x78, 0x89, + ]); + + let sk2 = StaticSecret::from([ + 0xfb, 0xd1, 0xd6, 0xe4, 0x65, 0x06, 0xd2, 0xe5, 0xc5, 0xdf, 0x6e, 0xab, 0x51, 0x71, 0xd8, + 0x70, 0xb5, 0xb7, 0x77, 0x51, 0xb4, 0xbe, 0xfb, 0xbc, 0x88, 0x62, 0x40, 0xca, 0x2c, 0xc2, + 0x66, 0xe2, + ]); + let pk1 = PublicKey::from(&sk1); + let pk2 = PublicKey::from(&sk2); wg1.new_peer(pk2); @@ -94,21 +118,79 @@ fn test_pure_wireguard() { let peer2 = wg1.lookup_peer(&pk2).unwrap(); let peer1 = wg2.lookup_peer(&pk1).unwrap(); - peer1.router.add_subnet("192.168.2.0".parse().unwrap(), 24); - peer2.router.add_subnet("192.168.1.0".parse().unwrap(), 24); + peer1 + .router + .add_allowed_ips("192.168.1.0".parse().unwrap(), 24); + + peer2 + .router + .add_allowed_ips("192.168.2.0".parse().unwrap(), 24); - // set endpoints + // set endpoint (the other should be learned dynamically) - peer1.router.set_endpoint(dummy::UnitEndpoint::new()); peer2.router.set_endpoint(dummy::UnitEndpoint::new()); - // create IP packets (causing a new handshake) + let num_packets = 20; + + // send IP packets (causing a new handshake) + + { + let mut packets: Vec> = Vec::with_capacity(num_packets); + + for id in 0..num_packets { + packets.push(make_packet( + 50 + 50 * id as usize, // size + "192.168.1.20".parse().unwrap(), // src + "192.168.2.10".parse().unwrap(), // dst + id as u64, // prng seed + )); + } + + let mut backup = packets.clone(); + + while let Some(p) = packets.pop() { + fake1.write(p); + } - let packet_p1_to_p2 = make_packet( - 1000, - "192.168.2.20".parse().unwrap(), // src - "192.168.1.10".parse().unwrap(), // dst - ); + wait(); - fake1.write(packet_p1_to_p2); + while let Some(p) = backup.pop() { + assert_eq!( + hex::encode(fake2.read()), + hex::encode(p), + "Failed to receive valid IPv4 packet unmodified and in-order" + ); + } + } + + // send IP packets (other direction) + + { + let mut packets: Vec> = Vec::with_capacity(num_packets); + + for id in 0..num_packets { + packets.push(make_packet( + 50 + 50 * id as usize, // size + "192.168.2.10".parse().unwrap(), // src + "192.168.1.20".parse().unwrap(), // dst + (id + 100) as u64, // prng seed + )); + } + + let mut backup = packets.clone(); + + while let Some(p) = packets.pop() { + fake2.write(p); + } + + wait(); + + while let Some(p) = backup.pop() { + assert_eq!( + hex::encode(fake1.read()), + hex::encode(p), + "Failed to receive valid IPv4 packet unmodified and in-order" + ); + } + } } diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs index 5ebc746..3b16bf6 100644 --- a/src/wireguard/timers.rs +++ b/src/wireguard/timers.rs @@ -7,10 +7,10 @@ use log::info; use hjul::{Runner, Timer}; -use super::{bind, tun}; use super::constants::*; -use super::router::{Callbacks, message_data_len}; +use super::router::{message_data_len, Callbacks}; use super::wireguard::{Peer, PeerInner}; +use super::{bind, tun}; pub struct Timers { handshake_pending: AtomicBool, @@ -32,16 +32,20 @@ impl Timers { } } -impl PeerInner { +impl PeerInner { /* should be called after an authenticated data packet is sent */ pub fn timers_data_sent(&self) { - self.timers().new_handshake.start(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT); + self.timers() + .new_handshake + .start(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT); } /* should be called after an authenticated data packet is received */ pub fn timers_data_received(&self) { if !self.timers().send_keepalive.start(KEEPALIVE_TIMEOUT) { - self.timers().need_another_keepalive.store(true, Ordering::SeqCst) + self.timers() + .need_another_keepalive + .store(true, Ordering::SeqCst) } } @@ -74,7 +78,9 @@ impl PeerInner { */ pub fn timers_handshake_complete(&self) { self.timers().handshake_attempts.store(0, Ordering::SeqCst); - self.timers().sent_lastminute_handshake.store(false, Ordering::SeqCst); + self.timers() + .sent_lastminute_handshake + .store(false, Ordering::SeqCst); // TODO: Store time in peer for config // self.walltime_last_handshake } @@ -92,7 +98,9 @@ impl PeerInner { pub fn timers_any_authenticated_packet_traversal(&self) { let keepalive = self.keepalive.load(Ordering::Acquire); if keepalive > 0 { - self.timers().send_persistent_keepalive.reset(Duration::from_secs(keepalive as u64)); + self.timers() + .send_persistent_keepalive + .reset(Duration::from_secs(keepalive as u64)); } } @@ -149,11 +157,7 @@ impl Timers { new_handshake: { let peer = peer.clone(); runner.timer(move || { - info!( - "Retrying handshake with {}, because we stopped hearing back after {} seconds", - peer, - (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT).as_secs() - ); + info!("Initiate new handshake with {}", peer); peer.new_handshake(); peer.timers.read().handshake_begun(); }) @@ -171,10 +175,12 @@ impl Timers { if keepalive > 0 { peer.router.send_keepalive(); peer.timers().send_keepalive.stop(); - peer.timers().send_persistent_keepalive.start(Duration::from_secs(keepalive as u64)); + peer.timers() + .send_persistent_keepalive + .start(Duration::from_secs(keepalive as u64)); } }) - } + }, } } @@ -196,7 +202,8 @@ impl Timers { pub fn updated_persistent_keepalive(&self, keepalive: usize) { if keepalive > 0 { - self.send_persistent_keepalive.reset(Duration::from_secs(keepalive as u64)); + self.send_persistent_keepalive + .reset(Duration::from_secs(keepalive as u64)); } } @@ -210,7 +217,7 @@ impl Timers { new_handshake: runner.timer(|| {}), send_keepalive: runner.timer(|| {}), send_persistent_keepalive: runner.timer(|| {}), - zero_key_material: runner.timer(|| {}) + zero_key_material: runner.timer(|| {}), } } diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index 25544d9..233559e 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -21,6 +21,7 @@ use std::collections::HashMap; use log::debug; use rand::rngs::OsRng; +use rand::Rng; use spin::{Mutex, RwLock, RwLockReadGuard}; use byteorder::{ByteOrder, LittleEndian}; @@ -37,6 +38,8 @@ pub struct Peer { } pub struct PeerInner { + pub id: u64, + pub keepalive: AtomicUsize, // keepalive interval pub rx_bytes: AtomicU64, pub tx_bytes: AtomicU64, @@ -50,6 +53,9 @@ pub struct PeerInner { } pub struct WireguardInner { + // identifier (for logging) + id: u32, + // provides access to the MTU value of the tun device // (otherwise owned solely by the router and a dedicated read IO thread) mtu: T::MTU, @@ -96,7 +102,13 @@ impl PeerInner { impl fmt::Display for Peer { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "peer()") + write!(f, "peer(id = {})", self.id) + } +} + +impl fmt::Display for WireguardInner { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "wireguard({:x})", self.id) } } @@ -209,7 +221,9 @@ impl Wireguard { } pub fn new_peer(&self, pk: PublicKey) { + let mut rng = OsRng::new().unwrap(); let state = Arc::new(PeerInner { + id: rng.gen(), pk, last_handshake: Mutex::new(SystemTime::UNIX_EPOCH), handshake_queued: AtomicBool::new(false), @@ -277,11 +291,17 @@ impl Wireguard { handshake::TYPE_COOKIE_REPLY | handshake::TYPE_INITIATION | handshake::TYPE_RESPONSE => { + debug!("{} : reader, received handshake message", wg); + + let pending = wg.pending.fetch_add(1, Ordering::SeqCst); + // update under_load flag - if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD { + if pending > THRESHOLD_UNDER_LOAD { + debug!("{} : reader, set under load (pending = {})", wg, pending); last_under_load = Instant::now(); wg.under_load.store(true, Ordering::SeqCst); } else if last_under_load.elapsed() > DURATION_UNDER_LOAD { + debug!("{} : reader, clear under load", wg); wg.under_load.store(false, Ordering::SeqCst); } @@ -291,6 +311,8 @@ impl Wireguard { .unwrap(); } router::TYPE_TRANSPORT => { + debug!("{} : reader, received transport message", wg); + // transport message let _ = wg.router.recv(src, msg).map_err(|e| { debug!("Failed to handle incoming transport message: {}", e); @@ -313,6 +335,7 @@ impl Wireguard { let mut rng = OsRng::new().unwrap(); let (tx, rx): (Sender>, _) = bounded(SIZE_HANDSHAKE_QUEUE); let wg = Arc::new(WireguardInner { + id: rng.gen(), mtu: mtu.clone(), peers: RwLock::new(HashMap::new()), send: RwLock::new(None), @@ -331,12 +354,13 @@ impl Wireguard { let wg = wg.clone(); let rx = rx.clone(); thread::spawn(move || { + debug!("{} : handshake worker, started", wg); + // prepare OsRng instance for this thread let mut rng = OsRng::new().unwrap(); // process elements from the handshake queue for job in rx { - wg.pending.fetch_sub(1, Ordering::SeqCst); let state = wg.handshake.read(); if !state.active { continue; @@ -344,6 +368,8 @@ impl Wireguard { match job { HandshakeJob::Message(msg, src) => { + wg.pending.fetch_sub(1, Ordering::SeqCst); + // feed message to handshake device let src_validate = (&src).into_address(); // TODO avoid @@ -352,6 +378,7 @@ impl Wireguard { &mut rng, &msg[..], if wg.under_load.load(Ordering::Relaxed) { + debug!("{} : handshake worker, under load", wg); Some(&src_validate) } else { None @@ -364,9 +391,14 @@ impl Wireguard { resp_len = msg.len() as u64; let send: &Option = &*wg.send.read(); if let Some(writer) = send.as_ref() { + debug!( + "{} : handshake worker, send response ({} bytes)", + wg, resp_len + ); let _ = writer.write(&msg[..], &src).map_err(|e| { debug!( - "handshake worker, failed to send response, error = {}", + "{} : handshake worker, failed to send response, error = {}", + wg, e ) }); @@ -387,11 +419,13 @@ impl Wireguard { // update timers after sending handshake response if resp_len > 0 { + debug!("{} : handshake worker, handshake response sent", wg); peer.state.sent_handshake_response(); } // add resulting keypair to peer keypair.map(|kp| { + debug!("{} : handshake worker, new keypair", wg); // free any unused ids for id in peer.router.add_keypair(kp) { state.device.release(id); @@ -400,14 +434,15 @@ impl Wireguard { } } } - Err(e) => debug!("handshake worker, error = {:?}", e), + Err(e) => debug!("{} : handshake worker, error = {:?}", wg, e), } } HandshakeJob::New(pk) => { + debug!("{} : handshake worker, new handshake requested", wg); let _ = state.device.begin(&mut rng, &pk).map(|msg| { if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { let _ = peer.router.send(&msg[..]).map_err(|e| { - debug!("handshake worker, failed to send handshake initiation, error = {}", e) + debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e) }); peer.state.sent_handshake_initiation(); } -- cgit v1.2.3-59-g8ed1b