From e04a11a8cae5f4f8d29febdb38b93d236c700def Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Tue, 29 Oct 2019 16:53:59 +0100 Subject: Unified use of make_packet during tests --- src/wireguard/router/device.rs | 39 ++-------------- src/wireguard/router/mod.rs | 1 + src/wireguard/router/route.rs | 101 ++++++++++++++++++++++++++++++++++++++++ src/wireguard/router/tests.rs | 40 ++++++---------- src/wireguard/router/workers.rs | 52 +-------------------- src/wireguard/tests.rs | 30 ++++++++---- 6 files changed, 144 insertions(+), 119 deletions(-) create mode 100644 src/wireguard/router/route.rs (limited to 'src') diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs index 254b3de..0818637 100644 --- a/src/wireguard/router/device.rs +++ b/src/wireguard/router/device.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; + use std::net::{Ipv4Addr, Ipv6Addr}; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::mpsc::sync_channel; @@ -14,13 +15,15 @@ use zerocopy::LayoutVerified; use super::anti_replay::AntiReplay; use super::constants::*; -use super::ip::*; + use super::messages::{TransportHeader, TYPE_TRANSPORT}; use super::peer::{new_peer, Peer, PeerInner}; use super::types::{Callbacks, RouterError}; use super::workers::{worker_parallel, JobParallel, Operation}; use super::SIZE_MESSAGE_PREFIX; +use super::route::get_route; + use super::super::{bind, tun, Endpoint, KeyPair}; pub struct DeviceInner> { @@ -84,40 +87,6 @@ impl> Drop for Dev } } -#[inline(always)] -fn get_route>( - device: &Arc>, - packet: &[u8], -) -> Option>> { - match packet.get(0)? >> 4 { - VERSION_IP4 => { - // check length and cast to IPv4 header - let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = - LayoutVerified::new_from_prefix(packet)?; - - // lookup destination address - device - .ipv4 - .read() - .longest_match(Ipv4Addr::from(header.f_destination)) - .and_then(|(_, _, p)| Some(p.clone())) - } - VERSION_IP6 => { - // check length and cast to IPv6 header - let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) = - LayoutVerified::new_from_prefix(packet)?; - - // lookup destination address - device - .ipv6 - .read() - .longest_match(Ipv6Addr::from(header.f_destination)) - .and_then(|(_, _, p)| Some(p.clone())) - } - _ => None, - } -} - impl> Device { pub fn new(num_workers: usize, tun: T) -> Device { // allocate shared device state diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs index 4e748cb..7b317f2 100644 --- a/src/wireguard/router/mod.rs +++ b/src/wireguard/router/mod.rs @@ -4,6 +4,7 @@ mod device; mod ip; mod messages; mod peer; +mod route; mod types; mod workers; diff --git a/src/wireguard/router/route.rs b/src/wireguard/router/route.rs new file mode 100644 index 0000000..94c7e23 --- /dev/null +++ b/src/wireguard/router/route.rs @@ -0,0 +1,101 @@ +use super::super::{bind, tun, Endpoint}; +use super::device::DeviceInner; +use super::ip::*; +use super::peer::PeerInner; +use super::types::Callbacks; + +use log::trace; +use zerocopy::LayoutVerified; + +use std::mem; +use std::net::{Ipv4Addr, Ipv6Addr}; +use std::sync::Arc; + +#[inline(always)] +pub fn get_route>( + device: &Arc>, + packet: &[u8], +) -> Option>> { + match packet.get(0)? >> 4 { + VERSION_IP4 => { + trace!("cryptokey router, get route for IPv4 packet"); + + // check length and cast to IPv4 header + let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = + LayoutVerified::new_from_prefix(packet)?; + + // check IPv4 source address + device + .ipv4 + .read() + .longest_match(Ipv4Addr::from(header.f_destination)) + .and_then(|(_, _, p)| Some(p.clone())) + } + VERSION_IP6 => { + trace!("cryptokey router, get route for IPv6 packet"); + + // check length and cast to IPv6 header + let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) = + LayoutVerified::new_from_prefix(packet)?; + + // check IPv6 source address + device + .ipv6 + .read() + .longest_match(Ipv6Addr::from(header.f_destination)) + .and_then(|(_, _, p)| Some(p.clone())) + } + _ => None, + } +} + +#[inline(always)] +pub fn check_route>( + device: &Arc>, + peer: &Arc>, + packet: &[u8], +) -> Option { + match packet.get(0)? >> 4 { + VERSION_IP4 => { + trace!("cryptokey route, check route for IPv4 packet"); + + // check length and cast to IPv4 header + let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = + LayoutVerified::new_from_prefix(packet)?; + + // check IPv4 source address + device + .ipv4 + .read() + .longest_match(Ipv4Addr::from(header.f_source)) + .and_then(|(_, _, p)| { + if Arc::ptr_eq(p, peer) { + Some(header.f_total_len.get() as usize) + } else { + None + } + }) + } + VERSION_IP6 => { + trace!("cryptokey route, check route for IPv6 packet"); + + // check length and cast to IPv6 header + let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) = + LayoutVerified::new_from_prefix(packet)?; + + // check IPv6 source address + device + .ipv6 + .read() + .longest_match(Ipv6Addr::from(header.f_source)) + .and_then(|(_, _, p)| { + if Arc::ptr_eq(p, peer) { + Some(header.f_len.get() as usize + mem::size_of::()) + } else { + None + } + }) + } + _ => None, + } +} diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs index 6184993..a14640c 100644 --- a/src/wireguard/router/tests.rs +++ b/src/wireguard/router/tests.rs @@ -6,13 +6,13 @@ use std::thread; use std::time::Duration; use num_cpus; -use pnet::packet::ipv4::MutableIpv4Packet; -use pnet::packet::ipv6::MutableIpv6Packet; use super::super::bind::*; use super::super::dummy; use super::super::dummy_keypair; -use super::{Callbacks, Device, SIZE_MESSAGE_PREFIX}; +use super::super::tests::make_packet_dst; +use super::SIZE_MESSAGE_PREFIX; +use super::{Callbacks, Device}; extern crate test; @@ -111,23 +111,11 @@ mod tests { let _ = env_logger::builder().is_test(true).try_init(); } - fn make_packet(size: usize, ip: IpAddr) -> Vec { - // create "IP packet" - let mut msg = Vec::with_capacity(SIZE_MESSAGE_PREFIX + size + 16); - msg.resize(SIZE_MESSAGE_PREFIX + size, 0); - match ip { - IpAddr::V4(ip) => { - let mut packet = MutableIpv4Packet::new(&mut msg[SIZE_MESSAGE_PREFIX..]).unwrap(); - packet.set_destination(ip); - packet.set_version(4); - } - IpAddr::V6(ip) => { - let mut packet = MutableIpv6Packet::new(&mut msg[SIZE_MESSAGE_PREFIX..]).unwrap(); - packet.set_destination(ip); - packet.set_version(6); - } - } - msg + fn make_packet_dst_padded(size: usize, dst: IpAddr, id: u64) -> Vec { + let p = make_packet_dst(size, dst, id); + let mut o = vec![0; p.len() + SIZE_MESSAGE_PREFIX]; + o[SIZE_MESSAGE_PREFIX..SIZE_MESSAGE_PREFIX + p.len()].copy_from_slice(&p[..]); + o } #[bench] @@ -162,7 +150,7 @@ mod tests { // every iteration sends 10 GB b.iter(|| { opaque.store(0, Ordering::SeqCst); - let msg = make_packet(1024, ip1); + let msg = make_packet_dst_padded(1024, ip1, 0); while opaque.load(Ordering::Acquire) < 10 * 1024 * 1024 { router.send(msg.to_vec()).unwrap(); } @@ -218,7 +206,7 @@ mod tests { peer.add_allowed_ips(mask, *len); // create "IP packet" - let msg = make_packet(1024, ip.parse().unwrap()); + let msg = make_packet_dst_padded(1024, ip.parse().unwrap(), 0); // cryptkey route the IP packet let res = router.send(msg); @@ -228,7 +216,7 @@ mod tests { if *okay { // cryptkey routing succeeded - assert!(res.is_ok(), "crypt-key routing should succeed"); + assert!(res.is_ok(), "crypt-key routing should succeed: {:?}", res); assert_eq!( opaque.need_key().is_some(), !set_key, @@ -351,7 +339,7 @@ mod tests { if *stage { // stage a packet which can be used for confirmation (in place of a keepalive) let (_mask, _len, ip, _okay) = p2; - let msg = make_packet(1024, ip.parse().unwrap()); + let msg = make_packet_dst_padded(1024, ip.parse().unwrap(), 0); router2.send(msg).expect("failed to sent staged packet"); wait(); @@ -396,7 +384,7 @@ mod tests { // now that peer1 has an endpoint // route packets : peer1 -> peer2 - for _ in 0..10 { + for id in 0..10 { assert!( opaq1.is_empty(), "we should have asserted a value for every callback on peer1" @@ -408,7 +396,7 @@ mod tests { // pass IP packet to router let (_mask, _len, ip, _okay) = p1; - let msg = make_packet(1024, ip.parse().unwrap()); + let msg = make_packet_dst_padded(1024, ip.parse().unwrap(), id); router1.send(msg).unwrap(); wait(); diff --git a/src/wireguard/router/workers.rs b/src/wireguard/router/workers.rs index 8ebb246..70334c1 100644 --- a/src/wireguard/router/workers.rs +++ b/src/wireguard/router/workers.rs @@ -1,4 +1,3 @@ -use std::mem; use std::sync::mpsc::Receiver; use std::sync::Arc; @@ -8,17 +7,17 @@ use futures::*; use log::debug; use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; -use std::net::{Ipv4Addr, Ipv6Addr}; + use std::sync::atomic::Ordering; use zerocopy::{AsBytes, LayoutVerified}; use super::device::{DecryptionState, DeviceInner}; use super::messages::{TransportHeader, TYPE_TRANSPORT}; use super::peer::PeerInner; +use super::route::check_route; use super::types::Callbacks; use super::super::{bind, tun, Endpoint}; -use super::ip::*; pub const SIZE_TAG: usize = 16; @@ -46,53 +45,6 @@ pub type JobInbound> = ( pub type JobOutbound = oneshot::Receiver; -#[inline(always)] -fn check_route>( - device: &Arc>, - peer: &Arc>, - packet: &[u8], -) -> Option { - match packet[0] >> 4 { - VERSION_IP4 => { - // check length and cast to IPv4 header - let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = - LayoutVerified::new_from_prefix(packet)?; - - // check IPv4 source address - device - .ipv4 - .read() - .longest_match(Ipv4Addr::from(header.f_source)) - .and_then(|(_, _, p)| { - if Arc::ptr_eq(p, &peer) { - Some(header.f_total_len.get() as usize) - } else { - None - } - }) - } - VERSION_IP6 => { - // check length and cast to IPv6 header - let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) = - LayoutVerified::new_from_prefix(packet)?; - - // check IPv6 source address - device - .ipv6 - .read() - .longest_match(Ipv6Addr::from(header.f_source)) - .and_then(|(_, _, p)| { - if Arc::ptr_eq(p, &peer) { - Some(header.f_len.get() as usize + mem::size_of::()) - } else { - None - } - }) - } - _ => None, - } -} - pub fn worker_inbound>( device: Arc>, // related device peer: Arc>, // related peer diff --git a/src/wireguard/tests.rs b/src/wireguard/tests.rs index 28dedec..3ecb979 100644 --- a/src/wireguard/tests.rs +++ b/src/wireguard/tests.rs @@ -14,19 +14,32 @@ use x25519_dalek::{PublicKey, StaticSecret}; use pnet::packet::ipv4::MutableIpv4Packet; use pnet::packet::ipv6::MutableIpv6Packet; -fn make_packet(size: usize, src: IpAddr, dst: IpAddr, id: u64) -> Vec { +pub fn make_packet_src(size: usize, src: IpAddr, id: u64) -> Vec { + match src { + IpAddr::V4(_) => make_packet(size, src, "127.0.0.1".parse().unwrap(), id), + IpAddr::V6(_) => make_packet(size, src, "::1".parse().unwrap(), id), + } +} + +pub fn make_packet_dst(size: usize, dst: IpAddr, id: u64) -> Vec { + match dst { + IpAddr::V4(_) => make_packet(size, "127.0.0.1".parse().unwrap(), dst, id), + IpAddr::V6(_) => make_packet(size, "::1".parse().unwrap(), dst, id), + } +} + +pub fn make_packet(size: usize, src: IpAddr, dst: IpAddr, id: u64) -> Vec { // expand pseudo random payload let mut rng: _ = ChaCha8Rng::seed_from_u64(id); - let mut p: Vec = vec![]; - for _ in 0..size { - p.push(rng.next_u32() as u8); - } + let mut p: Vec = vec![0; size]; + rng.fill_bytes(&mut p[..]); // create "IP packet" let mut msg = Vec::with_capacity(size); msg.resize(size, 0); match dst { IpAddr::V4(dst) => { + let length = size - MutableIpv4Packet::minimum_packet_size(); let mut packet = MutableIpv4Packet::new(&mut msg[..]).unwrap(); packet.set_destination(dst); packet.set_total_length(size as u16); @@ -35,19 +48,20 @@ fn make_packet(size: usize, src: IpAddr, dst: IpAddr, id: u64) -> Vec { } else { panic!("src.version != dst.version") }); - packet.set_payload(&p[..]); + packet.set_payload(&p[..length]); packet.set_version(4); } IpAddr::V6(dst) => { + let length = size - MutableIpv6Packet::minimum_packet_size(); let mut packet = MutableIpv6Packet::new(&mut msg[..]).unwrap(); packet.set_destination(dst); - packet.set_payload_length((size - MutableIpv6Packet::minimum_packet_size()) as u16); + packet.set_payload_length(length as u16); packet.set_source(if let IpAddr::V6(src) = src { src } else { panic!("src.version != dst.version") }); - packet.set_payload(&p[..]); + packet.set_payload(&p[..length]); packet.set_version(6); } } -- cgit v1.2.3-59-g8ed1b