diff options
Diffstat (limited to 'src/wireguard.rs')
-rw-r--r-- | src/wireguard.rs | 161 |
1 files changed, 93 insertions, 68 deletions
diff --git a/src/wireguard.rs b/src/wireguard.rs index bcb8592..f14a053 100644 --- a/src/wireguard.rs +++ b/src/wireguard.rs @@ -3,8 +3,10 @@ use crate::handshake; use crate::router; use crate::timers::{Events, Timers}; +use crate::types::bind::Reader as BindReader; use crate::types::bind::{Bind, Writer}; use crate::types::tun::{Reader, Tun, MTU}; + use crate::types::Endpoint; use hjul::Runner; @@ -53,7 +55,7 @@ pub struct PeerInner<B: Bind> { pub timers: RwLock<Timers>, // } -impl <B:Bind > PeerInner<B> { +impl<B: Bind> PeerInner<B> { #[inline(always)] pub fn timers(&self) -> RwLockReadGuard<Timers> { self.timers.read() @@ -153,7 +155,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { } pub fn get_sk(&self) -> Option<StaticSecret> { - let mut handshake = self.state.handshake.read(); + let handshake = self.state.handshake.read(); if handshake.active { Some(handshake.device.get_sk()) } else { @@ -184,66 +186,73 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { peer } - pub fn new_bind(reader: B::Reader, writer: B::Writer, closer: B::Closer) { - - // drop existing closer - - // swap IO thread for new reader - - // start UDP read IO thread - - /* - { - let wg = wg.clone(); - let mtu = mtu.clone(); - thread::spawn(move || { - let mut last_under_load = - Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000); - - loop { - // create vector big enough for any message given current MTU - let size = mtu.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE; - let mut msg: Vec<u8> = Vec::with_capacity(size); - msg.resize(size, 0); - - // read UDP packet into vector - let (size, src) = reader.read(&mut msg).unwrap(); // TODO handle error - msg.truncate(size); + /* Begin consuming messages from the reader. + * + * Any previous reader thread is stopped by closing the previous reader, + * which unblocks the thread and causes an error on reader.read + */ + pub fn add_reader(&self, reader: B::Reader) { + let wg = self.state.clone(); + thread::spawn(move || { + let mut last_under_load = + Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000); + + loop { + // create vector big enough for any message given current MTU + let size = wg.mtu.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE; + let mut msg: Vec<u8> = Vec::with_capacity(size); + msg.resize(size, 0); - // message type de-multiplexer - if msg.len() < std::mem::size_of::<u32>() { - continue; + // read UDP packet into vector + let (size, src) = match reader.read(&mut msg) { + Err(e) => { + debug!("Bind reader closed with {}", e); + return; } - match LittleEndian::read_u32(&msg[..]) { - handshake::TYPE_COOKIE_REPLY - | handshake::TYPE_INITIATION - | handshake::TYPE_RESPONSE => { - // update under_load flag - if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD { - last_under_load = Instant::now(); - wg.under_load.store(true, Ordering::SeqCst); - } else if last_under_load.elapsed() > DURATION_UNDER_LOAD { - wg.under_load.store(false, Ordering::SeqCst); - } + Ok(v) => v, + }; + msg.truncate(size); - wg.queue - .lock() - .send(HandshakeJob::Message(msg, src)) - .unwrap(); - } - router::TYPE_TRANSPORT => { - // transport message - let _ = wg.router.recv(src, msg); + // message type de-multiplexer + if msg.len() < std::mem::size_of::<u32>() { + continue; + } + match LittleEndian::read_u32(&msg[..]) { + handshake::TYPE_COOKIE_REPLY + | handshake::TYPE_INITIATION + | handshake::TYPE_RESPONSE => { + // update under_load flag + if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD { + last_under_load = Instant::now(); + wg.under_load.store(true, Ordering::SeqCst); + } else if last_under_load.elapsed() > DURATION_UNDER_LOAD { + wg.under_load.store(false, Ordering::SeqCst); } - _ => (), + + wg.queue + .lock() + .send(HandshakeJob::Message(msg, src)) + .unwrap(); } + router::TYPE_TRANSPORT => { + // transport message + let _ = wg.router.recv(src, msg).map_err(|e| { + debug!("Failed to handle incoming transport message: {}", e); + }); + } + _ => (), } - }); - } - */ + } + }); } - pub fn new(reader: T::Reader, writer: T::Writer, mtu: T::MTU) -> Wireguard<T, B> { + pub fn set_writer(&self, writer: B::Writer) { + // TODO: Consider unifying these and avoid Clone requirement on writer + *self.state.send.write() = Some(writer.clone()); + self.state.router.set_outbound_writer(writer); + } + + pub fn new(mut readers: Vec<T::Reader>, writer: T::Writer, mtu: T::MTU) -> Wireguard<T, B> { // create device state let mut rng = OsRng::new().unwrap(); let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE); @@ -292,14 +301,16 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { None }, ) { - Ok((pk, msg, keypair)) => { + Ok((pk, resp, keypair)) => { // send response - if let Some(msg) = msg { + let mut resp_len: u64 = 0; + if let Some(msg) = resp { + resp_len = msg.len() as u64; let send: &Option<B::Writer> = &*wg.send.read(); if let Some(writer) = send.as_ref() { let _ = writer.write(&msg[..], &src).map_err(|e| { debug!( - "handshake worker, failed to send response, error = {:?}", + "handshake worker, failed to send response, error = {}", e ) }); @@ -308,16 +319,23 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { // update timers if let Some(pk) = pk { + // authenticated handshake packet received if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { + // add to rx_bytes and tx_bytes + let req_len = msg.len() as u64; + peer.rx_bytes.fetch_add(req_len, Ordering::Relaxed); + peer.tx_bytes.fetch_add(resp_len, Ordering::Relaxed); + // update endpoint peer.router.set_endpoint(src); - // add keypair to peer and free any unused ids - if let Some(keypair) = keypair { - for id in peer.router.add_keypair(keypair) { + // add keypair to peer + keypair.map(|kp| { + // free any unused ids + for id in peer.router.add_keypair(kp) { state.device.release(id); } - } + }); } } } @@ -325,20 +343,27 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { } } HandshakeJob::New(pk) => { - let msg = state.device.begin(&mut rng, &pk).unwrap(); // TODO handle - if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { - peer.router.send(&msg[..]); - peer.timers.read().handshake_sent(); - } + let _ = state.device.begin(&mut rng, &pk).map(|msg| { + if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { + let _ = peer.router.send(&msg[..]).map_err(|e| { + debug!("handshake worker, failed to send handshake initiation, error = {}", e) + }); + } + }); } } } }); } - // start TUN read IO thread - { + // start TUN read IO threads (multiple threads to support multi-queue interfaces) + debug_assert!( + readers.len() > 0, + "attempted to create WG device without TUN readers" + ); + while let Some(reader) = readers.pop() { let wg = wg.clone(); + let mtu = mtu.clone(); thread::spawn(move || loop { // create vector big enough for any transport message (based on MTU) let mtu = mtu.mtu(); |