From 2f3ceab0364497a4a6cf866b505f74443ed6e3ae Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Wed, 16 Oct 2019 13:40:40 +0200 Subject: Work on porting timer semantics and linux platform --- src/platform/linux/mod.rs | 181 +------------------------------------- src/platform/linux/tun.rs | 188 ++++++++++++++++++++++++++++++++++++++++ src/platform/linux/udp.rs | 0 src/platform/mod.rs | 18 +--- src/wireguard/router/mod.rs | 8 +- src/wireguard/router/tests.rs | 22 ++--- src/wireguard/router/types.rs | 8 +- src/wireguard/router/workers.rs | 24 +++-- src/wireguard/timers.rs | 52 +++++++++-- src/wireguard/wireguard.rs | 29 +++++-- 10 files changed, 293 insertions(+), 237 deletions(-) create mode 100644 src/platform/linux/tun.rs create mode 100644 src/platform/linux/udp.rs diff --git a/src/platform/linux/mod.rs b/src/platform/linux/mod.rs index ad2b8be..7a456ad 100644 --- a/src/platform/linux/mod.rs +++ b/src/platform/linux/mod.rs @@ -1,179 +1,4 @@ -use super::Tun; -use super::TunBind; +mod tun; +mod udp; -use super::super::wireguard::tun::*; - -use libc::*; - -use std::os::raw::c_short; -use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; - -const IFNAMSIZ: usize = 16; -const TUNSETIFF: u64 = 0x4004_54ca; - -const IFF_UP: i16 = 0x1; -const IFF_RUNNING: i16 = 0x40; - -const IFF_TUN: c_short = 0x0001; -const IFF_NO_PI: c_short = 0x1000; - -use std::error::Error; -use std::fmt; -use std::sync::atomic::AtomicUsize; -use std::sync::Arc; - -const CLONE_DEVICE_PATH: &'static [u8] = b"/dev/net/tun\0"; - -const TUN_MAGIC: u8 = b'T'; -const TUN_SET_IFF: u8 = 202; - -#[repr(C)] -struct Ifreq { - name: [u8; libc::IFNAMSIZ], - flags: c_short, - _pad: [u8; 64], -} - -pub struct PlatformTun {} - -pub struct PlatformTunReader { - fd: RawFd, -} - -pub struct PlatformTunWriter { - fd: RawFd, -} - -/* Listens for netlink messages - * announcing an MTU update for the interface - */ -#[derive(Clone)] -pub struct PlatformTunMTU { - value: Arc, -} - -#[derive(Debug)] -pub enum LinuxTunError { - InvalidTunDeviceName, - FailedToOpenCloneDevice, - SetIFFIoctlFailed, - Closed, // TODO -} - -impl fmt::Display for LinuxTunError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - unimplemented!() - } -} - -impl Error for LinuxTunError { - fn description(&self) -> &str { - unimplemented!() - } - - fn source(&self) -> Option<&(dyn Error + 'static)> { - unimplemented!() - } -} - -impl MTU for PlatformTunMTU { - fn mtu(&self) -> usize { - 1500 - } -} - -impl Reader for PlatformTunReader { - type Error = LinuxTunError; - - fn read(&self, buf: &mut [u8], offset: usize) -> Result { - debug_assert!( - offset < buf.len(), - "There is no space for the body of the TUN read" - ); - let n = unsafe { read(self.fd, buf[offset..].as_mut_ptr() as _, buf.len() - offset) }; - if n < 0 { - Err(LinuxTunError::Closed) - } else { - // conversion is safe - Ok(n as usize) - } - } -} - -impl Writer for PlatformTunWriter { - type Error = LinuxTunError; - - fn write(&self, src: &[u8]) -> Result<(), Self::Error> { - match unsafe { write(self.fd, src.as_ptr() as _, src.len() as _) } { - -1 => Err(LinuxTunError::Closed), - _ => Ok(()), - } - } -} - -impl Tun for PlatformTun { - type Error = LinuxTunError; - type Reader = PlatformTunReader; - type Writer = PlatformTunWriter; - type MTU = PlatformTunMTU; -} - -impl TunBind for PlatformTun { - fn create(name: &str) -> Result<(Vec, Self::Writer, Self::MTU), Self::Error> { - // construct request struct - let mut req = Ifreq { - name: [0u8; libc::IFNAMSIZ], - flags: (libc::IFF_TUN | libc::IFF_NO_PI) as c_short, - _pad: [0u8; 64], - }; - - // sanity check length of device name - let bs = name.as_bytes(); - if bs.len() > libc::IFNAMSIZ - 1 { - return Err(LinuxTunError::InvalidTunDeviceName); - } - req.name[..bs.len()].copy_from_slice(bs); - - // open clone device - let fd = match unsafe { open(CLONE_DEVICE_PATH.as_ptr() as _, O_RDWR) } { - -1 => return Err(LinuxTunError::FailedToOpenCloneDevice), - fd => fd, - }; - - // create TUN device - if unsafe { ioctl(fd, TUNSETIFF as _, &req) } < 0 { - return Err(LinuxTunError::SetIFFIoctlFailed); - } - - // create PlatformTunMTU instance - - Ok(( - vec![PlatformTunReader { fd }], - PlatformTunWriter { fd }, - PlatformTunMTU { - value: Arc::new(AtomicUsize::new(1500)), - }, - )) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::env; - - fn is_root() -> bool { - match env::var("USER") { - Ok(val) => val == "root", - Err(e) => false, - } - } - - #[test] - fn test_tun_create() { - if !is_root() { - return; - } - let (readers, writers, mtu) = PlatformTun::create("test").unwrap(); - } -} +pub use tun::PlatformTun; diff --git a/src/platform/linux/tun.rs b/src/platform/linux/tun.rs new file mode 100644 index 0000000..17390a1 --- /dev/null +++ b/src/platform/linux/tun.rs @@ -0,0 +1,188 @@ +use super::super::super::wireguard::tun::*; +use super::super::Tun; +use super::super::TunBind; + +use libc::*; + +use std::error::Error; +use std::fmt; +use std::os::raw::c_short; +use std::os::unix::io::RawFd; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +const IFNAMSIZ: usize = 16; +const TUNSETIFF: u64 = 0x4004_54ca; + +const IFF_UP: i16 = 0x1; +const IFF_RUNNING: i16 = 0x40; + +const IFF_TUN: c_short = 0x0001; +const IFF_NO_PI: c_short = 0x1000; + +const CLONE_DEVICE_PATH: &'static [u8] = b"/dev/net/tun\0"; + +const TUN_MAGIC: u8 = b'T'; +const TUN_SET_IFF: u8 = 202; + +#[repr(C)] +struct Ifreq { + name: [u8; libc::IFNAMSIZ], + flags: c_short, + _pad: [u8; 64], +} + +pub struct PlatformTun {} + +pub struct PlatformTunReader { + fd: RawFd, +} + +pub struct PlatformTunWriter { + fd: RawFd, +} + +/* Listens for netlink messages + * announcing an MTU update for the interface + */ +#[derive(Clone)] +pub struct PlatformTunMTU { + value: Arc, +} + +#[derive(Debug)] +pub enum LinuxTunError { + InvalidTunDeviceName, + FailedToOpenCloneDevice, + SetIFFIoctlFailed, + Closed, // TODO +} + +impl fmt::Display for LinuxTunError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + LinuxTunError::InvalidTunDeviceName => write!(f, "Invalid name (too long)"), + LinuxTunError::FailedToOpenCloneDevice => { + write!(f, "Failed to obtain fd for clone device") + } + LinuxTunError::SetIFFIoctlFailed => { + write!(f, "set_iff ioctl failed (insufficient permissions?)") + } + LinuxTunError::Closed => write!(f, "The tunnel has been closed"), + } + } +} + +impl Error for LinuxTunError { + fn description(&self) -> &str { + unimplemented!() + } + + fn source(&self) -> Option<&(dyn Error + 'static)> { + unimplemented!() + } +} + +impl MTU for PlatformTunMTU { + #[inline(always)] + fn mtu(&self) -> usize { + self.value.load(Ordering::Relaxed) + } +} + +impl Reader for PlatformTunReader { + type Error = LinuxTunError; + + fn read(&self, buf: &mut [u8], offset: usize) -> Result { + debug_assert!( + offset < buf.len(), + "There is no space for the body of the read" + ); + let n: isize = + unsafe { read(self.fd, buf[offset..].as_mut_ptr() as _, buf.len() - offset) }; + if n < 0 { + Err(LinuxTunError::Closed) + } else { + // conversion is safe + Ok(n as usize) + } + } +} + +impl Writer for PlatformTunWriter { + type Error = LinuxTunError; + + fn write(&self, src: &[u8]) -> Result<(), Self::Error> { + match unsafe { write(self.fd, src.as_ptr() as _, src.len() as _) } { + -1 => Err(LinuxTunError::Closed), + _ => Ok(()), + } + } +} + +impl Tun for PlatformTun { + type Error = LinuxTunError; + type Reader = PlatformTunReader; + type Writer = PlatformTunWriter; + type MTU = PlatformTunMTU; +} + +impl TunBind for PlatformTun { + fn create(name: &str) -> Result<(Vec, Self::Writer, Self::MTU), Self::Error> { + // construct request struct + let mut req = Ifreq { + name: [0u8; libc::IFNAMSIZ], + flags: (libc::IFF_TUN | libc::IFF_NO_PI) as c_short, + _pad: [0u8; 64], + }; + + // sanity check length of device name + let bs = name.as_bytes(); + if bs.len() > libc::IFNAMSIZ - 1 { + return Err(LinuxTunError::InvalidTunDeviceName); + } + req.name[..bs.len()].copy_from_slice(bs); + + // open clone device + let fd: RawFd = match unsafe { open(CLONE_DEVICE_PATH.as_ptr() as _, O_RDWR) } { + -1 => return Err(LinuxTunError::FailedToOpenCloneDevice), + fd => fd, + }; + assert!(fd >= 0); + + // create TUN device + if unsafe { ioctl(fd, TUNSETIFF as _, &req) } < 0 { + return Err(LinuxTunError::SetIFFIoctlFailed); + } + + // create PlatformTunMTU instance + Ok(( + vec![PlatformTunReader { fd }], // TODO: enable multi-queue for Linux + PlatformTunWriter { fd }, + PlatformTunMTU { + value: Arc::new(AtomicUsize::new(1500)), + }, + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::env; + + fn is_root() -> bool { + match env::var("USER") { + Ok(val) => val == "root", + Err(e) => false, + } + } + + #[test] + fn test_tun_create() { + if !is_root() { + return; + } + let (readers, writers, mtu) = PlatformTun::create("test").unwrap(); + } +} diff --git a/src/platform/linux/udp.rs b/src/platform/linux/udp.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/platform/mod.rs b/src/platform/mod.rs index e83384c..de33714 100644 --- a/src/platform/mod.rs +++ b/src/platform/mod.rs @@ -9,26 +9,12 @@ mod linux; #[cfg(target_os = "linux")] pub use linux::PlatformTun; -/* Syntax is nasty here, due to open issue: - * https://github.com/rust-lang/rust/issues/38078 - */ -pub trait UDPBind { +pub trait UDPBind: Bind { type Closer; - type Error: Error; - type Bind: Bind; /// Bind to a new port, returning the reader/writer and /// an associated instance of the Closer type, which closes the UDP socket upon "drop". - fn bind( - port: u16, - ) -> Result< - ( - <::Bind as Bind>::Reader, - <::Bind as Bind>::Writer, - Self::Closer, - ), - Self::Error, - >; + fn bind(port: u16) -> Result<(Self::Reader, Self::Writer, Self::Closer), Self::Error>; } pub trait TunBind: Tun { diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs index 7a29cd9..4e748cb 100644 --- a/src/wireguard/router/mod.rs +++ b/src/wireguard/router/mod.rs @@ -14,9 +14,13 @@ use messages::TransportHeader; use std::mem; pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::(); -pub const CAPACITY_MESSAGE_POSTFIX: usize = 16; +pub const CAPACITY_MESSAGE_POSTFIX: usize = workers::SIZE_TAG; + +pub const fn message_data_len(payload: usize) -> usize { + payload + mem::size_of::() + workers::SIZE_TAG +} -pub use messages::TYPE_TRANSPORT; pub use device::Device; +pub use messages::TYPE_TRANSPORT; pub use peer::Peer; pub use types::Callbacks; diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs index fbee39e..93c0773 100644 --- a/src/wireguard/router/tests.rs +++ b/src/wireguard/router/tests.rs @@ -28,8 +28,8 @@ mod tests { // type for tracking events inside the router module struct Flags { - send: Mutex>, - recv: Mutex>, + send: Mutex>, + recv: Mutex>, need_key: Mutex>, key_confirmed: Mutex>, } @@ -56,11 +56,11 @@ mod tests { self.0.key_confirmed.lock().unwrap().clear(); } - fn send(&self) -> Option<(usize, bool, bool)> { + fn send(&self) -> Option<(usize, bool)> { self.0.send.lock().unwrap().pop() } - fn recv(&self) -> Option<(usize, bool, bool)> { + fn recv(&self) -> Option<(usize, bool)> { self.0.recv.lock().unwrap().pop() } @@ -85,12 +85,12 @@ mod tests { impl Callbacks for TestCallbacks { type Opaque = Opaque; - fn send(t: &Self::Opaque, size: usize, data: bool, sent: bool) { - t.0.send.lock().unwrap().push((size, data, sent)) + fn send(t: &Self::Opaque, size: usize, sent: bool) { + t.0.send.lock().unwrap().push((size, sent)) } - fn recv(t: &Self::Opaque, size: usize, data: bool, sent: bool) { - t.0.recv.lock().unwrap().push((size, data, sent)) + fn recv(t: &Self::Opaque, size: usize, sent: bool) { + t.0.recv.lock().unwrap().push((size, sent)) } fn need_key(t: &Self::Opaque) { @@ -135,10 +135,10 @@ mod tests { struct BencherCallbacks {} impl Callbacks for BencherCallbacks { type Opaque = Arc; - fn send(t: &Self::Opaque, size: usize, _data: bool, _sent: bool) { + fn send(t: &Self::Opaque, size: usize, _sent: bool) { t.fetch_add(size, Ordering::SeqCst); } - fn recv(_: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {} + fn recv(_: &Self::Opaque, _size: usize, _sent: bool) {} fn need_key(_: &Self::Opaque) {} fn key_confirmed(_: &Self::Opaque) {} } @@ -253,7 +253,7 @@ mod tests { assert_eq!( opaque.send(), if set_key { - Some((SIZE_KEEPALIVE, false, false)) + Some((SIZE_KEEPALIVE, false)) } else { None }, diff --git a/src/wireguard/router/types.rs b/src/wireguard/router/types.rs index b7c3ae0..52ee4f1 100644 --- a/src/wireguard/router/types.rs +++ b/src/wireguard/router/types.rs @@ -10,9 +10,9 @@ impl Opaque for T where T: Send + Sync + 'static {} /// * `0`, a reference to the opaque value assigned to the peer /// * `1`, a bool indicating whether the message contained data (not just keepalive) /// * `2`, a bool indicating whether the message was transmitted (i.e. did the peer have an associated endpoint?) -pub trait Callback: Fn(&T, usize, bool, bool) -> () + Sync + Send + 'static {} +pub trait Callback: Fn(&T, usize, bool) -> () + Sync + Send + 'static {} -impl Callback for F where F: Fn(&T, usize, bool, bool) -> () + Sync + Send + 'static {} +impl Callback for F where F: Fn(&T, usize, bool) -> () + Sync + Send + 'static {} /// A key callback takes 1 argument /// @@ -23,8 +23,8 @@ impl KeyCallback for F where F: Fn(&T) -> () + Sync + Send + 'static {} pub trait Callbacks: Send + Sync + 'static { type Opaque: Opaque; - fn send(opaque: &Self::Opaque, size: usize, data: bool, sent: bool); - fn recv(opaque: &Self::Opaque, size: usize, data: bool, sent: bool); + fn send(opaque: &Self::Opaque, size: usize, sent: bool); + fn recv(opaque: &Self::Opaque, size: usize, sent: bool); fn need_key(opaque: &Self::Opaque); fn key_confirmed(opaque: &Self::Opaque); } diff --git a/src/wireguard/router/workers.rs b/src/wireguard/router/workers.rs index 2e89bb0..61a7620 100644 --- a/src/wireguard/router/workers.rs +++ b/src/wireguard/router/workers.rs @@ -17,10 +17,10 @@ use super::messages::{TransportHeader, TYPE_TRANSPORT}; use super::peer::PeerInner; use super::types::Callbacks; -use super::super::types::{Endpoint, tun, bind}; +use super::super::types::{bind, tun, Endpoint}; use super::ip::*; -const SIZE_TAG: usize = 16; +pub const SIZE_TAG: usize = 16; #[derive(PartialEq, Debug)] pub enum Operation { @@ -47,7 +47,7 @@ pub type JobInbound> = ( pub type JobOutbound = oneshot::Receiver; #[inline(always)] -fn check_route>( +fn check_route>( device: &Arc>, peer: &Arc>, packet: &[u8], @@ -93,7 +93,7 @@ fn check_route>( } } -pub fn worker_inbound>( +pub fn worker_inbound>( device: Arc>, // related device peer: Arc>, // related peer receiver: Receiver>, @@ -151,7 +151,8 @@ pub fn worker_inbound 0 { if let Some(inner_len) = check_route(&device, &peer, &packet[..length]) { - debug_assert!(inner_len <= length, "should be validated"); + // TODO: Consider moving the cryptkey route check to parallel decryption worker + debug_assert!(inner_len <= length, "should be validated earlier"); if inner_len <= length { sent = match device.inbound.write(&packet[..inner_len]) { Err(e) => { @@ -167,7 +168,7 @@ pub fn worker_inbound>( +pub fn worker_outbound>( device: Arc>, // related device peer: Arc>, // related peer receiver: Receiver, @@ -198,7 +199,7 @@ pub fn worker_outbound = &*device.outbound.read(); + let send: &Option = &*device.outbound.read(); if let Some(writer) = send.as_ref() { match writer.write(&buf.msg[..], dst) { Err(e) => { @@ -215,12 +216,7 @@ pub fn worker_outbound SIZE_TAG + mem::size_of::(), - xmit, - ); + C::send(&peer.opaque, buf.msg.len(), xmit); } }) .wait(); diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs index 2792c7b..1d9b8a0 100644 --- a/src/wireguard/timers.rs +++ b/src/wireguard/timers.rs @@ -1,14 +1,14 @@ use std::marker::PhantomData; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; -use std::time::Duration; +use std::time::{Duration, SystemTime}; use log::info; use hjul::{Runner, Timer}; use super::constants::*; -use super::router::Callbacks; +use super::router::{Callbacks, message_data_len}; use super::types::{bind, tun}; use super::wireguard::{Peer, PeerInner}; @@ -32,7 +32,7 @@ impl Timers { } } -impl Peer { +impl PeerInner { /* should be called after an authenticated data packet is sent */ pub fn timers_data_sent(&self) { self.timers().new_handshake.start(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT); @@ -90,11 +90,25 @@ impl Peer { * keepalive, data, or handshake is sent, or after one is received. */ pub fn timers_any_authenticated_packet_traversal(&self) { - let keepalive = self.state.keepalive.load(Ordering::Acquire); + let keepalive = self.keepalive.load(Ordering::Acquire); if keepalive > 0 { self.timers().send_persistent_keepalive.reset(Duration::from_secs(keepalive as u64)); } } + + /* Called after a handshake worker sends a handshake initiation to the peer + */ + pub fn sent_handshake_initiation(&self) { + *self.last_handshake.lock() = SystemTime::now(); + self.handshake_queued.store(false, Ordering::Acquire); + self.timers_any_authenticated_packet_traversal(); + self.timers_any_authenticated_packet_sent(); + } + + pub fn sent_handshake_response(&self) { + self.timers_any_authenticated_packet_traversal(); + self.timers_any_authenticated_packet_sent(); + } } impl Timers { @@ -212,14 +226,40 @@ pub struct Events(PhantomData<(T, B)>); impl Callbacks for Events { type Opaque = Arc>; - fn send(peer: &Self::Opaque, size: usize, data: bool, sent: bool) { + /* Called after the router encrypts a transport message destined for the peer. + * This method is called, even if the encrypted payload is empty (keepalive) + */ + fn send(peer: &Self::Opaque, size: usize, sent: bool) { + peer.timers_any_authenticated_packet_traversal(); + peer.timers_any_authenticated_packet_sent(); peer.tx_bytes.fetch_add(size as u64, Ordering::Relaxed); + if size > message_data_len(0) && sent { + peer.timers_data_sent(); + } } - fn recv(peer: &Self::Opaque, size: usize, data: bool, sent: bool) { + /* Called after the router successfully decrypts a transport message from a peer. + * This method is called, even if the decrypted packet is: + * + * - A keepalive + * - A malformed IP packet + * - Fails to cryptkey route + */ + fn recv(peer: &Self::Opaque, size: usize, sent: bool) { + peer.timers_any_authenticated_packet_traversal(); + peer.timers_any_authenticated_packet_received(); peer.rx_bytes.fetch_add(size as u64, Ordering::Relaxed); + if size > 0 && sent { + peer.timers_data_received(); + } } + /* Called every time the router detects that a key is required, + * but no valid key-material is available for the particular peer. + * + * The message is called continuously + * (e.g. for every packet that must be encrypted, until a key becomes available) + */ fn need_key(peer: &Self::Opaque) { let timers = peer.timers(); if !timers.handshake_pending.swap(true, Ordering::SeqCst) { diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index 7a22280..1363c27 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -15,7 +15,7 @@ use std::ops::Deref; use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; use std::thread; -use std::time::{Duration, Instant}; +use std::time::{Duration, Instant, SystemTime}; use std::collections::HashMap; @@ -49,6 +49,10 @@ pub struct PeerInner { pub keepalive: AtomicUsize, // keepalive interval pub rx_bytes: AtomicU64, pub tx_bytes: AtomicU64, + + pub last_handshake: Mutex, + pub handshake_queued: AtomicBool, + pub queue: Mutex>>, // handshake queue pub pk: PublicKey, // DISCUSS: Change layout in handshake module (adopt pattern of router), to avoid this. pub timers: RwLock, // @@ -75,9 +79,13 @@ impl Deref for Peer { } impl PeerInner { + /* Queue a handshake request for the parallel workers + * (if one does not already exist) + */ pub fn new_handshake(&self) { - // TODO: clear endpoint source address ("unsticky") - self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap(); + if !self.handshake_queued.swap(true, Ordering::SeqCst) { + self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap(); + } } } @@ -165,6 +173,8 @@ impl Wireguard { pub fn new_peer(&self, pk: PublicKey) -> Peer { let state = Arc::new(PeerInner { pk, + last_handshake: Mutex::new(SystemTime::UNIX_EPOCH), + handshake_queued: AtomicBool::new(false), queue: Mutex::new(self.state.queue.lock().clone()), keepalive: AtomicUsize::new(0), rx_bytes: AtomicU64::new(0), @@ -180,6 +190,7 @@ impl Wireguard { * problem of the timer callbacks being able to set timers themselves. * * This is in fact the only place where the write lock is ever taken. + * TODO: Consider the ease of using atomic pointers instead. */ *peer.timers.write() = Timers::new(&self.runner, peer.clone()); peer @@ -301,7 +312,7 @@ impl Wireguard { }, ) { Ok((pk, resp, keypair)) => { - // send response + // send response (might be cookie reply or handshake response) let mut resp_len: u64 = 0; if let Some(msg) = resp { resp_len = msg.len() as u64; @@ -316,7 +327,7 @@ impl Wireguard { } } - // update timers + // update peer state if let Some(pk) = pk { // authenticated handshake packet received if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { @@ -328,7 +339,12 @@ impl Wireguard { // update endpoint peer.router.set_endpoint(src); - // add keypair to peer + // update timers after sending handshake response + if resp_len > 0 { + peer.state.sent_handshake_response(); + } + + // add resulting keypair to peer keypair.map(|kp| { // free any unused ids for id in peer.router.add_keypair(kp) { @@ -347,6 +363,7 @@ impl Wireguard { let _ = peer.router.send(&msg[..]).map_err(|e| { debug!("handshake worker, failed to send handshake initiation, error = {}", e) }); + peer.state.sent_handshake_initiation(); } }); } -- cgit v1.2.3-59-g8ed1b