diff options
author | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-09-28 18:01:55 +0200 |
---|---|---|
committer | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-09-28 18:01:55 +0200 |
commit | edfd2f235a7954c2a2b846d112a468156ceddfa6 (patch) | |
tree | c5490b795c4776110ddf5d2374ee437152afb30d /src/router | |
parent | Work on peer timers (diff) | |
download | wireguard-rs-edfd2f235a7954c2a2b846d112a468156ceddfa6.tar.xz wireguard-rs-edfd2f235a7954c2a2b846d112a468156ceddfa6.zip |
Added key_confirmed callback
Diffstat (limited to 'src/router')
-rw-r--r-- | src/router/device.rs | 6 | ||||
-rw-r--r-- | src/router/peer.rs | 159 | ||||
-rw-r--r-- | src/router/tests.rs | 84 | ||||
-rw-r--r-- | src/router/types.rs | 7 |
4 files changed, 167 insertions, 89 deletions
diff --git a/src/router/device.rs b/src/router/device.rs index e8250cb..d126959 100644 --- a/src/router/device.rs +++ b/src/router/device.rs @@ -60,6 +60,8 @@ pub struct Device<C: Callbacks, T: Tun, B: Bind> { impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> { fn drop(&mut self) { + debug!("router: dropping device"); + // drop all queues { let mut queues = self.state.queues.lock(); @@ -76,7 +78,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> { _ => false, } {} - debug!("device dropped"); + debug!("router: device dropped"); } } @@ -175,7 +177,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> { let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptKeyRoute)?; // schedule for encryption and transmission to peer - if let Some(job) = peer.send_job(msg) { + if let Some(job) = peer.send_job(msg, true) { debug_assert_eq!(job.1.op, Operation::Encryption); // add job to worker queue diff --git a/src/router/peer.rs b/src/router/peer.rs index 7a3ede8..86723bb 100644 --- a/src/router/peer.rs +++ b/src/router/peer.rs @@ -217,6 +217,7 @@ pub fn new_peer<C: Callbacks, T: Tun, B: Bind>( impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> { fn send_staged(&self) -> bool { + debug!("peer.send_staged"); let mut sent = false; let mut staged = self.staged_packets.lock(); loop { @@ -230,8 +231,11 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> { } } + // Treat the msg as the payload of a transport message + // Unlike device.send, peer.send_raw does not buffer messages when a key is not available. fn send_raw(&self, msg: Vec<u8>) -> bool { - match self.send_job(msg) { + debug!("peer.send_raw"); + match self.send_job(msg, false) { Some(job) => { debug!("send_raw: got obtained send_job"); let index = self.device.queue_next.fetch_add(1, Ordering::SeqCst); @@ -246,29 +250,35 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> { } pub fn confirm_key(&self, keypair: &Arc<KeyPair>) { - // take lock and check keypair = keys.next - let mut keys = self.keys.lock(); - let next = match keys.next.as_ref() { - Some(next) => next, - None => { + debug!("peer.confirm_key"); + { + // take lock and check keypair = keys.next + let mut keys = self.keys.lock(); + let next = match keys.next.as_ref() { + Some(next) => next, + None => { + return; + } + }; + if !Arc::ptr_eq(&next, keypair) { return; } - }; - if !Arc::ptr_eq(&next, keypair) { - return; - } - // allocate new encryption state - let ekey = Some(EncryptionState::new(&next)); + // allocate new encryption state + let ekey = Some(EncryptionState::new(&next)); - // rotate key-wheel - let mut swap = None; - mem::swap(&mut keys.next, &mut swap); - mem::swap(&mut keys.current, &mut swap); - mem::swap(&mut keys.previous, &mut swap); + // rotate key-wheel + let mut swap = None; + mem::swap(&mut keys.next, &mut swap); + mem::swap(&mut keys.current, &mut swap); + mem::swap(&mut keys.previous, &mut swap); - // set new encryption key - *self.ekey.lock() = ekey; + // tell the world outside the router that a key was confirmed + C::key_confirmed(&self.opaque); + + // set new key for encryption + *self.ekey.lock() = ekey; + } // start transmission of staged packets self.send_staged(); @@ -296,7 +306,8 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> { } } - pub fn send_job(&self, mut msg: Vec<u8>) -> Option<JobParallel> { + pub fn send_job(&self, mut msg: Vec<u8>, stage: bool) -> Option<JobParallel> { + debug!("peer.send_job"); debug_assert!( msg.len() >= mem::size_of::<TransportHeader>(), "received message with size: {:}", @@ -319,7 +330,6 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> { None } else { // there should be no stacked packets lingering around - debug_assert_eq!(self.staged_packets.lock().len(), 0); debug!("encryption state available, nonce = {}", state.nonce); // set transport message fields @@ -334,7 +344,7 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> { // If not suitable key was found: // 1. Stage packet for later transmission // 2. Request new key - if key.is_none() { + if key.is_none() && stage { self.staged_packets.lock().push_back(msg); C::need_key(&self.opaque); return None; @@ -372,6 +382,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> { /// This API still permits support for the "sticky socket" behavior, /// as sockets should be "unsticked" when manually updating the endpoint pub fn set_endpoint(&self, address: SocketAddr) { + debug!("peer.set_endpoint"); *self.state.endpoint.lock() = Some(B::Endpoint::from_address(address)); } @@ -381,6 +392,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> { /// /// Does not convey potential "sticky socket" information pub fn get_endpoint(&self) -> Option<SocketAddr> { + debug!("peer.get_endpoint"); self.state .endpoint .lock() @@ -390,6 +402,8 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> { /// Zero all key-material related to the peer pub fn zero_keys(&self) { + debug!("peer.zero_keys"); + let mut release: Vec<u32> = Vec::with_capacity(3); let mut keys = self.state.keys.lock(); @@ -429,57 +443,74 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> { /// since the only way to add additional keys to the peer is by using this method /// and a peer can have at most 3 keys allocated in the router at any time. pub fn add_keypair(&self, new: KeyPair) -> Vec<u32> { - let new = Arc::new(new); - let mut keys = self.state.keys.lock(); - let mut release = mem::replace(&mut keys.retired, vec![]); + debug!("peer.add_keypair"); + + let initiator = new.initiator; + let release = { + let new = Arc::new(new); + let mut keys = self.state.keys.lock(); + let mut release = mem::replace(&mut keys.retired, vec![]); + + // update key-wheel + if new.initiator { + // start using key for encryption + *self.state.ekey.lock() = Some(EncryptionState::new(&new)); + + // move current into previous + keys.previous = keys.current.as_ref().map(|v| v.clone()); + keys.current = Some(new.clone()); + } else { + // store the key and await confirmation + keys.previous = keys.next.as_ref().map(|v| v.clone()); + keys.next = Some(new.clone()); + }; - // update key-wheel - if new.initiator { - // start using key for encryption - *self.state.ekey.lock() = Some(EncryptionState::new(&new)); - - // move current into previous - keys.previous = keys.current.as_ref().map(|v| v.clone()); - keys.current = Some(new.clone()); - } else { - // store the key and await confirmation - keys.previous = keys.next.as_ref().map(|v| v.clone()); - keys.next = Some(new.clone()); + // update incoming packet id map + { + debug!("peer.add_keypair: updating inbound id map"); + let mut recv = self.state.device.recv.write(); + + // purge recv map of previous id + keys.previous.as_ref().map(|k| { + recv.remove(&k.local_id()); + release.push(k.local_id()); + }); + + // map new id to decryption state + debug_assert!(!recv.contains_key(&new.recv.id)); + recv.insert( + new.recv.id, + Arc::new(DecryptionState::new(&self.state, &new)), + ); + } + release }; - // update incoming packet id map - { - let mut recv = self.state.device.recv.write(); - - // purge recv map of previous id - keys.previous.as_ref().map(|k| { - recv.remove(&k.local_id()); - release.push(k.local_id()); - }); - - // map new id to decryption state - debug_assert!(!recv.contains_key(&new.recv.id)); - recv.insert( - new.recv.id, - Arc::new(DecryptionState::new(&self.state, &new)), - ); - } - // schedule confirmation - if new.initiator { - // fall back to keepalive packet + if initiator { + debug_assert!(self.state.ekey.lock().is_some()); + debug!("peer.add_keypair: is initiator, must confirm the key"); + // attempt to confirm using staged packets if !self.state.send_staged() { - let ok = self.keepalive(); - debug!("keepalive for confirmation, sent = {}", ok); + // fall back to keepalive packet + let ok = self.send_keepalive(); + debug!( + "peer.add_keypair: keepalive for confirmation, sent = {}", + ok + ); } + debug!("peer.add_keypair: key attempted confirmed"); } - debug_assert!(release.len() <= 3); + debug_assert!( + release.len() <= 3, + "since the key-wheel contains at most 3 keys" + ); release } - pub fn keepalive(&self) -> bool { - debug!("send keepalive"); + pub fn send_keepalive(&self) -> bool { + debug!("peer.send_keepalive"); self.state.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX]) } @@ -498,6 +529,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> { /// If an identical value already exists as part of a prior peer, /// the allowed IP entry will be removed from that peer and added to this peer. pub fn add_subnet(&self, ip: IpAddr, masklen: u32) { + debug!("peer.add_subnet"); match ip { IpAddr::V4(v4) => { self.state @@ -522,6 +554,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> { /// /// A vector of subnets, represented by as mask/size pub fn list_subnets(&self) -> Vec<(IpAddr, u32)> { + debug!("peer.list_subnets"); let mut res = Vec::new(); res.append(&mut treebit_list( &self.state, @@ -540,6 +573,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> { /// After the call, no subnets will be cryptkey routed to the peer. /// Used for the UAPI command "replace_allowed_ips=true" pub fn remove_subnets(&self) { + debug!("peer.remove_subnets"); treebit_remove(self, &self.state.device.ipv4); treebit_remove(self, &self.state.device.ipv6); } @@ -554,6 +588,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> { /// /// Unit if packet was sent, or an error indicating why sending failed pub fn send(&self, msg: &[u8]) -> Result<(), RouterError> { + debug!("peer.send"); let inner = &self.state; match inner.endpoint.lock().as_ref() { Some(endpoint) => inner diff --git a/src/router/tests.rs b/src/router/tests.rs index ca6312d..07afa5d 100644 --- a/src/router/tests.rs +++ b/src/router/tests.rs @@ -1,7 +1,7 @@ use std::error::Error; use std::fmt; use std::net::{IpAddr, SocketAddr}; -use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::atomic::Ordering; use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; use std::sync::Arc; use std::sync::Mutex; @@ -228,6 +228,7 @@ mod tests { send: Mutex<Vec<(usize, bool, bool)>>, recv: Mutex<Vec<(usize, bool, bool)>>, need_key: Mutex<Vec<()>>, + key_confirmed: Mutex<Vec<()>>, } #[derive(Clone)] @@ -241,6 +242,7 @@ mod tests { send: Mutex::new(vec![]), recv: Mutex::new(vec![]), need_key: Mutex::new(vec![]), + key_confirmed: Mutex::new(vec![]), })) } @@ -248,6 +250,7 @@ mod tests { self.0.send.lock().unwrap().clear(); self.0.recv.lock().unwrap().clear(); self.0.need_key.lock().unwrap().clear(); + self.0.key_confirmed.lock().unwrap().clear(); } fn send(&self) -> Option<(usize, bool, bool)> { @@ -262,11 +265,17 @@ mod tests { self.0.need_key.lock().unwrap().pop() } + fn key_confirmed(&self) -> Option<()> { + self.0.key_confirmed.lock().unwrap().pop() + } + + // has all events been accounted for by assertions? fn is_empty(&self) -> bool { let send = self.0.send.lock().unwrap(); let recv = self.0.recv.lock().unwrap(); let need_key = self.0.need_key.lock().unwrap(); - send.is_empty() && recv.is_empty() && need_key.is_empty() + let key_confirmed = self.0.key_confirmed.lock().unwrap(); + send.is_empty() && recv.is_empty() && need_key.is_empty() & key_confirmed.is_empty() } } @@ -284,6 +293,15 @@ mod tests { fn need_key(t: &Self::Opaque) { t.0.need_key.lock().unwrap().push(()); } + + fn key_confirmed(t: &Self::Opaque) { + t.0.key_confirmed.lock().unwrap().push(()); + } + } + + // wait for scheduling + fn wait() { + thread::sleep(Duration::from_millis(50)); } fn init() { @@ -319,6 +337,7 @@ mod tests { } fn recv(_: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {} fn need_key(_: &Self::Opaque) {} + fn key_confirmed(_: &Self::Opaque) {} } // create device @@ -336,7 +355,7 @@ mod tests { let ip1: IpAddr = ip.parse().unwrap(); peer.add_subnet(mask, len); - // every iteration sends 50 GB + // every iteration sends 10 GB b.iter(|| { opaque.store(0, Ordering::SeqCst); let msg = make_packet(1024, ip1); @@ -400,7 +419,7 @@ mod tests { let res = router.send(msg); // allow some scheduling - thread::sleep(Duration::from_millis(20)); + wait(); if *okay { // cryptkey routing succeeded @@ -444,12 +463,8 @@ mod tests { } } - fn wait() { - thread::sleep(Duration::from_millis(20)); - } - #[test] - fn test_outbound_inbound() { + fn test_bidirectional() { init(); let tests = [ @@ -463,15 +478,42 @@ mod tests { ("192.168.1.0", 24, "192.168.1.20", true), ("172.133.133.133", 32, "172.133.133.133", true), ), + ( + false, // confirm with keepalive + ( + "2001:db8::ff00:42:8000", + 113, + "2001:db8::ff00:42:ffff", + true, + ), + ( + "2001:db8::ff40:42:8000", + 113, + "2001:db8::ff40:42:ffff", + true, + ), + ), + ( + false, // confirm with staged packet + ( + "2001:db8::ff00:42:8000", + 113, + "2001:db8::ff00:42:ffff", + true, + ), + ( + "2001:db8::ff40:42:8000", + 113, + "2001:db8::ff40:42:ffff", + true, + ), + ), ]; for (stage, p1, p2) in tests.iter() { - let (bind1, bind2) = bind_pair(); - // create matching devices - + let (bind1, bind2) = bind_pair(); let router1: Device<TestCallbacks, _, _> = Device::new(1, TunTest {}, bind1.clone()); - let router2: Device<TestCallbacks, _, _> = Device::new(1, TunTest {}, bind2.clone()); // prepare opaque values for tracing callbacks @@ -519,9 +561,7 @@ mod tests { wait(); assert!(opaq2.send().is_some()); - assert!(opaq2.recv().is_none()); - assert!(opaq2.need_key().is_none()); - assert!(opaq2.is_empty()); + assert!(opaq2.is_empty(), "events on peer2 should be 'send'"); assert!(opaq1.is_empty(), "nothing should happened on peer1"); // read confirming message received by the other end ("across the internet") @@ -531,14 +571,16 @@ mod tests { router1.recv(from, buf).unwrap(); wait(); - assert!(opaq1.send().is_none()); assert!(opaq1.recv().is_some()); - assert!(opaq1.need_key().is_none()); - assert!(opaq1.is_empty()); + assert!(opaq1.key_confirmed().is_some()); + assert!( + opaq1.is_empty(), + "events on peer1 should be 'recv' and 'key_confirmed'" + ); assert!(peer1.get_endpoint().is_some()); assert!(opaq2.is_empty(), "nothing should happened on peer2"); - // how that peer1 has an endpoint + // now that peer1 has an endpoint // route packets : peer1 -> peer2 for _ in 0..10 { @@ -572,8 +614,6 @@ mod tests { assert!(opaq2.recv().is_some()); assert!(opaq2.need_key().is_none()); } - - // route packets : peer2 -> peer1 } } } diff --git a/src/router/types.rs b/src/router/types.rs index 736e7c8..b7c3ae0 100644 --- a/src/router/types.rs +++ b/src/router/types.rs @@ -23,9 +23,10 @@ impl<T, F> KeyCallback<T> for F where F: Fn(&T) -> () + Sync + Send + 'static {} pub trait Callbacks: Send + Sync + 'static { type Opaque: Opaque; - fn send(_opaque: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {} - fn recv(_opaque: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {} - fn need_key(_opaque: &Self::Opaque) {} + fn send(opaque: &Self::Opaque, size: usize, data: bool, sent: bool); + fn recv(opaque: &Self::Opaque, size: usize, data: bool, sent: bool); + fn need_key(opaque: &Self::Opaque); + fn key_confirmed(opaque: &Self::Opaque); } #[derive(Debug)] |