aboutsummaryrefslogtreecommitdiffstats
path: root/src/wireguard
diff options
context:
space:
mode:
Diffstat (limited to 'src/wireguard')
-rw-r--r--src/wireguard/router/tests.rs8
-rw-r--r--src/wireguard/tests.rs8
-rw-r--r--src/wireguard/wireguard.rs12
3 files changed, 15 insertions, 13 deletions
diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs
index 2d6bb63..d96dc90 100644
--- a/src/wireguard/router/tests.rs
+++ b/src/wireguard/router/tests.rs
@@ -139,7 +139,7 @@ mod tests {
}
// create device
- let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false);
+ let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(false);
let router: Device<_, BencherCallbacks, dummy::TunWriter, dummy::VoidBind> =
Device::new(num_cpus::get(), tun_writer);
@@ -169,7 +169,7 @@ mod tests {
init();
// create device
- let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false);
+ let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(false);
let router: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer);
router.set_outbound_writer(dummy::VoidBind::new());
@@ -315,8 +315,8 @@ mod tests {
dummy::PairBind::pair();
// create matching device
- let (_fake, _, tun_writer1, _) = dummy::TunTest::create(1500, false);
- let (_fake, _, tun_writer2, _) = dummy::TunTest::create(1500, false);
+ let (_fake, _, tun_writer1, _) = dummy::TunTest::create(false);
+ let (_fake, _, tun_writer2, _) = dummy::TunTest::create(false);
let router1: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer1);
router1.set_outbound_writer(bind_writer1);
diff --git a/src/wireguard/tests.rs b/src/wireguard/tests.rs
index 8217d72..7a18005 100644
--- a/src/wireguard/tests.rs
+++ b/src/wireguard/tests.rs
@@ -84,17 +84,17 @@ fn test_pure_wireguard() {
// create WG instances for dummy TUN devices
- let (fake1, tun_reader1, tun_writer1, _) = dummy::TunTest::create(1500, true);
+ let (fake1, tun_reader1, tun_writer1, _) = dummy::TunTest::create(true);
let wg1: Wireguard<dummy::TunTest, dummy::PairBind> =
Wireguard::new(vec![tun_reader1], tun_writer1);
- wg1.set_mtu(1500);
+ wg1.up(1500);
- let (fake2, tun_reader2, tun_writer2, _) = dummy::TunTest::create(1500, true);
+ let (fake2, tun_reader2, tun_writer2, _) = dummy::TunTest::create(true);
let wg2: Wireguard<dummy::TunTest, dummy::PairBind> =
Wireguard::new(vec![tun_reader2], tun_writer2);
- wg2.set_mtu(1500);
+ wg2.up(1500);
// create pair bind to connect the interfaces "over the internet"
diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs
index 41f6857..61f6428 100644
--- a/src/wireguard/wireguard.rs
+++ b/src/wireguard/wireguard.rs
@@ -147,6 +147,9 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
// ensure exclusive access (to avoid race with "up" call)
let peers = self.peers.write();
+ // set mtu
+ self.state.mtu.store(0, Ordering::Relaxed);
+
// avoid tranmission from router
self.router.down();
@@ -158,10 +161,13 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
/// Brings the WireGuard device up.
/// Usually called when the associated interface is brought up.
- pub fn up(&self) {
+ pub fn up(&self, mtu: usize) {
// ensure exclusive access (to avoid race with "down" call)
let peers = self.peers.write();
+ // set mtu
+ self.state.mtu.store(mtu, Ordering::Relaxed);
+
// enable tranmission from router
self.router.up();
@@ -338,10 +344,6 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
});
}
- pub fn set_mtu(&self, mtu: usize) {
- self.mtu.store(mtu, Ordering::Relaxed);
- }
-
pub fn set_writer(&self, writer: B::Writer) {
// TODO: Consider unifying these and avoid Clone requirement on writer
*self.state.send.write() = Some(writer.clone());