From 6785aa4cb56833131b69f4d2b44301908b1a1b4c Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Wed, 28 Aug 2019 16:27:26 +0200 Subject: Join with worker threads on device drop --- src/main.rs | 78 ++++++++++++++++++++++++++++++++++++++++++--------- src/router/device.rs | 18 ++++++------ src/router/types.rs | 12 ++++++-- src/router/workers.rs | 3 +- src/types/endpoint.rs | 2 ++ src/types/udp.rs | 8 ++++-- 6 files changed, 91 insertions(+), 30 deletions(-) diff --git a/src/main.rs b/src/main.rs index 5c58b24..fc1a26a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,11 +7,60 @@ mod types; use hjul::*; +use std::error::Error; +use std::fmt; +use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; use sodiumoxide; -use types::KeyPair; +use types::{Bind, KeyPair}; + +struct Test {} + +impl Bind for Test { + type Error = BindError; + type Endpoint = SocketAddr; + + fn new() -> Test { + Test {} + } + + fn set_port(&self, port: u16) -> Result<(), Self::Error> { + Ok(()) + } + + fn get_port(&self) -> Option { + None + } + + fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> { + Ok((0, "127.0.0.1:8080".parse().unwrap())) + } + + fn send(&self, buf: &[u8], dst: &Self::Endpoint) -> Result<(), Self::Error> { + Ok(()) + } +} + +#[derive(Debug)] +enum BindError {} + +impl Error for BindError { + fn description(&self) -> &str { + "Generic Bind Error" + } + + fn source(&self) -> Option<&(dyn Error + 'static)> { + None + } +} + +impl fmt::Display for BindError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Not Possible") + } +} #[derive(Debug, Clone)] struct PeerTimer { @@ -24,20 +73,23 @@ fn main() { // choose optimal crypto implementations for platform sodiumoxide::init().unwrap(); + { + let router = router::Device::new( + 4, + |t: &PeerTimer, data: bool, sent: bool| t.a.reset(Duration::from_millis(1000)), + |t: &PeerTimer, data: bool, sent: bool| t.b.reset(Duration::from_millis(1000)), + |t: &PeerTimer| println!("new key requested"), + ); - let router = router::Device::new( - 4, - |t: &PeerTimer, data: bool, sent: bool| t.a.reset(Duration::from_millis(1000)), - |t: &PeerTimer, data: bool, sent: bool| t.b.reset(Duration::from_millis(1000)), - |t: &PeerTimer| println!("new key requested"), - ); + let pt = PeerTimer { + a: runner.timer(|| println!("timer-a fired for peer")), + b: runner.timer(|| println!("timer-b fired for peer")), + }; - let pt = PeerTimer { - a: runner.timer(|| println!("timer-a fired for peer")), - b: runner.timer(|| println!("timer-b fired for peer")), - }; + let peer = router.new_peer(pt.clone()); - let peer = router.new_peer(pt.clone()); + println!("{:?}", pt); + } - println!("{:?}", pt); + println!("joined"); } diff --git a/src/router/device.rs b/src/router/device.rs index bee4ad4..a7f0590 100644 --- a/src/router/device.rs +++ b/src/router/device.rs @@ -9,7 +9,8 @@ use crossbeam_deque::{Injector, Steal, Stealer, Worker}; use spin; use treebitmap::IpLookupTable; -use super::super::types::KeyPair; +use super::super::types::{Bind, KeyPair, Tun}; + use super::anti_replay::AntiReplay; use super::peer; use super::peer::{Peer, PeerInner}; @@ -62,16 +63,15 @@ impl, R: Callback, K: KeyCallback> Drop for Devi let device = &self.0; device.running.store(false, Ordering::SeqCst); - // eat all parallel jobs - while match device.injector.steal() { - Steal::Empty => true, + // join all worker threads + while match self.1.pop() { + Some(handle) => { + handle.thread().unpark(); + handle.join().unwrap(); + true + } _ => false, } {} - - // unpark all threads - for handle in &self.1 { - handle.thread().unpark(); - } } } diff --git a/src/router/types.rs b/src/router/types.rs index 2ed011b..3d486bc 100644 --- a/src/router/types.rs +++ b/src/router/types.rs @@ -3,7 +3,7 @@ pub trait Opaque: Send + Sync + 'static {} impl Opaque for T where T: Send + Sync + 'static {} /// A send/recv callback takes 3 arguments: -/// +/// /// * `0`, a reference to the opaque value assigned to the peer /// * `1`, a bool indicating whether the message contained data (not just keepalive) /// * `2`, a bool indicating whether the message was transmitted (i.e. did the peer have an associated endpoint?) @@ -12,8 +12,14 @@ pub trait Callback: Fn(&T, bool, bool) -> () + Sync + Send + 'static {} impl Callback for F where F: Fn(&T, bool, bool) -> () + Sync + Send + 'static {} /// A key callback takes 1 argument -/// +/// /// * `0`, a reference to the opaque value assigned to the peer pub trait KeyCallback: Fn(&T) -> () + Sync + Send + 'static {} -impl KeyCallback for F where F: Fn(&T) -> () + Sync + Send + 'static {} \ No newline at end of file +impl KeyCallback for F where F: Fn(&T) -> () + Sync + Send + 'static {} + +pub trait TunCallback: Fn(&T, bool, bool) -> () + Sync + Send + 'static {} + +pub trait BindCallback: Fn(&T, bool, bool) -> () + Sync + Send + 'static {} + +pub trait Endpoint: Send + Sync {} diff --git a/src/router/workers.rs b/src/router/workers.rs index 1fd2cdf..4861847 100644 --- a/src/router/workers.rs +++ b/src/router/workers.rs @@ -208,7 +208,7 @@ pub fn worker_parallel, R: Callback, K: KeyCallback local: Worker, // local job queue (local to thread) stealers: Vec>, // stealers (from other threads) ) { - while !device.running.load(Ordering::SeqCst) { + while device.running.load(Ordering::SeqCst) { match find_task(&local, &device.injector, &stealers) { Some(job) => { let (handle, buf) = job; @@ -262,7 +262,6 @@ pub fn worker_parallel, R: Callback, K: KeyCallback handle.thread().unpark(); } None => { - // no jobs, park the worker device.parked.store(true, Ordering::Release); thread::park(); } diff --git a/src/types/endpoint.rs b/src/types/endpoint.rs index d97905a..aa4dfd7 100644 --- a/src/types/endpoint.rs +++ b/src/types/endpoint.rs @@ -4,3 +4,5 @@ use std::net::SocketAddr; * is to simply use SocketAddr directly as the endpoint. */ pub trait Endpoint: Into {} + +impl Endpoint for T where T: Into {} diff --git a/src/types/udp.rs b/src/types/udp.rs index 00e218f..4bf0a9c 100644 --- a/src/types/udp.rs +++ b/src/types/udp.rs @@ -21,7 +21,9 @@ pub trait Bind: Send + Sync { fn set_port(&self, port: u16) -> Result<(), Self::Error>; /// Returns the current port of the bind - fn get_port(&self) -> u16; - fn recv(&self, dst: &mut [u8]) -> Self::Endpoint; - fn send(&self, src: &[u8], dst: &Self::Endpoint); + fn get_port(&self) -> Option; + + fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error>; + + fn send(&self, buf: &[u8], dst: &Self::Endpoint) -> Result<(), Self::Error>; } -- cgit v1.2.3-59-g8ed1b