diff options
author | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-10-16 13:40:40 +0200 |
---|---|---|
committer | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-10-16 13:40:40 +0200 |
commit | 2f3ceab0364497a4a6cf866b505f74443ed6e3ae (patch) | |
tree | 5ed11473dc4b4d6f265fc739c0600db972a28ed5 /src/wireguard | |
parent | Work on Linux platform code (diff) | |
download | wireguard-rs-2f3ceab0364497a4a6cf866b505f74443ed6e3ae.tar.xz wireguard-rs-2f3ceab0364497a4a6cf866b505f74443ed6e3ae.zip |
Work on porting timer semantics and linux platform
Diffstat (limited to 'src/wireguard')
-rw-r--r-- | src/wireguard/router/mod.rs | 8 | ||||
-rw-r--r-- | src/wireguard/router/tests.rs | 22 | ||||
-rw-r--r-- | src/wireguard/router/types.rs | 8 | ||||
-rw-r--r-- | src/wireguard/router/workers.rs | 24 | ||||
-rw-r--r-- | src/wireguard/timers.rs | 52 | ||||
-rw-r--r-- | src/wireguard/wireguard.rs | 29 |
6 files changed, 100 insertions, 43 deletions
diff --git a/src/wireguard/router/mod.rs b/src/wireguard/router/mod.rs index 7a29cd9..4e748cb 100644 --- a/src/wireguard/router/mod.rs +++ b/src/wireguard/router/mod.rs @@ -14,9 +14,13 @@ use messages::TransportHeader; use std::mem; pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::<TransportHeader>(); -pub const CAPACITY_MESSAGE_POSTFIX: usize = 16; +pub const CAPACITY_MESSAGE_POSTFIX: usize = workers::SIZE_TAG; + +pub const fn message_data_len(payload: usize) -> usize { + payload + mem::size_of::<TransportHeader>() + workers::SIZE_TAG +} -pub use messages::TYPE_TRANSPORT; pub use device::Device; +pub use messages::TYPE_TRANSPORT; pub use peer::Peer; pub use types::Callbacks; diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs index fbee39e..93c0773 100644 --- a/src/wireguard/router/tests.rs +++ b/src/wireguard/router/tests.rs @@ -28,8 +28,8 @@ mod tests { // type for tracking events inside the router module struct Flags { - send: Mutex<Vec<(usize, bool, bool)>>, - recv: Mutex<Vec<(usize, bool, bool)>>, + send: Mutex<Vec<(usize, bool)>>, + recv: Mutex<Vec<(usize, bool)>>, need_key: Mutex<Vec<()>>, key_confirmed: Mutex<Vec<()>>, } @@ -56,11 +56,11 @@ mod tests { self.0.key_confirmed.lock().unwrap().clear(); } - fn send(&self) -> Option<(usize, bool, bool)> { + fn send(&self) -> Option<(usize, bool)> { self.0.send.lock().unwrap().pop() } - fn recv(&self) -> Option<(usize, bool, bool)> { + fn recv(&self) -> Option<(usize, bool)> { self.0.recv.lock().unwrap().pop() } @@ -85,12 +85,12 @@ mod tests { impl Callbacks for TestCallbacks { type Opaque = Opaque; - fn send(t: &Self::Opaque, size: usize, data: bool, sent: bool) { - t.0.send.lock().unwrap().push((size, data, sent)) + fn send(t: &Self::Opaque, size: usize, sent: bool) { + t.0.send.lock().unwrap().push((size, sent)) } - fn recv(t: &Self::Opaque, size: usize, data: bool, sent: bool) { - t.0.recv.lock().unwrap().push((size, data, sent)) + fn recv(t: &Self::Opaque, size: usize, sent: bool) { + t.0.recv.lock().unwrap().push((size, sent)) } fn need_key(t: &Self::Opaque) { @@ -135,10 +135,10 @@ mod tests { struct BencherCallbacks {} impl Callbacks for BencherCallbacks { type Opaque = Arc<AtomicUsize>; - fn send(t: &Self::Opaque, size: usize, _data: bool, _sent: bool) { + fn send(t: &Self::Opaque, size: usize, _sent: bool) { t.fetch_add(size, Ordering::SeqCst); } - fn recv(_: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {} + fn recv(_: &Self::Opaque, _size: usize, _sent: bool) {} fn need_key(_: &Self::Opaque) {} fn key_confirmed(_: &Self::Opaque) {} } @@ -253,7 +253,7 @@ mod tests { assert_eq!( opaque.send(), if set_key { - Some((SIZE_KEEPALIVE, false, false)) + Some((SIZE_KEEPALIVE, false)) } else { None }, diff --git a/src/wireguard/router/types.rs b/src/wireguard/router/types.rs index b7c3ae0..52ee4f1 100644 --- a/src/wireguard/router/types.rs +++ b/src/wireguard/router/types.rs @@ -10,9 +10,9 @@ impl<T> Opaque for T where T: Send + Sync + 'static {} /// * `0`, a reference to the opaque value assigned to the peer /// * `1`, a bool indicating whether the message contained data (not just keepalive) /// * `2`, a bool indicating whether the message was transmitted (i.e. did the peer have an associated endpoint?) -pub trait Callback<T>: Fn(&T, usize, bool, bool) -> () + Sync + Send + 'static {} +pub trait Callback<T>: Fn(&T, usize, bool) -> () + Sync + Send + 'static {} -impl<T, F> Callback<T> for F where F: Fn(&T, usize, bool, bool) -> () + Sync + Send + 'static {} +impl<T, F> Callback<T> for F where F: Fn(&T, usize, bool) -> () + Sync + Send + 'static {} /// A key callback takes 1 argument /// @@ -23,8 +23,8 @@ impl<T, F> KeyCallback<T> for F where F: Fn(&T) -> () + Sync + Send + 'static {} pub trait Callbacks: Send + Sync + 'static { type Opaque: Opaque; - fn send(opaque: &Self::Opaque, size: usize, data: bool, sent: bool); - fn recv(opaque: &Self::Opaque, size: usize, data: bool, sent: bool); + fn send(opaque: &Self::Opaque, size: usize, sent: bool); + fn recv(opaque: &Self::Opaque, size: usize, sent: bool); fn need_key(opaque: &Self::Opaque); fn key_confirmed(opaque: &Self::Opaque); } diff --git a/src/wireguard/router/workers.rs b/src/wireguard/router/workers.rs index 2e89bb0..61a7620 100644 --- a/src/wireguard/router/workers.rs +++ b/src/wireguard/router/workers.rs @@ -17,10 +17,10 @@ use super::messages::{TransportHeader, TYPE_TRANSPORT}; use super::peer::PeerInner; use super::types::Callbacks; -use super::super::types::{Endpoint, tun, bind}; +use super::super::types::{bind, tun, Endpoint}; use super::ip::*; -const SIZE_TAG: usize = 16; +pub const SIZE_TAG: usize = 16; #[derive(PartialEq, Debug)] pub enum Operation { @@ -47,7 +47,7 @@ pub type JobInbound<E, C, T, B: bind::Writer<E>> = ( pub type JobOutbound = oneshot::Receiver<JobBuffer>; #[inline(always)] -fn check_route<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( +fn check_route<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( device: &Arc<DeviceInner<E, C, T, B>>, peer: &Arc<PeerInner<E, C, T, B>>, packet: &[u8], @@ -93,7 +93,7 @@ fn check_route<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( } } -pub fn worker_inbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( +pub fn worker_inbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( device: Arc<DeviceInner<E, C, T, B>>, // related device peer: Arc<PeerInner<E, C, T, B>>, // related peer receiver: Receiver<JobInbound<E, C, T, B>>, @@ -151,7 +151,8 @@ pub fn worker_inbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Write let mut sent = false; if length > 0 { if let Some(inner_len) = check_route(&device, &peer, &packet[..length]) { - debug_assert!(inner_len <= length, "should be validated"); + // TODO: Consider moving the cryptkey route check to parallel decryption worker + debug_assert!(inner_len <= length, "should be validated earlier"); if inner_len <= length { sent = match device.inbound.write(&packet[..inner_len]) { Err(e) => { @@ -167,7 +168,7 @@ pub fn worker_inbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Write } // trigger callback - C::recv(&peer.opaque, buf.msg.len(), length == 0, sent); + C::recv(&peer.opaque, buf.msg.len(), sent); } else { debug!("inbound worker: authentication failure") } @@ -176,7 +177,7 @@ pub fn worker_inbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Write } } -pub fn worker_outbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( +pub fn worker_outbound<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( device: Arc<DeviceInner<E, C, T, B>>, // related device peer: Arc<PeerInner<E, C, T, B>>, // related peer receiver: Receiver<JobOutbound>, @@ -198,7 +199,7 @@ pub fn worker_outbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writ if buf.okay { // write to UDP bind let xmit = if let Some(dst) = peer.endpoint.lock().as_ref() { - let send : &Option<B> = &*device.outbound.read(); + let send: &Option<B> = &*device.outbound.read(); if let Some(writer) = send.as_ref() { match writer.write(&buf.msg[..], dst) { Err(e) => { @@ -215,12 +216,7 @@ pub fn worker_outbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writ }; // trigger callback - C::send( - &peer.opaque, - buf.msg.len(), - buf.msg.len() > SIZE_TAG + mem::size_of::<TransportHeader>(), - xmit, - ); + C::send(&peer.opaque, buf.msg.len(), xmit); } }) .wait(); diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs index 2792c7b..1d9b8a0 100644 --- a/src/wireguard/timers.rs +++ b/src/wireguard/timers.rs @@ -1,14 +1,14 @@ use std::marker::PhantomData; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; -use std::time::Duration; +use std::time::{Duration, SystemTime}; use log::info; use hjul::{Runner, Timer}; use super::constants::*; -use super::router::Callbacks; +use super::router::{Callbacks, message_data_len}; use super::types::{bind, tun}; use super::wireguard::{Peer, PeerInner}; @@ -32,7 +32,7 @@ impl Timers { } } -impl <T: tun::Tun, B: bind::Bind>Peer<T, B> { +impl <B: bind::Bind>PeerInner<B> { /* should be called after an authenticated data packet is sent */ pub fn timers_data_sent(&self) { self.timers().new_handshake.start(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT); @@ -90,11 +90,25 @@ impl <T: tun::Tun, B: bind::Bind>Peer<T, B> { * keepalive, data, or handshake is sent, or after one is received. */ pub fn timers_any_authenticated_packet_traversal(&self) { - let keepalive = self.state.keepalive.load(Ordering::Acquire); + let keepalive = self.keepalive.load(Ordering::Acquire); if keepalive > 0 { self.timers().send_persistent_keepalive.reset(Duration::from_secs(keepalive as u64)); } } + + /* Called after a handshake worker sends a handshake initiation to the peer + */ + pub fn sent_handshake_initiation(&self) { + *self.last_handshake.lock() = SystemTime::now(); + self.handshake_queued.store(false, Ordering::Acquire); + self.timers_any_authenticated_packet_traversal(); + self.timers_any_authenticated_packet_sent(); + } + + pub fn sent_handshake_response(&self) { + self.timers_any_authenticated_packet_traversal(); + self.timers_any_authenticated_packet_sent(); + } } impl Timers { @@ -212,14 +226,40 @@ pub struct Events<T, B>(PhantomData<(T, B)>); impl<T: tun::Tun, B: bind::Bind> Callbacks for Events<T, B> { type Opaque = Arc<PeerInner<B>>; - fn send(peer: &Self::Opaque, size: usize, data: bool, sent: bool) { + /* Called after the router encrypts a transport message destined for the peer. + * This method is called, even if the encrypted payload is empty (keepalive) + */ + fn send(peer: &Self::Opaque, size: usize, sent: bool) { + peer.timers_any_authenticated_packet_traversal(); + peer.timers_any_authenticated_packet_sent(); peer.tx_bytes.fetch_add(size as u64, Ordering::Relaxed); + if size > message_data_len(0) && sent { + peer.timers_data_sent(); + } } - fn recv(peer: &Self::Opaque, size: usize, data: bool, sent: bool) { + /* Called after the router successfully decrypts a transport message from a peer. + * This method is called, even if the decrypted packet is: + * + * - A keepalive + * - A malformed IP packet + * - Fails to cryptkey route + */ + fn recv(peer: &Self::Opaque, size: usize, sent: bool) { + peer.timers_any_authenticated_packet_traversal(); + peer.timers_any_authenticated_packet_received(); peer.rx_bytes.fetch_add(size as u64, Ordering::Relaxed); + if size > 0 && sent { + peer.timers_data_received(); + } } + /* Called every time the router detects that a key is required, + * but no valid key-material is available for the particular peer. + * + * The message is called continuously + * (e.g. for every packet that must be encrypted, until a key becomes available) + */ fn need_key(peer: &Self::Opaque) { let timers = peer.timers(); if !timers.handshake_pending.swap(true, Ordering::SeqCst) { diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index 7a22280..1363c27 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -15,7 +15,7 @@ use std::ops::Deref; use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; use std::thread; -use std::time::{Duration, Instant}; +use std::time::{Duration, Instant, SystemTime}; use std::collections::HashMap; @@ -49,6 +49,10 @@ pub struct PeerInner<B: Bind> { pub keepalive: AtomicUsize, // keepalive interval pub rx_bytes: AtomicU64, pub tx_bytes: AtomicU64, + + pub last_handshake: Mutex<SystemTime>, + pub handshake_queued: AtomicBool, + pub queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, // handshake queue pub pk: PublicKey, // DISCUSS: Change layout in handshake module (adopt pattern of router), to avoid this. pub timers: RwLock<Timers>, // @@ -75,9 +79,13 @@ impl<T: Tun, B: Bind> Deref for Peer<T, B> { } impl<B: Bind> PeerInner<B> { + /* Queue a handshake request for the parallel workers + * (if one does not already exist) + */ pub fn new_handshake(&self) { - // TODO: clear endpoint source address ("unsticky") - self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap(); + if !self.handshake_queued.swap(true, Ordering::SeqCst) { + self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap(); + } } } @@ -165,6 +173,8 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { pub fn new_peer(&self, pk: PublicKey) -> Peer<T, B> { let state = Arc::new(PeerInner { pk, + last_handshake: Mutex::new(SystemTime::UNIX_EPOCH), + handshake_queued: AtomicBool::new(false), queue: Mutex::new(self.state.queue.lock().clone()), keepalive: AtomicUsize::new(0), rx_bytes: AtomicU64::new(0), @@ -180,6 +190,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { * problem of the timer callbacks being able to set timers themselves. * * This is in fact the only place where the write lock is ever taken. + * TODO: Consider the ease of using atomic pointers instead. */ *peer.timers.write() = Timers::new(&self.runner, peer.clone()); peer @@ -301,7 +312,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { }, ) { Ok((pk, resp, keypair)) => { - // send response + // send response (might be cookie reply or handshake response) let mut resp_len: u64 = 0; if let Some(msg) = resp { resp_len = msg.len() as u64; @@ -316,7 +327,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { } } - // update timers + // update peer state if let Some(pk) = pk { // authenticated handshake packet received if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { @@ -328,7 +339,12 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { // update endpoint peer.router.set_endpoint(src); - // add keypair to peer + // update timers after sending handshake response + if resp_len > 0 { + peer.state.sent_handshake_response(); + } + + // add resulting keypair to peer keypair.map(|kp| { // free any unused ids for id in peer.router.add_keypair(kp) { @@ -347,6 +363,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { let _ = peer.router.send(&msg[..]).map_err(|e| { debug!("handshake worker, failed to send handshake initiation, error = {}", e) }); + peer.state.sent_handshake_initiation(); } }); } |