diff options
author | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-09-18 15:31:10 +0200 |
---|---|---|
committer | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-09-18 15:31:10 +0200 |
commit | 6311aa34022a24224b1dc8d0427cd72dd42e9396 (patch) | |
tree | 234937066c4429838dff270e944e95d32e58a862 /src/wireguard.rs | |
parent | WIP: Work on handshake worker (diff) | |
download | wireguard-rs-6311aa34022a24224b1dc8d0427cd72dd42e9396.tar.xz wireguard-rs-6311aa34022a24224b1dc8d0427cd72dd42e9396.zip |
WIP: TUN IO worker
Also removed the type parameters from the handshake device.
Diffstat (limited to 'src/wireguard.rs')
-rw-r--r-- | src/wireguard.rs | 222 |
1 files changed, 146 insertions, 76 deletions
diff --git a/src/wireguard.rs b/src/wireguard.rs index 2c166b4..f98369f 100644 --- a/src/wireguard.rs +++ b/src/wireguard.rs @@ -2,17 +2,20 @@ use crate::handshake; use crate::router; use crate::types::{Bind, Endpoint, Tun}; -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; use std::thread; use std::time::{Duration, Instant}; +use std::collections::HashMap; + use log::debug; use rand::rngs::OsRng; +use spin::{Mutex, RwLock}; use byteorder::{ByteOrder, LittleEndian}; -use crossbeam_channel::bounded; -use x25519_dalek::StaticSecret; +use crossbeam_channel::{bounded, Sender}; +use x25519_dalek::{PublicKey, StaticSecret}; const SIZE_HANDSHAKE_QUEUE: usize = 128; const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4; @@ -22,8 +25,10 @@ const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000); pub struct Peer<T: Tun, B: Bind>(Arc<PeerInner<T, B>>); pub struct PeerInner<T: Tun, B: Bind> { - peer: router::Peer<Events, T, B>, + router: router::Peer<Events, T, B>, timers: Timers, + rx: AtomicU64, + tx: AtomicU64, } pub struct Timers {} @@ -40,96 +45,96 @@ impl router::Callbacks for Events { fn need_key(t: &Timers) {} } -pub struct Wireguard<T: Tun, B: Bind> { - router: Arc<router::Device<Events, T, B>>, - handshake: Option<Arc<handshake::Device<()>>>, +struct Handshake { + device: handshake::Device, + active: bool, } -impl<T: Tun, B: Bind> Wireguard<T, B> { - fn start(&self) {} - - fn new(tun: T, bind: B, sk: StaticSecret) -> Wireguard<T, B> { - let router = Arc::new(router::Device::new( - num_cpus::get(), - tun.clone(), - bind.clone(), - )); - - let handshake_staged = Arc::new(AtomicUsize::new(0)); - let handshake_device: Arc<handshake::Device<Peer<T, B>>> = - Arc::new(handshake::Device::new(sk)); +struct WireguardInner<T: Tun, B: Bind> { + // identify and configuration map + peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>, - // start UDP read IO thread - let (handshake_tx, handshake_rx) = bounded(128); - { - let tun = tun.clone(); - let bind = bind.clone(); - thread::spawn(move || { - let mut under_load = - Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000); + // cryptkey routing + router: router::Device<Events, T, B>, - loop { - // read UDP packet into vector - let size = tun.mtu() + 148; // maximum message size - let mut msg: Vec<u8> = - Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX); - msg.resize(size, 0); - let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error - msg.truncate(size); + // handshake related state + handshake: RwLock<Handshake>, + under_load: AtomicBool, + pending: AtomicUsize, // num of pending handshake packets in queue + queue: Mutex<Sender<(Vec<u8>, B::Endpoint)>>, - // message type de-multiplexer - if msg.len() < std::mem::size_of::<u32>() { - continue; - } + // IO + bind: B, +} - match LittleEndian::read_u32(&msg[..]) { - handshake::TYPE_COOKIE_REPLY - | handshake::TYPE_INITIATION - | handshake::TYPE_RESPONSE => { - // detect if under load - if handshake_staged.fetch_add(1, Ordering::SeqCst) - > THRESHOLD_UNDER_LOAD - { - under_load = Instant::now() - } +pub struct Wireguard<T: Tun, B: Bind> { + state: Arc<WireguardInner<T, B>>, +} - // pass source address along if under load - handshake_tx - .send((msg, src, under_load.elapsed() < DURATION_UNDER_LOAD)) - .unwrap(); - } - router::TYPE_TRANSPORT => { - // transport message - } - _ => (), - } - } - }); +impl<T: Tun, B: Bind> Wireguard<T, B> { + fn set_key(&self, sk: Option<StaticSecret>) { + let mut handshake = self.state.handshake.write(); + match sk { + None => { + let mut rng = OsRng::new().unwrap(); + handshake.device.set_sk(StaticSecret::new(&mut rng)); + handshake.active = false; + } + Some(sk) => { + handshake.device.set_sk(sk); + handshake.active = true; + } } + } + + fn new(tun: T, bind: B) -> Wireguard<T, B> { + // create device state + let mut rng = OsRng::new().unwrap(); + let (tx, rx): (Sender<(Vec<u8>, B::Endpoint)>, _) = bounded(SIZE_HANDSHAKE_QUEUE); + let wg = Arc::new(WireguardInner { + peers: RwLock::new(HashMap::new()), + router: router::Device::new(num_cpus::get(), tun.clone(), bind.clone()), + pending: AtomicUsize::new(0), + handshake: RwLock::new(Handshake { + device: handshake::Device::new(StaticSecret::new(&mut rng)), + active: false, + }), + under_load: AtomicBool::new(false), + bind: bind.clone(), + queue: Mutex::new(tx), + }); // start handshake workers for _ in 0..num_cpus::get() { + let wg = wg.clone(); + let rx = rx.clone(); let bind = bind.clone(); - let handshake_rx = handshake_rx.clone(); - let handshake_device = handshake_device.clone(); thread::spawn(move || { // prepare OsRng instance for this thread let mut rng = OsRng::new().unwrap(); // process elements from the handshake queue - for (msg, src, under_load) in handshake_rx { + for (msg, src) in rx { + wg.pending.fetch_sub(1, Ordering::SeqCst); + // feed message to handshake device let src_validate = (&src).into_address(); // TODO avoid - match handshake_device.process( + let state = wg.handshake.read(); + if !state.active { + continue; + } + + // process message + match state.device.process( &mut rng, &msg[..], - if under_load { + if wg.under_load.load(Ordering::Relaxed) { Some(&src_validate) } else { None }, ) { - Ok((identity, msg, keypair)) => { + Ok((pk, msg, keypair)) => { // send response if let Some(msg) = msg { let _ = bind.send(&msg[..], &src).map_err(|e| { @@ -141,11 +146,13 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { } // update timers - if let Some(identity) = identity { + if let Some(pk) = pk { // add keypair to peer and free any unused ids if let Some(keypair) = keypair { - for id in identity.0.peer.add_keypair(keypair) { - handshake_device.release(id); + if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { + for id in peer.0.router.add_keypair(keypair) { + state.device.release(id); + } } } } @@ -156,13 +163,76 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { }); } - // start TUN read IO thread + // start UDP read IO thread + { + let wg = wg.clone(); + let tun = tun.clone(); + let bind = bind.clone(); + thread::spawn(move || { + let mut last_under_load = + Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000); + + loop { + // read UDP packet into vector + let size = tun.mtu() + 148; // maximum message size + let mut msg: Vec<u8> = Vec::with_capacity(size); + msg.resize(size, 0); + let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error + msg.truncate(size); - thread::spawn(move || {}); + // message type de-multiplexer + if msg.len() < std::mem::size_of::<u32>() { + continue; + } - Wireguard { - router, - handshake: None, + match LittleEndian::read_u32(&msg[..]) { + handshake::TYPE_COOKIE_REPLY + | handshake::TYPE_INITIATION + | handshake::TYPE_RESPONSE => { + // update under_load flag + if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD { + last_under_load = Instant::now(); + wg.under_load.store(true, Ordering::SeqCst); + } else if last_under_load.elapsed() > DURATION_UNDER_LOAD { + wg.under_load.store(false, Ordering::SeqCst); + } + + wg.queue.lock().send((msg, src)).unwrap(); + } + router::TYPE_TRANSPORT => { + // transport message + + // pad the message + + let _ = wg.router.recv(src, msg); + } + _ => (), + } + } + }); + } + + // start TUN read IO thread + { + let wg = wg.clone(); + thread::spawn(move || loop { + // read a new IP packet + let mtu = tun.mtu(); + let size = mtu + 148; + let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX); + let size = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap(); + msg.truncate(size); + + // pad message to multiple of 16 + while msg.len() < mtu && msg.len() % 16 != 0 { + msg.push(0); + } + + // crypt-key route + let _ = wg.router.send(msg); + }); } + + Wireguard { state: wg } } } |