diff options
author | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-10-09 15:08:26 +0200 |
---|---|---|
committer | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-10-09 15:08:26 +0200 |
commit | 761c46064d7510303f08cde27c9e13b07293f3af (patch) | |
tree | 7b914169725952e557223972b3f0b611c54e6829 /src/wireguard.rs | |
parent | Restructure dummy implementations (diff) | |
download | wireguard-rs-761c46064d7510303f08cde27c9e13b07293f3af.tar.xz wireguard-rs-761c46064d7510303f08cde27c9e13b07293f3af.zip |
Restructure IO traits.
Diffstat (limited to '')
-rw-r--r-- | src/wireguard.rs | 200 |
1 files changed, 128 insertions, 72 deletions
diff --git a/src/wireguard.rs b/src/wireguard.rs index ea600d0..ba81f47 100644 --- a/src/wireguard.rs +++ b/src/wireguard.rs @@ -2,11 +2,13 @@ use crate::constants::*; use crate::handshake; use crate::router; use crate::timers::{Events, Timers}; -use crate::types::{Bind, Endpoint, Tun}; + +use crate::types::Endpoint; +use crate::types::tun::{Tun, Reader, MTU}; +use crate::types::bind::{Bind, Writer}; use hjul::Runner; -use std::cmp; use std::ops::Deref; use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; @@ -27,12 +29,20 @@ const SIZE_HANDSHAKE_QUEUE: usize = 128; const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4; const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000); -#[derive(Clone)] pub struct Peer<T: Tun, B: Bind> { - pub router: Arc<router::Peer<Events<T, B>, T, B>>, + pub router: Arc<router::Peer<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>, pub state: Arc<PeerInner<B>>, } +impl <T : Tun, B : Bind> Clone for Peer<T, B > { + fn clone(&self) -> Peer<T, B> { + Peer{ + router: self.router.clone(), + state: self.state.clone() + } + } +} + pub struct PeerInner<B: Bind> { pub keepalive: AtomicUsize, // keepalive interval pub rx_bytes: AtomicU64, @@ -66,20 +76,22 @@ pub enum HandshakeJob<E> { } struct WireguardInner<T: Tun, B: Bind> { + // provides access to the MTU value of the tun device + // (otherwise owned solely by the router and a dedicated read IO thread) + mtu: T::MTU, + send: RwLock<Option<B::Writer>>, + // identify and configuration map peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>, // cryptkey router - router: router::Device<Events<T, B>, T, B>, + router: router::Device<B::Endpoint, Events<T, B>, T::Writer, B::Writer>, // handshake related state handshake: RwLock<Handshake>, under_load: AtomicBool, pending: AtomicUsize, // num of pending handshake packets in queue queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, - - // IO - bind: B, } pub struct Wireguard<T: Tun, B: Bind> { @@ -87,6 +99,17 @@ pub struct Wireguard<T: Tun, B: Bind> { state: Arc<WireguardInner<T, B>>, } +/* Returns the padded length of a message: + * + * # Arguments + * + * - `size` : Size of unpadded message + * - `mtu` : Maximum transmission unit of the device + * + * # Returns + * + * The padded length (always less than or equal to the MTU) + */ #[inline(always)] const fn padding(size: usize, mtu: usize) -> usize { #[inline(always)] @@ -114,6 +137,15 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { } } + pub fn get_sk(&self) -> Option<StaticSecret> { + let mut handshake = self.state.handshake.read(); + if handshake.active { + Some(handshake.device.get_sk()) + } else { + None + } + } + pub fn new_peer(&self, pk: PublicKey) -> Peer<T, B> { let state = Arc::new(PeerInner { pk, @@ -137,20 +169,92 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { peer } - pub fn new(tun: T, bind: B) -> Wireguard<T, B> { + pub fn new_bind( + reader: B::Reader, + writer: B::Writer, + closer: B::Closer + ) { + + // drop existing closer + + + // swap IO thread for new reader + + + // start UDP read IO thread + + /* + { + let wg = wg.clone(); + let mtu = mtu.clone(); + thread::spawn(move || { + let mut last_under_load = + Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000); + + loop { + // create vector big enough for any message given current MTU + let size = mtu.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE; + let mut msg: Vec<u8> = Vec::with_capacity(size); + msg.resize(size, 0); + + // read UDP packet into vector + let (size, src) = reader.read(&mut msg).unwrap(); // TODO handle error + msg.truncate(size); + + // message type de-multiplexer + if msg.len() < std::mem::size_of::<u32>() { + continue; + } + match LittleEndian::read_u32(&msg[..]) { + handshake::TYPE_COOKIE_REPLY + | handshake::TYPE_INITIATION + | handshake::TYPE_RESPONSE => { + // update under_load flag + if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD { + last_under_load = Instant::now(); + wg.under_load.store(true, Ordering::SeqCst); + } else if last_under_load.elapsed() > DURATION_UNDER_LOAD { + wg.under_load.store(false, Ordering::SeqCst); + } + + wg.queue + .lock() + .send(HandshakeJob::Message(msg, src)) + .unwrap(); + } + router::TYPE_TRANSPORT => { + // transport message + let _ = wg.router.recv(src, msg); + } + _ => (), + } + } + }); + } + */ + + + } + + pub fn new( + reader: T::Reader, + writer: T::Writer, + mtu: T::MTU, + ) -> Wireguard<T, B> { // create device state let mut rng = OsRng::new().unwrap(); let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE); let wg = Arc::new(WireguardInner { + mtu: mtu.clone(), peers: RwLock::new(HashMap::new()), - router: router::Device::new(num_cpus::get(), tun.clone(), bind.clone()), + send: RwLock::new(None), + router: router::Device::new(num_cpus::get(), writer), // router owns the writing half pending: AtomicUsize::new(0), handshake: RwLock::new(Handshake { device: handshake::Device::new(StaticSecret::new(&mut rng)), active: false, }), under_load: AtomicBool::new(false), - bind: bind.clone(), queue: Mutex::new(tx), }); @@ -158,7 +262,6 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { for _ in 0..num_cpus::get() { let wg = wg.clone(); let rx = rx.clone(); - let bind = bind.clone(); thread::spawn(move || { // prepare OsRng instance for this thread let mut rng = OsRng::new().unwrap(); @@ -189,19 +292,22 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { Ok((pk, msg, keypair)) => { // send response if let Some(msg) = msg { - let _ = bind.send(&msg[..], &src).map_err(|e| { - debug!( - "handshake worker, failed to send response, error = {:?}", - e - ) - }); + let send : &Option<B::Writer> = &*wg.send.read(); + if let Some(writer) = send.as_ref() { + let _ = writer.write(&msg[..], &src).map_err(|e| { + debug!( + "handshake worker, failed to send response, error = {:?}", + e + ) + }); + } } // update timers if let Some(pk) = pk { if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { - // update endpoint (DISCUSS: right semantics?) - peer.router.set_endpoint(src_validate); + // update endpoint + peer.router.set_endpoint(src); // add keypair to peer and free any unused ids if let Some(keypair) = keypair { @@ -227,68 +333,18 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { }); } - // start UDP read IO thread - { - let wg = wg.clone(); - let tun = tun.clone(); - let bind = bind.clone(); - thread::spawn(move || { - let mut last_under_load = - Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000); - - loop { - // create vector big enough for any message given current MTU - let size = tun.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE; - let mut msg: Vec<u8> = Vec::with_capacity(size); - msg.resize(size, 0); - - // read UDP packet into vector - let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error - msg.truncate(size); - - // message type de-multiplexer - if msg.len() < std::mem::size_of::<u32>() { - continue; - } - match LittleEndian::read_u32(&msg[..]) { - handshake::TYPE_COOKIE_REPLY - | handshake::TYPE_INITIATION - | handshake::TYPE_RESPONSE => { - // update under_load flag - if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD { - last_under_load = Instant::now(); - wg.under_load.store(true, Ordering::SeqCst); - } else if last_under_load.elapsed() > DURATION_UNDER_LOAD { - wg.under_load.store(false, Ordering::SeqCst); - } - - wg.queue - .lock() - .send(HandshakeJob::Message(msg, src)) - .unwrap(); - } - router::TYPE_TRANSPORT => { - // transport message - let _ = wg.router.recv(src, msg); - } - _ => (), - } - } - }); - } - // start TUN read IO thread { let wg = wg.clone(); thread::spawn(move || loop { // create vector big enough for any transport message (based on MTU) - let mtu = tun.mtu(); + let mtu = mtu.mtu(); let size = mtu + router::SIZE_MESSAGE_PREFIX; let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX); msg.resize(size, 0); // read a new IP packet - let payload = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap(); + let payload = reader.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap(); debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu); // truncate padding |