aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-08-20 21:19:53 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-08-20 21:19:53 +0200
commit9cef264581ec8cf859113234b29ad5b58577ed8c (patch)
tree55320a3a619805b643e46d41068ceefe1628b0df
parentRemoved platform mod (diff)
downloadwireguard-rs-9cef264581ec8cf859113234b29ad5b58577ed8c.tar.xz
wireguard-rs-9cef264581ec8cf859113234b29ad5b58577ed8c.zip
Ensure peer threads are stopped on drop
-rw-r--r--src/main.rs6
-rw-r--r--src/router/peer.rs130
-rw-r--r--src/router/workers.rs121
3 files changed, 157 insertions, 100 deletions
diff --git a/src/main.rs b/src/main.rs
index aa73fd2..eab4b61 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -15,6 +15,8 @@ fn main() {
sodiumoxide::init().unwrap();
let mut router = router::Device::new(8);
-
- let peer = router.new_peer();
+ {
+ let peer = router.new_peer();
+ }
+ loop {}
}
diff --git a/src/router/peer.rs b/src/router/peer.rs
index f7b8bf4..1edb635 100644
--- a/src/router/peer.rs
+++ b/src/router/peer.rs
@@ -1,29 +1,29 @@
-use std::sync::atomic::{AtomicU64, AtomicBool, Ordering};
-use std::sync::{Weak, Arc};
+use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
+use std::sync::{Arc, Weak};
use std::thread;
-
+use std::mem;
use std::net::{IpAddr, SocketAddr};
-
use std::sync::mpsc::{sync_channel, SyncSender};
use spin;
use arraydeque::{ArrayDeque, Wrapping};
-use treebitmap::IpLookupTable;
use treebitmap::address::Address;
+use treebitmap::IpLookupTable;
-use super::super::types::KeyPair;
use super::super::constants::*;
+use super::super::types::KeyPair;
use super::anti_replay::AntiReplay;
+use super::device::DecryptionState;
use super::device::DeviceInner;
use super::device::EncryptionState;
-use super::device::DecryptionState;
+use super::workers::{worker_inbound, worker_outbound, JobInbound, JobOutbound};
const MAX_STAGED_PACKETS: usize = 128;
-struct KeyWheel {
+pub struct KeyWheel {
next: Option<Arc<KeyPair>>, // next key state (unconfirmed)
current: Option<Arc<KeyPair>>, // current key state (used for encryption)
previous: Option<Arc<KeyPair>>, // old key state (used for decryption)
@@ -31,18 +31,18 @@ struct KeyWheel {
}
pub struct PeerInner {
- stopped: AtomicBool,
- device: Arc<DeviceInner>,
- thread_outbound: spin::Mutex<thread::JoinHandle<()>>,
- thread_inbound: spin::Mutex<thread::JoinHandle<()>>,
- inorder_outbound: SyncSender<()>,
- inorder_inbound: SyncSender<()>,
- staged_packets: spin::Mutex<ArrayDeque<[Vec<u8>; MAX_STAGED_PACKETS], Wrapping>>, // packets awaiting handshake
- rx_bytes: AtomicU64, // received bytes
- tx_bytes: AtomicU64, // transmitted bytes
- keys: spin::Mutex<KeyWheel>, // key-wheel
- ekey: spin::Mutex<Option<EncryptionState>>, // encryption state
- endpoint: spin::Mutex<Option<Arc<SocketAddr>>>,
+ pub stopped: AtomicBool,
+ pub device: Arc<DeviceInner>,
+ pub thread_outbound: spin::Mutex<Option<thread::JoinHandle<()>>>,
+ pub thread_inbound: spin::Mutex<Option<thread::JoinHandle<()>>>,
+ pub queue_outbound: SyncSender<JobOutbound>,
+ pub queue_inbound: SyncSender<JobInbound>,
+ pub staged_packets: spin::Mutex<ArrayDeque<[Vec<u8>; MAX_STAGED_PACKETS], Wrapping>>, // packets awaiting handshake
+ pub rx_bytes: AtomicU64, // received bytes
+ pub tx_bytes: AtomicU64, // transmitted bytes
+ pub keys: spin::Mutex<KeyWheel>, // key-wheel
+ pub ekey: spin::Mutex<Option<EncryptionState>>, // encryption state
+ pub endpoint: spin::Mutex<Option<Arc<SocketAddr>>>,
}
pub struct Peer(Arc<PeerInner>);
@@ -93,6 +93,7 @@ where
impl Drop for Peer {
fn drop(&mut self) {
+
// mark peer as stopped
let peer = &self.0;
@@ -105,8 +106,19 @@ impl Drop for Peer {
// unpark threads
- peer.thread_inbound.lock().thread().unpark();
- peer.thread_outbound.lock().thread().unpark();
+ peer.thread_inbound
+ .lock()
+ .as_ref()
+ .unwrap()
+ .thread()
+ .unpark();
+
+ peer.thread_outbound
+ .lock()
+ .as_ref()
+ .unwrap()
+ .thread()
+ .unpark();
// release ids from the receiver map
@@ -132,42 +144,62 @@ impl Drop for Peer {
*peer.ekey.lock() = None;
*peer.endpoint.lock() = None;
+
}
}
pub fn new_peer(device: Arc<DeviceInner>) -> Peer {
+ // allocate in-order queues
+ let (send_inbound, recv_inbound) = sync_channel(MAX_STAGED_PACKETS);
+ let (send_outbound, recv_outbound) = sync_channel(MAX_STAGED_PACKETS);
+
+ // allocate peer object
+ let peer = {
+ let device = device.clone();
+ Arc::new(PeerInner {
+ stopped: AtomicBool::new(false),
+ device: device,
+ ekey: spin::Mutex::new(None),
+ endpoint: spin::Mutex::new(None),
+ queue_inbound: send_inbound,
+ queue_outbound: send_outbound,
+ keys: spin::Mutex::new(KeyWheel {
+ next: None,
+ current: None,
+ previous: None,
+ retired: None,
+ }),
+ rx_bytes: AtomicU64::new(0),
+ tx_bytes: AtomicU64::new(0),
+ staged_packets: spin::Mutex::new(ArrayDeque::new()),
+ thread_inbound: spin::Mutex::new(None),
+ thread_outbound: spin::Mutex::new(None),
+ })
+ };
+
// spawn inbound thread
- let (send_inbound, recv_inbound) = sync_channel(1);
- let handle_inbound = thread::spawn(move || {});
+ *peer.thread_inbound.lock() = {
+ let peer = peer.clone();
+ let device = device.clone();
+ Some(thread::spawn(move || {
+ worker_outbound(device, peer, recv_outbound)
+ }))
+ };
// spawn outbound thread
- let (send_outbound, recv_inbound) = sync_channel(1);
- let handle_outbound = thread::spawn(move || {});
-
- // allocate peer object
- Peer::new(PeerInner {
- stopped: AtomicBool::new(false),
- device: device,
- ekey: spin::Mutex::new(None),
- endpoint: spin::Mutex::new(None),
- inorder_inbound: send_inbound,
- inorder_outbound: send_outbound,
- keys: spin::Mutex::new(KeyWheel {
- next: None,
- current: None,
- previous: None,
- retired: None,
- }),
- rx_bytes: AtomicU64::new(0),
- tx_bytes: AtomicU64::new(0),
- staged_packets: spin::Mutex::new(ArrayDeque::new()),
- thread_inbound: spin::Mutex::new(handle_inbound),
- thread_outbound: spin::Mutex::new(handle_outbound),
- })
+ *peer.thread_outbound.lock() = {
+ let peer = peer.clone();
+ let device = device.clone();
+ Some(thread::spawn(move || {
+ worker_inbound(device, peer, recv_inbound)
+ }))
+ };
+
+ Peer(peer)
}
impl Peer {
- fn new(inner : PeerInner) -> Peer {
+ fn new(inner: PeerInner) -> Peer {
Peer(Arc::new(inner))
}
@@ -282,4 +314,4 @@ impl Peer {
));
res
}
-} \ No newline at end of file
+}
diff --git a/src/router/workers.rs b/src/router/workers.rs
index 2117190..da5b600 100644
--- a/src/router/workers.rs
+++ b/src/router/workers.rs
@@ -6,7 +6,7 @@ use crossbeam_deque::{Injector, Steal, Stealer, Worker};
use spin;
use std::iter;
use std::sync::atomic::{AtomicBool, Ordering};
-use std::sync::mpsc::{sync_channel, Receiver};
+use std::sync::mpsc::{sync_channel, Receiver, TryRecvError};
use std::sync::{Arc, Weak};
use std::thread;
@@ -23,17 +23,17 @@ enum Status {
Waiting, // job awaiting completion
}
-struct JobInner {
+pub struct JobInner {
msg: Vec<u8>, // message buffer (nonce and receiver id set)
key: [u8; 32], // chacha20poly1305 key
status: Status, // state of the job
op: Operation, // should be buffer be encrypted / decrypted?
}
-type JobBuffer = Arc<spin::Mutex<JobInner>>;
-type JobParallel = (Arc<thread::JoinHandle<()>>, JobBuffer);
-type JobInbound = (Arc<DecryptionState>, JobBuffer);
-type JobOutbound = (Weak<PeerInner>, JobBuffer);
+pub type JobBuffer = Arc<spin::Mutex<JobInner>>;
+pub type JobParallel = (Arc<thread::JoinHandle<()>>, JobBuffer);
+pub type JobInbound = (Weak<DecryptionState>, JobBuffer);
+pub type JobOutbound = JobBuffer;
/* Strategy for workers acquiring a new job:
*
@@ -53,62 +53,85 @@ fn find_task<T>(local: &Worker<T>, global: &Injector<T>, stealers: &[Stealer<T>]
})
}
-fn worker_inbound(
- device: Arc<DeviceInner>, // related device
- peer: Arc<PeerInner>, // related peer
- recv: Receiver<JobInbound>, // in order queue
-) {
- // reads from in order channel
- for job in recv.recv().iter() {
- loop {
- let (state, buf) = job;
-
- // check if job is complete
- match buf.try_lock() {
- None => (),
- Some(buf) => {
- if buf.status != Status::Waiting {
- // check replay protector
-
- // check if confirms keypair
-
- // write to tun device
-
- // continue to next job (no parking)
- break;
- }
+fn wait_buffer(stopped: AtomicBool, buf: &JobBuffer) {
+ while !stopped.load(Ordering::Acquire) {
+ match buf.try_lock() {
+ None => (),
+ Some(buf) => {
+ if buf.status == Status::Waiting {
+ return;
}
}
+ };
+ thread::park();
+ }
+}
- // wait for job to complete
- thread::park();
- }
+fn wait_recv<T>(stopped: &AtomicBool, recv: &Receiver<T>) -> Result<T, TryRecvError> {
+ while !stopped.load(Ordering::Acquire) {
+ match recv.try_recv() {
+ Err(TryRecvError::Empty) => (),
+ value => {
+ return value;
+ }
+ };
+ thread::park();
}
+ return Err(TryRecvError::Disconnected);
}
-fn worker_outbound(
+pub fn worker_inbound(
device: Arc<DeviceInner>, // related device
peer: Arc<PeerInner>, // related peer
recv: Receiver<JobInbound>, // in order queue
) {
- // reads from in order channel
- for job in recv.recv().iter() {
- loop {
- let (peer, buf) = job;
-
- // check if job is complete
- match buf.try_lock() {
- None => (),
- Some(buf) => {
- if buf.status != Status::Waiting {
- // send buffer to peer endpoint
- break;
- }
+ loop {
+ match wait_recv(&peer.stopped, &recv) {
+ Ok((state, buf)) => {
+ while !peer.stopped.load(Ordering::Acquire) {
+ match buf.try_lock() {
+ None => (),
+ Some(buf) => {
+ if buf.status != Status::Waiting {
+ // consume
+ break;
+ }
+ }
+ };
+ thread::park();
}
}
+ Err(_) => {
+ break;
+ }
+ }
+ }
+}
- // wait for job to complete
- thread::park();
+pub fn worker_outbound(
+ device: Arc<DeviceInner>, // related device
+ peer: Arc<PeerInner>, // related peer
+ recv: Receiver<JobOutbound>, // in order queue
+) {
+ loop {
+ match wait_recv(&peer.stopped, &recv) {
+ Ok(buf) => {
+ while !peer.stopped.load(Ordering::Acquire) {
+ match buf.try_lock() {
+ None => (),
+ Some(buf) => {
+ if buf.status != Status::Waiting {
+ // consume
+ break;
+ }
+ }
+ };
+ thread::park();
+ }
+ }
+ Err(_) => {
+ break;
+ }
}
}
}