aboutsummaryrefslogtreecommitdiffstats
path: root/src/wireguard/wireguard.rs
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/wireguard/wireguard.rs50
1 files changed, 35 insertions, 15 deletions
diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs
index eb43512..41f6857 100644
--- a/src/wireguard/wireguard.rs
+++ b/src/wireguard/wireguard.rs
@@ -4,9 +4,13 @@ use super::router;
use super::timers::{Events, Timers};
use super::{Peer, PeerInner};
-use super::bind::Reader as BindReader;
-use super::bind::{Bind, Writer};
-use super::tun::{Reader, Tun, MTU};
+use super::tun;
+use super::tun::Reader as TunReader;
+
+use super::udp;
+use super::udp::Reader as UDPReader;
+use super::udp::Writer as UDPWriter;
+
use super::Endpoint;
use hjul::Runner;
@@ -34,13 +38,15 @@ const SIZE_HANDSHAKE_QUEUE: usize = 128;
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000);
-pub struct WireguardInner<T: Tun, B: Bind> {
+pub struct WireguardInner<T: tun::Tun, B: udp::UDP> {
// identifier (for logging)
id: u32,
start: Instant,
+ // current MTU
+ mtu: AtomicUsize,
+
// provides access to the MTU value of the tun device
- mtu: T::MTU,
send: RwLock<Option<B::Writer>>,
// identity and configuration map
@@ -56,7 +62,7 @@ pub struct WireguardInner<T: Tun, B: Bind> {
queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>,
}
-impl<T: Tun, B: Bind> PeerInner<T, B> {
+impl<T: tun::Tun, B: udp::UDP> PeerInner<T, B> {
/* Queue a handshake request for the parallel workers
* (if one does not already exist)
*
@@ -87,20 +93,20 @@ pub enum HandshakeJob<E> {
New(PublicKey),
}
-impl<T: Tun, B: Bind> fmt::Display for WireguardInner<T, B> {
+impl<T: tun::Tun, B: udp::UDP> fmt::Display for WireguardInner<T, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "wireguard({:x})", self.id)
}
}
-impl<T: Tun, B: Bind> Deref for Wireguard<T, B> {
+impl<T: tun::Tun, B: udp::UDP> Deref for Wireguard<T, B> {
type Target = Arc<WireguardInner<T, B>>;
fn deref(&self) -> &Self::Target {
&self.state
}
}
-pub struct Wireguard<T: Tun, B: Bind> {
+pub struct Wireguard<T: tun::Tun, B: udp::UDP> {
runner: Runner,
state: Arc<WireguardInner<T, B>>,
}
@@ -127,7 +133,7 @@ const fn padding(size: usize, mtu: usize) -> usize {
min(mtu, size + (pad - size % pad) % pad)
}
-impl<T: Tun, B: Bind> Wireguard<T, B> {
+impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
/// Brings the WireGuard device down.
/// Usually called when the associated interface is brought down.
///
@@ -269,7 +275,8 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
loop {
// create vector big enough for any message given current MTU
- let size = wg.mtu.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE;
+ let mtu = wg.mtu.load(Ordering::Relaxed);
+ let size = mtu + handshake::MAX_HANDSHAKE_MSG_SIZE;
let mut msg: Vec<u8> = Vec::with_capacity(size);
msg.resize(size, 0);
@@ -283,6 +290,11 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
};
msg.truncate(size);
+ // TODO: start device down
+ if mtu == 0 {
+ continue;
+ }
+
// message type de-multiplexer
if msg.len() < std::mem::size_of::<u32>() {
continue;
@@ -326,13 +338,17 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
});
}
+ pub fn set_mtu(&self, mtu: usize) {
+ self.mtu.store(mtu, Ordering::Relaxed);
+ }
+
pub fn set_writer(&self, writer: B::Writer) {
// TODO: Consider unifying these and avoid Clone requirement on writer
*self.state.send.write() = Some(writer.clone());
self.state.router.set_outbound_writer(writer);
}
- pub fn new(mut readers: Vec<T::Reader>, writer: T::Writer, mtu: T::MTU) -> Wireguard<T, B> {
+ pub fn new(mut readers: Vec<T::Reader>, writer: T::Writer) -> Wireguard<T, B> {
// create device state
let mut rng = OsRng::new().unwrap();
@@ -342,7 +358,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
let wg = Arc::new(WireguardInner {
start: Instant::now(),
id: rng.gen(),
- mtu: mtu.clone(),
+ mtu: AtomicUsize::new(0),
peers: RwLock::new(HashMap::new()),
send: RwLock::new(None),
router: router::Device::new(num_cpus::get(), writer), // router owns the writing half
@@ -475,10 +491,9 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
);
while let Some(reader) = readers.pop() {
let wg = wg.clone();
- let mtu = mtu.clone();
thread::spawn(move || loop {
// create vector big enough for any transport message (based on MTU)
- let mtu = mtu.mtu();
+ let mtu = wg.mtu.load(Ordering::Relaxed);
let size = mtu + router::SIZE_MESSAGE_PREFIX;
let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
msg.resize(size, 0);
@@ -493,6 +508,11 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
};
debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu);
+ // TODO: start device down
+ if mtu == 0 {
+ continue;
+ }
+
// truncate padding
let padded = padding(payload, mtu);
log::trace!(