diff options
Diffstat (limited to 'src/router/device.rs')
-rw-r--r-- | src/router/device.rs | 45 |
1 files changed, 22 insertions, 23 deletions
diff --git a/src/router/device.rs b/src/router/device.rs index 73678cb..703fa55 100644 --- a/src/router/device.rs +++ b/src/router/device.rs @@ -16,8 +16,7 @@ use super::anti_replay::AntiReplay; use super::constants::*; use super::ip::*; use super::messages::{TransportHeader, TYPE_TRANSPORT}; -use super::peer; -use super::peer::{Peer, PeerInner}; +use super::peer::{new_peer, Peer, PeerInner}; use super::types::{Callback, Callbacks, KeyCallback, Opaque, PhantomCallbacks, RouterError}; use super::workers::{worker_parallel, JobParallel, Operation}; use super::SIZE_MESSAGE_PREFIX; @@ -36,6 +35,10 @@ pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> { pub recv: RwLock<HashMap<u32, Arc<DecryptionState<C, T, B>>>>, // receiver id -> decryption state pub ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<PeerInner<C, T, B>>>>, // ipv4 cryptkey routing pub ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<PeerInner<C, T, B>>>>, // ipv6 cryptkey routing + + // work queues + pub queue_next: AtomicUsize, // next round-robin index + pub queues: Mutex<Vec<SyncSender<JobParallel>>>, // work queues (1 per thread) } pub struct EncryptionState { @@ -56,14 +59,15 @@ pub struct DecryptionState<C: Callbacks, T: Tun, B: Bind> { pub struct Device<C: Callbacks, T: Tun, B: Bind> { state: Arc<DeviceInner<C, T, B>>, // reference to device state handles: Vec<thread::JoinHandle<()>>, // join handles for workers - queue_next: AtomicUsize, // next round-robin index - queues: Vec<Mutex<SyncSender<JobParallel>>>, // work queues (1 per thread) } impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> { fn drop(&mut self) { // drop all queues - while self.queues.pop().is_some() {} + { + let mut queues = self.state.queues.lock(); + while queues.pop().is_some() {} + } // join all worker threads while match self.handles.pop() { @@ -91,32 +95,31 @@ impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun, B: Bi call_need_key: K, ) -> Device<PhantomCallbacks<O, R, S, K>, T, B> { // allocate shared device state - let inner = Arc::new(DeviceInner { + let mut inner = DeviceInner { tun, bind, call_recv, call_send, + queues: Mutex::new(Vec::with_capacity(num_workers)), + queue_next: AtomicUsize::new(0), call_need_key, recv: RwLock::new(HashMap::new()), ipv4: RwLock::new(IpLookupTable::new()), ipv6: RwLock::new(IpLookupTable::new()), - }); + }; // start worker threads - let mut queues = Vec::with_capacity(num_workers); let mut threads = Vec::with_capacity(num_workers); for _ in 0..num_workers { let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE); - queues.push(Mutex::new(tx)); + inner.queues.lock().push(tx); threads.push(thread::spawn(move || worker_parallel(rx))); } // return exported device handle Device { - state: inner, + state: Arc::new(inner), handles: threads, - queue_next: AtomicUsize::new(0), - queues: queues, } } } @@ -168,7 +171,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> { /// /// A atomic ref. counted peer (with liftime matching the device) pub fn new_peer(&self, opaque: C::Opaque) -> Peer<C, T, B> { - peer::new_peer(self.state.clone(), opaque) + new_peer(self.state.clone(), opaque) } /// Cryptkey routes and sends a plaintext message (IP packet) @@ -189,11 +192,9 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> { debug_assert_eq!(job.1.op, Operation::Encryption); // add job to worker queue - let idx = self.queue_next.fetch_add(1, Ordering::SeqCst); - self.queues[idx % self.queues.len()] - .lock() - .send(job) - .unwrap(); + let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst); + let queues = self.state.queues.lock(); + queues[idx % queues.len()].send(job).unwrap(); } Ok(()) @@ -234,11 +235,9 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> { debug_assert_eq!(job.1.op, Operation::Decryption); // add job to worker queue - let idx = self.queue_next.fetch_add(1, Ordering::SeqCst); - self.queues[idx % self.queues.len()] - .lock() - .send(job) - .unwrap(); + let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst); + let queues = self.state.queues.lock(); + queues[idx % queues.len()].send(job).unwrap(); } Ok(()) |