summaryrefslogtreecommitdiffstats
path: root/src/wireguard
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-10-29 16:53:59 +0100
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-10-29 16:53:59 +0100
commite04a11a8cae5f4f8d29febdb38b93d236c700def (patch)
treeec756778c5462069ccdad7a30840829391132f19 /src/wireguard
parentFirst full test of pure WireGuard (diff)
downloadwireguard-rs-e04a11a8cae5f4f8d29febdb38b93d236c700def.tar.xz
wireguard-rs-e04a11a8cae5f4f8d29febdb38b93d236c700def.zip
Unified use of make_packet during tests
Diffstat (limited to 'src/wireguard')
-rw-r--r--src/wireguard/router/device.rs39
-rw-r--r--src/wireguard/router/mod.rs1
-rw-r--r--src/wireguard/router/route.rs101
-rw-r--r--src/wireguard/router/tests.rs40
-rw-r--r--src/wireguard/router/workers.rs52
-rw-r--r--src/wireguard/tests.rs30
6 files changed, 144 insertions, 119 deletions
diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs
index 254b3de..0818637 100644
--- a/src/wireguard/router/device.rs
+++ b/src/wireguard/router/device.rs
@@ -1,4 +1,5 @@
use std::collections::HashMap;
+
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::mpsc::sync_channel;
@@ -14,13 +15,15 @@ use zerocopy::LayoutVerified;
use super::anti_replay::AntiReplay;
use super::constants::*;
-use super::ip::*;
+
use super::messages::{TransportHeader, TYPE_TRANSPORT};
use super::peer::{new_peer, Peer, PeerInner};
use super::types::{Callbacks, RouterError};
use super::workers::{worker_parallel, JobParallel, Operation};
use super::SIZE_MESSAGE_PREFIX;
+use super::route::get_route;
+
use super::super::{bind, tun, Endpoint, KeyPair};
pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
@@ -84,40 +87,6 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Dev
}
}
-#[inline(always)]
-fn get_route<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
- device: &Arc<DeviceInner<E, C, T, B>>,
- packet: &[u8],
-) -> Option<Arc<PeerInner<E, C, T, B>>> {
- match packet.get(0)? >> 4 {
- VERSION_IP4 => {
- // check length and cast to IPv4 header
- let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) =
- LayoutVerified::new_from_prefix(packet)?;
-
- // lookup destination address
- device
- .ipv4
- .read()
- .longest_match(Ipv4Addr::from(header.f_destination))
- .and_then(|(_, _, p)| Some(p.clone()))
- }
- VERSION_IP6 => {
- // check length and cast to IPv6 header
- let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) =
- LayoutVerified::new_from_prefix(packet)?;
-
- // lookup destination address
- device
- .ipv6
- .read()
- .longest_match(Ipv6Addr::from(header.f_destination))
- .and_then(|(_, _, p)| Some(p.clone()))
- }
- _ => None,
- }
-}
-
impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C, T, B> {
pub fn new(num_workers: usize, tun: T) -> Device<E, C, T, B> {
// allocate shared device state
diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs
index 4e748cb..7b317f2 100644
--- a/src/wireguard/router/mod.rs
+++ b/src/wireguard/router/mod.rs
@@ -4,6 +4,7 @@ mod device;
mod ip;
mod messages;
mod peer;
+mod route;
mod types;
mod workers;
diff --git a/src/wireguard/router/route.rs b/src/wireguard/router/route.rs
new file mode 100644
index 0000000..94c7e23
--- /dev/null
+++ b/src/wireguard/router/route.rs
@@ -0,0 +1,101 @@
+use super::super::{bind, tun, Endpoint};
+use super::device::DeviceInner;
+use super::ip::*;
+use super::peer::PeerInner;
+use super::types::Callbacks;
+
+use log::trace;
+use zerocopy::LayoutVerified;
+
+use std::mem;
+use std::net::{Ipv4Addr, Ipv6Addr};
+use std::sync::Arc;
+
+#[inline(always)]
+pub fn get_route<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+ device: &Arc<DeviceInner<E, C, T, B>>,
+ packet: &[u8],
+) -> Option<Arc<PeerInner<E, C, T, B>>> {
+ match packet.get(0)? >> 4 {
+ VERSION_IP4 => {
+ trace!("cryptokey router, get route for IPv4 packet");
+
+ // check length and cast to IPv4 header
+ let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) =
+ LayoutVerified::new_from_prefix(packet)?;
+
+ // check IPv4 source address
+ device
+ .ipv4
+ .read()
+ .longest_match(Ipv4Addr::from(header.f_destination))
+ .and_then(|(_, _, p)| Some(p.clone()))
+ }
+ VERSION_IP6 => {
+ trace!("cryptokey router, get route for IPv6 packet");
+
+ // check length and cast to IPv6 header
+ let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) =
+ LayoutVerified::new_from_prefix(packet)?;
+
+ // check IPv6 source address
+ device
+ .ipv6
+ .read()
+ .longest_match(Ipv6Addr::from(header.f_destination))
+ .and_then(|(_, _, p)| Some(p.clone()))
+ }
+ _ => None,
+ }
+}
+
+#[inline(always)]
+pub fn check_route<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+ device: &Arc<DeviceInner<E, C, T, B>>,
+ peer: &Arc<PeerInner<E, C, T, B>>,
+ packet: &[u8],
+) -> Option<usize> {
+ match packet.get(0)? >> 4 {
+ VERSION_IP4 => {
+ trace!("cryptokey route, check route for IPv4 packet");
+
+ // check length and cast to IPv4 header
+ let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) =
+ LayoutVerified::new_from_prefix(packet)?;
+
+ // check IPv4 source address
+ device
+ .ipv4
+ .read()
+ .longest_match(Ipv4Addr::from(header.f_source))
+ .and_then(|(_, _, p)| {
+ if Arc::ptr_eq(p, peer) {
+ Some(header.f_total_len.get() as usize)
+ } else {
+ None
+ }
+ })
+ }
+ VERSION_IP6 => {
+ trace!("cryptokey route, check route for IPv6 packet");
+
+ // check length and cast to IPv6 header
+ let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) =
+ LayoutVerified::new_from_prefix(packet)?;
+
+ // check IPv6 source address
+ device
+ .ipv6
+ .read()
+ .longest_match(Ipv6Addr::from(header.f_source))
+ .and_then(|(_, _, p)| {
+ if Arc::ptr_eq(p, peer) {
+ Some(header.f_len.get() as usize + mem::size_of::<IPv6Header>())
+ } else {
+ None
+ }
+ })
+ }
+ _ => None,
+ }
+}
diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs
index 6184993..a14640c 100644
--- a/src/wireguard/router/tests.rs
+++ b/src/wireguard/router/tests.rs
@@ -6,13 +6,13 @@ use std::thread;
use std::time::Duration;
use num_cpus;
-use pnet::packet::ipv4::MutableIpv4Packet;
-use pnet::packet::ipv6::MutableIpv6Packet;
use super::super::bind::*;
use super::super::dummy;
use super::super::dummy_keypair;
-use super::{Callbacks, Device, SIZE_MESSAGE_PREFIX};
+use super::super::tests::make_packet_dst;
+use super::SIZE_MESSAGE_PREFIX;
+use super::{Callbacks, Device};
extern crate test;
@@ -111,23 +111,11 @@ mod tests {
let _ = env_logger::builder().is_test(true).try_init();
}
- fn make_packet(size: usize, ip: IpAddr) -> Vec<u8> {
- // create "IP packet"
- let mut msg = Vec::with_capacity(SIZE_MESSAGE_PREFIX + size + 16);
- msg.resize(SIZE_MESSAGE_PREFIX + size, 0);
- match ip {
- IpAddr::V4(ip) => {
- let mut packet = MutableIpv4Packet::new(&mut msg[SIZE_MESSAGE_PREFIX..]).unwrap();
- packet.set_destination(ip);
- packet.set_version(4);
- }
- IpAddr::V6(ip) => {
- let mut packet = MutableIpv6Packet::new(&mut msg[SIZE_MESSAGE_PREFIX..]).unwrap();
- packet.set_destination(ip);
- packet.set_version(6);
- }
- }
- msg
+ fn make_packet_dst_padded(size: usize, dst: IpAddr, id: u64) -> Vec<u8> {
+ let p = make_packet_dst(size, dst, id);
+ let mut o = vec![0; p.len() + SIZE_MESSAGE_PREFIX];
+ o[SIZE_MESSAGE_PREFIX..SIZE_MESSAGE_PREFIX + p.len()].copy_from_slice(&p[..]);
+ o
}
#[bench]
@@ -162,7 +150,7 @@ mod tests {
// every iteration sends 10 GB
b.iter(|| {
opaque.store(0, Ordering::SeqCst);
- let msg = make_packet(1024, ip1);
+ let msg = make_packet_dst_padded(1024, ip1, 0);
while opaque.load(Ordering::Acquire) < 10 * 1024 * 1024 {
router.send(msg.to_vec()).unwrap();
}
@@ -218,7 +206,7 @@ mod tests {
peer.add_allowed_ips(mask, *len);
// create "IP packet"
- let msg = make_packet(1024, ip.parse().unwrap());
+ let msg = make_packet_dst_padded(1024, ip.parse().unwrap(), 0);
// cryptkey route the IP packet
let res = router.send(msg);
@@ -228,7 +216,7 @@ mod tests {
if *okay {
// cryptkey routing succeeded
- assert!(res.is_ok(), "crypt-key routing should succeed");
+ assert!(res.is_ok(), "crypt-key routing should succeed: {:?}", res);
assert_eq!(
opaque.need_key().is_some(),
!set_key,
@@ -351,7 +339,7 @@ mod tests {
if *stage {
// stage a packet which can be used for confirmation (in place of a keepalive)
let (_mask, _len, ip, _okay) = p2;
- let msg = make_packet(1024, ip.parse().unwrap());
+ let msg = make_packet_dst_padded(1024, ip.parse().unwrap(), 0);
router2.send(msg).expect("failed to sent staged packet");
wait();
@@ -396,7 +384,7 @@ mod tests {
// now that peer1 has an endpoint
// route packets : peer1 -> peer2
- for _ in 0..10 {
+ for id in 0..10 {
assert!(
opaq1.is_empty(),
"we should have asserted a value for every callback on peer1"
@@ -408,7 +396,7 @@ mod tests {
// pass IP packet to router
let (_mask, _len, ip, _okay) = p1;
- let msg = make_packet(1024, ip.parse().unwrap());
+ let msg = make_packet_dst_padded(1024, ip.parse().unwrap(), id);
router1.send(msg).unwrap();
wait();
diff --git a/src/wireguard/router/workers.rs b/src/wireguard/router/workers.rs
index 8ebb246..70334c1 100644
--- a/src/wireguard/router/workers.rs
+++ b/src/wireguard/router/workers.rs
@@ -1,4 +1,3 @@
-use std::mem;
use std::sync::mpsc::Receiver;
use std::sync::Arc;
@@ -8,17 +7,17 @@ use futures::*;
use log::debug;
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
-use std::net::{Ipv4Addr, Ipv6Addr};
+
use std::sync::atomic::Ordering;
use zerocopy::{AsBytes, LayoutVerified};
use super::device::{DecryptionState, DeviceInner};
use super::messages::{TransportHeader, TYPE_TRANSPORT};
use super::peer::PeerInner;
+use super::route::check_route;
use super::types::Callbacks;
use super::super::{bind, tun, Endpoint};
-use super::ip::*;
pub const SIZE_TAG: usize = 16;
@@ -46,53 +45,6 @@ pub type JobInbound<E, C, T, B: bind::Writer<E>> = (
pub type JobOutbound = oneshot::Receiver<JobBuffer>;
-#[inline(always)]
-fn check_route<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
- device: &Arc<DeviceInner<E, C, T, B>>,
- peer: &Arc<PeerInner<E, C, T, B>>,
- packet: &[u8],
-) -> Option<usize> {
- match packet[0] >> 4 {
- VERSION_IP4 => {
- // check length and cast to IPv4 header
- let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) =
- LayoutVerified::new_from_prefix(packet)?;
-
- // check IPv4 source address
- device
- .ipv4
- .read()
- .longest_match(Ipv4Addr::from(header.f_source))
- .and_then(|(_, _, p)| {
- if Arc::ptr_eq(p, &peer) {
- Some(header.f_total_len.get() as usize)
- } else {
- None
- }
- })
- }
- VERSION_IP6 => {
- // check length and cast to IPv6 header
- let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) =
- LayoutVerified::new_from_prefix(packet)?;
-
- // check IPv6 source address
- device
- .ipv6
- .read()
- .longest_match(Ipv6Addr::from(header.f_source))
- .and_then(|(_, _, p)| {
- if Arc::ptr_eq(p, &peer) {
- Some(header.f_len.get() as usize + mem::size_of::<IPv6Header>())
- } else {
- None
- }
- })
- }
- _ => None,
- }
-}
-
pub fn worker_inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
device: Arc<DeviceInner<E, C, T, B>>, // related device
peer: Arc<PeerInner<E, C, T, B>>, // related peer
diff --git a/src/wireguard/tests.rs b/src/wireguard/tests.rs
index 28dedec..3ecb979 100644
--- a/src/wireguard/tests.rs
+++ b/src/wireguard/tests.rs
@@ -14,19 +14,32 @@ use x25519_dalek::{PublicKey, StaticSecret};
use pnet::packet::ipv4::MutableIpv4Packet;
use pnet::packet::ipv6::MutableIpv6Packet;
-fn make_packet(size: usize, src: IpAddr, dst: IpAddr, id: u64) -> Vec<u8> {
+pub fn make_packet_src(size: usize, src: IpAddr, id: u64) -> Vec<u8> {
+ match src {
+ IpAddr::V4(_) => make_packet(size, src, "127.0.0.1".parse().unwrap(), id),
+ IpAddr::V6(_) => make_packet(size, src, "::1".parse().unwrap(), id),
+ }
+}
+
+pub fn make_packet_dst(size: usize, dst: IpAddr, id: u64) -> Vec<u8> {
+ match dst {
+ IpAddr::V4(_) => make_packet(size, "127.0.0.1".parse().unwrap(), dst, id),
+ IpAddr::V6(_) => make_packet(size, "::1".parse().unwrap(), dst, id),
+ }
+}
+
+pub fn make_packet(size: usize, src: IpAddr, dst: IpAddr, id: u64) -> Vec<u8> {
// expand pseudo random payload
let mut rng: _ = ChaCha8Rng::seed_from_u64(id);
- let mut p: Vec<u8> = vec![];
- for _ in 0..size {
- p.push(rng.next_u32() as u8);
- }
+ let mut p: Vec<u8> = vec![0; size];
+ rng.fill_bytes(&mut p[..]);
// create "IP packet"
let mut msg = Vec::with_capacity(size);
msg.resize(size, 0);
match dst {
IpAddr::V4(dst) => {
+ let length = size - MutableIpv4Packet::minimum_packet_size();
let mut packet = MutableIpv4Packet::new(&mut msg[..]).unwrap();
packet.set_destination(dst);
packet.set_total_length(size as u16);
@@ -35,19 +48,20 @@ fn make_packet(size: usize, src: IpAddr, dst: IpAddr, id: u64) -> Vec<u8> {
} else {
panic!("src.version != dst.version")
});
- packet.set_payload(&p[..]);
+ packet.set_payload(&p[..length]);
packet.set_version(4);
}
IpAddr::V6(dst) => {
+ let length = size - MutableIpv6Packet::minimum_packet_size();
let mut packet = MutableIpv6Packet::new(&mut msg[..]).unwrap();
packet.set_destination(dst);
- packet.set_payload_length((size - MutableIpv6Packet::minimum_packet_size()) as u16);
+ packet.set_payload_length(length as u16);
packet.set_source(if let IpAddr::V6(src) = src {
src
} else {
panic!("src.version != dst.version")
});
- packet.set_payload(&p[..]);
+ packet.set_payload(&p[..length]);
packet.set_version(6);
}
}