aboutsummaryrefslogtreecommitdiffstats
path: root/src/wireguard
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-12-06 21:45:21 +0100
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-12-06 21:45:21 +0100
commit74e576a9c21b0de451e0588428fbbb99b24eb074 (patch)
tree381ad26325ae4bee1f6a17449110ac941c2a9192 /src/wireguard
parentMoving away from peer threads (diff)
downloadwireguard-rs-74e576a9c21b0de451e0588428fbbb99b24eb074.tar.xz
wireguard-rs-74e576a9c21b0de451e0588428fbbb99b24eb074.zip
Fixed inbound job bug (add to sequential queue)
Diffstat (limited to 'src/wireguard')
-rw-r--r--src/wireguard/router/device.rs63
-rw-r--r--src/wireguard/router/inbound.rs22
-rw-r--r--src/wireguard/router/mod.rs1
-rw-r--r--src/wireguard/router/outbound.rs4
-rw-r--r--src/wireguard/router/peer.rs4
-rw-r--r--src/wireguard/router/pool.rs1
-rw-r--r--src/wireguard/router/queue.rs46
-rw-r--r--src/wireguard/router/route.rs13
-rw-r--r--src/wireguard/router/tests.rs253
-rw-r--r--src/wireguard/timers.rs89
10 files changed, 289 insertions, 207 deletions
diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs
index 88eeae1..e405446 100644
--- a/src/wireguard/router/device.rs
+++ b/src/wireguard/router/device.rs
@@ -1,8 +1,6 @@
use std::collections::HashMap;
use std::ops::Deref;
-use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
-use std::sync::mpsc::sync_channel;
-use std::sync::mpsc::{Receiver, SyncSender};
+use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::thread;
use std::time::Instant;
@@ -25,47 +23,7 @@ use super::SIZE_MESSAGE_PREFIX;
use super::route::RoutingTable;
use super::super::{tun, udp, Endpoint, KeyPair};
-
-pub struct ParallelQueue<T> {
- next: AtomicUsize, // next round-robin index
- queues: Vec<Mutex<SyncSender<T>>>, // work queues (1 per thread)
-}
-
-impl<T> ParallelQueue<T> {
- fn new(queues: usize) -> (Vec<Receiver<T>>, Self) {
- let mut rxs = vec![];
- let mut txs = vec![];
-
- for _ in 0..queues {
- let (tx, rx) = sync_channel(128);
- txs.push(Mutex::new(tx));
- rxs.push(rx);
- }
-
- (
- rxs,
- ParallelQueue {
- next: AtomicUsize::new(0),
- queues: txs,
- },
- )
- }
-
- pub fn send(&self, v: T) {
- let len = self.queues.len();
- let idx = self.next.fetch_add(1, Ordering::SeqCst);
- let que = self.queues[idx % len].lock();
- que.send(v).unwrap();
- }
-
- pub fn close(&self) {
- for i in 0..self.queues.len() {
- let (tx, _) = sync_channel(0);
- let queue = &self.queues[i];
- *queue.lock() = tx;
- }
- }
-}
+use super::queue::ParallelQueue;
pub struct DeviceInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> {
// inbound writer (TUN)
@@ -171,16 +129,25 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<
// start worker threads
let mut threads = Vec::with_capacity(num_workers);
+
for _ in 0..num_workers {
let rx = inrx.pop().unwrap();
- threads.push(thread::spawn(move || inbound::worker(rx)));
+ threads.push(thread::spawn(move || {
+ log::debug!("inbound router worker started");
+ inbound::worker(rx)
+ }));
}
for _ in 0..num_workers {
let rx = outrx.pop().unwrap();
- threads.push(thread::spawn(move || outbound::worker(rx)));
+ threads.push(thread::spawn(move || {
+ log::debug!("outbound router worker started");
+ outbound::worker(rx)
+ }));
}
+ debug_assert_eq!(threads.len(), num_workers * 2);
+
// return exported device handle
DeviceHandle {
state: Device {
@@ -274,7 +241,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<
);
log::trace!(
- "Router, handle transport message: (receiver = {}, counter = {})",
+ "handle transport message: (receiver = {}, counter = {})",
header.f_receiver,
header.f_counter
);
@@ -287,9 +254,9 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<
// schedule for decryption and TUN write
if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) {
+ log::trace!("schedule decryption of transport message");
self.state.inbound_queue.send(job);
}
-
Ok(())
}
diff --git a/src/wireguard/router/inbound.rs b/src/wireguard/router/inbound.rs
index d4ad307..3d47bb7 100644
--- a/src/wireguard/router/inbound.rs
+++ b/src/wireguard/router/inbound.rs
@@ -42,6 +42,8 @@ fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
peer: &Peer<E, C, T, B>,
body: &mut Inbound<E, C, T, B>,
) {
+ log::trace!("worker, parallel section, obtained job");
+
// cast to header followed by payload
let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
match LayoutVerified::new_from_prefix(&mut body.msg[..]) {
@@ -70,6 +72,7 @@ fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
Ok(_) => (),
Err(_) => {
// fault and return early
+ log::trace!("inbound worker: authentication failure");
body.failed = true;
return;
}
@@ -89,9 +92,15 @@ fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
// truncate to remove tag
match inner_len {
None => {
+ log::trace!("inbound worker: cryptokey routing failed");
body.failed = true;
}
Some(len) => {
+ log::trace!(
+ "inbound worker: good route, length = {} {}",
+ len,
+ if len == 0 { "(keepalive)" } else { "" }
+ );
body.msg.truncate(mem::size_of::<TransportHeader>() + len);
}
}
@@ -102,8 +111,11 @@ fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
peer: &Peer<E, C, T, B>,
body: &mut Inbound<E, C, T, B>,
) {
+ log::trace!("worker, sequential section, obtained job");
+
// decryption failed, return early
if body.failed {
+ log::trace!("job faulted, remove from queue and ignore");
return;
}
@@ -116,10 +128,6 @@ fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
return;
}
};
- debug_assert!(
- packet.len() >= CHACHA20_POLY1305.tag_len(),
- "this should be checked earlier in the pipeline (decryption should fail)"
- );
// check for replay
if !body.state.protector.lock().update(header.f_counter.get()) {
@@ -136,13 +144,9 @@ fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
// update endpoint
*peer.endpoint.lock() = body.endpoint.take();
- // calculate length of IP packet + padding
- let length = packet.len() - SIZE_TAG;
- log::debug!("inbound worker: plaintext length = {}", length);
-
// check if should be written to TUN
let mut sent = false;
- if length > 0 {
+ if packet.len() > 0 {
sent = match peer.device.inbound.write(&packet[..]) {
Err(e) => {
log::debug!("failed to write inbound packet to TUN: {:?}", e);
diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs
index 3243b88..bccb0a9 100644
--- a/src/wireguard/router/mod.rs
+++ b/src/wireguard/router/mod.rs
@@ -7,6 +7,7 @@ mod messages;
mod outbound;
mod peer;
mod pool;
+mod queue;
mod route;
mod types;
diff --git a/src/wireguard/router/outbound.rs b/src/wireguard/router/outbound.rs
index 30b7c2c..d08637b 100644
--- a/src/wireguard/router/outbound.rs
+++ b/src/wireguard/router/outbound.rs
@@ -35,6 +35,8 @@ fn parallel<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
_peer: &Peer<E, C, T, B>,
body: &mut Outbound,
) {
+ log::trace!("worker, parallel section, obtained job");
+
// make space for the tag
body.msg.extend([0u8; SIZE_TAG].iter());
@@ -77,6 +79,8 @@ fn sequential<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>>(
peer: &Peer<E, C, T, B>,
body: &mut Outbound,
) {
+ log::trace!("worker, sequential section, obtained job");
+
// send to peer
let xmit = peer.send(&body.msg[..]).is_ok();
diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs
index 192d4e2..40442a8 100644
--- a/src/wireguard/router/peer.rs
+++ b/src/wireguard/router/peer.rs
@@ -276,7 +276,9 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> Peer<E, C, T,
dec: Arc<DecryptionState<E, C, T, B>>,
msg: Vec<u8>,
) -> Option<Job<Self, Inbound<E, C, T, B>>> {
- Some(Job::new(self.clone(), Inbound::new(msg, dec, src)))
+ let job = Job::new(self.clone(), Inbound::new(msg, dec, src));
+ self.inbound.send(job.clone());
+ Some(job)
}
pub fn send_job(&self, msg: Vec<u8>, stage: bool) -> Option<Job<Self, Outbound>> {
diff --git a/src/wireguard/router/pool.rs b/src/wireguard/router/pool.rs
index 12956c1..9c72372 100644
--- a/src/wireguard/router/pool.rs
+++ b/src/wireguard/router/pool.rs
@@ -106,6 +106,7 @@ pub fn worker_template<
work_sequential: S, // perform sequential work on peer
queue: Q, // resolve a peer to an inorder queue
) {
+ log::trace!("router worker started");
loop {
// handle new job
let peer = {
diff --git a/src/wireguard/router/queue.rs b/src/wireguard/router/queue.rs
new file mode 100644
index 0000000..5d0165c
--- /dev/null
+++ b/src/wireguard/router/queue.rs
@@ -0,0 +1,46 @@
+use std::sync::atomic::{AtomicUsize, Ordering};
+use std::sync::mpsc::sync_channel;
+use std::sync::mpsc::{Receiver, SyncSender};
+
+use spin::Mutex;
+
+pub struct ParallelQueue<T> {
+ next: AtomicUsize, // next round-robin index
+ queues: Vec<Mutex<SyncSender<T>>>, // work queues (1 per thread)
+}
+
+impl<T> ParallelQueue<T> {
+ pub fn new(queues: usize) -> (Vec<Receiver<T>>, Self) {
+ let mut rxs = vec![];
+ let mut txs = vec![];
+
+ for _ in 0..queues {
+ let (tx, rx) = sync_channel(128);
+ txs.push(Mutex::new(tx));
+ rxs.push(rx);
+ }
+
+ (
+ rxs,
+ ParallelQueue {
+ next: AtomicUsize::new(0),
+ queues: txs,
+ },
+ )
+ }
+
+ pub fn send(&self, v: T) {
+ let len = self.queues.len();
+ let idx = self.next.fetch_add(1, Ordering::SeqCst);
+ let que = self.queues[idx % len].lock();
+ que.send(v).unwrap();
+ }
+
+ pub fn close(&self) {
+ for i in 0..self.queues.len() {
+ let (tx, _) = sync_channel(0);
+ let queue = &self.queues[i];
+ *queue.lock() = tx;
+ }
+ }
+}
diff --git a/src/wireguard/router/route.rs b/src/wireguard/router/route.rs
index 40dc36b..56ad32f 100644
--- a/src/wireguard/router/route.rs
+++ b/src/wireguard/router/route.rs
@@ -81,7 +81,7 @@ impl<T: Eq + Clone> RoutingTable<T> {
LayoutVerified::new_from_prefix(packet)?;
log::trace!(
- "Router, get route for IPv4 destination: {:?}",
+ "router, get route for IPv4 destination: {:?}",
Ipv4Addr::from(header.f_destination)
);
@@ -97,7 +97,7 @@ impl<T: Eq + Clone> RoutingTable<T> {
LayoutVerified::new_from_prefix(packet)?;
log::trace!(
- "Router, get route for IPv6 destination: {:?}",
+ "router, get route for IPv6 destination: {:?}",
Ipv6Addr::from(header.f_destination)
);
@@ -107,7 +107,10 @@ impl<T: Eq + Clone> RoutingTable<T> {
.longest_match(Ipv6Addr::from(header.f_destination))
.and_then(|(_, _, p)| Some(p.clone()))
}
- _ => None,
+ v => {
+ log::trace!("router, invalid IP version {}", v);
+ None
+ },
}
}
@@ -120,7 +123,7 @@ impl<T: Eq + Clone> RoutingTable<T> {
LayoutVerified::new_from_prefix(packet)?;
log::trace!(
- "Router, check route for IPv4 source: {:?}",
+ "router, check route for IPv4 source: {:?}",
Ipv4Addr::from(header.f_source)
);
@@ -142,7 +145,7 @@ impl<T: Eq + Clone> RoutingTable<T> {
LayoutVerified::new_from_prefix(packet)?;
log::trace!(
- "Router, check route for IPv6 source: {:?}",
+ "router, check route for IPv6 source: {:?}",
Ipv6Addr::from(header.f_source)
);
diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs
index d96dc90..1f500c0 100644
--- a/src/wireguard/router/tests.rs
+++ b/src/wireguard/router/tests.rs
@@ -9,7 +9,7 @@ use num_cpus;
use super::super::dummy;
use super::super::dummy_keypair;
-use super::super::tests::make_packet_dst;
+use super::super::tests::make_packet;
use super::super::udp::*;
use super::KeyPair;
use super::SIZE_MESSAGE_PREFIX;
@@ -105,15 +105,15 @@ mod tests {
// wait for scheduling
fn wait() {
- thread::sleep(Duration::from_millis(50));
+ thread::sleep(Duration::from_millis(15));
}
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
- fn make_packet_dst_padded(size: usize, dst: IpAddr, id: u64) -> Vec<u8> {
- let p = make_packet_dst(size, dst, id);
+ fn make_packet_padded(size: usize, src: IpAddr, dst: IpAddr, id: u64) -> Vec<u8> {
+ let p = make_packet(size, src, dst, id);
let mut o = vec![0; p.len() + SIZE_MESSAGE_PREFIX];
o[SIZE_MESSAGE_PREFIX..SIZE_MESSAGE_PREFIX + p.len()].copy_from_slice(&p[..]);
o
@@ -149,15 +149,21 @@ mod tests {
peer.add_keypair(dummy_keypair(true));
// add subnet to peer
- let (mask, len, ip) = ("192.168.1.0", 24, "192.168.1.20");
+ let (mask, len, dst) = ("192.168.1.0", 24, "192.168.1.20");
let mask: IpAddr = mask.parse().unwrap();
- let ip1: IpAddr = ip.parse().unwrap();
peer.add_allowed_ip(mask, len);
+ // create "IP packet"
+ let dst = dst.parse().unwrap();
+ let src = match dst {
+ IpAddr::V4(_) => "127.0.0.1".parse().unwrap(),
+ IpAddr::V6(_) => "::1".parse().unwrap()
+ };
+ let msg = make_packet_padded(1024, src, dst, 0);
+
// every iteration sends 10 GB
b.iter(|| {
opaque.store(0, Ordering::SeqCst);
- let msg = make_packet_dst_padded(1024, ip1, 0);
while opaque.load(Ordering::Acquire) < 10 * 1024 * 1024 {
router.send(msg.to_vec()).unwrap();
}
@@ -197,7 +203,8 @@ mod tests {
),
];
- for (num, (mask, len, ip, okay)) in tests.iter().enumerate() {
+ for (num, (mask, len, dst, okay)) in tests.iter().enumerate() {
+ println!("Check: {} {} {}/{}", dst, if *okay { "\\in" } else { "\\notin" }, mask, len);
for set_key in vec![true, false] {
debug!("index = {}, set_key = {}", num, set_key);
@@ -213,7 +220,12 @@ mod tests {
peer.add_allowed_ip(mask, *len);
// create "IP packet"
- let msg = make_packet_dst_padded(1024, ip.parse().unwrap(), 0);
+ let dst = dst.parse().unwrap();
+ let src = match dst {
+ IpAddr::V4(_) => "127.0.0.1".parse().unwrap(),
+ IpAddr::V6(_) => "::1".parse().unwrap()
+ };
+ let msg = make_packet_padded(1024, src, dst, 0);
// cryptkey route the IP packet
let res = router.send(msg);
@@ -269,17 +281,14 @@ mod tests {
let tests = [
(
- false, // confirm with keepalive
("192.168.1.0", 24, "192.168.1.20", true),
("172.133.133.133", 32, "172.133.133.133", true),
),
(
- true, // confirm with staged packet
("192.168.1.0", 24, "192.168.1.20", true),
("172.133.133.133", 32, "172.133.133.133", true),
),
(
- false, // confirm with keepalive
(
"2001:db8::ff00:42:8000",
113,
@@ -294,7 +303,6 @@ mod tests {
),
),
(
- false, // confirm with staged packet
(
"2001:db8::ff00:42:8000",
113,
@@ -310,117 +318,152 @@ mod tests {
),
];
- for (stage, p1, p2) in tests.iter() {
- let ((bind_reader1, bind_writer1), (bind_reader2, bind_writer2)) =
- dummy::PairBind::pair();
+ for stage in vec![true, false] {
+ for (p1, p2) in tests.iter() {
+ let ((bind_reader1, bind_writer1), (bind_reader2, bind_writer2)) =
+ dummy::PairBind::pair();
- // create matching device
- let (_fake, _, tun_writer1, _) = dummy::TunTest::create(false);
- let (_fake, _, tun_writer2, _) = dummy::TunTest::create(false);
+ // create matching device
+ let (_fake, _, tun_writer1, _) = dummy::TunTest::create(false);
+ let (_fake, _, tun_writer2, _) = dummy::TunTest::create(false);
- let router1: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer1);
- router1.set_outbound_writer(bind_writer1);
+ let router1: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer1);
+ router1.set_outbound_writer(bind_writer1);
- let router2: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer2);
- router2.set_outbound_writer(bind_writer2);
+ let router2: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer2);
+ router2.set_outbound_writer(bind_writer2);
- // prepare opaque values for tracing callbacks
+ // prepare opaque values for tracing callbacks
- let opaq1 = Opaque::new();
- let opaq2 = Opaque::new();
+ let opaque1 = Opaque::new();
+ let opaque2 = Opaque::new();
- // create peers with matching keypairs and assign subnets
+ // create peers with matching keypairs and assign subnets
- let (mask, len, _ip, _okay) = p1;
- let peer1 = router1.new_peer(opaq1.clone());
- let mask: IpAddr = mask.parse().unwrap();
- peer1.add_allowed_ip(mask, *len);
- peer1.add_keypair(dummy_keypair(false));
+ let peer1 = router1.new_peer(opaque1.clone());
+ let peer2 = router2.new_peer(opaque2.clone());
- let (mask, len, _ip, _okay) = p2;
- let peer2 = router2.new_peer(opaq2.clone());
- let mask: IpAddr = mask.parse().unwrap();
- peer2.add_allowed_ip(mask, *len);
- peer2.set_endpoint(dummy::UnitEndpoint::new());
+ {
+ let (mask, len, _ip, _okay) = p1;
+ let mask: IpAddr = mask.parse().unwrap();
+ peer1.add_allowed_ip(mask, *len);
+ peer1.add_keypair(dummy_keypair(false));
+ }
- if *stage {
- // stage a packet which can be used for confirmation (in place of a keepalive)
- let (_mask, _len, ip, _okay) = p2;
- let msg = make_packet_dst_padded(1024, ip.parse().unwrap(), 0);
- router2.send(msg).expect("failed to sent staged packet");
+ {
+ let (mask, len, _ip, _okay) = p2;
+ let mask: IpAddr = mask.parse().unwrap();
+ peer2.add_allowed_ip(mask, *len);
+ peer2.set_endpoint(dummy::UnitEndpoint::new());
+ }
- wait();
- assert!(opaq2.recv().is_none());
- assert!(
- opaq2.send().is_none(),
- "sending should fail as not key is set"
- );
- assert!(
- opaq2.need_key().is_some(),
- "a new key should be requested since a packet was attempted transmitted"
- );
- assert!(opaq2.is_empty(), "callbacks should only run once");
- }
+ if stage {
+ println!("confirm using staged packet");
+
+ // create IP packet
+ let (_mask, _len, ip1, _okay) = p1;
+ let (_mask, _len, ip2, _okay) = p2;
+ let msg = make_packet_padded(
+ 1024,
+ ip1.parse().unwrap(), // src
+ ip2.parse().unwrap(), // dst
+ 0,
+ );
- // this should cause a key-confirmation packet (keepalive or staged packet)
- // this also causes peer1 to learn the "endpoint" for peer2
- assert!(peer1.get_endpoint().is_none());
- peer2.add_keypair(dummy_keypair(true));
-
- wait();
- assert!(opaq2.send().is_some());
- assert!(opaq2.is_empty(), "events on peer2 should be 'send'");
- assert!(opaq1.is_empty(), "nothing should happened on peer1");
-
- // read confirming message received by the other end ("across the internet")
- let mut buf = vec![0u8; 2048];
- let (len, from) = bind_reader1.read(&mut buf).unwrap();
- buf.truncate(len);
- router1.recv(from, buf).unwrap();
-
- wait();
- assert!(opaq1.recv().is_some());
- assert!(opaq1.key_confirmed().is_some());
- assert!(
- opaq1.is_empty(),
- "events on peer1 should be 'recv' and 'key_confirmed'"
- );
- assert!(peer1.get_endpoint().is_some());
- assert!(opaq2.is_empty(), "nothing should happened on peer2");
-
- // now that peer1 has an endpoint
- // route packets : peer1 -> peer2
-
- for id in 0..10 {
- assert!(
- opaq1.is_empty(),
- "we should have asserted a value for every callback on peer1"
- );
- assert!(
- opaq2.is_empty(),
- "we should have asserted a value for every callback on peer2"
- );
+ // stage packet for sending
+ router2.send(msg).expect("failed to sent staged packet");
+ wait();
+
+ // validate events
+ assert!(opaque2.recv().is_none());
+ assert!(
+ opaque2.send().is_none(),
+ "sending should fail as not key is set"
+ );
+ assert!(
+ opaque2.need_key().is_some(),
+ "a new key should be requested since a packet was attempted transmitted"
+ );
+ assert!(opaque2.is_empty(), "callbacks should only run once");
+ }
- // pass IP packet to router
- let (_mask, _len, ip, _okay) = p1;
- let msg = make_packet_dst_padded(1024, ip.parse().unwrap(), id);
- router1.send(msg).unwrap();
+ // this should cause a key-confirmation packet (keepalive or staged packet)
+ // this also causes peer1 to learn the "endpoint" for peer2
+ assert!(peer1.get_endpoint().is_none());
+ peer2.add_keypair(dummy_keypair(true));
wait();
- assert!(opaq1.send().is_some());
- assert!(opaq1.recv().is_none());
- assert!(opaq1.need_key().is_none());
+ assert!(opaque2.send().is_some());
+ assert!(opaque2.is_empty(), "events on peer2 should be 'send'");
+ assert!(opaque1.is_empty(), "nothing should happened on peer1");
- // receive ("across the internet") on the other end
+ // read confirming message received by the other end ("across the internet")
let mut buf = vec![0u8; 2048];
- let (len, from) = bind_reader2.read(&mut buf).unwrap();
+ let (len, from) = bind_reader1.read(&mut buf).unwrap();
buf.truncate(len);
- router2.recv(from, buf).unwrap();
+ router1.recv(from, buf).unwrap();
wait();
- assert!(opaq2.send().is_none());
- assert!(opaq2.recv().is_some());
- assert!(opaq2.need_key().is_none());
+ assert!(opaque1.recv().is_some());
+ assert!(opaque1.key_confirmed().is_some());
+ assert!(
+ opaque1.is_empty(),
+ "events on peer1 should be 'recv' and 'key_confirmed'"
+ );
+ assert!(peer1.get_endpoint().is_some());
+ assert!(opaque2.is_empty(), "nothing should happened on peer2");
+
+ // now that peer1 has an endpoint
+ // route packets : peer1 -> peer2
+
+ for id in 1..11 {
+ println!("round: {}", id);
+ assert!(
+ opaque1.is_empty(),
+ "we should have asserted a value for every callback on peer1"
+ );
+ assert!(
+ opaque2.is_empty(),
+ "we should have asserted a value for every callback on peer2"
+ );
+
+ // pass IP packet to router
+ let (_mask, _len, ip1, _okay) = p1;
+ let (_mask, _len, ip2, _okay) = p2;
+ let msg =
+ make_packet_padded(
+ 1024,
+ ip2.parse().unwrap(), // src
+ ip1.parse().unwrap(), // dst
+ id
+ );
+ router1.send(msg).unwrap();
+
+ wait();
+ assert!(opaque1.send().is_some(), "encryption should succeed");
+ assert!(
+ opaque1.recv().is_none(),
+ "receiving callback should not be called"
+ );
+ assert!(opaque1.need_key().is_none());
+
+ // receive ("across the internet") on the other end
+ let mut buf = vec![0u8; 2048];
+ let (len, from) = bind_reader2.read(&mut buf).unwrap();
+ buf.truncate(len);
+ router2.recv(from, buf).unwrap();
+
+ wait();
+ assert!(
+ opaque2.send().is_none(),
+ "sending callback should not be called"
+ );
+ assert!(
+ opaque2.recv().is_some(),
+ "decryption and routing should succeed"
+ );
+ assert!(opaque2.need_key().is_none());
+ }
}
}
}
diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs
index e1aabad..f292afd 100644
--- a/src/wireguard/timers.rs
+++ b/src/wireguard/timers.rs
@@ -3,14 +3,14 @@ use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime};
-use log::debug;
use hjul::{Runner, Timer};
+use log::debug;
use super::constants::*;
use super::router::{message_data_len, Callbacks};
-use super::{Peer, PeerInner};
-use super::{udp, tun};
use super::types::KeyPair;
+use super::{tun, udp};
+use super::{Peer, PeerInner};
pub struct Timers {
// only updated during configuration
@@ -36,7 +36,6 @@ impl Timers {
}
impl<T: tun::Tun, B: udp::UDP> PeerInner<T, B> {
-
pub fn get_keepalive_interval(&self) -> u64 {
self.timers().keepalive_interval
}
@@ -57,17 +56,19 @@ impl<T: tun::Tun, B: udp::UDP> PeerInner<T, B> {
timers.send_persistent_keepalive.stop();
timers.zero_key_material.stop();
timers.new_handshake.stop();
-
+
// reset all timer state
timers.handshake_attempts.store(0, Ordering::SeqCst);
- timers.sent_lastminute_handshake.store(false, Ordering::SeqCst);
+ timers
+ .sent_lastminute_handshake
+ .store(false, Ordering::SeqCst);
timers.need_another_keepalive.store(false, Ordering::SeqCst);
}
pub fn start_timers(&self) {
// take a write lock preventing simultaneous "stop_timers" call
let mut timers = self.timers_mut();
-
+
// set flag to reenable timer events
if timers.enabled {
return;
@@ -76,18 +77,20 @@ impl<T: tun::Tun, B: udp::UDP> PeerInner<T, B> {
// start send_persistent_keepalive
if timers.keepalive_interval > 0 {
- timers.send_persistent_keepalive.start(
- Duration::from_secs(timers.keepalive_interval)
- );
+ timers
+ .send_persistent_keepalive
+ .start(Duration::from_secs(timers.keepalive_interval));
}
}
/* should be called after an authenticated data packet is sent */
pub fn timers_data_sent(&self) {
- let timers = self.timers();
- if timers.enabled {
- timers.new_handshake.start(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT);
- }
+ let timers = self.timers();
+ if timers.enabled {
+ timers
+ .new_handshake
+ .start(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT);
+ }
}
/* should be called after an authenticated data packet is received */
@@ -139,7 +142,9 @@ impl<T: tun::Tun, B: udp::UDP> PeerInner<T, B> {
if timers.enabled {
timers.retransmit_handshake.stop();
timers.handshake_attempts.store(0, Ordering::SeqCst);
- timers.sent_lastminute_handshake.store(false, Ordering::SeqCst);
+ timers
+ .sent_lastminute_handshake
+ .store(false, Ordering::SeqCst);
*self.walltime_last_handshake.lock() = Some(SystemTime::now());
}
}
@@ -161,9 +166,9 @@ impl<T: tun::Tun, B: udp::UDP> PeerInner<T, B> {
let timers = self.timers();
if timers.enabled && timers.keepalive_interval > 0 {
// push persistent_keepalive into the future
- timers.send_persistent_keepalive.reset(Duration::from_secs(
- timers.keepalive_interval
- ));
+ timers
+ .send_persistent_keepalive
+ .reset(Duration::from_secs(timers.keepalive_interval));
}
}
@@ -179,7 +184,6 @@ impl<T: tun::Tun, B: udp::UDP> PeerInner<T, B> {
if timers.enabled {
timers.retransmit_handshake.reset(REKEY_TIMEOUT);
}
-
}
/* Called after a handshake worker sends a handshake initiation to the peer
@@ -195,7 +199,7 @@ impl<T: tun::Tun, B: udp::UDP> PeerInner<T, B> {
*self.last_handshake_sent.lock() = Instant::now();
self.timers_any_authenticated_packet_traversal();
self.timers_any_authenticated_packet_sent();
- }
+ }
pub fn set_persistent_keepalive_interval(&self, secs: u64) {
let mut timers = self.timers_mut();
@@ -205,10 +209,12 @@ impl<T: tun::Tun, B: udp::UDP> PeerInner<T, B> {
// stop the keepalive timer with the old interval
timers.send_persistent_keepalive.stop();
-
+
// restart the persistent_keepalive timer with the new interval
if secs > 0 && timers.enabled {
- timers.send_persistent_keepalive.start(Duration::from_secs(secs));
+ timers
+ .send_persistent_keepalive
+ .start(Duration::from_secs(secs));
}
}
@@ -220,7 +226,6 @@ impl<T: tun::Tun, B: udp::UDP> PeerInner<T, B> {
}
}
-
impl Timers {
pub fn new<T, B>(runner: &Runner, running: bool, peer: Peer<T, B>) -> Timers
where
@@ -242,9 +247,12 @@ impl Timers {
if !timers.enabled {
return;
}
-
+
// check if handshake attempts remaining
- let attempts = peer.timers().handshake_attempts.fetch_add(1, Ordering::SeqCst);
+ let attempts = peer
+ .timers()
+ .handshake_attempts
+ .fetch_add(1, Ordering::SeqCst);
if attempts > MAX_TIMER_HANDSHAKES {
debug!(
"Handshake for peer {} did not complete after {} attempts, giving up",
@@ -257,8 +265,8 @@ impl Timers {
} else {
debug!(
"Handshake for {} did not complete after {} seconds, retrying (try {})",
- peer,
- REKEY_TIMEOUT.as_secs(),
+ peer,
+ REKEY_TIMEOUT.as_secs(),
attempts
);
timers.retransmit_handshake.reset(REKEY_TIMEOUT);
@@ -287,7 +295,7 @@ impl Timers {
runner.timer(move || {
debug!(
"Retrying handshake with {} because we stopped hearing back after {} seconds",
- peer,
+ peer,
(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT).as_secs()
);
peer.router.clear_src();
@@ -307,9 +315,9 @@ impl Timers {
if timers.enabled && timers.keepalive_interval > 0 {
peer.router.send_keepalive();
timers.send_keepalive.stop();
- timers.send_persistent_keepalive.start(Duration::from_secs(
- timers.keepalive_interval
- ));
+ timers
+ .send_persistent_keepalive
+ .start(Duration::from_secs(timers.keepalive_interval));
}
})
},
@@ -318,7 +326,7 @@ impl Timers {
pub fn dummy(runner: &Runner) -> Timers {
Timers {
- enabled: false,
+ enabled: false,
keepalive_interval: 0,
need_another_keepalive: AtomicBool::new(false),
sent_lastminute_handshake: AtomicBool::new(false),
@@ -344,9 +352,8 @@ impl<T: tun::Tun, B: udp::UDP> Callbacks for Events<T, B> {
*/
#[inline(always)]
fn send(peer: &Self::Opaque, size: usize, sent: bool, keypair: &Arc<KeyPair>, counter: u64) {
-
// update timers and stats
-
+
peer.timers_any_authenticated_packet_traversal();
peer.timers_any_authenticated_packet_sent();
peer.tx_bytes.fetch_add(size as u64, Ordering::Relaxed);
@@ -375,7 +382,6 @@ impl<T: tun::Tun, B: udp::UDP> Callbacks for Events<T, B> {
*/
#[inline(always)]
fn recv(peer: &Self::Opaque, size: usize, sent: bool, keypair: &Arc<KeyPair>) {
-
// update timers and stats
peer.timers_any_authenticated_packet_traversal();
@@ -386,13 +392,18 @@ impl<T: tun::Tun, B: udp::UDP> Callbacks for Events<T, B> {
}
// keep_key_fresh
-
+
#[inline(always)]
fn keep_key_fresh(keypair: &Arc<KeyPair>) -> bool {
- Instant::now() - keypair.birth > REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT
+ Instant::now() - keypair.birth > REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT
}
- if keep_key_fresh(keypair) && !peer.timers().sent_lastminute_handshake.swap(true, Ordering::Acquire) {
+ if keep_key_fresh(keypair)
+ && !peer
+ .timers()
+ .sent_lastminute_handshake
+ .swap(true, Ordering::Acquire)
+ {
peer.packet_send_queued_handshake_initiation(false);
}
}
@@ -405,7 +416,7 @@ impl<T: tun::Tun, B: udp::UDP> Callbacks for Events<T, B> {
*/
#[inline(always)]
fn need_key(peer: &Self::Opaque) {
- peer.packet_send_queued_handshake_initiation(false);
+ peer.packet_send_queued_handshake_initiation(false);
}
#[inline(always)]