aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/platform/linux/mod.rs181
-rw-r--r--src/platform/linux/tun.rs188
-rw-r--r--src/platform/linux/udp.rs0
-rw-r--r--src/platform/mod.rs18
-rw-r--r--src/wireguard/router/mod.rs8
-rw-r--r--src/wireguard/router/tests.rs22
-rw-r--r--src/wireguard/router/types.rs8
-rw-r--r--src/wireguard/router/workers.rs24
-rw-r--r--src/wireguard/timers.rs52
-rw-r--r--src/wireguard/wireguard.rs29
10 files changed, 293 insertions, 237 deletions
diff --git a/src/platform/linux/mod.rs b/src/platform/linux/mod.rs
index ad2b8be..7a456ad 100644
--- a/src/platform/linux/mod.rs
+++ b/src/platform/linux/mod.rs
@@ -1,179 +1,4 @@
-use super::Tun;
-use super::TunBind;
+mod tun;
+mod udp;
-use super::super::wireguard::tun::*;
-
-use libc::*;
-
-use std::os::raw::c_short;
-use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
-
-const IFNAMSIZ: usize = 16;
-const TUNSETIFF: u64 = 0x4004_54ca;
-
-const IFF_UP: i16 = 0x1;
-const IFF_RUNNING: i16 = 0x40;
-
-const IFF_TUN: c_short = 0x0001;
-const IFF_NO_PI: c_short = 0x1000;
-
-use std::error::Error;
-use std::fmt;
-use std::sync::atomic::AtomicUsize;
-use std::sync::Arc;
-
-const CLONE_DEVICE_PATH: &'static [u8] = b"/dev/net/tun\0";
-
-const TUN_MAGIC: u8 = b'T';
-const TUN_SET_IFF: u8 = 202;
-
-#[repr(C)]
-struct Ifreq {
- name: [u8; libc::IFNAMSIZ],
- flags: c_short,
- _pad: [u8; 64],
-}
-
-pub struct PlatformTun {}
-
-pub struct PlatformTunReader {
- fd: RawFd,
-}
-
-pub struct PlatformTunWriter {
- fd: RawFd,
-}
-
-/* Listens for netlink messages
- * announcing an MTU update for the interface
- */
-#[derive(Clone)]
-pub struct PlatformTunMTU {
- value: Arc<AtomicUsize>,
-}
-
-#[derive(Debug)]
-pub enum LinuxTunError {
- InvalidTunDeviceName,
- FailedToOpenCloneDevice,
- SetIFFIoctlFailed,
- Closed, // TODO
-}
-
-impl fmt::Display for LinuxTunError {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- unimplemented!()
- }
-}
-
-impl Error for LinuxTunError {
- fn description(&self) -> &str {
- unimplemented!()
- }
-
- fn source(&self) -> Option<&(dyn Error + 'static)> {
- unimplemented!()
- }
-}
-
-impl MTU for PlatformTunMTU {
- fn mtu(&self) -> usize {
- 1500
- }
-}
-
-impl Reader for PlatformTunReader {
- type Error = LinuxTunError;
-
- fn read(&self, buf: &mut [u8], offset: usize) -> Result<usize, Self::Error> {
- debug_assert!(
- offset < buf.len(),
- "There is no space for the body of the TUN read"
- );
- let n = unsafe { read(self.fd, buf[offset..].as_mut_ptr() as _, buf.len() - offset) };
- if n < 0 {
- Err(LinuxTunError::Closed)
- } else {
- // conversion is safe
- Ok(n as usize)
- }
- }
-}
-
-impl Writer for PlatformTunWriter {
- type Error = LinuxTunError;
-
- fn write(&self, src: &[u8]) -> Result<(), Self::Error> {
- match unsafe { write(self.fd, src.as_ptr() as _, src.len() as _) } {
- -1 => Err(LinuxTunError::Closed),
- _ => Ok(()),
- }
- }
-}
-
-impl Tun for PlatformTun {
- type Error = LinuxTunError;
- type Reader = PlatformTunReader;
- type Writer = PlatformTunWriter;
- type MTU = PlatformTunMTU;
-}
-
-impl TunBind for PlatformTun {
- fn create(name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::MTU), Self::Error> {
- // construct request struct
- let mut req = Ifreq {
- name: [0u8; libc::IFNAMSIZ],
- flags: (libc::IFF_TUN | libc::IFF_NO_PI) as c_short,
- _pad: [0u8; 64],
- };
-
- // sanity check length of device name
- let bs = name.as_bytes();
- if bs.len() > libc::IFNAMSIZ - 1 {
- return Err(LinuxTunError::InvalidTunDeviceName);
- }
- req.name[..bs.len()].copy_from_slice(bs);
-
- // open clone device
- let fd = match unsafe { open(CLONE_DEVICE_PATH.as_ptr() as _, O_RDWR) } {
- -1 => return Err(LinuxTunError::FailedToOpenCloneDevice),
- fd => fd,
- };
-
- // create TUN device
- if unsafe { ioctl(fd, TUNSETIFF as _, &req) } < 0 {
- return Err(LinuxTunError::SetIFFIoctlFailed);
- }
-
- // create PlatformTunMTU instance
-
- Ok((
- vec![PlatformTunReader { fd }],
- PlatformTunWriter { fd },
- PlatformTunMTU {
- value: Arc::new(AtomicUsize::new(1500)),
- },
- ))
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use std::env;
-
- fn is_root() -> bool {
- match env::var("USER") {
- Ok(val) => val == "root",
- Err(e) => false,
- }
- }
-
- #[test]
- fn test_tun_create() {
- if !is_root() {
- return;
- }
- let (readers, writers, mtu) = PlatformTun::create("test").unwrap();
- }
-}
+pub use tun::PlatformTun;
diff --git a/src/platform/linux/tun.rs b/src/platform/linux/tun.rs
new file mode 100644
index 0000000..17390a1
--- /dev/null
+++ b/src/platform/linux/tun.rs
@@ -0,0 +1,188 @@
+use super::super::super::wireguard::tun::*;
+use super::super::Tun;
+use super::super::TunBind;
+
+use libc::*;
+
+use std::error::Error;
+use std::fmt;
+use std::os::raw::c_short;
+use std::os::unix::io::RawFd;
+use std::sync::atomic::{AtomicUsize, Ordering};
+use std::sync::Arc;
+
+const IFNAMSIZ: usize = 16;
+const TUNSETIFF: u64 = 0x4004_54ca;
+
+const IFF_UP: i16 = 0x1;
+const IFF_RUNNING: i16 = 0x40;
+
+const IFF_TUN: c_short = 0x0001;
+const IFF_NO_PI: c_short = 0x1000;
+
+const CLONE_DEVICE_PATH: &'static [u8] = b"/dev/net/tun\0";
+
+const TUN_MAGIC: u8 = b'T';
+const TUN_SET_IFF: u8 = 202;
+
+#[repr(C)]
+struct Ifreq {
+ name: [u8; libc::IFNAMSIZ],
+ flags: c_short,
+ _pad: [u8; 64],
+}
+
+pub struct PlatformTun {}
+
+pub struct PlatformTunReader {
+ fd: RawFd,
+}
+
+pub struct PlatformTunWriter {
+ fd: RawFd,
+}
+
+/* Listens for netlink messages
+ * announcing an MTU update for the interface
+ */
+#[derive(Clone)]
+pub struct PlatformTunMTU {
+ value: Arc<AtomicUsize>,
+}
+
+#[derive(Debug)]
+pub enum LinuxTunError {
+ InvalidTunDeviceName,
+ FailedToOpenCloneDevice,
+ SetIFFIoctlFailed,
+ Closed, // TODO
+}
+
+impl fmt::Display for LinuxTunError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ LinuxTunError::InvalidTunDeviceName => write!(f, "Invalid name (too long)"),
+ LinuxTunError::FailedToOpenCloneDevice => {
+ write!(f, "Failed to obtain fd for clone device")
+ }
+ LinuxTunError::SetIFFIoctlFailed => {
+ write!(f, "set_iff ioctl failed (insufficient permissions?)")
+ }
+ LinuxTunError::Closed => write!(f, "The tunnel has been closed"),
+ }
+ }
+}
+
+impl Error for LinuxTunError {
+ fn description(&self) -> &str {
+ unimplemented!()
+ }
+
+ fn source(&self) -> Option<&(dyn Error + 'static)> {
+ unimplemented!()
+ }
+}
+
+impl MTU for PlatformTunMTU {
+ #[inline(always)]
+ fn mtu(&self) -> usize {
+ self.value.load(Ordering::Relaxed)
+ }
+}
+
+impl Reader for PlatformTunReader {
+ type Error = LinuxTunError;
+
+ fn read(&self, buf: &mut [u8], offset: usize) -> Result<usize, Self::Error> {
+ debug_assert!(
+ offset < buf.len(),
+ "There is no space for the body of the read"
+ );
+ let n: isize =
+ unsafe { read(self.fd, buf[offset..].as_mut_ptr() as _, buf.len() - offset) };
+ if n < 0 {
+ Err(LinuxTunError::Closed)
+ } else {
+ // conversion is safe
+ Ok(n as usize)
+ }
+ }
+}
+
+impl Writer for PlatformTunWriter {
+ type Error = LinuxTunError;
+
+ fn write(&self, src: &[u8]) -> Result<(), Self::Error> {
+ match unsafe { write(self.fd, src.as_ptr() as _, src.len() as _) } {
+ -1 => Err(LinuxTunError::Closed),
+ _ => Ok(()),
+ }
+ }
+}
+
+impl Tun for PlatformTun {
+ type Error = LinuxTunError;
+ type Reader = PlatformTunReader;
+ type Writer = PlatformTunWriter;
+ type MTU = PlatformTunMTU;
+}
+
+impl TunBind for PlatformTun {
+ fn create(name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::MTU), Self::Error> {
+ // construct request struct
+ let mut req = Ifreq {
+ name: [0u8; libc::IFNAMSIZ],
+ flags: (libc::IFF_TUN | libc::IFF_NO_PI) as c_short,
+ _pad: [0u8; 64],
+ };
+
+ // sanity check length of device name
+ let bs = name.as_bytes();
+ if bs.len() > libc::IFNAMSIZ - 1 {
+ return Err(LinuxTunError::InvalidTunDeviceName);
+ }
+ req.name[..bs.len()].copy_from_slice(bs);
+
+ // open clone device
+ let fd: RawFd = match unsafe { open(CLONE_DEVICE_PATH.as_ptr() as _, O_RDWR) } {
+ -1 => return Err(LinuxTunError::FailedToOpenCloneDevice),
+ fd => fd,
+ };
+ assert!(fd >= 0);
+
+ // create TUN device
+ if unsafe { ioctl(fd, TUNSETIFF as _, &req) } < 0 {
+ return Err(LinuxTunError::SetIFFIoctlFailed);
+ }
+
+ // create PlatformTunMTU instance
+ Ok((
+ vec![PlatformTunReader { fd }], // TODO: enable multi-queue for Linux
+ PlatformTunWriter { fd },
+ PlatformTunMTU {
+ value: Arc::new(AtomicUsize::new(1500)),
+ },
+ ))
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::env;
+
+ fn is_root() -> bool {
+ match env::var("USER") {
+ Ok(val) => val == "root",
+ Err(e) => false,
+ }
+ }
+
+ #[test]
+ fn test_tun_create() {
+ if !is_root() {
+ return;
+ }
+ let (readers, writers, mtu) = PlatformTun::create("test").unwrap();
+ }
+}
diff --git a/src/platform/linux/udp.rs b/src/platform/linux/udp.rs
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/platform/linux/udp.rs
diff --git a/src/platform/mod.rs b/src/platform/mod.rs
index e83384c..de33714 100644
--- a/src/platform/mod.rs
+++ b/src/platform/mod.rs
@@ -9,26 +9,12 @@ mod linux;
#[cfg(target_os = "linux")]
pub use linux::PlatformTun;
-/* Syntax is nasty here, due to open issue:
- * https://github.com/rust-lang/rust/issues/38078
- */
-pub trait UDPBind {
+pub trait UDPBind: Bind {
type Closer;
- type Error: Error;
- type Bind: Bind;
/// Bind to a new port, returning the reader/writer and
/// an associated instance of the Closer type, which closes the UDP socket upon "drop".
- fn bind(
- port: u16,
- ) -> Result<
- (
- <<Self as UDPBind>::Bind as Bind>::Reader,
- <<Self as UDPBind>::Bind as Bind>::Writer,
- Self::Closer,
- ),
- Self::Error,
- >;
+ fn bind(port: u16) -> Result<(Self::Reader, Self::Writer, Self::Closer), Self::Error>;
}
pub trait TunBind: Tun {
diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs
index 7a29cd9..4e748cb 100644
--- a/src/wireguard/router/mod.rs
+++ b/src/wireguard/router/mod.rs
@@ -14,9 +14,13 @@ use messages::TransportHeader;
use std::mem;
pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::<TransportHeader>();
-pub const CAPACITY_MESSAGE_POSTFIX: usize = 16;
+pub const CAPACITY_MESSAGE_POSTFIX: usize = workers::SIZE_TAG;
+
+pub const fn message_data_len(payload: usize) -> usize {
+ payload + mem::size_of::<TransportHeader>() + workers::SIZE_TAG
+}
-pub use messages::TYPE_TRANSPORT;
pub use device::Device;
+pub use messages::TYPE_TRANSPORT;
pub use peer::Peer;
pub use types::Callbacks;
diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs
index fbee39e..93c0773 100644
--- a/src/wireguard/router/tests.rs
+++ b/src/wireguard/router/tests.rs
@@ -28,8 +28,8 @@ mod tests {
// type for tracking events inside the router module
struct Flags {
- send: Mutex<Vec<(usize, bool, bool)>>,
- recv: Mutex<Vec<(usize, bool, bool)>>,
+ send: Mutex<Vec<(usize, bool)>>,
+ recv: Mutex<Vec<(usize, bool)>>,
need_key: Mutex<Vec<()>>,
key_confirmed: Mutex<Vec<()>>,
}
@@ -56,11 +56,11 @@ mod tests {
self.0.key_confirmed.lock().unwrap().clear();
}
- fn send(&self) -> Option<(usize, bool, bool)> {
+ fn send(&self) -> Option<(usize, bool)> {
self.0.send.lock().unwrap().pop()
}
- fn recv(&self) -> Option<(usize, bool, bool)> {
+ fn recv(&self) -> Option<(usize, bool)> {
self.0.recv.lock().unwrap().pop()
}
@@ -85,12 +85,12 @@ mod tests {
impl Callbacks for TestCallbacks {
type Opaque = Opaque;
- fn send(t: &Self::Opaque, size: usize, data: bool, sent: bool) {
- t.0.send.lock().unwrap().push((size, data, sent))
+ fn send(t: &Self::Opaque, size: usize, sent: bool) {
+ t.0.send.lock().unwrap().push((size, sent))
}
- fn recv(t: &Self::Opaque, size: usize, data: bool, sent: bool) {
- t.0.recv.lock().unwrap().push((size, data, sent))
+ fn recv(t: &Self::Opaque, size: usize, sent: bool) {
+ t.0.recv.lock().unwrap().push((size, sent))
}
fn need_key(t: &Self::Opaque) {
@@ -135,10 +135,10 @@ mod tests {
struct BencherCallbacks {}
impl Callbacks for BencherCallbacks {
type Opaque = Arc<AtomicUsize>;
- fn send(t: &Self::Opaque, size: usize, _data: bool, _sent: bool) {
+ fn send(t: &Self::Opaque, size: usize, _sent: bool) {
t.fetch_add(size, Ordering::SeqCst);
}
- fn recv(_: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {}
+ fn recv(_: &Self::Opaque, _size: usize, _sent: bool) {}
fn need_key(_: &Self::Opaque) {}
fn key_confirmed(_: &Self::Opaque) {}
}
@@ -253,7 +253,7 @@ mod tests {
assert_eq!(
opaque.send(),
if set_key {
- Some((SIZE_KEEPALIVE, false, false))
+ Some((SIZE_KEEPALIVE, false))
} else {
None
},
diff --git a/src/wireguard/router/types.rs b/src/wireguard/router/types.rs
index b7c3ae0..52ee4f1 100644
--- a/src/wireguard/router/types.rs
+++ b/src/wireguard/router/types.rs
@@ -10,9 +10,9 @@ impl<T> Opaque for T where T: Send + Sync + 'static {}
/// * `0`, a reference to the opaque value assigned to the peer
/// * `1`, a bool indicating whether the message contained data (not just keepalive)
/// * `2`, a bool indicating whether the message was transmitted (i.e. did the peer have an associated endpoint?)
-pub trait Callback<T>: Fn(&T, usize, bool, bool) -> () + Sync + Send + 'static {}
+pub trait Callback<T>: Fn(&T, usize, bool) -> () + Sync + Send + 'static {}
-impl<T, F> Callback<T> for F where F: Fn(&T, usize, bool, bool) -> () + Sync + Send + 'static {}
+impl<T, F> Callback<T> for F where F: Fn(&T, usize, bool) -> () + Sync + Send + 'static {}
/// A key callback takes 1 argument
///
@@ -23,8 +23,8 @@ impl<T, F> KeyCallback<T> for F where F: Fn(&T) -> () + Sync + Send + 'static {}
pub trait Callbacks: Send + Sync + 'static {
type Opaque: Opaque;
- fn send(opaque: &Self::Opaque, size: usize, data: bool, sent: bool);
- fn recv(opaque: &Self::Opaque, size: usize, data: bool, sent: bool);
+ fn send(opaque: &Self::Opaque, size: usize, sent: bool);
+ fn recv(opaque: &Self::Opaque, size: usize, sent: bool);
fn need_key(opaque: &Self::Opaque);
fn key_confirmed(opaque: &Self::Opaque);
}
diff --git a/src/wireguard/router/workers.rs b/src/wireguard/router/workers.rs
index 2e89bb0..61a7620 100644
--- a/src/wireguard/router/workers.rs
+++ b/src/wireguard/router/workers.rs
@@ -17,10 +17,10 @@ use super::messages::{TransportHeader, TYPE_TRANSPORT};
use super::peer::PeerInner;
use super::types::Callbacks;
-use super::super::types::{Endpoint, tun, bind};
+use super::super::types::{bind, tun, Endpoint};
use super::ip::*;
-const SIZE_TAG: usize = 16;
+pub const SIZE_TAG: usize = 16;
#[derive(PartialEq, Debug)]
pub enum Operation {
@@ -47,7 +47,7 @@ pub type JobInbound<E, C, T, B: bind::Writer<E>> = (
pub type JobOutbound = oneshot::Receiver<JobBuffer>;
#[inline(always)]
-fn check_route<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+fn check_route<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
device: &Arc<DeviceInner<E, C, T, B>>,
peer: &Arc<PeerInner<E, C, T, B>>,
packet: &[u8],
@@ -93,7 +93,7 @@ fn check_route<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
}
}
-pub fn worker_inbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+pub fn worker_inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
device: Arc<DeviceInner<E, C, T, B>>, // related device
peer: Arc<PeerInner<E, C, T, B>>, // related peer
receiver: Receiver<JobInbound<E, C, T, B>>,
@@ -151,7 +151,8 @@ pub fn worker_inbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Write
let mut sent = false;
if length > 0 {
if let Some(inner_len) = check_route(&device, &peer, &packet[..length]) {
- debug_assert!(inner_len <= length, "should be validated");
+ // TODO: Consider moving the cryptkey route check to parallel decryption worker
+ debug_assert!(inner_len <= length, "should be validated earlier");
if inner_len <= length {
sent = match device.inbound.write(&packet[..inner_len]) {
Err(e) => {
@@ -167,7 +168,7 @@ pub fn worker_inbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Write
}
// trigger callback
- C::recv(&peer.opaque, buf.msg.len(), length == 0, sent);
+ C::recv(&peer.opaque, buf.msg.len(), sent);
} else {
debug!("inbound worker: authentication failure")
}
@@ -176,7 +177,7 @@ pub fn worker_inbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Write
}
}
-pub fn worker_outbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+pub fn worker_outbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
device: Arc<DeviceInner<E, C, T, B>>, // related device
peer: Arc<PeerInner<E, C, T, B>>, // related peer
receiver: Receiver<JobOutbound>,
@@ -198,7 +199,7 @@ pub fn worker_outbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writ
if buf.okay {
// write to UDP bind
let xmit = if let Some(dst) = peer.endpoint.lock().as_ref() {
- let send : &Option<B> = &*device.outbound.read();
+ let send: &Option<B> = &*device.outbound.read();
if let Some(writer) = send.as_ref() {
match writer.write(&buf.msg[..], dst) {
Err(e) => {
@@ -215,12 +216,7 @@ pub fn worker_outbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writ
};
// trigger callback
- C::send(
- &peer.opaque,
- buf.msg.len(),
- buf.msg.len() > SIZE_TAG + mem::size_of::<TransportHeader>(),
- xmit,
- );
+ C::send(&peer.opaque, buf.msg.len(), xmit);
}
})
.wait();
diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs
index 2792c7b..1d9b8a0 100644
--- a/src/wireguard/timers.rs
+++ b/src/wireguard/timers.rs
@@ -1,14 +1,14 @@
use std::marker::PhantomData;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
-use std::time::Duration;
+use std::time::{Duration, SystemTime};
use log::info;
use hjul::{Runner, Timer};
use super::constants::*;
-use super::router::Callbacks;
+use super::router::{Callbacks, message_data_len};
use super::types::{bind, tun};
use super::wireguard::{Peer, PeerInner};
@@ -32,7 +32,7 @@ impl Timers {
}
}
-impl <T: tun::Tun, B: bind::Bind>Peer<T, B> {
+impl <B: bind::Bind>PeerInner<B> {
/* should be called after an authenticated data packet is sent */
pub fn timers_data_sent(&self) {
self.timers().new_handshake.start(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT);
@@ -90,11 +90,25 @@ impl <T: tun::Tun, B: bind::Bind>Peer<T, B> {
* keepalive, data, or handshake is sent, or after one is received.
*/
pub fn timers_any_authenticated_packet_traversal(&self) {
- let keepalive = self.state.keepalive.load(Ordering::Acquire);
+ let keepalive = self.keepalive.load(Ordering::Acquire);
if keepalive > 0 {
self.timers().send_persistent_keepalive.reset(Duration::from_secs(keepalive as u64));
}
}
+
+ /* Called after a handshake worker sends a handshake initiation to the peer
+ */
+ pub fn sent_handshake_initiation(&self) {
+ *self.last_handshake.lock() = SystemTime::now();
+ self.handshake_queued.store(false, Ordering::Acquire);
+ self.timers_any_authenticated_packet_traversal();
+ self.timers_any_authenticated_packet_sent();
+ }
+
+ pub fn sent_handshake_response(&self) {
+ self.timers_any_authenticated_packet_traversal();
+ self.timers_any_authenticated_packet_sent();
+ }
}
impl Timers {
@@ -212,14 +226,40 @@ pub struct Events<T, B>(PhantomData<(T, B)>);
impl<T: tun::Tun, B: bind::Bind> Callbacks for Events<T, B> {
type Opaque = Arc<PeerInner<B>>;
- fn send(peer: &Self::Opaque, size: usize, data: bool, sent: bool) {
+ /* Called after the router encrypts a transport message destined for the peer.
+ * This method is called, even if the encrypted payload is empty (keepalive)
+ */
+ fn send(peer: &Self::Opaque, size: usize, sent: bool) {
+ peer.timers_any_authenticated_packet_traversal();
+ peer.timers_any_authenticated_packet_sent();
peer.tx_bytes.fetch_add(size as u64, Ordering::Relaxed);
+ if size > message_data_len(0) && sent {
+ peer.timers_data_sent();
+ }
}
- fn recv(peer: &Self::Opaque, size: usize, data: bool, sent: bool) {
+ /* Called after the router successfully decrypts a transport message from a peer.
+ * This method is called, even if the decrypted packet is:
+ *
+ * - A keepalive
+ * - A malformed IP packet
+ * - Fails to cryptkey route
+ */
+ fn recv(peer: &Self::Opaque, size: usize, sent: bool) {
+ peer.timers_any_authenticated_packet_traversal();
+ peer.timers_any_authenticated_packet_received();
peer.rx_bytes.fetch_add(size as u64, Ordering::Relaxed);
+ if size > 0 && sent {
+ peer.timers_data_received();
+ }
}
+ /* Called every time the router detects that a key is required,
+ * but no valid key-material is available for the particular peer.
+ *
+ * The message is called continuously
+ * (e.g. for every packet that must be encrypted, until a key becomes available)
+ */
fn need_key(peer: &Self::Opaque) {
let timers = peer.timers();
if !timers.handshake_pending.swap(true, Ordering::SeqCst) {
diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs
index 7a22280..1363c27 100644
--- a/src/wireguard/wireguard.rs
+++ b/src/wireguard/wireguard.rs
@@ -15,7 +15,7 @@ use std::ops::Deref;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
-use std::time::{Duration, Instant};
+use std::time::{Duration, Instant, SystemTime};
use std::collections::HashMap;
@@ -49,6 +49,10 @@ pub struct PeerInner<B: Bind> {
pub keepalive: AtomicUsize, // keepalive interval
pub rx_bytes: AtomicU64,
pub tx_bytes: AtomicU64,
+
+ pub last_handshake: Mutex<SystemTime>,
+ pub handshake_queued: AtomicBool,
+
pub queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, // handshake queue
pub pk: PublicKey, // DISCUSS: Change layout in handshake module (adopt pattern of router), to avoid this.
pub timers: RwLock<Timers>, //
@@ -75,9 +79,13 @@ impl<T: Tun, B: Bind> Deref for Peer<T, B> {
}
impl<B: Bind> PeerInner<B> {
+ /* Queue a handshake request for the parallel workers
+ * (if one does not already exist)
+ */
pub fn new_handshake(&self) {
- // TODO: clear endpoint source address ("unsticky")
- self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap();
+ if !self.handshake_queued.swap(true, Ordering::SeqCst) {
+ self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap();
+ }
}
}
@@ -165,6 +173,8 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
pub fn new_peer(&self, pk: PublicKey) -> Peer<T, B> {
let state = Arc::new(PeerInner {
pk,
+ last_handshake: Mutex::new(SystemTime::UNIX_EPOCH),
+ handshake_queued: AtomicBool::new(false),
queue: Mutex::new(self.state.queue.lock().clone()),
keepalive: AtomicUsize::new(0),
rx_bytes: AtomicU64::new(0),
@@ -180,6 +190,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
* problem of the timer callbacks being able to set timers themselves.
*
* This is in fact the only place where the write lock is ever taken.
+ * TODO: Consider the ease of using atomic pointers instead.
*/
*peer.timers.write() = Timers::new(&self.runner, peer.clone());
peer
@@ -301,7 +312,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
},
) {
Ok((pk, resp, keypair)) => {
- // send response
+ // send response (might be cookie reply or handshake response)
let mut resp_len: u64 = 0;
if let Some(msg) = resp {
resp_len = msg.len() as u64;
@@ -316,7 +327,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
}
}
- // update timers
+ // update peer state
if let Some(pk) = pk {
// authenticated handshake packet received
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
@@ -328,7 +339,12 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
// update endpoint
peer.router.set_endpoint(src);
- // add keypair to peer
+ // update timers after sending handshake response
+ if resp_len > 0 {
+ peer.state.sent_handshake_response();
+ }
+
+ // add resulting keypair to peer
keypair.map(|kp| {
// free any unused ids
for id in peer.router.add_keypair(kp) {
@@ -347,6 +363,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
let _ = peer.router.send(&msg[..]).map_err(|e| {
debug!("handshake worker, failed to send handshake initiation, error = {}", e)
});
+ peer.state.sent_handshake_initiation();
}
});
}