aboutsummaryrefslogtreecommitdiffstats
path: root/src/wireguard.rs
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-09-18 15:31:10 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-09-18 15:31:10 +0200
commit6311aa34022a24224b1dc8d0427cd72dd42e9396 (patch)
tree234937066c4429838dff270e944e95d32e58a862 /src/wireguard.rs
parentWIP: Work on handshake worker (diff)
downloadwireguard-rs-6311aa34022a24224b1dc8d0427cd72dd42e9396.tar.xz
wireguard-rs-6311aa34022a24224b1dc8d0427cd72dd42e9396.zip
WIP: TUN IO worker
Also removed the type parameters from the handshake device.
Diffstat (limited to 'src/wireguard.rs')
-rw-r--r--src/wireguard.rs222
1 files changed, 146 insertions, 76 deletions
diff --git a/src/wireguard.rs b/src/wireguard.rs
index 2c166b4..f98369f 100644
--- a/src/wireguard.rs
+++ b/src/wireguard.rs
@@ -2,17 +2,20 @@ use crate::handshake;
use crate::router;
use crate::types::{Bind, Endpoint, Tun};
-use std::sync::atomic::{AtomicUsize, Ordering};
+use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant};
+use std::collections::HashMap;
+
use log::debug;
use rand::rngs::OsRng;
+use spin::{Mutex, RwLock};
use byteorder::{ByteOrder, LittleEndian};
-use crossbeam_channel::bounded;
-use x25519_dalek::StaticSecret;
+use crossbeam_channel::{bounded, Sender};
+use x25519_dalek::{PublicKey, StaticSecret};
const SIZE_HANDSHAKE_QUEUE: usize = 128;
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
@@ -22,8 +25,10 @@ const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000);
pub struct Peer<T: Tun, B: Bind>(Arc<PeerInner<T, B>>);
pub struct PeerInner<T: Tun, B: Bind> {
- peer: router::Peer<Events, T, B>,
+ router: router::Peer<Events, T, B>,
timers: Timers,
+ rx: AtomicU64,
+ tx: AtomicU64,
}
pub struct Timers {}
@@ -40,96 +45,96 @@ impl router::Callbacks for Events {
fn need_key(t: &Timers) {}
}
-pub struct Wireguard<T: Tun, B: Bind> {
- router: Arc<router::Device<Events, T, B>>,
- handshake: Option<Arc<handshake::Device<()>>>,
+struct Handshake {
+ device: handshake::Device,
+ active: bool,
}
-impl<T: Tun, B: Bind> Wireguard<T, B> {
- fn start(&self) {}
-
- fn new(tun: T, bind: B, sk: StaticSecret) -> Wireguard<T, B> {
- let router = Arc::new(router::Device::new(
- num_cpus::get(),
- tun.clone(),
- bind.clone(),
- ));
-
- let handshake_staged = Arc::new(AtomicUsize::new(0));
- let handshake_device: Arc<handshake::Device<Peer<T, B>>> =
- Arc::new(handshake::Device::new(sk));
+struct WireguardInner<T: Tun, B: Bind> {
+ // identify and configuration map
+ peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>,
- // start UDP read IO thread
- let (handshake_tx, handshake_rx) = bounded(128);
- {
- let tun = tun.clone();
- let bind = bind.clone();
- thread::spawn(move || {
- let mut under_load =
- Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
+ // cryptkey routing
+ router: router::Device<Events, T, B>,
- loop {
- // read UDP packet into vector
- let size = tun.mtu() + 148; // maximum message size
- let mut msg: Vec<u8> =
- Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
- msg.resize(size, 0);
- let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error
- msg.truncate(size);
+ // handshake related state
+ handshake: RwLock<Handshake>,
+ under_load: AtomicBool,
+ pending: AtomicUsize, // num of pending handshake packets in queue
+ queue: Mutex<Sender<(Vec<u8>, B::Endpoint)>>,
- // message type de-multiplexer
- if msg.len() < std::mem::size_of::<u32>() {
- continue;
- }
+ // IO
+ bind: B,
+}
- match LittleEndian::read_u32(&msg[..]) {
- handshake::TYPE_COOKIE_REPLY
- | handshake::TYPE_INITIATION
- | handshake::TYPE_RESPONSE => {
- // detect if under load
- if handshake_staged.fetch_add(1, Ordering::SeqCst)
- > THRESHOLD_UNDER_LOAD
- {
- under_load = Instant::now()
- }
+pub struct Wireguard<T: Tun, B: Bind> {
+ state: Arc<WireguardInner<T, B>>,
+}
- // pass source address along if under load
- handshake_tx
- .send((msg, src, under_load.elapsed() < DURATION_UNDER_LOAD))
- .unwrap();
- }
- router::TYPE_TRANSPORT => {
- // transport message
- }
- _ => (),
- }
- }
- });
+impl<T: Tun, B: Bind> Wireguard<T, B> {
+ fn set_key(&self, sk: Option<StaticSecret>) {
+ let mut handshake = self.state.handshake.write();
+ match sk {
+ None => {
+ let mut rng = OsRng::new().unwrap();
+ handshake.device.set_sk(StaticSecret::new(&mut rng));
+ handshake.active = false;
+ }
+ Some(sk) => {
+ handshake.device.set_sk(sk);
+ handshake.active = true;
+ }
}
+ }
+
+ fn new(tun: T, bind: B) -> Wireguard<T, B> {
+ // create device state
+ let mut rng = OsRng::new().unwrap();
+ let (tx, rx): (Sender<(Vec<u8>, B::Endpoint)>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
+ let wg = Arc::new(WireguardInner {
+ peers: RwLock::new(HashMap::new()),
+ router: router::Device::new(num_cpus::get(), tun.clone(), bind.clone()),
+ pending: AtomicUsize::new(0),
+ handshake: RwLock::new(Handshake {
+ device: handshake::Device::new(StaticSecret::new(&mut rng)),
+ active: false,
+ }),
+ under_load: AtomicBool::new(false),
+ bind: bind.clone(),
+ queue: Mutex::new(tx),
+ });
// start handshake workers
for _ in 0..num_cpus::get() {
+ let wg = wg.clone();
+ let rx = rx.clone();
let bind = bind.clone();
- let handshake_rx = handshake_rx.clone();
- let handshake_device = handshake_device.clone();
thread::spawn(move || {
// prepare OsRng instance for this thread
let mut rng = OsRng::new().unwrap();
// process elements from the handshake queue
- for (msg, src, under_load) in handshake_rx {
+ for (msg, src) in rx {
+ wg.pending.fetch_sub(1, Ordering::SeqCst);
+
// feed message to handshake device
let src_validate = (&src).into_address(); // TODO avoid
- match handshake_device.process(
+ let state = wg.handshake.read();
+ if !state.active {
+ continue;
+ }
+
+ // process message
+ match state.device.process(
&mut rng,
&msg[..],
- if under_load {
+ if wg.under_load.load(Ordering::Relaxed) {
Some(&src_validate)
} else {
None
},
) {
- Ok((identity, msg, keypair)) => {
+ Ok((pk, msg, keypair)) => {
// send response
if let Some(msg) = msg {
let _ = bind.send(&msg[..], &src).map_err(|e| {
@@ -141,11 +146,13 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
}
// update timers
- if let Some(identity) = identity {
+ if let Some(pk) = pk {
// add keypair to peer and free any unused ids
if let Some(keypair) = keypair {
- for id in identity.0.peer.add_keypair(keypair) {
- handshake_device.release(id);
+ if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
+ for id in peer.0.router.add_keypair(keypair) {
+ state.device.release(id);
+ }
}
}
}
@@ -156,13 +163,76 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
});
}
- // start TUN read IO thread
+ // start UDP read IO thread
+ {
+ let wg = wg.clone();
+ let tun = tun.clone();
+ let bind = bind.clone();
+ thread::spawn(move || {
+ let mut last_under_load =
+ Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
+
+ loop {
+ // read UDP packet into vector
+ let size = tun.mtu() + 148; // maximum message size
+ let mut msg: Vec<u8> = Vec::with_capacity(size);
+ msg.resize(size, 0);
+ let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error
+ msg.truncate(size);
- thread::spawn(move || {});
+ // message type de-multiplexer
+ if msg.len() < std::mem::size_of::<u32>() {
+ continue;
+ }
- Wireguard {
- router,
- handshake: None,
+ match LittleEndian::read_u32(&msg[..]) {
+ handshake::TYPE_COOKIE_REPLY
+ | handshake::TYPE_INITIATION
+ | handshake::TYPE_RESPONSE => {
+ // update under_load flag
+ if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD {
+ last_under_load = Instant::now();
+ wg.under_load.store(true, Ordering::SeqCst);
+ } else if last_under_load.elapsed() > DURATION_UNDER_LOAD {
+ wg.under_load.store(false, Ordering::SeqCst);
+ }
+
+ wg.queue.lock().send((msg, src)).unwrap();
+ }
+ router::TYPE_TRANSPORT => {
+ // transport message
+
+ // pad the message
+
+ let _ = wg.router.recv(src, msg);
+ }
+ _ => (),
+ }
+ }
+ });
+ }
+
+ // start TUN read IO thread
+ {
+ let wg = wg.clone();
+ thread::spawn(move || loop {
+ // read a new IP packet
+ let mtu = tun.mtu();
+ let size = mtu + 148;
+ let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
+ let size = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
+ msg.truncate(size);
+
+ // pad message to multiple of 16
+ while msg.len() < mtu && msg.len() % 16 != 0 {
+ msg.push(0);
+ }
+
+ // crypt-key route
+ let _ = wg.router.send(msg);
+ });
}
+
+ Wireguard { state: wg }
}
}