aboutsummaryrefslogtreecommitdiffstats
path: root/src/router/peer.rs
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-09-07 18:38:19 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-09-07 18:38:19 +0200
commit7b61ee4c2db87e195f5291fb1a3927648d38a2a4 (patch)
tree410c0609c3f4d1afbd0d87791b9156a538f59398 /src/router/peer.rs
parentAdded outbound benchmark (diff)
downloadwireguard-rs-7b61ee4c2db87e195f5291fb1a3927648d38a2a4.tar.xz
wireguard-rs-7b61ee4c2db87e195f5291fb1a3927648d38a2a4.zip
Write inbound packets to TUN device
Diffstat (limited to '')
-rw-r--r--src/router/peer.rs114
1 files changed, 80 insertions, 34 deletions
diff --git a/src/router/peer.rs b/src/router/peer.rs
index 634f980..0cd588d 100644
--- a/src/router/peer.rs
+++ b/src/router/peer.rs
@@ -30,7 +30,7 @@ use super::workers::Operation;
use super::workers::{worker_inbound, worker_outbound};
use super::workers::{JobBuffer, JobInbound, JobOutbound, JobParallel};
-use super::constants::MAX_STAGED_PACKETS;
+use super::constants::*;
use super::types::Callbacks;
pub struct KeyWheel {
@@ -50,7 +50,7 @@ pub struct PeerInner<C: Callbacks, T: Tun, B: Bind> {
pub tx_bytes: AtomicU64, // transmitted bytes
pub keys: Mutex<KeyWheel>, // key-wheel
pub ekey: Mutex<Option<EncryptionState>>, // encryption state
- pub endpoint: Mutex<Option<Arc<SocketAddr>>>,
+ pub endpoint: Mutex<Option<B::Endpoint>>,
}
pub struct Peer<C: Callbacks, T: Tun, B: Bind> {
@@ -61,7 +61,7 @@ pub struct Peer<C: Callbacks, T: Tun, B: Bind> {
fn treebit_list<A, E, C: Callbacks, T: Tun, B: Bind>(
peer: &Arc<PeerInner<C, T, B>>,
- table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T, B>>>>,
+ table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<C, T, B>>>>,
callback: Box<dyn Fn(A, u32) -> E>,
) -> Vec<E>
where
@@ -70,10 +70,8 @@ where
let mut res = Vec::new();
for subnet in table.read().iter() {
let (ip, masklen, p) = subnet;
- if let Some(p) = p.upgrade() {
- if Arc::ptr_eq(&p, &peer) {
- res.push(callback(ip, masklen))
- }
+ if Arc::ptr_eq(&p, &peer) {
+ res.push(callback(ip, masklen))
}
}
res
@@ -81,7 +79,7 @@ where
fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>(
peer: &Peer<C, T, B>,
- table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T, B>>>>,
+ table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<C, T, B>>>>,
) {
let mut m = table.write();
@@ -89,10 +87,8 @@ fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>(
let mut subnets = vec![];
for subnet in m.iter() {
let (ip, masklen, p) = subnet;
- if let Some(p) = p.upgrade() {
- if Arc::ptr_eq(&p, &peer.state) {
- subnets.push((ip, masklen))
- }
+ if Arc::ptr_eq(&p, &peer.state) {
+ subnets.push((ip, masklen))
}
}
@@ -103,6 +99,29 @@ fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>(
}
}
+impl EncryptionState {
+ fn new(keypair: &Arc<KeyPair>) -> EncryptionState {
+ EncryptionState {
+ id: keypair.send.id,
+ key: keypair.send.key,
+ nonce: 0,
+ death: keypair.birth + REJECT_AFTER_TIME,
+ }
+ }
+}
+
+impl<C: Callbacks, T: Tun, B: Bind> DecryptionState<C, T, B> {
+ fn new(peer: &Arc<PeerInner<C, T, B>>, keypair: &Arc<KeyPair>) -> DecryptionState<C, T, B> {
+ DecryptionState {
+ confirmed: AtomicBool::new(keypair.initiator),
+ keypair: keypair.clone(),
+ protector: spin::Mutex::new(AntiReplay::new()),
+ peer: peer.clone(),
+ death: keypair.birth + REJECT_AFTER_TIME,
+ }
+ }
+}
+
impl<C: Callbacks, T: Tun, B: Bind> Drop for Peer<C, T, B> {
fn drop(&mut self) {
let peer = &self.state;
@@ -202,12 +221,52 @@ pub fn new_peer<C: Callbacks, T: Tun, B: Bind>(
}
impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
- pub fn confirm_key(&self, kp: Weak<KeyPair>) {
- // upgrade key-pair to strong reference
+ pub fn confirm_key(&self, keypair: &Arc<KeyPair>) {
+ // take lock and check keypair = keys.next
+ let mut keys = self.keys.lock();
+ let next = match keys.next.as_ref() {
+ Some(next) => next,
+ None => {
+ return;
+ }
+ };
+ if !Arc::ptr_eq(&next, keypair) {
+ return;
+ }
- // check it is the new unconfirmed key
+ // allocate new encryption state
+ let ekey = Some(EncryptionState::new(&next));
// rotate key-wheel
+ let mut swap = None;
+ mem::swap(&mut keys.next, &mut swap);
+ mem::swap(&mut keys.current, &mut swap);
+ mem::swap(&mut keys.previous, &mut swap);
+
+ // set new encryption key
+ *self.ekey.lock() = ekey;
+ }
+
+ pub fn recv_job(
+ &self,
+ src: B::Endpoint,
+ dec: Arc<DecryptionState<C, T, B>>,
+ mut msg: Vec<u8>,
+ ) -> Option<JobParallel> {
+ let (tx, rx) = oneshot();
+ let key = dec.keypair.send.key;
+ match self.inbound.lock().try_send((dec, src, rx)) {
+ Ok(_) => Some((
+ tx,
+ JobBuffer {
+ msg,
+ key: key,
+ okay: false,
+ op: Operation::Decryption,
+ },
+ )),
+ Err(_) => None,
+ }
}
pub fn send_job(&self, mut msg: Vec<u8>) -> Option<JobParallel> {
@@ -260,7 +319,7 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
pub fn set_endpoint(&self, endpoint: SocketAddr) {
- *self.state.endpoint.lock() = Some(Arc::new(endpoint))
+ *self.state.endpoint.lock() = Some(endpoint.into());
}
/// Add a new keypair
@@ -285,12 +344,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
// update key-wheel
if new.initiator {
// start using key for encryption
- *self.state.ekey.lock() = Some(EncryptionState {
- id: new.send.id,
- key: new.send.key,
- nonce: 0,
- death: new.birth + REJECT_AFTER_TIME,
- });
+ *self.state.ekey.lock() = Some(EncryptionState::new(&new));
// move current into previous
keys.previous = keys.current.as_ref().map(|v| v.clone());
@@ -310,19 +364,11 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
recv.remove(&id);
}
- // map new id to keypair
+ // map new id to decryption state
debug_assert!(!recv.contains_key(&new.recv.id));
-
recv.insert(
new.recv.id,
- DecryptionState {
- confirmed: AtomicBool::new(new.initiator),
- keypair: Arc::downgrade(&new),
- key: new.recv.key,
- protector: spin::Mutex::new(AntiReplay::new()),
- peer: Arc::downgrade(&self.state),
- death: new.birth + REJECT_AFTER_TIME,
- },
+ Arc::new(DecryptionState::new(&self.state, &new)),
);
}
@@ -345,14 +391,14 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
.device
.ipv4
.write()
- .insert(v4, masklen, Arc::downgrade(&self.state))
+ .insert(v4, masklen, self.state.clone())
}
IpAddr::V6(v6) => {
self.state
.device
.ipv6
.write()
- .insert(v6, masklen, Arc::downgrade(&self.state))
+ .insert(v6, masklen, self.state.clone())
}
};
}