diff options
Diffstat (limited to 'src/wireguard.rs')
-rw-r--r-- | src/wireguard.rs | 70 |
1 files changed, 51 insertions, 19 deletions
diff --git a/src/wireguard.rs b/src/wireguard.rs index 182cec2..ea600d0 100644 --- a/src/wireguard.rs +++ b/src/wireguard.rs @@ -6,6 +6,7 @@ use crate::types::{Bind, Endpoint, Tun}; use hjul::Runner; +use std::cmp; use std::ops::Deref; use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; @@ -86,8 +87,19 @@ pub struct Wireguard<T: Tun, B: Bind> { state: Arc<WireguardInner<T, B>>, } +#[inline(always)] +const fn padding(size: usize, mtu: usize) -> usize { + #[inline(always)] + const fn min(a: usize, b: usize) -> usize { + let m = (a > b) as usize; + a * m + (1 - m) * b + } + let pad = MESSAGE_PADDING_MULTIPLE; + min(mtu, size + (pad - size % pad) % pad) +} + impl<T: Tun, B: Bind> Wireguard<T, B> { - fn set_key(&self, sk: Option<StaticSecret>) { + pub fn set_key(&self, sk: Option<StaticSecret>) { let mut handshake = self.state.handshake.write(); match sk { None => { @@ -102,7 +114,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { } } - fn new_peer(&self, pk: PublicKey) -> Peer<T, B> { + pub fn new_peer(&self, pk: PublicKey) -> Peer<T, B> { let state = Arc::new(PeerInner { pk, queue: Mutex::new(self.state.queue.lock().clone()), @@ -111,11 +123,21 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { tx_bytes: AtomicU64::new(0), timers: RwLock::new(Timers::dummy(&self.runner)), }); + let router = Arc::new(self.state.router.new_peer(state.clone())); - Peer { router, state } + + let peer = Peer { router, state }; + + /* The need for dummy timers arises from the chicken-egg + * problem of the timer callbacks being able to set timers themselves. + * + * This is in fact the only place where the write lock is ever taken. + */ + *peer.timers.write() = Timers::new(&self.runner, peer.clone()); + peer } - fn new(tun: T, bind: B) -> Wireguard<T, B> { + pub fn new(tun: T, bind: B) -> Wireguard<T, B> { // create device state let mut rng = OsRng::new().unwrap(); let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE); @@ -215,10 +237,12 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000); loop { - // read UDP packet into vector - let size = tun.mtu() + 148; // maximum message size + // create vector big enough for any message given current MTU + let size = tun.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE; let mut msg: Vec<u8> = Vec::with_capacity(size); msg.resize(size, 0); + + // read UDP packet into vector let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error msg.truncate(size); @@ -226,7 +250,6 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { if msg.len() < std::mem::size_of::<u32>() { continue; } - match LittleEndian::read_u32(&msg[..]) { handshake::TYPE_COOKIE_REPLY | handshake::TYPE_INITIATION @@ -246,9 +269,6 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { } router::TYPE_TRANSPORT => { // transport message - - // pad the message - let _ = wg.router.recv(src, msg); } _ => (), @@ -261,20 +281,32 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { { let wg = wg.clone(); thread::spawn(move || loop { - // read a new IP packet + // create vector big enough for any transport message (based on MTU) let mtu = tun.mtu(); - let size = mtu + 148; + let size = mtu + router::SIZE_MESSAGE_PREFIX; let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX); - let size = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap(); - msg.truncate(size); + msg.resize(size, 0); - // pad message to multiple of 16 bytes - while msg.len() < mtu && msg.len() % 16 != 0 { - msg.push(0); - } + // read a new IP packet + let payload = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap(); + debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu); + + // truncate padding + let payload = padding(payload, mtu); + msg.truncate(router::SIZE_MESSAGE_PREFIX + payload); + debug_assert!(payload <= mtu); + debug_assert_eq!( + if payload < mtu { + (msg.len() - router::SIZE_MESSAGE_PREFIX) % MESSAGE_PADDING_MULTIPLE + } else { + 0 + }, + 0 + ); // crypt-key route - let _ = wg.router.send(msg); + let e = wg.router.send(msg); + debug!("TUN worker, router returned {:?}", e); }); } |