aboutsummaryrefslogtreecommitdiffstats
path: root/src/wireguard.rs
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-09-16 22:33:46 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-09-16 22:33:46 +0200
commitdfe4a22920e31f30f0e7ceb7c0d588dd48af13ad (patch)
tree129f8955ee34c8b36fd27305471c41349282b0bc /src/wireguard.rs
parentWIP: Handshake queue and workers (diff)
downloadwireguard-rs-dfe4a22920e31f30f0e7ceb7c0d588dd48af13ad.tar.xz
wireguard-rs-dfe4a22920e31f30f0e7ceb7c0d588dd48af13ad.zip
WIP: Work on handshake worker
Diffstat (limited to 'src/wireguard.rs')
-rw-r--r--src/wireguard.rs72
1 files changed, 62 insertions, 10 deletions
diff --git a/src/wireguard.rs b/src/wireguard.rs
index 71b981e..2c166b4 100644
--- a/src/wireguard.rs
+++ b/src/wireguard.rs
@@ -1,13 +1,15 @@
use crate::handshake;
use crate::router;
-use crate::types::{Bind, Tun};
+use crate::types::{Bind, Endpoint, Tun};
use std::sync::atomic::{AtomicUsize, Ordering};
-use std::sync::mpsc::sync_channel;
use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant};
+use log::debug;
+use rand::rngs::OsRng;
+
use byteorder::{ByteOrder, LittleEndian};
use crossbeam_channel::bounded;
use x25519_dalek::StaticSecret;
@@ -16,6 +18,14 @@ const SIZE_HANDSHAKE_QUEUE: usize = 128;
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000);
+#[derive(Clone)]
+pub struct Peer<T: Tun, B: Bind>(Arc<PeerInner<T, B>>);
+
+pub struct PeerInner<T: Tun, B: Bind> {
+ peer: router::Peer<Events, T, B>,
+ timers: Timers,
+}
+
pub struct Timers {}
pub struct Events();
@@ -38,7 +48,7 @@ pub struct Wireguard<T: Tun, B: Bind> {
impl<T: Tun, B: Bind> Wireguard<T, B> {
fn start(&self) {}
- fn new(tun: T, bind: B) -> Wireguard<T, B> {
+ fn new(tun: T, bind: B, sk: StaticSecret) -> Wireguard<T, B> {
let router = Arc::new(router::Device::new(
num_cpus::get(),
tun.clone(),
@@ -46,11 +56,14 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
));
let handshake_staged = Arc::new(AtomicUsize::new(0));
+ let handshake_device: Arc<handshake::Device<Peer<T, B>>> =
+ Arc::new(handshake::Device::new(sk));
// start UDP read IO thread
let (handshake_tx, handshake_rx) = bounded(128);
{
let tun = tun.clone();
+ let bind = bind.clone();
thread::spawn(move || {
let mut under_load =
Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
@@ -81,11 +94,9 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
}
// pass source address along if under load
- if under_load.elapsed() < DURATION_UNDER_LOAD {
- handshake_tx.send((msg, Some(src))).unwrap();
- } else {
- handshake_tx.send((msg, None)).unwrap();
- }
+ handshake_tx
+ .send((msg, src, under_load.elapsed() < DURATION_UNDER_LOAD))
+ .unwrap();
}
router::TYPE_TRANSPORT => {
// transport message
@@ -98,9 +109,50 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
// start handshake workers
for _ in 0..num_cpus::get() {
+ let bind = bind.clone();
let handshake_rx = handshake_rx.clone();
- thread::spawn(move || loop {
- let (msg, src) = handshake_rx.recv().unwrap(); // TODO handle error
+ let handshake_device = handshake_device.clone();
+ thread::spawn(move || {
+ // prepare OsRng instance for this thread
+ let mut rng = OsRng::new().unwrap();
+
+ // process elements from the handshake queue
+ for (msg, src, under_load) in handshake_rx {
+ // feed message to handshake device
+ let src_validate = (&src).into_address(); // TODO avoid
+ match handshake_device.process(
+ &mut rng,
+ &msg[..],
+ if under_load {
+ Some(&src_validate)
+ } else {
+ None
+ },
+ ) {
+ Ok((identity, msg, keypair)) => {
+ // send response
+ if let Some(msg) = msg {
+ let _ = bind.send(&msg[..], &src).map_err(|e| {
+ debug!(
+ "handshake worker, failed to send response, error = {:?}",
+ e
+ )
+ });
+ }
+
+ // update timers
+ if let Some(identity) = identity {
+ // add keypair to peer and free any unused ids
+ if let Some(keypair) = keypair {
+ for id in identity.0.peer.add_keypair(keypair) {
+ handshake_device.release(id);
+ }
+ }
+ }
+ }
+ Err(e) => debug!("handshake worker, error = {:?}", e),
+ }
+ }
});
}