diff options
author | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-09-07 18:38:19 +0200 |
---|---|---|
committer | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-09-07 18:38:19 +0200 |
commit | 7b61ee4c2db87e195f5291fb1a3927648d38a2a4 (patch) | |
tree | 410c0609c3f4d1afbd0d87791b9156a538f59398 /src/router/peer.rs | |
parent | Added outbound benchmark (diff) | |
download | wireguard-rs-7b61ee4c2db87e195f5291fb1a3927648d38a2a4.tar.xz wireguard-rs-7b61ee4c2db87e195f5291fb1a3927648d38a2a4.zip |
Write inbound packets to TUN device
Diffstat (limited to '')
-rw-r--r-- | src/router/peer.rs | 114 |
1 files changed, 80 insertions, 34 deletions
diff --git a/src/router/peer.rs b/src/router/peer.rs index 634f980..0cd588d 100644 --- a/src/router/peer.rs +++ b/src/router/peer.rs @@ -30,7 +30,7 @@ use super::workers::Operation; use super::workers::{worker_inbound, worker_outbound}; use super::workers::{JobBuffer, JobInbound, JobOutbound, JobParallel}; -use super::constants::MAX_STAGED_PACKETS; +use super::constants::*; use super::types::Callbacks; pub struct KeyWheel { @@ -50,7 +50,7 @@ pub struct PeerInner<C: Callbacks, T: Tun, B: Bind> { pub tx_bytes: AtomicU64, // transmitted bytes pub keys: Mutex<KeyWheel>, // key-wheel pub ekey: Mutex<Option<EncryptionState>>, // encryption state - pub endpoint: Mutex<Option<Arc<SocketAddr>>>, + pub endpoint: Mutex<Option<B::Endpoint>>, } pub struct Peer<C: Callbacks, T: Tun, B: Bind> { @@ -61,7 +61,7 @@ pub struct Peer<C: Callbacks, T: Tun, B: Bind> { fn treebit_list<A, E, C: Callbacks, T: Tun, B: Bind>( peer: &Arc<PeerInner<C, T, B>>, - table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T, B>>>>, + table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<C, T, B>>>>, callback: Box<dyn Fn(A, u32) -> E>, ) -> Vec<E> where @@ -70,10 +70,8 @@ where let mut res = Vec::new(); for subnet in table.read().iter() { let (ip, masklen, p) = subnet; - if let Some(p) = p.upgrade() { - if Arc::ptr_eq(&p, &peer) { - res.push(callback(ip, masklen)) - } + if Arc::ptr_eq(&p, &peer) { + res.push(callback(ip, masklen)) } } res @@ -81,7 +79,7 @@ where fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>( peer: &Peer<C, T, B>, - table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T, B>>>>, + table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<C, T, B>>>>, ) { let mut m = table.write(); @@ -89,10 +87,8 @@ fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>( let mut subnets = vec![]; for subnet in m.iter() { let (ip, masklen, p) = subnet; - if let Some(p) = p.upgrade() { - if Arc::ptr_eq(&p, &peer.state) { - subnets.push((ip, masklen)) - } + if Arc::ptr_eq(&p, &peer.state) { + subnets.push((ip, masklen)) } } @@ -103,6 +99,29 @@ fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>( } } +impl EncryptionState { + fn new(keypair: &Arc<KeyPair>) -> EncryptionState { + EncryptionState { + id: keypair.send.id, + key: keypair.send.key, + nonce: 0, + death: keypair.birth + REJECT_AFTER_TIME, + } + } +} + +impl<C: Callbacks, T: Tun, B: Bind> DecryptionState<C, T, B> { + fn new(peer: &Arc<PeerInner<C, T, B>>, keypair: &Arc<KeyPair>) -> DecryptionState<C, T, B> { + DecryptionState { + confirmed: AtomicBool::new(keypair.initiator), + keypair: keypair.clone(), + protector: spin::Mutex::new(AntiReplay::new()), + peer: peer.clone(), + death: keypair.birth + REJECT_AFTER_TIME, + } + } +} + impl<C: Callbacks, T: Tun, B: Bind> Drop for Peer<C, T, B> { fn drop(&mut self) { let peer = &self.state; @@ -202,12 +221,52 @@ pub fn new_peer<C: Callbacks, T: Tun, B: Bind>( } impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> { - pub fn confirm_key(&self, kp: Weak<KeyPair>) { - // upgrade key-pair to strong reference + pub fn confirm_key(&self, keypair: &Arc<KeyPair>) { + // take lock and check keypair = keys.next + let mut keys = self.keys.lock(); + let next = match keys.next.as_ref() { + Some(next) => next, + None => { + return; + } + }; + if !Arc::ptr_eq(&next, keypair) { + return; + } - // check it is the new unconfirmed key + // allocate new encryption state + let ekey = Some(EncryptionState::new(&next)); // rotate key-wheel + let mut swap = None; + mem::swap(&mut keys.next, &mut swap); + mem::swap(&mut keys.current, &mut swap); + mem::swap(&mut keys.previous, &mut swap); + + // set new encryption key + *self.ekey.lock() = ekey; + } + + pub fn recv_job( + &self, + src: B::Endpoint, + dec: Arc<DecryptionState<C, T, B>>, + mut msg: Vec<u8>, + ) -> Option<JobParallel> { + let (tx, rx) = oneshot(); + let key = dec.keypair.send.key; + match self.inbound.lock().try_send((dec, src, rx)) { + Ok(_) => Some(( + tx, + JobBuffer { + msg, + key: key, + okay: false, + op: Operation::Decryption, + }, + )), + Err(_) => None, + } } pub fn send_job(&self, mut msg: Vec<u8>) -> Option<JobParallel> { @@ -260,7 +319,7 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> { impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> { pub fn set_endpoint(&self, endpoint: SocketAddr) { - *self.state.endpoint.lock() = Some(Arc::new(endpoint)) + *self.state.endpoint.lock() = Some(endpoint.into()); } /// Add a new keypair @@ -285,12 +344,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> { // update key-wheel if new.initiator { // start using key for encryption - *self.state.ekey.lock() = Some(EncryptionState { - id: new.send.id, - key: new.send.key, - nonce: 0, - death: new.birth + REJECT_AFTER_TIME, - }); + *self.state.ekey.lock() = Some(EncryptionState::new(&new)); // move current into previous keys.previous = keys.current.as_ref().map(|v| v.clone()); @@ -310,19 +364,11 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> { recv.remove(&id); } - // map new id to keypair + // map new id to decryption state debug_assert!(!recv.contains_key(&new.recv.id)); - recv.insert( new.recv.id, - DecryptionState { - confirmed: AtomicBool::new(new.initiator), - keypair: Arc::downgrade(&new), - key: new.recv.key, - protector: spin::Mutex::new(AntiReplay::new()), - peer: Arc::downgrade(&self.state), - death: new.birth + REJECT_AFTER_TIME, - }, + Arc::new(DecryptionState::new(&self.state, &new)), ); } @@ -345,14 +391,14 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> { .device .ipv4 .write() - .insert(v4, masklen, Arc::downgrade(&self.state)) + .insert(v4, masklen, self.state.clone()) } IpAddr::V6(v6) => { self.state .device .ipv6 .write() - .insert(v6, masklen, Arc::downgrade(&self.state)) + .insert(v6, masklen, self.state.clone()) } }; } |