aboutsummaryrefslogtreecommitdiffstats
path: root/src/wireguard.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/wireguard.rs')
-rw-r--r--src/wireguard.rs70
1 files changed, 51 insertions, 19 deletions
diff --git a/src/wireguard.rs b/src/wireguard.rs
index 182cec2..ea600d0 100644
--- a/src/wireguard.rs
+++ b/src/wireguard.rs
@@ -6,6 +6,7 @@ use crate::types::{Bind, Endpoint, Tun};
use hjul::Runner;
+use std::cmp;
use std::ops::Deref;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
@@ -86,8 +87,19 @@ pub struct Wireguard<T: Tun, B: Bind> {
state: Arc<WireguardInner<T, B>>,
}
+#[inline(always)]
+const fn padding(size: usize, mtu: usize) -> usize {
+ #[inline(always)]
+ const fn min(a: usize, b: usize) -> usize {
+ let m = (a > b) as usize;
+ a * m + (1 - m) * b
+ }
+ let pad = MESSAGE_PADDING_MULTIPLE;
+ min(mtu, size + (pad - size % pad) % pad)
+}
+
impl<T: Tun, B: Bind> Wireguard<T, B> {
- fn set_key(&self, sk: Option<StaticSecret>) {
+ pub fn set_key(&self, sk: Option<StaticSecret>) {
let mut handshake = self.state.handshake.write();
match sk {
None => {
@@ -102,7 +114,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
}
}
- fn new_peer(&self, pk: PublicKey) -> Peer<T, B> {
+ pub fn new_peer(&self, pk: PublicKey) -> Peer<T, B> {
let state = Arc::new(PeerInner {
pk,
queue: Mutex::new(self.state.queue.lock().clone()),
@@ -111,11 +123,21 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
tx_bytes: AtomicU64::new(0),
timers: RwLock::new(Timers::dummy(&self.runner)),
});
+
let router = Arc::new(self.state.router.new_peer(state.clone()));
- Peer { router, state }
+
+ let peer = Peer { router, state };
+
+ /* The need for dummy timers arises from the chicken-egg
+ * problem of the timer callbacks being able to set timers themselves.
+ *
+ * This is in fact the only place where the write lock is ever taken.
+ */
+ *peer.timers.write() = Timers::new(&self.runner, peer.clone());
+ peer
}
- fn new(tun: T, bind: B) -> Wireguard<T, B> {
+ pub fn new(tun: T, bind: B) -> Wireguard<T, B> {
// create device state
let mut rng = OsRng::new().unwrap();
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
@@ -215,10 +237,12 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
loop {
- // read UDP packet into vector
- let size = tun.mtu() + 148; // maximum message size
+ // create vector big enough for any message given current MTU
+ let size = tun.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE;
let mut msg: Vec<u8> = Vec::with_capacity(size);
msg.resize(size, 0);
+
+ // read UDP packet into vector
let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error
msg.truncate(size);
@@ -226,7 +250,6 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
if msg.len() < std::mem::size_of::<u32>() {
continue;
}
-
match LittleEndian::read_u32(&msg[..]) {
handshake::TYPE_COOKIE_REPLY
| handshake::TYPE_INITIATION
@@ -246,9 +269,6 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
}
router::TYPE_TRANSPORT => {
// transport message
-
- // pad the message
-
let _ = wg.router.recv(src, msg);
}
_ => (),
@@ -261,20 +281,32 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
{
let wg = wg.clone();
thread::spawn(move || loop {
- // read a new IP packet
+ // create vector big enough for any transport message (based on MTU)
let mtu = tun.mtu();
- let size = mtu + 148;
+ let size = mtu + router::SIZE_MESSAGE_PREFIX;
let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
- let size = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
- msg.truncate(size);
+ msg.resize(size, 0);
- // pad message to multiple of 16 bytes
- while msg.len() < mtu && msg.len() % 16 != 0 {
- msg.push(0);
- }
+ // read a new IP packet
+ let payload = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
+ debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu);
+
+ // truncate padding
+ let payload = padding(payload, mtu);
+ msg.truncate(router::SIZE_MESSAGE_PREFIX + payload);
+ debug_assert!(payload <= mtu);
+ debug_assert_eq!(
+ if payload < mtu {
+ (msg.len() - router::SIZE_MESSAGE_PREFIX) % MESSAGE_PADDING_MULTIPLE
+ } else {
+ 0
+ },
+ 0
+ );
// crypt-key route
- let _ = wg.router.send(msg);
+ let e = wg.router.send(msg);
+ debug!("TUN worker, router returned {:?}", e);
});
}