From edfd2f235a7954c2a2b846d112a468156ceddfa6 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sat, 28 Sep 2019 18:01:55 +0200 Subject: Added key_confirmed callback --- src/constants.rs | 5 ++ src/main.rs | 3 + src/router/device.rs | 6 +- src/router/peer.rs | 159 +++++++++++++++++++++++++++++++-------------------- src/router/tests.rs | 84 ++++++++++++++++++++------- src/router/types.rs | 7 ++- src/timers.rs | 18 ++++-- src/wireguard.rs | 49 ++++++++++------ 8 files changed, 217 insertions(+), 114 deletions(-) diff --git a/src/constants.rs b/src/constants.rs index 5a895e5..c4e3ae7 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -11,3 +11,8 @@ pub const REKEY_TIMEOUT: Duration = Duration::from_secs(5); pub const KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10); pub const MAX_TIMER_HANDSHAKES: usize = 18; + +pub const TIMER_MAX_DURATION: Duration = Duration::from_secs(200); +pub const TIMERS_TICK: Duration = Duration::from_millis(100); +pub const TIMERS_SLOTS: usize = (TIMER_MAX_DURATION.as_micros() / TIMERS_TICK.as_micros()) as usize; +pub const TIMERS_CAPACITY: usize = 1024; diff --git a/src/main.rs b/src/main.rs index a52eecc..26b39a2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,4 +12,7 @@ mod timers; mod types; mod wireguard; +#[test] +fn test_pure_wireguard() {} + fn main() {} diff --git a/src/router/device.rs b/src/router/device.rs index e8250cb..d126959 100644 --- a/src/router/device.rs +++ b/src/router/device.rs @@ -60,6 +60,8 @@ pub struct Device { impl Drop for Device { fn drop(&mut self) { + debug!("router: dropping device"); + // drop all queues { let mut queues = self.state.queues.lock(); @@ -76,7 +78,7 @@ impl Drop for Device { _ => false, } {} - debug!("device dropped"); + debug!("router: device dropped"); } } @@ -175,7 +177,7 @@ impl Device { let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptKeyRoute)?; // schedule for encryption and transmission to peer - if let Some(job) = peer.send_job(msg) { + if let Some(job) = peer.send_job(msg, true) { debug_assert_eq!(job.1.op, Operation::Encryption); // add job to worker queue diff --git a/src/router/peer.rs b/src/router/peer.rs index 7a3ede8..86723bb 100644 --- a/src/router/peer.rs +++ b/src/router/peer.rs @@ -217,6 +217,7 @@ pub fn new_peer( impl PeerInner { fn send_staged(&self) -> bool { + debug!("peer.send_staged"); let mut sent = false; let mut staged = self.staged_packets.lock(); loop { @@ -230,8 +231,11 @@ impl PeerInner { } } + // Treat the msg as the payload of a transport message + // Unlike device.send, peer.send_raw does not buffer messages when a key is not available. fn send_raw(&self, msg: Vec) -> bool { - match self.send_job(msg) { + debug!("peer.send_raw"); + match self.send_job(msg, false) { Some(job) => { debug!("send_raw: got obtained send_job"); let index = self.device.queue_next.fetch_add(1, Ordering::SeqCst); @@ -246,29 +250,35 @@ impl PeerInner { } pub fn confirm_key(&self, keypair: &Arc) { - // take lock and check keypair = keys.next - let mut keys = self.keys.lock(); - let next = match keys.next.as_ref() { - Some(next) => next, - None => { + debug!("peer.confirm_key"); + { + // take lock and check keypair = keys.next + let mut keys = self.keys.lock(); + let next = match keys.next.as_ref() { + Some(next) => next, + None => { + return; + } + }; + if !Arc::ptr_eq(&next, keypair) { return; } - }; - if !Arc::ptr_eq(&next, keypair) { - return; - } - // allocate new encryption state - let ekey = Some(EncryptionState::new(&next)); + // allocate new encryption state + let ekey = Some(EncryptionState::new(&next)); - // rotate key-wheel - let mut swap = None; - mem::swap(&mut keys.next, &mut swap); - mem::swap(&mut keys.current, &mut swap); - mem::swap(&mut keys.previous, &mut swap); + // rotate key-wheel + let mut swap = None; + mem::swap(&mut keys.next, &mut swap); + mem::swap(&mut keys.current, &mut swap); + mem::swap(&mut keys.previous, &mut swap); - // set new encryption key - *self.ekey.lock() = ekey; + // tell the world outside the router that a key was confirmed + C::key_confirmed(&self.opaque); + + // set new key for encryption + *self.ekey.lock() = ekey; + } // start transmission of staged packets self.send_staged(); @@ -296,7 +306,8 @@ impl PeerInner { } } - pub fn send_job(&self, mut msg: Vec) -> Option { + pub fn send_job(&self, mut msg: Vec, stage: bool) -> Option { + debug!("peer.send_job"); debug_assert!( msg.len() >= mem::size_of::(), "received message with size: {:}", @@ -319,7 +330,6 @@ impl PeerInner { None } else { // there should be no stacked packets lingering around - debug_assert_eq!(self.staged_packets.lock().len(), 0); debug!("encryption state available, nonce = {}", state.nonce); // set transport message fields @@ -334,7 +344,7 @@ impl PeerInner { // If not suitable key was found: // 1. Stage packet for later transmission // 2. Request new key - if key.is_none() { + if key.is_none() && stage { self.staged_packets.lock().push_back(msg); C::need_key(&self.opaque); return None; @@ -372,6 +382,7 @@ impl Peer { /// This API still permits support for the "sticky socket" behavior, /// as sockets should be "unsticked" when manually updating the endpoint pub fn set_endpoint(&self, address: SocketAddr) { + debug!("peer.set_endpoint"); *self.state.endpoint.lock() = Some(B::Endpoint::from_address(address)); } @@ -381,6 +392,7 @@ impl Peer { /// /// Does not convey potential "sticky socket" information pub fn get_endpoint(&self) -> Option { + debug!("peer.get_endpoint"); self.state .endpoint .lock() @@ -390,6 +402,8 @@ impl Peer { /// Zero all key-material related to the peer pub fn zero_keys(&self) { + debug!("peer.zero_keys"); + let mut release: Vec = Vec::with_capacity(3); let mut keys = self.state.keys.lock(); @@ -429,57 +443,74 @@ impl Peer { /// since the only way to add additional keys to the peer is by using this method /// and a peer can have at most 3 keys allocated in the router at any time. pub fn add_keypair(&self, new: KeyPair) -> Vec { - let new = Arc::new(new); - let mut keys = self.state.keys.lock(); - let mut release = mem::replace(&mut keys.retired, vec![]); + debug!("peer.add_keypair"); + + let initiator = new.initiator; + let release = { + let new = Arc::new(new); + let mut keys = self.state.keys.lock(); + let mut release = mem::replace(&mut keys.retired, vec![]); + + // update key-wheel + if new.initiator { + // start using key for encryption + *self.state.ekey.lock() = Some(EncryptionState::new(&new)); + + // move current into previous + keys.previous = keys.current.as_ref().map(|v| v.clone()); + keys.current = Some(new.clone()); + } else { + // store the key and await confirmation + keys.previous = keys.next.as_ref().map(|v| v.clone()); + keys.next = Some(new.clone()); + }; - // update key-wheel - if new.initiator { - // start using key for encryption - *self.state.ekey.lock() = Some(EncryptionState::new(&new)); - - // move current into previous - keys.previous = keys.current.as_ref().map(|v| v.clone()); - keys.current = Some(new.clone()); - } else { - // store the key and await confirmation - keys.previous = keys.next.as_ref().map(|v| v.clone()); - keys.next = Some(new.clone()); + // update incoming packet id map + { + debug!("peer.add_keypair: updating inbound id map"); + let mut recv = self.state.device.recv.write(); + + // purge recv map of previous id + keys.previous.as_ref().map(|k| { + recv.remove(&k.local_id()); + release.push(k.local_id()); + }); + + // map new id to decryption state + debug_assert!(!recv.contains_key(&new.recv.id)); + recv.insert( + new.recv.id, + Arc::new(DecryptionState::new(&self.state, &new)), + ); + } + release }; - // update incoming packet id map - { - let mut recv = self.state.device.recv.write(); - - // purge recv map of previous id - keys.previous.as_ref().map(|k| { - recv.remove(&k.local_id()); - release.push(k.local_id()); - }); - - // map new id to decryption state - debug_assert!(!recv.contains_key(&new.recv.id)); - recv.insert( - new.recv.id, - Arc::new(DecryptionState::new(&self.state, &new)), - ); - } - // schedule confirmation - if new.initiator { - // fall back to keepalive packet + if initiator { + debug_assert!(self.state.ekey.lock().is_some()); + debug!("peer.add_keypair: is initiator, must confirm the key"); + // attempt to confirm using staged packets if !self.state.send_staged() { - let ok = self.keepalive(); - debug!("keepalive for confirmation, sent = {}", ok); + // fall back to keepalive packet + let ok = self.send_keepalive(); + debug!( + "peer.add_keypair: keepalive for confirmation, sent = {}", + ok + ); } + debug!("peer.add_keypair: key attempted confirmed"); } - debug_assert!(release.len() <= 3); + debug_assert!( + release.len() <= 3, + "since the key-wheel contains at most 3 keys" + ); release } - pub fn keepalive(&self) -> bool { - debug!("send keepalive"); + pub fn send_keepalive(&self) -> bool { + debug!("peer.send_keepalive"); self.state.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX]) } @@ -498,6 +529,7 @@ impl Peer { /// If an identical value already exists as part of a prior peer, /// the allowed IP entry will be removed from that peer and added to this peer. pub fn add_subnet(&self, ip: IpAddr, masklen: u32) { + debug!("peer.add_subnet"); match ip { IpAddr::V4(v4) => { self.state @@ -522,6 +554,7 @@ impl Peer { /// /// A vector of subnets, represented by as mask/size pub fn list_subnets(&self) -> Vec<(IpAddr, u32)> { + debug!("peer.list_subnets"); let mut res = Vec::new(); res.append(&mut treebit_list( &self.state, @@ -540,6 +573,7 @@ impl Peer { /// After the call, no subnets will be cryptkey routed to the peer. /// Used for the UAPI command "replace_allowed_ips=true" pub fn remove_subnets(&self) { + debug!("peer.remove_subnets"); treebit_remove(self, &self.state.device.ipv4); treebit_remove(self, &self.state.device.ipv6); } @@ -554,6 +588,7 @@ impl Peer { /// /// Unit if packet was sent, or an error indicating why sending failed pub fn send(&self, msg: &[u8]) -> Result<(), RouterError> { + debug!("peer.send"); let inner = &self.state; match inner.endpoint.lock().as_ref() { Some(endpoint) => inner diff --git a/src/router/tests.rs b/src/router/tests.rs index ca6312d..07afa5d 100644 --- a/src/router/tests.rs +++ b/src/router/tests.rs @@ -1,7 +1,7 @@ use std::error::Error; use std::fmt; use std::net::{IpAddr, SocketAddr}; -use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::atomic::Ordering; use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; use std::sync::Arc; use std::sync::Mutex; @@ -228,6 +228,7 @@ mod tests { send: Mutex>, recv: Mutex>, need_key: Mutex>, + key_confirmed: Mutex>, } #[derive(Clone)] @@ -241,6 +242,7 @@ mod tests { send: Mutex::new(vec![]), recv: Mutex::new(vec![]), need_key: Mutex::new(vec![]), + key_confirmed: Mutex::new(vec![]), })) } @@ -248,6 +250,7 @@ mod tests { self.0.send.lock().unwrap().clear(); self.0.recv.lock().unwrap().clear(); self.0.need_key.lock().unwrap().clear(); + self.0.key_confirmed.lock().unwrap().clear(); } fn send(&self) -> Option<(usize, bool, bool)> { @@ -262,11 +265,17 @@ mod tests { self.0.need_key.lock().unwrap().pop() } + fn key_confirmed(&self) -> Option<()> { + self.0.key_confirmed.lock().unwrap().pop() + } + + // has all events been accounted for by assertions? fn is_empty(&self) -> bool { let send = self.0.send.lock().unwrap(); let recv = self.0.recv.lock().unwrap(); let need_key = self.0.need_key.lock().unwrap(); - send.is_empty() && recv.is_empty() && need_key.is_empty() + let key_confirmed = self.0.key_confirmed.lock().unwrap(); + send.is_empty() && recv.is_empty() && need_key.is_empty() & key_confirmed.is_empty() } } @@ -284,6 +293,15 @@ mod tests { fn need_key(t: &Self::Opaque) { t.0.need_key.lock().unwrap().push(()); } + + fn key_confirmed(t: &Self::Opaque) { + t.0.key_confirmed.lock().unwrap().push(()); + } + } + + // wait for scheduling + fn wait() { + thread::sleep(Duration::from_millis(50)); } fn init() { @@ -319,6 +337,7 @@ mod tests { } fn recv(_: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {} fn need_key(_: &Self::Opaque) {} + fn key_confirmed(_: &Self::Opaque) {} } // create device @@ -336,7 +355,7 @@ mod tests { let ip1: IpAddr = ip.parse().unwrap(); peer.add_subnet(mask, len); - // every iteration sends 50 GB + // every iteration sends 10 GB b.iter(|| { opaque.store(0, Ordering::SeqCst); let msg = make_packet(1024, ip1); @@ -400,7 +419,7 @@ mod tests { let res = router.send(msg); // allow some scheduling - thread::sleep(Duration::from_millis(20)); + wait(); if *okay { // cryptkey routing succeeded @@ -444,12 +463,8 @@ mod tests { } } - fn wait() { - thread::sleep(Duration::from_millis(20)); - } - #[test] - fn test_outbound_inbound() { + fn test_bidirectional() { init(); let tests = [ @@ -463,15 +478,42 @@ mod tests { ("192.168.1.0", 24, "192.168.1.20", true), ("172.133.133.133", 32, "172.133.133.133", true), ), + ( + false, // confirm with keepalive + ( + "2001:db8::ff00:42:8000", + 113, + "2001:db8::ff00:42:ffff", + true, + ), + ( + "2001:db8::ff40:42:8000", + 113, + "2001:db8::ff40:42:ffff", + true, + ), + ), + ( + false, // confirm with staged packet + ( + "2001:db8::ff00:42:8000", + 113, + "2001:db8::ff00:42:ffff", + true, + ), + ( + "2001:db8::ff40:42:8000", + 113, + "2001:db8::ff40:42:ffff", + true, + ), + ), ]; for (stage, p1, p2) in tests.iter() { - let (bind1, bind2) = bind_pair(); - // create matching devices - + let (bind1, bind2) = bind_pair(); let router1: Device = Device::new(1, TunTest {}, bind1.clone()); - let router2: Device = Device::new(1, TunTest {}, bind2.clone()); // prepare opaque values for tracing callbacks @@ -519,9 +561,7 @@ mod tests { wait(); assert!(opaq2.send().is_some()); - assert!(opaq2.recv().is_none()); - assert!(opaq2.need_key().is_none()); - assert!(opaq2.is_empty()); + assert!(opaq2.is_empty(), "events on peer2 should be 'send'"); assert!(opaq1.is_empty(), "nothing should happened on peer1"); // read confirming message received by the other end ("across the internet") @@ -531,14 +571,16 @@ mod tests { router1.recv(from, buf).unwrap(); wait(); - assert!(opaq1.send().is_none()); assert!(opaq1.recv().is_some()); - assert!(opaq1.need_key().is_none()); - assert!(opaq1.is_empty()); + assert!(opaq1.key_confirmed().is_some()); + assert!( + opaq1.is_empty(), + "events on peer1 should be 'recv' and 'key_confirmed'" + ); assert!(peer1.get_endpoint().is_some()); assert!(opaq2.is_empty(), "nothing should happened on peer2"); - // how that peer1 has an endpoint + // now that peer1 has an endpoint // route packets : peer1 -> peer2 for _ in 0..10 { @@ -572,8 +614,6 @@ mod tests { assert!(opaq2.recv().is_some()); assert!(opaq2.need_key().is_none()); } - - // route packets : peer2 -> peer1 } } } diff --git a/src/router/types.rs b/src/router/types.rs index 736e7c8..b7c3ae0 100644 --- a/src/router/types.rs +++ b/src/router/types.rs @@ -23,9 +23,10 @@ impl KeyCallback for F where F: Fn(&T) -> () + Sync + Send + 'static {} pub trait Callbacks: Send + Sync + 'static { type Opaque: Opaque; - fn send(_opaque: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {} - fn recv(_opaque: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {} - fn need_key(_opaque: &Self::Opaque) {} + fn send(opaque: &Self::Opaque, size: usize, data: bool, sent: bool); + fn recv(opaque: &Self::Opaque, size: usize, data: bool, sent: bool); + fn need_key(opaque: &Self::Opaque); + fn key_confirmed(opaque: &Self::Opaque); } #[derive(Debug)] diff --git a/src/timers.rs b/src/timers.rs index fc00d85..303fd35 100644 --- a/src/timers.rs +++ b/src/timers.rs @@ -1,5 +1,6 @@ use std::marker::PhantomData; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::Arc; use std::time::Duration; use hjul::{Runner, Timer}; @@ -7,7 +8,7 @@ use hjul::{Runner, Timer}; use crate::constants::*; use crate::router::Callbacks; use crate::types::{Bind, Tun}; -use crate::wireguard::Peer; +use crate::wireguard::{Peer, PeerInner}; pub struct Timers { handshake_pending: AtomicBool, @@ -47,7 +48,7 @@ impl Timers { send_keepalive: { let peer = peer.clone(); runner.timer(move || { - peer.router.keepalive(); + peer.router.send_keepalive(); let keepalive = peer.keepalive.load(Ordering::Acquire); if keepalive > 0 { peer.timers @@ -103,21 +104,26 @@ impl Timers { pub struct Events(PhantomData<(T, B)>); impl Callbacks for Events { - type Opaque = Peer; + type Opaque = Arc>; - fn send(peer: &Peer, size: usize, data: bool, sent: bool) { + fn send(peer: &Self::Opaque, size: usize, data: bool, sent: bool) { peer.tx_bytes.fetch_add(size as u64, Ordering::Relaxed); } - fn recv(peer: &Peer, size: usize, data: bool, sent: bool) { + fn recv(peer: &Self::Opaque, size: usize, data: bool, sent: bool) { peer.rx_bytes.fetch_add(size as u64, Ordering::Relaxed); } - fn need_key(peer: &Peer) { + fn need_key(peer: &Self::Opaque) { let timers = peer.timers.read(); if !timers.handshake_pending.swap(true, Ordering::SeqCst) { timers.handshake_attempts.store(0, Ordering::SeqCst); timers.new_handshake.fire(); } } + + fn key_confirmed(peer: &Self::Opaque) { + let timers = peer.timers.read(); + timers.retransmit_handshake.stop(); + } } diff --git a/src/wireguard.rs b/src/wireguard.rs index cd61cf0..182cec2 100644 --- a/src/wireguard.rs +++ b/src/wireguard.rs @@ -1,8 +1,12 @@ +use crate::constants::*; use crate::handshake; use crate::router; use crate::timers::{Events, Timers}; use crate::types::{Bind, Endpoint, Tun}; +use hjul::Runner; + +use std::ops::Deref; use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; use std::thread; @@ -22,28 +26,32 @@ const SIZE_HANDSHAKE_QUEUE: usize = 128; const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4; const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000); -pub type Peer = Arc>; +#[derive(Clone)] +pub struct Peer { + pub router: Arc, T, B>>, + pub state: Arc>, +} -pub struct PeerInner { +pub struct PeerInner { pub keepalive: AtomicUsize, // keepalive interval pub rx_bytes: AtomicU64, pub tx_bytes: AtomicU64, - pub pk: PublicKey, // DISCUSS: Change layout in handshake module (adopt pattern of router), to avoid this. pub queue: Mutex>>, // handshake queue - pub router: router::Peer, T, B>, // router peer + pub pk: PublicKey, // DISCUSS: Change layout in handshake module (adopt pattern of router), to avoid this. pub timers: RwLock, // } -impl PeerInner { - pub fn new_handshake(&self) { - self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap(); +impl Deref for Peer { + type Target = PeerInner; + fn deref(&self) -> &Self::Target { + &self.state } } -macro_rules! timers { - ($peer:expr) => { - $peer.timers.read() - }; +impl PeerInner { + pub fn new_handshake(&self) { + self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap(); + } } struct Handshake { @@ -74,6 +82,7 @@ struct WireguardInner { } pub struct Wireguard { + runner: Runner, state: Arc>, } @@ -93,19 +102,18 @@ impl Wireguard { } } - /* fn new_peer(&self, pk: PublicKey) -> Peer { - let router = self.state.router.new_peer(); - - Arc::new(PeerInner { + let state = Arc::new(PeerInner { pk, queue: Mutex::new(self.state.queue.lock().clone()), keepalive: AtomicUsize::new(0), rx_bytes: AtomicU64::new(0), tx_bytes: AtomicU64::new(0), - }) + timers: RwLock::new(Timers::dummy(&self.runner)), + }); + let router = Arc::new(self.state.router.new_peer(state.clone())); + Peer { router, state } } - */ fn new(tun: T, bind: B) -> Wireguard { // create device state @@ -189,7 +197,7 @@ impl Wireguard { let msg = state.device.begin(&mut rng, &pk).unwrap(); // TODO handle if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { peer.router.send(&msg[..]); - timers!(peer).handshake_sent(); + peer.timers.read().handshake_sent(); } } } @@ -270,6 +278,9 @@ impl Wireguard { }); } - Wireguard { state: wg } + Wireguard { + state: wg, + runner: Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY), + } } } -- cgit v1.2.3-59-g8ed1b