diff options
Diffstat (limited to '')
-rw-r--r-- | src/wireguard/wireguard.rs | 224 |
1 files changed, 149 insertions, 75 deletions
diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index 61f6428..2b0e779 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -22,6 +22,10 @@ use std::sync::Arc; use std::thread; use std::time::{Duration, Instant}; +// TODO: avoid +use std::sync::Condvar; +use std::sync::Mutex as StdMutex; + use std::collections::hash_map::Entry; use std::collections::HashMap; @@ -38,15 +42,51 @@ const SIZE_HANDSHAKE_QUEUE: usize = 128; const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4; const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000); +#[derive(Clone)] +pub struct WaitHandle(Arc<(StdMutex<usize>, Condvar)>); + +impl WaitHandle { + pub fn wait(&self) { + let (lock, cvar) = &*self.0; + let mut nread = lock.lock().unwrap(); + while *nread > 0 { + nread = cvar.wait(nread).unwrap(); + } + } + + fn new() -> Self { + Self(Arc::new((StdMutex::new(0), Condvar::new()))) + } + + fn decrease(&self) { + let (lock, cvar) = &*self.0; + let mut nread = lock.lock().unwrap(); + assert!(*nread > 0); + *nread -= 1; + cvar.notify_all(); + } + + fn increase(&self) { + let (lock, _) = &*self.0; + let mut nread = lock.lock().unwrap(); + *nread += 1; + } +} + pub struct WireguardInner<T: tun::Tun, B: udp::UDP> { // identifier (for logging) id: u32, - start: Instant, + + // device enabled + enabled: RwLock<bool>, + + // enables waiting for all readers to finish + tun_readers: WaitHandle, // current MTU mtu: AtomicUsize, - // provides access to the MTU value of the tun device + // outbound writer send: RwLock<Option<B::Writer>>, // identity and configuration map @@ -145,7 +185,12 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> { /// on both ends of the device. pub fn down(&self) { // ensure exclusive access (to avoid race with "up" call) - let peers = self.peers.write(); + let mut enabled = self.enabled.write(); + + // check if already down + if *enabled == false { + return; + } // set mtu self.state.mtu.store(0, Ordering::Relaxed); @@ -154,27 +199,36 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> { self.router.down(); // set all peers down (stops timers) - for peer in peers.values() { + for peer in self.peers.write().values() { peer.down(); } + + *enabled = false; } /// Brings the WireGuard device up. /// Usually called when the associated interface is brought up. pub fn up(&self, mtu: usize) { - // ensure exclusive access (to avoid race with "down" call) - let peers = self.peers.write(); + // ensure exclusive access (to avoid race with "up" call) + let mut enabled = self.enabled.write(); // set mtu self.state.mtu.store(mtu, Ordering::Relaxed); + // check if already up + if *enabled { + return; + } + // enable tranmission from router self.router.up(); // set all peers up (restarts timers) - for peer in peers.values() { + for peer in self.peers.write().values() { peer.up(); } + + *enabled = true; } pub fn clear_peers(&self) { @@ -232,7 +286,7 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> { pk, wg: self.state.clone(), walltime_last_handshake: Mutex::new(None), - last_handshake_sent: Mutex::new(self.state.start - TIME_HORIZON), + last_handshake_sent: Mutex::new(Instant::now() - TIME_HORIZON), handshake_queued: AtomicBool::new(false), queue: Mutex::new(self.state.queue.lock().clone()), rx_bytes: AtomicU64::new(0), @@ -246,24 +300,31 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> { // form WireGuard peer let peer = Peer { router, state }; - /* The need for dummy timers arises from the chicken-egg - * problem of the timer callbacks being able to set timers themselves. - * - * This is in fact the only place where the write lock is ever taken. - * TODO: Consider the ease of using atomic pointers instead. - */ - *peer.timers.write() = Timers::new(&self.runner, peer.clone()); - // finally, add the peer to the wireguard device let mut peers = self.state.peers.write(); match peers.entry(*pk.as_bytes()) { Entry::Occupied(_) => false, Entry::Vacant(vacancy) => { + // check that the public key does not cause conflict with the private key of the device let ok_pk = self.state.handshake.write().add(pk).is_ok(); - if ok_pk { - vacancy.insert(peer); + if !ok_pk { + return false; } - ok_pk + + // prevent up/down while inserting + let enabled = self.enabled.read(); + + /* The need for dummy timers arises from the chicken-egg + * problem of the timer callbacks being able to set timers themselves. + * + * This is in fact the only place where the write lock is ever taken. + * TODO: Consider the ease of using atomic pointers instead. + */ + *peer.timers.write() = Timers::new(&self.runner, *enabled, peer.clone()); + + // insert into peer map (takes ownership and ensures that the peer is not dropped) + vacancy.insert(peer); + true } } } @@ -273,7 +334,7 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> { /// /// Any previous reader thread is stopped by closing the previous reader, /// which unblocks the thread and causes an error on reader.read - pub fn add_reader(&self, reader: B::Reader) { + pub fn add_udp_reader(&self, reader: B::Reader) { let wg = self.state.clone(); thread::spawn(move || { let mut last_under_load = @@ -350,7 +411,72 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> { self.state.router.set_outbound_writer(writer); } - pub fn new(mut readers: Vec<T::Reader>, writer: T::Writer) -> Wireguard<T, B> { + pub fn add_tun_reader(&self, reader: T::Reader) { + fn worker<T: tun::Tun, B: udp::UDP>(wg: &Arc<WireguardInner<T, B>>, reader: T::Reader) { + loop { + // create vector big enough for any transport message (based on MTU) + let mtu = wg.mtu.load(Ordering::Relaxed); + let size = mtu + router::SIZE_MESSAGE_PREFIX + 1; + let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX); + msg.resize(size, 0); + + // read a new IP packet + let payload = match reader.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX) { + Ok(payload) => payload, + Err(e) => { + debug!("TUN worker, failed to read from tun device: {}", e); + break; + } + }; + debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu); + + // TODO: start device down + if mtu == 0 { + continue; + } + + // truncate padding + let padded = padding(payload, mtu); + log::trace!( + "TUN worker, payload length = {}, padded length = {}", + payload, + padded + ); + msg.truncate(router::SIZE_MESSAGE_PREFIX + padded); + debug_assert!(padded <= mtu); + debug_assert_eq!( + if padded < mtu { + (msg.len() - router::SIZE_MESSAGE_PREFIX) % MESSAGE_PADDING_MULTIPLE + } else { + 0 + }, + 0 + ); + + // crypt-key route + let e = wg.router.send(msg); + debug!("TUN worker, router returned {:?}", e); + } + } + + // start a thread for every reader + let wg = self.state.clone(); + + // increment reader count + wg.tun_readers.increase(); + + // start worker + thread::spawn(move || { + worker(&wg, reader); + wg.tun_readers.decrease(); + }); + } + + pub fn wait(&self) -> WaitHandle { + self.state.tun_readers.clone() + } + + pub fn new(writer: T::Writer) -> Wireguard<T, B> { // create device state let mut rng = OsRng::new().unwrap(); @@ -358,7 +484,8 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> { let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE); let wg = Arc::new(WireguardInner { - start: Instant::now(), + enabled: RwLock::new(false), + tun_readers: WaitHandle::new(), id: rng.gen(), mtu: AtomicUsize::new(0), peers: RwLock::new(HashMap::new()), @@ -486,59 +613,6 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> { }); } - // start TUN read IO threads (multiple threads to support multi-queue interfaces) - debug_assert!( - readers.len() > 0, - "attempted to create WG device without TUN readers" - ); - while let Some(reader) = readers.pop() { - let wg = wg.clone(); - thread::spawn(move || loop { - // create vector big enough for any transport message (based on MTU) - let mtu = wg.mtu.load(Ordering::Relaxed); - let size = mtu + router::SIZE_MESSAGE_PREFIX; - let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX); - msg.resize(size, 0); - - // read a new IP packet - let payload = match reader.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX) { - Ok(payload) => payload, - Err(e) => { - debug!("TUN worker, failed to read from tun device: {}", e); - return; - } - }; - debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu); - - // TODO: start device down - if mtu == 0 { - continue; - } - - // truncate padding - let padded = padding(payload, mtu); - log::trace!( - "TUN worker, payload length = {}, padded length = {}", - payload, - padded - ); - msg.truncate(router::SIZE_MESSAGE_PREFIX + padded); - debug_assert!(padded <= mtu); - debug_assert_eq!( - if padded < mtu { - (msg.len() - router::SIZE_MESSAGE_PREFIX) % MESSAGE_PADDING_MULTIPLE - } else { - 0 - }, - 0 - ); - - // crypt-key route - let e = wg.router.send(msg); - debug!("TUN worker, router returned {:?}", e); - }); - } - Wireguard { state: wg, runner: Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY), |