diff options
author | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-08-17 16:31:08 +0200 |
---|---|---|
committer | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-08-17 16:31:08 +0200 |
commit | 78ab1a93e6d519bf404fbe61fc7ec3c3ab35a72a (patch) | |
tree | 75106e1ff89a03a6869184994b902a70315dfc30 /src/router/device.rs | |
parent | Begin drafting cross-platform interface (diff) | |
download | wireguard-rs-78ab1a93e6d519bf404fbe61fc7ec3c3ab35a72a.tar.xz wireguard-rs-78ab1a93e6d519bf404fbe61fc7ec3c3ab35a72a.zip |
Remove peer from cryptkey router on drop
Diffstat (limited to 'src/router/device.rs')
-rw-r--r-- | src/router/device.rs | 182 |
1 files changed, 129 insertions, 53 deletions
diff --git a/src/router/device.rs b/src/router/device.rs index 5dfd22c..4dd6539 100644 --- a/src/router/device.rs +++ b/src/router/device.rs @@ -1,37 +1,38 @@ use arraydeque::{ArrayDeque, Wrapping}; +use treebitmap::address::Address; use treebitmap::IpLookupTable; use crossbeam_deque::{Injector, Steal}; use std::collections::HashMap; -use std::mem; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; +use std::sync::mpsc::SyncSender; use std::sync::{Arc, Mutex, Weak}; use std::thread; -use std::time::{Duration, Instant}; +use std::time::Instant; use spin; +use super::super::constants::*; use super::super::types::KeyPair; use super::anti_replay::AntiReplay; use std::u64; -const REJECT_AFTER_MESSAGES: u64 = u64::MAX - (1 << 4); const MAX_STAGED_PACKETS: usize = 128; struct DeviceInner { stopped: AtomicBool, - injector: Injector<()>, // parallel enc/dec task injector - threads: Vec<thread::JoinHandle<()>>, - recv: spin::RwLock<HashMap<u32, DecryptionState>>, - ipv4: IpLookupTable<Ipv4Addr, Weak<PeerInner>>, - ipv6: IpLookupTable<Ipv6Addr, Weak<PeerInner>>, + injector: Injector<()>, // parallel enc/dec task injector + threads: Vec<thread::JoinHandle<()>>, // join handles of worker threads + recv: spin::RwLock<HashMap<u32, DecryptionState>>, // receiver id -> decryption state + ipv4: spin::RwLock<IpLookupTable<Ipv4Addr, Weak<PeerInner>>>, // ipv4 cryptkey routing + ipv6: spin::RwLock<IpLookupTable<Ipv6Addr, Weak<PeerInner>>>, // ipv6 cryptkey routing } struct PeerInner { stopped: AtomicBool, + device: Arc<DeviceInner>, thread_outbound: spin::Mutex<thread::JoinHandle<()>>, thread_inbound: spin::Mutex<thread::JoinHandle<()>>, inorder_outbound: SyncSender<()>, @@ -40,7 +41,7 @@ struct PeerInner { rx_bytes: AtomicU64, // received bytes tx_bytes: AtomicU64, // transmitted bytes keys: spin::Mutex<KeyWheel>, // key-wheel - ekey: spin::Mutex<EncryptionState>, // encryption state + ekey: spin::Mutex<Option<EncryptionState>>, // encryption state endpoint: spin::Mutex<Option<Arc<SocketAddr>>>, } @@ -68,26 +69,104 @@ struct KeyWheel { pub struct Peer(Arc<PeerInner>); pub struct Device(DeviceInner); +fn treebit_list<A, R>( + peer: &Peer, + table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner>>>, + callback: Box<dyn Fn(A, u32) -> R>, +) -> Vec<R> +where + A: Address, +{ + let mut res = Vec::new(); + for subnet in table.read().iter() { + let (ip, masklen, p) = subnet; + if let Some(p) = p.upgrade() { + if Arc::ptr_eq(&p, &peer.0) { + res.push(callback(ip, masklen)) + } + } + } + res +} + +fn treebit_remove<A>(peer: &Peer, table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner>>>) +where + A: Address, +{ + let mut m = table.write(); + + // collect keys for value + let mut subnets = vec![]; + for subnet in m.iter() { + let (ip, masklen, p) = subnet; + if let Some(p) = p.upgrade() { + if Arc::ptr_eq(&p, &peer.0) { + subnets.push((ip, masklen)) + } + } + } + + // remove all key mappings + for subnet in subnets { + let r = m.remove(subnet.0, subnet.1); + debug_assert!(r.is_some()); + } +} + impl Drop for Peer { fn drop(&mut self) { // mark peer as stopped - let inner = &self.0; - inner.stopped.store(true, Ordering::SeqCst); + let peer = &self.0; + peer.stopped.store(true, Ordering::SeqCst); + + // remove from cryptkey router + treebit_remove(self, &peer.device.ipv4); + treebit_remove(self, &peer.device.ipv6); + + // unpark threads + + peer.thread_inbound.lock().thread().unpark(); + peer.thread_outbound.lock().thread().unpark(); + // collect ids to release + let mut keys = peer.keys.lock(); + let mut release = Vec::with_capacity(3); + + keys.next.map(|k| release.push(k.recv.id)); + keys.current.map(|k| release.push(k.recv.id)); + keys.previous.map(|k| release.push(k.recv.id)); + + // remove from receive id map + if release.len() > 0 { + let mut recv = peer.device.recv.write(); + for id in &release { + recv.remove(id); + } + } + + // null key-material (TODO: extend) - // unpark threads to stop - inner.thread_inbound.lock().thread().unpark(); - inner.thread_outbound.lock().thread().unpark(); + keys.next = None; + keys.current = None; + keys.previous = None; + + *peer.ekey.lock() = None; + *peer.endpoint.lock() = None; } } impl Drop for Device { fn drop(&mut self) { // mark device as stopped - let inner = &self.0; - inner.stopped.store(true, Ordering::SeqCst); + let device = &self.0; + device.stopped.store(true, Ordering::SeqCst); // eat all parallel jobs - while inner.injector.steal() != Steal::Empty {} + while device.injector.steal() != Steal::Empty {} + + // unpark all threads + for handle in &device.threads { + handle.thread().unpark(); + } } } @@ -97,12 +176,12 @@ impl Peer { } pub fn keypair_confirm(&self, ks: Arc<KeyPair>) { - *self.0.ekey.lock() = EncryptionState { + *self.0.ekey.lock() = Some(EncryptionState { id: ks.send.id, key: ks.send.key, nonce: 0, - death: ks.birth + Duration::from_millis(1337), // todo - }; + death: ks.birth + REJECT_AFTER_TIME, + }); } fn keypair_add(&self, new: KeyPair) -> Option<u32> { @@ -112,12 +191,12 @@ impl Peer { // update key-wheel if new.confirmed { // start using key for encryption - *self.0.ekey.lock() = EncryptionState { + *self.0.ekey.lock() = Some(EncryptionState { id: new.send.id, key: new.send.key, nonce: 0, - death: new.birth + Duration::from_millis(1337), // todo - }; + death: new.birth + REJECT_AFTER_TIME, + }); // move current into previous keys.previous = keys.current; @@ -148,42 +227,39 @@ impl Device { stopped: AtomicBool::new(false), injector: Injector::new(), recv: spin::RwLock::new(HashMap::new()), - ipv4: IpLookupTable::new(), - ipv6: IpLookupTable::new(), + ipv4: spin::RwLock::new(IpLookupTable::new()), + ipv6: spin::RwLock::new(IpLookupTable::new()), }) } pub fn add_subnet(&mut self, ip: IpAddr, masklen: u32, peer: Peer) { match ip { - IpAddr::V4(v4) => self.0.ipv4.insert(v4, masklen, Arc::downgrade(&peer.0)), - IpAddr::V6(v6) => self.0.ipv6.insert(v6, masklen, Arc::downgrade(&peer.0)), + IpAddr::V4(v4) => self + .0 + .ipv4 + .write() + .insert(v4, masklen, Arc::downgrade(&peer.0)), + IpAddr::V6(v6) => self + .0 + .ipv6 + .write() + .insert(v6, masklen, Arc::downgrade(&peer.0)), }; } - pub fn subnets(&self, peer: Peer) -> Vec<(IpAddr, u32)> { - let mut subnets = Vec::new(); - - // extract ipv4 entries - for subnet in self.0.ipv4.iter() { - let (ip, masklen, p) = subnet; - if let Some(p) = p.upgrade() { - if Arc::ptr_eq(&p, &peer.0) { - subnets.push((IpAddr::V4(ip), masklen)) - } - } - } - - // extract ipv6 entries - for subnet in self.0.ipv6.iter() { - let (ip, masklen, p) = subnet; - if let Some(p) = p.upgrade() { - if Arc::ptr_eq(&p, &peer.0) { - subnets.push((IpAddr::V6(ip), masklen)) - } - } - } - - subnets + pub fn list_subnets(&self, peer: Peer) -> Vec<(IpAddr, u32)> { + let mut res = Vec::new(); + res.append(&mut treebit_list( + &peer, + &self.0.ipv4, + Box::new(|ip, masklen| (IpAddr::V4(ip), masklen)), + )); + res.append(&mut treebit_list( + &peer, + &self.0.ipv6, + Box::new(|ip, masklen| (IpAddr::V6(ip), masklen)), + )); + res } pub fn keypair_add(&self, peer: Peer, new: KeyPair) -> Option<u32> { @@ -208,7 +284,7 @@ impl Device { key: new.recv.key, protector: Arc::new(spin::Mutex::new(AntiReplay::new())), peer: Arc::downgrade(&peer.0), - death: new.birth + Duration::from_millis(2600), // todo + death: new.birth + REJECT_AFTER_TIME, }, ); |