aboutsummaryrefslogtreecommitdiffstats
path: root/src/wireguard.rs
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-09-21 17:22:03 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-09-21 17:22:03 +0200
commit5cc108349968fbaa6998220631eb749276e64f45 (patch)
treede426ba6593a453503ec0a69349c9874e32db229 /src/wireguard.rs
parentWIP: TUN IO worker (diff)
downloadwireguard-rs-5cc108349968fbaa6998220631eb749276e64f45.tar.xz
wireguard-rs-5cc108349968fbaa6998220631eb749276e64f45.zip
Added zero_key to peer
Diffstat (limited to '')
-rw-r--r--src/wireguard.rs116
1 files changed, 65 insertions, 51 deletions
diff --git a/src/wireguard.rs b/src/wireguard.rs
index f98369f..3b4724e 100644
--- a/src/wireguard.rs
+++ b/src/wireguard.rs
@@ -1,5 +1,6 @@
use crate::handshake;
use crate::router;
+use crate::timers::{Events, Timers};
use crate::types::{Bind, Endpoint, Tun};
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
@@ -21,28 +22,19 @@ const SIZE_HANDSHAKE_QUEUE: usize = 128;
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000);
-#[derive(Clone)]
-pub struct Peer<T: Tun, B: Bind>(Arc<PeerInner<T, B>>);
+type Peer<T: Tun, B: Bind> = Arc<PeerInner<T, B>>;
pub struct PeerInner<T: Tun, B: Bind> {
- router: router::Peer<Events, T, B>,
- timers: Timers,
- rx: AtomicU64,
- tx: AtomicU64,
+ queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, // handshake queue
+ router: router::Peer<Events, T, B>, // router peer
+ timers: Option<Timers>, //
}
-pub struct Timers {}
-
-pub struct Events();
-
-impl router::Callbacks for Events {
- type Opaque = Timers;
-
- fn send(t: &Timers, size: usize, data: bool, sent: bool) {}
-
- fn recv(t: &Timers, size: usize, data: bool, sent: bool) {}
-
- fn need_key(t: &Timers) {}
+impl<T: Tun, B: Bind> PeerInner<T, B> {
+ #[inline(always)]
+ fn timers(&self) -> &Timers {
+ self.timers.as_ref().unwrap()
+ }
}
struct Handshake {
@@ -50,6 +42,11 @@ struct Handshake {
active: bool,
}
+enum HandshakeJob<E> {
+ Message(Vec<u8>, E),
+ New(PublicKey),
+}
+
struct WireguardInner<T: Tun, B: Bind> {
// identify and configuration map
peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>,
@@ -61,7 +58,7 @@ struct WireguardInner<T: Tun, B: Bind> {
handshake: RwLock<Handshake>,
under_load: AtomicBool,
pending: AtomicUsize, // num of pending handshake packets in queue
- queue: Mutex<Sender<(Vec<u8>, B::Endpoint)>>,
+ queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>,
// IO
bind: B,
@@ -90,7 +87,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
fn new(tun: T, bind: B) -> Wireguard<T, B> {
// create device state
let mut rng = OsRng::new().unwrap();
- let (tx, rx): (Sender<(Vec<u8>, B::Endpoint)>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
+ let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
let wg = Arc::new(WireguardInner {
peers: RwLock::new(HashMap::new()),
router: router::Device::new(num_cpus::get(), tun.clone(), bind.clone()),
@@ -114,50 +111,64 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
let mut rng = OsRng::new().unwrap();
// process elements from the handshake queue
- for (msg, src) in rx {
+ for job in rx {
wg.pending.fetch_sub(1, Ordering::SeqCst);
-
- // feed message to handshake device
- let src_validate = (&src).into_address(); // TODO avoid
let state = wg.handshake.read();
if !state.active {
continue;
}
- // process message
- match state.device.process(
- &mut rng,
- &msg[..],
- if wg.under_load.load(Ordering::Relaxed) {
- Some(&src_validate)
- } else {
- None
- },
- ) {
- Ok((pk, msg, keypair)) => {
- // send response
- if let Some(msg) = msg {
- let _ = bind.send(&msg[..], &src).map_err(|e| {
- debug!(
+ match job {
+ HandshakeJob::Message(msg, src) => {
+ // feed message to handshake device
+ let src_validate = (&src).into_address(); // TODO avoid
+
+ // process message
+ match state.device.process(
+ &mut rng,
+ &msg[..],
+ if wg.under_load.load(Ordering::Relaxed) {
+ Some(&src_validate)
+ } else {
+ None
+ },
+ ) {
+ Ok((pk, msg, keypair)) => {
+ // send response
+ if let Some(msg) = msg {
+ let _ = bind.send(&msg[..], &src).map_err(|e| {
+ debug!(
"handshake worker, failed to send response, error = {:?}",
e
)
- });
- }
+ });
+ }
- // update timers
- if let Some(pk) = pk {
- // add keypair to peer and free any unused ids
- if let Some(keypair) = keypair {
- if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
- for id in peer.0.router.add_keypair(keypair) {
- state.device.release(id);
+ // update timers
+ if let Some(pk) = pk {
+ if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
+ // update endpoint (DISCUSS: right semantics?)
+ peer.router.set_endpoint(src_validate);
+
+ // add keypair to peer and free any unused ids
+ if let Some(keypair) = keypair {
+ for id in peer.router.add_keypair(keypair) {
+ state.device.release(id);
+ }
+ }
}
}
}
+ Err(e) => debug!("handshake worker, error = {:?}", e),
+ }
+ }
+ HandshakeJob::New(pk) => {
+ let msg = state.device.begin(&mut rng, &pk).unwrap(); // TODO handle
+ if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
+ peer.router.send(&msg[..]);
+ peer.timers().handshake_sent();
}
}
- Err(e) => debug!("handshake worker, error = {:?}", e),
}
}
});
@@ -197,7 +208,10 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
wg.under_load.store(false, Ordering::SeqCst);
}
- wg.queue.lock().send((msg, src)).unwrap();
+ wg.queue
+ .lock()
+ .send(HandshakeJob::Message(msg, src))
+ .unwrap();
}
router::TYPE_TRANSPORT => {
// transport message
@@ -223,7 +237,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
let size = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
msg.truncate(size);
- // pad message to multiple of 16
+ // pad message to multiple of 16 bytes
while msg.len() < mtu && msg.len() % 16 != 0 {
msg.push(0);
}