aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-08-28 16:27:26 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-08-28 16:27:26 +0200
commit6785aa4cb56833131b69f4d2b44301908b1a1b4c (patch)
treed9dc6ec7a3c45291f44ae3e8d707200a1b1da410 /src
parentRenamed confirmed -> initator on keypair (diff)
downloadwireguard-rs-6785aa4cb56833131b69f4d2b44301908b1a1b4c.tar.xz
wireguard-rs-6785aa4cb56833131b69f4d2b44301908b1a1b4c.zip
Join with worker threads on device drop
Diffstat (limited to 'src')
-rw-r--r--src/main.rs78
-rw-r--r--src/router/device.rs18
-rw-r--r--src/router/types.rs12
-rw-r--r--src/router/workers.rs3
-rw-r--r--src/types/endpoint.rs2
-rw-r--r--src/types/udp.rs8
6 files changed, 91 insertions, 30 deletions
diff --git a/src/main.rs b/src/main.rs
index 5c58b24..fc1a26a 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -7,11 +7,60 @@ mod types;
use hjul::*;
+use std::error::Error;
+use std::fmt;
+use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use sodiumoxide;
-use types::KeyPair;
+use types::{Bind, KeyPair};
+
+struct Test {}
+
+impl Bind for Test {
+ type Error = BindError;
+ type Endpoint = SocketAddr;
+
+ fn new() -> Test {
+ Test {}
+ }
+
+ fn set_port(&self, port: u16) -> Result<(), Self::Error> {
+ Ok(())
+ }
+
+ fn get_port(&self) -> Option<u16> {
+ None
+ }
+
+ fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> {
+ Ok((0, "127.0.0.1:8080".parse().unwrap()))
+ }
+
+ fn send(&self, buf: &[u8], dst: &Self::Endpoint) -> Result<(), Self::Error> {
+ Ok(())
+ }
+}
+
+#[derive(Debug)]
+enum BindError {}
+
+impl Error for BindError {
+ fn description(&self) -> &str {
+ "Generic Bind Error"
+ }
+
+ fn source(&self) -> Option<&(dyn Error + 'static)> {
+ None
+ }
+}
+
+impl fmt::Display for BindError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "Not Possible")
+ }
+}
#[derive(Debug, Clone)]
struct PeerTimer {
@@ -24,20 +73,23 @@ fn main() {
// choose optimal crypto implementations for platform
sodiumoxide::init().unwrap();
+ {
+ let router = router::Device::new(
+ 4,
+ |t: &PeerTimer, data: bool, sent: bool| t.a.reset(Duration::from_millis(1000)),
+ |t: &PeerTimer, data: bool, sent: bool| t.b.reset(Duration::from_millis(1000)),
+ |t: &PeerTimer| println!("new key requested"),
+ );
- let router = router::Device::new(
- 4,
- |t: &PeerTimer, data: bool, sent: bool| t.a.reset(Duration::from_millis(1000)),
- |t: &PeerTimer, data: bool, sent: bool| t.b.reset(Duration::from_millis(1000)),
- |t: &PeerTimer| println!("new key requested"),
- );
+ let pt = PeerTimer {
+ a: runner.timer(|| println!("timer-a fired for peer")),
+ b: runner.timer(|| println!("timer-b fired for peer")),
+ };
- let pt = PeerTimer {
- a: runner.timer(|| println!("timer-a fired for peer")),
- b: runner.timer(|| println!("timer-b fired for peer")),
- };
+ let peer = router.new_peer(pt.clone());
- let peer = router.new_peer(pt.clone());
+ println!("{:?}", pt);
+ }
- println!("{:?}", pt);
+ println!("joined");
}
diff --git a/src/router/device.rs b/src/router/device.rs
index bee4ad4..a7f0590 100644
--- a/src/router/device.rs
+++ b/src/router/device.rs
@@ -9,7 +9,8 @@ use crossbeam_deque::{Injector, Steal, Stealer, Worker};
use spin;
use treebitmap::IpLookupTable;
-use super::super::types::KeyPair;
+use super::super::types::{Bind, KeyPair, Tun};
+
use super::anti_replay::AntiReplay;
use super::peer;
use super::peer::{Peer, PeerInner};
@@ -62,16 +63,15 @@ impl<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> Drop for Devi
let device = &self.0;
device.running.store(false, Ordering::SeqCst);
- // eat all parallel jobs
- while match device.injector.steal() {
- Steal::Empty => true,
+ // join all worker threads
+ while match self.1.pop() {
+ Some(handle) => {
+ handle.thread().unpark();
+ handle.join().unwrap();
+ true
+ }
_ => false,
} {}
-
- // unpark all threads
- for handle in &self.1 {
- handle.thread().unpark();
- }
}
}
diff --git a/src/router/types.rs b/src/router/types.rs
index 2ed011b..3d486bc 100644
--- a/src/router/types.rs
+++ b/src/router/types.rs
@@ -3,7 +3,7 @@ pub trait Opaque: Send + Sync + 'static {}
impl<T> Opaque for T where T: Send + Sync + 'static {}
/// A send/recv callback takes 3 arguments:
-///
+///
/// * `0`, a reference to the opaque value assigned to the peer
/// * `1`, a bool indicating whether the message contained data (not just keepalive)
/// * `2`, a bool indicating whether the message was transmitted (i.e. did the peer have an associated endpoint?)
@@ -12,8 +12,14 @@ pub trait Callback<T>: Fn(&T, bool, bool) -> () + Sync + Send + 'static {}
impl<T, F> Callback<T> for F where F: Fn(&T, bool, bool) -> () + Sync + Send + 'static {}
/// A key callback takes 1 argument
-///
+///
/// * `0`, a reference to the opaque value assigned to the peer
pub trait KeyCallback<T>: Fn(&T) -> () + Sync + Send + 'static {}
-impl<T, F> KeyCallback<T> for F where F: Fn(&T) -> () + Sync + Send + 'static {} \ No newline at end of file
+impl<T, F> KeyCallback<T> for F where F: Fn(&T) -> () + Sync + Send + 'static {}
+
+pub trait TunCallback<T>: Fn(&T, bool, bool) -> () + Sync + Send + 'static {}
+
+pub trait BindCallback<T>: Fn(&T, bool, bool) -> () + Sync + Send + 'static {}
+
+pub trait Endpoint: Send + Sync {}
diff --git a/src/router/workers.rs b/src/router/workers.rs
index 1fd2cdf..4861847 100644
--- a/src/router/workers.rs
+++ b/src/router/workers.rs
@@ -208,7 +208,7 @@ pub fn worker_parallel<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback
local: Worker<JobParallel>, // local job queue (local to thread)
stealers: Vec<Stealer<JobParallel>>, // stealers (from other threads)
) {
- while !device.running.load(Ordering::SeqCst) {
+ while device.running.load(Ordering::SeqCst) {
match find_task(&local, &device.injector, &stealers) {
Some(job) => {
let (handle, buf) = job;
@@ -262,7 +262,6 @@ pub fn worker_parallel<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback
handle.thread().unpark();
}
None => {
- // no jobs, park the worker
device.parked.store(true, Ordering::Release);
thread::park();
}
diff --git a/src/types/endpoint.rs b/src/types/endpoint.rs
index d97905a..aa4dfd7 100644
--- a/src/types/endpoint.rs
+++ b/src/types/endpoint.rs
@@ -4,3 +4,5 @@ use std::net::SocketAddr;
* is to simply use SocketAddr directly as the endpoint.
*/
pub trait Endpoint: Into<SocketAddr> {}
+
+impl<T> Endpoint for T where T: Into<SocketAddr> {}
diff --git a/src/types/udp.rs b/src/types/udp.rs
index 00e218f..4bf0a9c 100644
--- a/src/types/udp.rs
+++ b/src/types/udp.rs
@@ -21,7 +21,9 @@ pub trait Bind: Send + Sync {
fn set_port(&self, port: u16) -> Result<(), Self::Error>;
/// Returns the current port of the bind
- fn get_port(&self) -> u16;
- fn recv(&self, dst: &mut [u8]) -> Self::Endpoint;
- fn send(&self, src: &[u8], dst: &Self::Endpoint);
+ fn get_port(&self) -> Option<u16>;
+
+ fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error>;
+
+ fn send(&self, buf: &[u8], dst: &Self::Endpoint) -> Result<(), Self::Error>;
}