aboutsummaryrefslogtreecommitdiffstats
path: root/src/wireguard
diff options
context:
space:
mode:
Diffstat (limited to 'src/wireguard')
-rw-r--r--src/wireguard/router/device.rs12
-rw-r--r--src/wireguard/router/mod.rs1
-rw-r--r--src/wireguard/router/receive.rs7
-rw-r--r--src/wireguard/router/send.rs10
-rw-r--r--src/wireguard/router/tests.rs78
-rw-r--r--src/wireguard/tests.rs33
6 files changed, 77 insertions, 64 deletions
diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs
index 8bfa261..7c90f22 100644
--- a/src/wireguard/router/device.rs
+++ b/src/wireguard/router/device.rs
@@ -121,14 +121,14 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<
// start worker threads
let mut threads = Vec::with_capacity(num_workers);
while let Some(rx) = consumers.pop() {
- println!("spawn");
- threads.push(thread::spawn(move || {
- println!("spawned");
- worker(rx);
- }));
+ threads.push(thread::spawn(move || worker(rx)));
}
debug_assert!(num_workers > 0, "zero worker threads");
- debug_assert_eq!(threads.len(), num_workers);
+ debug_assert_eq!(
+ threads.len(),
+ num_workers,
+ "workers does not match consumers"
+ );
// return exported device handle
DeviceHandle {
diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs
index 19e037f..699c621 100644
--- a/src/wireguard/router/mod.rs
+++ b/src/wireguard/router/mod.rs
@@ -24,7 +24,6 @@ use super::types::*;
pub const SIZE_TAG: usize = 16;
pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::<TransportHeader>();
-pub const SIZE_KEEPALIVE: usize = mem::size_of::<TransportHeader>() + SIZE_TAG;
pub const CAPACITY_MESSAGE_POSTFIX: usize = SIZE_TAG;
pub const fn message_data_len(payload: usize) -> usize {
diff --git a/src/wireguard/router/receive.rs b/src/wireguard/router/receive.rs
index 0e5cb0f..45ef423 100644
--- a/src/wireguard/router/receive.rs
+++ b/src/wireguard/router/receive.rs
@@ -3,7 +3,7 @@ use super::ip::inner_length;
use super::messages::TransportHeader;
use super::queue::{ParallelJob, Queue, SequentialJob};
use super::types::Callbacks;
-use super::{REJECT_AFTER_MESSAGES, SIZE_KEEPALIVE};
+use super::{REJECT_AFTER_MESSAGES, SIZE_TAG};
use super::super::{tun, udp, Endpoint};
@@ -93,7 +93,6 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ParallelJob
debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len());
nonce[4..].copy_from_slice(header.f_counter.as_bytes());
let nonce = Nonce::assume_unique_for_key(nonce);
-
// do the weird ring AEAD dance
let key = LessSafeKey::new(
UnboundKey::new(&CHACHA20_POLY1305, &job.state.keypair.recv.key[..]).unwrap(),
@@ -111,7 +110,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ParallelJob
}
// check crypto-key router
- packet.len() == SIZE_KEEPALIVE || peer.device.table.check_route(&peer, &packet)
+ packet.len() == SIZE_TAG || peer.device.table.check_route(&peer, &packet)
})();
// remove message in case of failure:
@@ -174,7 +173,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> SequentialJob
// check if should be written to TUN
// (keep-alive and malformed packets will have no inner length)
if let Some(inner) = inner_length(packet) {
- if inner >= packet.len() {
+ if inner + SIZE_TAG <= packet.len() {
let _ = peer.device.inbound.write(&packet[..inner]).map_err(|e| {
log::debug!("failed to write inbound packet to TUN: {:?}", e);
});
diff --git a/src/wireguard/router/send.rs b/src/wireguard/router/send.rs
index db6b079..0472e11 100644
--- a/src/wireguard/router/send.rs
+++ b/src/wireguard/router/send.rs
@@ -91,19 +91,17 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> ParallelJob
nonce[4..].copy_from_slice(header.f_counter.as_bytes());
let nonce = Nonce::assume_unique_for_key(nonce);
- // do the weird ring AEAD dance
+ // encrypt contents of transport message in-place
+ let tag_offset = packet.len() - SIZE_TAG;
let key = LessSafeKey::new(
UnboundKey::new(&CHACHA20_POLY1305, &job.keypair.send.key[..]).unwrap(),
);
-
- // encrypt contents of transport message in-place
- let end = packet.len() - SIZE_TAG;
let tag = key
- .seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end])
+ .seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..tag_offset])
.unwrap();
// append tag
- packet[end..].copy_from_slice(tag.as_ref());
+ packet[tag_offset..].copy_from_slice(tag.as_ref());
}
// mark ready
diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs
index 3afa422..842dd52 100644
--- a/src/wireguard/router/tests.rs
+++ b/src/wireguard/router/tests.rs
@@ -2,7 +2,7 @@ use super::KeyPair;
use super::SIZE_MESSAGE_PREFIX;
use super::{Callbacks, Device};
-use super::SIZE_KEEPALIVE;
+use super::message_data_len;
use super::super::dummy;
use super::super::dummy_keypair;
@@ -21,12 +21,13 @@ use std::time::Duration;
use env_logger;
use num_cpus;
+use rand::Rng;
use test::Bencher;
extern crate test;
const SIZE_MSG: usize = 1024;
-
+const SIZE_KEEPALIVE: usize = message_data_len(0);
const TIMEOUT: Duration = Duration::from_millis(1000);
struct EventTracker<E> {
@@ -133,10 +134,9 @@ fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
-fn make_packet_padded(size: usize, src: IpAddr, dst: IpAddr, id: u64) -> Vec<u8> {
- let p = make_packet(size, src, dst, id);
- let mut o = vec![0; p.len() + SIZE_MESSAGE_PREFIX];
- o[SIZE_MESSAGE_PREFIX..SIZE_MESSAGE_PREFIX + p.len()].copy_from_slice(&p[..]);
+fn pad(msg: &[u8]) -> Vec<u8> {
+ let mut o = vec![0; msg.len() + SIZE_MESSAGE_PREFIX];
+ o[SIZE_MESSAGE_PREFIX..SIZE_MESSAGE_PREFIX + msg.len()].copy_from_slice(msg);
o
}
@@ -180,7 +180,7 @@ fn bench_outbound(b: &mut Bencher) {
IpAddr::V4(_) => "127.0.0.1".parse().unwrap(),
IpAddr::V6(_) => "::1".parse().unwrap(),
};
- let msg = make_packet_padded(1024, src, dst, 0);
+ let msg = pad(&make_packet(1024, src, dst, 0));
// every iteration sends 10 GB
b.iter(|| {
@@ -266,10 +266,10 @@ fn test_outbound() {
IpAddr::V4(_) => "127.0.0.1".parse().unwrap(),
IpAddr::V6(_) => "::1".parse().unwrap(),
};
- let msg = make_packet_padded(SIZE_MSG, src, dst, 0);
+ let msg = make_packet(SIZE_MSG, src, dst, 0);
// crypto-key route the IP packet
- let res = router.send(msg);
+ let res = router.send(pad(&msg));
assert_eq!(
res.is_ok(),
okay,
@@ -303,7 +303,7 @@ fn test_outbound() {
if send_payload {
assert_eq!(
opaque.send.wait(TIMEOUT),
- Some((SIZE_KEEPALIVE + SIZE_MSG, false)),
+ Some((SIZE_KEEPALIVE + msg.len(), false)),
"message buffer should be encrypted"
)
}
@@ -319,6 +319,8 @@ fn test_outbound() {
fn test_bidirectional() {
init();
+ const MAX_SIZE_BODY: usize = 1 << 15;
+
let tests = [
(
("192.168.1.0", 24, "192.168.1.20", true),
@@ -358,6 +360,8 @@ fn test_bidirectional() {
),
];
+ let mut rng = rand::thread_rng();
+
for (p1, p2) in tests.iter() {
for confirm_with_staged_packet in vec![true, false] {
println!(
@@ -368,11 +372,7 @@ fn test_bidirectional() {
let ((bind_reader1, bind_writer1), (bind_reader2, bind_writer2)) =
dummy::PairBind::pair();
- let confirm_packet_size = if confirm_with_staged_packet {
- SIZE_KEEPALIVE + SIZE_MSG
- } else {
- SIZE_KEEPALIVE
- };
+ let mut confirm_packet_size = SIZE_KEEPALIVE;
// create matching device
let (_fake, _, tun_writer1, _) = dummy::TunTest::create(false);
@@ -412,15 +412,21 @@ fn test_bidirectional() {
// create IP packet
let (_mask, _len, ip1, _okay) = p1;
let (_mask, _len, ip2, _okay) = p2;
- let msg = make_packet_padded(
+
+ let msg = make_packet(
SIZE_MSG,
ip1.parse().unwrap(), // src
ip2.parse().unwrap(), // dst
0,
);
+ // calculate size of encapsulated IP packet
+ confirm_packet_size = msg.len() + SIZE_KEEPALIVE;
+
// stage packet for sending
- router2.send(msg).expect("failed to sent staged packet");
+ router2
+ .send(pad(&msg))
+ .expect("failed to sent staged packet");
// a new key should have been requested from the handshake machine
assert_eq!(
@@ -429,6 +435,7 @@ fn test_bidirectional() {
"a new key should be requested since a packet was attempted transmitted"
);
+ // no other events should fire
no_events!(opaque1);
no_events!(opaque2);
}
@@ -454,12 +461,7 @@ fn test_bidirectional() {
buf.truncate(len);
assert_eq!(
- len,
- if confirm_with_staged_packet {
- SIZE_MSG + SIZE_KEEPALIVE
- } else {
- SIZE_KEEPALIVE
- },
+ len, confirm_packet_size,
"unexpected size of confirmation message"
);
@@ -491,31 +493,39 @@ fn test_bidirectional() {
// no other events should fire
no_events!(opaque1);
no_events!(opaque2);
+
// now that peer1 has an endpoint
// route packets in the other direction: peer1 -> peer2
- for id in 1..11 {
- println!("packet: {}", id);
-
- let message_size = 1024;
+ let mut sizes = vec![0, 1, 1500, MAX_SIZE_BODY];
+ for _ in 0..100 {
+ let body_size: usize = rng.gen();
+ let body_size = body_size % MAX_SIZE_BODY;
+ sizes.push(body_size);
+ }
+ for (id, body_size) in sizes.iter().enumerate() {
+ println!("packet: id = {}, body_size = {}", id, body_size);
// pass IP packet to router
let (_mask, _len, ip1, _okay) = p1;
let (_mask, _len, ip2, _okay) = p2;
- let msg = make_packet_padded(
- message_size,
+ let msg = make_packet(
+ *body_size,
ip2.parse().unwrap(), // src
ip1.parse().unwrap(), // dst
- id,
+ id as u64,
);
+ // calculate encrypted size
+ let encrypted_size = msg.len() + SIZE_KEEPALIVE;
+
router1
- .send(msg)
+ .send(pad(&msg))
.expect("we expect routing to be successful");
// encryption succeeds and the correct size is logged
assert_eq!(
opaque1.send.wait(TIMEOUT),
- Some((message_size + SIZE_KEEPALIVE, true)),
+ Some((encrypted_size, true)),
"expected send event for peer1 -> peer2 payload"
);
@@ -524,7 +534,7 @@ fn test_bidirectional() {
no_events!(opaque2);
// receive ("across the internet") on the other end
- let mut buf = vec![0u8; 2048];
+ let mut buf = vec![0u8; MAX_SIZE_BODY + 512];
let (len, from) = bind_reader2.read(&mut buf).unwrap();
buf.truncate(len);
router2.recv(from, buf).unwrap();
@@ -532,7 +542,7 @@ fn test_bidirectional() {
// check that decryption succeeds
assert_eq!(
opaque2.recv.wait(TIMEOUT),
- Some((message_size + SIZE_KEEPALIVE, true)),
+ Some((msg.len() + SIZE_KEEPALIVE, true)),
"decryption and routing should succeed"
);
diff --git a/src/wireguard/tests.rs b/src/wireguard/tests.rs
index 2ed2202..4cc441e 100644
--- a/src/wireguard/tests.rs
+++ b/src/wireguard/tests.rs
@@ -1,3 +1,7 @@
+use super::dummy;
+use super::wireguard::WireGuard;
+
+use std::convert::TryInto;
use std::net::IpAddr;
use hex;
@@ -8,43 +12,43 @@ use x25519_dalek::{PublicKey, StaticSecret};
use pnet::packet::ipv4::MutableIpv4Packet;
use pnet::packet::ipv6::MutableIpv6Packet;
-use super::dummy;
-use super::wireguard::WireGuard;
-
pub fn make_packet(size: usize, src: IpAddr, dst: IpAddr, id: u64) -> Vec<u8> {
// expand pseudo random payload
let mut rng: _ = ChaCha8Rng::seed_from_u64(id);
let mut p: Vec<u8> = vec![0; size];
- rng.fill_bytes(&mut p[..]);
+ rng.fill_bytes(&mut p);
// create "IP packet"
let mut msg = Vec::with_capacity(size);
- msg.resize(size, 0);
match dst {
IpAddr::V4(dst) => {
- let length = size - MutableIpv4Packet::minimum_packet_size();
+ let length = size + MutableIpv4Packet::minimum_packet_size();
+ msg.resize(length, 0);
+
let mut packet = MutableIpv4Packet::new(&mut msg[..]).unwrap();
packet.set_destination(dst);
- packet.set_total_length(size as u16);
+ packet.set_total_length(length.try_into().expect("length too great for IPv4 packet"));
packet.set_source(if let IpAddr::V4(src) = src {
src
} else {
panic!("src.version != dst.version")
});
- packet.set_payload(&p[..length]);
+ packet.set_payload(&p);
packet.set_version(4);
}
IpAddr::V6(dst) => {
- let length = size - MutableIpv6Packet::minimum_packet_size();
+ let length = size + MutableIpv6Packet::minimum_packet_size();
+ msg.resize(length, 0);
+
let mut packet = MutableIpv6Packet::new(&mut msg[..]).unwrap();
packet.set_destination(dst);
- packet.set_payload_length(length as u16);
+ packet.set_payload_length(size.try_into().expect("length too great for IPv6 packet"));
packet.set_source(if let IpAddr::V6(src) = src {
src
} else {
panic!("src.version != dst.version")
});
- packet.set_payload(&p[..length]);
+ packet.set_payload(&p);
packet.set_version(6);
}
}
@@ -83,7 +87,7 @@ fn test_pure_wireguard() {
wg1.add_udp_reader(bind_reader1);
wg2.add_udp_reader(bind_reader2);
- // generate (public, private) key pairs
+ // configure (public, private) key pairs
let sk1 = StaticSecret::from([
0x3f, 0x69, 0x86, 0xd1, 0xc0, 0xec, 0x25, 0xa0, 0x9c, 0x8e, 0x56, 0xb5, 0x1d, 0xb7, 0x3c,
@@ -107,7 +111,7 @@ fn test_pure_wireguard() {
wg1.set_key(Some(sk1));
wg2.set_key(Some(sk2));
- // configure cryptkey router
+ // configure crypto-key router
let peer2 = wg1.lookup_peer(&pk2).unwrap();
let peer1 = wg2.lookup_peer(&pk1).unwrap();
@@ -143,10 +147,13 @@ fn test_pure_wireguard() {
let mut backup = packets.clone();
while let Some(p) = packets.pop() {
+ println!("send");
fake1.write(p);
}
while let Some(p) = backup.pop() {
+ println!("read");
+
assert_eq!(
hex::encode(fake2.read()),
hex::encode(p),