diff options
author | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-10-09 20:22:16 +0200 |
---|---|---|
committer | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-10-09 20:22:16 +0200 |
commit | 7ce5415169097839cf711b02ff4188f9a585b7a2 (patch) | |
tree | 7e0158b52fb700c1d53d3c0bf2fc9a5ac3592493 /src | |
parent | Restructure IO traits. (diff) | |
download | wireguard-rs-7ce5415169097839cf711b02ff4188f9a585b7a2.tar.xz wireguard-rs-7ce5415169097839cf711b02ff4188f9a585b7a2.zip |
Start porting kernel timer semantics
Diffstat (limited to 'src')
-rw-r--r-- | src/main.rs | 1 | ||||
-rw-r--r-- | src/router/peer.rs | 34 | ||||
-rw-r--r-- | src/timers.rs | 79 | ||||
-rw-r--r-- | src/wireguard.rs | 51 |
4 files changed, 115 insertions, 50 deletions
diff --git a/src/main.rs b/src/main.rs index 7a31119..9b69f54 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,5 @@ #![feature(test)] +#![allow(dead_code)] extern crate jemallocator; diff --git a/src/router/peer.rs b/src/router/peer.rs index 189904c..13e5af4 100644 --- a/src/router/peer.rs +++ b/src/router/peer.rs @@ -14,7 +14,7 @@ use treebitmap::IpLookupTable; use zerocopy::LayoutVerified; use super::super::constants::*; -use super::super::types::{Endpoint, KeyPair, bind, tun}; +use super::super::types::{bind, tun, Endpoint, KeyPair}; use super::anti_replay::AntiReplay; use super::device::DecryptionState; @@ -39,7 +39,7 @@ pub struct KeyWheel { retired: Vec<u32>, // retired ids } -pub struct PeerInner<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { +pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { pub device: Arc<DeviceInner<E, C, T, B>>, pub opaque: C::Opaque, pub outbound: Mutex<SyncSender<JobOutbound>>, @@ -50,13 +50,13 @@ pub struct PeerInner<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer pub endpoint: Mutex<Option<E>>, } -pub struct Peer<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { +pub struct Peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { state: Arc<PeerInner<E, C, T, B>>, thread_outbound: Option<thread::JoinHandle<()>>, thread_inbound: Option<thread::JoinHandle<()>>, } -fn treebit_list<A, R, E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( +fn treebit_list<A, R, E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( peer: &Arc<PeerInner<E, C, T, B>>, table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<E, C, T, B>>>>, callback: Box<dyn Fn(A, u32) -> R>, @@ -74,7 +74,7 @@ where res } -fn treebit_remove<E : Endpoint, A: Address, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( +fn treebit_remove<E: Endpoint, A: Address, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( peer: &Peer<E, C, T, B>, table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<E, C, T, B>>>>, ) { @@ -107,8 +107,11 @@ impl EncryptionState { } } -impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> DecryptionState<E, C, T, B> { - fn new(peer: &Arc<PeerInner<E, C, T, B>>, keypair: &Arc<KeyPair>) -> DecryptionState<E, C, T, B> { +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> DecryptionState<E, C, T, B> { + fn new( + peer: &Arc<PeerInner<E, C, T, B>>, + keypair: &Arc<KeyPair>, + ) -> DecryptionState<E, C, T, B> { DecryptionState { confirmed: AtomicBool::new(keypair.initiator), keypair: keypair.clone(), @@ -119,7 +122,7 @@ impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> DecryptionS } } -impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Peer<E, C, T, B> { +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Peer<E, C, T, B> { fn drop(&mut self) { let peer = &self.state; @@ -167,7 +170,7 @@ impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Pe } } -pub fn new_peer<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( +pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( device: Arc<DeviceInner<E, C, T, B>>, opaque: C::Opaque, ) -> Peer<E, C, T, B> { @@ -215,7 +218,7 @@ pub fn new_peer<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( } } -impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> PeerInner<E, C, T, B> { +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> PeerInner<E, C, T, B> { fn send_staged(&self) -> bool { debug!("peer.send_staged"); let mut sent = false; @@ -370,7 +373,7 @@ impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> PeerInner<E } } -impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T, B> { +impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T, B> { /// Set the endpoint of the peer /// /// # Arguments @@ -591,13 +594,18 @@ impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, debug!("peer.send"); let inner = &self.state; match inner.endpoint.lock().as_ref() { - Some(endpoint) => inner.device + Some(endpoint) => inner + .device .outbound .read() .as_ref() .ok_or(RouterError::SendError) - .and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError) ), + .and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError)), None => Err(RouterError::NoEndpoint), } } + + pub fn purge_staged_packets(&self) { + self.state.staged_packets.lock().clear(); + } } diff --git a/src/timers.rs b/src/timers.rs index 67ece06..23cbb87 100644 --- a/src/timers.rs +++ b/src/timers.rs @@ -3,11 +3,13 @@ use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; use std::time::Duration; +use log::info; + use hjul::{Runner, Timer}; use crate::constants::*; use crate::router::Callbacks; -use crate::types::{tun, bind}; +use crate::types::{bind, tun}; use crate::wireguard::{Peer, PeerInner}; pub struct Timers { @@ -16,8 +18,17 @@ pub struct Timers { retransmit_handshake: Timer, send_keepalive: Timer, + send_persistent_keepalive: Timer, zero_key_material: Timer, new_handshake: Timer, + need_another_keepalive: AtomicBool, +} + +impl Timers { + #[inline(always)] + fn need_another_keepalive(&self) -> bool { + self.need_another_keepalive.swap(false, Ordering::SeqCst) + } } impl Timers { @@ -28,34 +39,42 @@ impl Timers { { // create a timer instance for the provided peer Timers { + need_another_keepalive: AtomicBool::new(false), handshake_pending: AtomicBool::new(false), handshake_attempts: AtomicUsize::new(0), retransmit_handshake: { let peer = peer.clone(); runner.timer(move || { - if peer.timers.read().handshake_retry() { + if peer.timers().handshake_retry() { + info!("Retransmit handshake for {}", peer); peer.new_handshake(); + } else { + info!("Failed to complete handshake for {}", peer); + peer.router.purge_staged_packets(); + peer.timers().send_keepalive.stop(); + peer.timers().zero_key_material.start(REJECT_AFTER_TIME * 3); } }) }, - new_handshake: { + send_keepalive: { let peer = peer.clone(); runner.timer(move || { - peer.new_handshake(); - peer.timers.read().handshake_begun(); + peer.router.send_keepalive(); + if peer.timers().need_another_keepalive() { + peer.timers().send_keepalive.start(KEEPALIVE_TIMEOUT); + } }) }, - send_keepalive: { + new_handshake: { let peer = peer.clone(); runner.timer(move || { - peer.router.send_keepalive(); - let keepalive = peer.keepalive.load(Ordering::Acquire); - if keepalive > 0 { - peer.timers - .read() - .send_keepalive - .reset(Duration::from_secs(keepalive as u64)) - } + info!( + "Retrying handshake with {}, because we stopped hearing back after {} seconds", + peer, + (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT).as_secs() + ); + peer.new_handshake(); + peer.timers.read().handshake_begun(); }) }, zero_key_material: { @@ -64,6 +83,17 @@ impl Timers { peer.router.zero_keys(); }) }, + send_persistent_keepalive: { + let peer = peer.clone(); + runner.timer(move || { + let keepalive = peer.state.keepalive.load(Ordering::Acquire); + if keepalive > 0 { + peer.router.send_keepalive(); + peer.timers().send_keepalive.stop(); + peer.timers().send_persistent_keepalive.start(Duration::from_secs(keepalive as u64)); + } + }) + } } } @@ -83,6 +113,12 @@ impl Timers { } } + pub fn updated_persistent_keepalive(&self, keepalive: usize) { + if keepalive > 0 { + self.send_persistent_keepalive.reset(Duration::from_secs(keepalive as u64)); + } + } + pub fn dummy(runner: &Runner) -> Timers { Timers { handshake_pending: AtomicBool::new(false), @@ -90,13 +126,28 @@ impl Timers { retransmit_handshake: runner.timer(|| {}), new_handshake: runner.timer(|| {}), send_keepalive: runner.timer(|| {}), + send_persistent_keepalive: runner.timer(|| {}), zero_key_material: runner.timer(|| {}), + need_another_keepalive: AtomicBool::new(false), } } pub fn handshake_sent(&self) { self.send_keepalive.stop(); } + + + pub fn any_authenticatec_packet_recieved(&self) { + + } + + pub fn handshake_initiated(&self) { + + } + + pub fn handhsake_complete(&self) { + + } } /* Instance of the router callbacks */ diff --git a/src/wireguard.rs b/src/wireguard.rs index ba81f47..bcb8592 100644 --- a/src/wireguard.rs +++ b/src/wireguard.rs @@ -3,12 +3,13 @@ use crate::handshake; use crate::router; use crate::timers::{Events, Timers}; -use crate::types::Endpoint; -use crate::types::tun::{Tun, Reader, MTU}; use crate::types::bind::{Bind, Writer}; +use crate::types::tun::{Reader, Tun, MTU}; +use crate::types::Endpoint; use hjul::Runner; +use std::fmt; use std::ops::Deref; use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; @@ -19,7 +20,7 @@ use std::collections::HashMap; use log::debug; use rand::rngs::OsRng; -use spin::{Mutex, RwLock}; +use spin::{Mutex, RwLock, RwLockReadGuard}; use byteorder::{ByteOrder, LittleEndian}; use crossbeam_channel::{bounded, Sender}; @@ -34,11 +35,11 @@ pub struct Peer<T: Tun, B: Bind> { pub state: Arc<PeerInner<B>>, } -impl <T : Tun, B : Bind> Clone for Peer<T, B > { +impl<T: Tun, B: Bind> Clone for Peer<T, B> { fn clone(&self) -> Peer<T, B> { - Peer{ + Peer { router: self.router.clone(), - state: self.state.clone() + state: self.state.clone(), } } } @@ -52,6 +53,19 @@ pub struct PeerInner<B: Bind> { pub timers: RwLock<Timers>, // } +impl <B:Bind > PeerInner<B> { + #[inline(always)] + pub fn timers(&self) -> RwLockReadGuard<Timers> { + self.timers.read() + } +} + +impl<T: Tun, B: Bind> fmt::Display for Peer<T, B> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "peer()") + } +} + impl<T: Tun, B: Bind> Deref for Peer<T, B> { type Target = PeerInner<B>; fn deref(&self) -> &Self::Target { @@ -61,6 +75,7 @@ impl<T: Tun, B: Bind> Deref for Peer<T, B> { impl<B: Bind> PeerInner<B> { pub fn new_handshake(&self) { + // TODO: clear endpoint source address ("unsticky") self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap(); } } @@ -76,7 +91,7 @@ pub enum HandshakeJob<E> { } struct WireguardInner<T: Tun, B: Bind> { - // provides access to the MTU value of the tun device + // provides access to the MTU value of the tun device // (otherwise owned solely by the router and a dedicated read IO thread) mtu: T::MTU, send: RwLock<Option<B::Writer>>, @@ -169,18 +184,12 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { peer } - pub fn new_bind( - reader: B::Reader, - writer: B::Writer, - closer: B::Closer - ) { + pub fn new_bind(reader: B::Reader, writer: B::Writer, closer: B::Closer) { // drop existing closer - // swap IO thread for new reader - // start UDP read IO thread /* @@ -232,15 +241,9 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { }); } */ - - } - pub fn new( - reader: T::Reader, - writer: T::Writer, - mtu: T::MTU, - ) -> Wireguard<T, B> { + pub fn new(reader: T::Reader, writer: T::Writer, mtu: T::MTU) -> Wireguard<T, B> { // create device state let mut rng = OsRng::new().unwrap(); let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE); @@ -292,7 +295,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { Ok((pk, msg, keypair)) => { // send response if let Some(msg) = msg { - let send : &Option<B::Writer> = &*wg.send.read(); + let send: &Option<B::Writer> = &*wg.send.read(); if let Some(writer) = send.as_ref() { let _ = writer.write(&msg[..], &src).map_err(|e| { debug!( @@ -344,7 +347,9 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { msg.resize(size, 0); // read a new IP packet - let payload = reader.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap(); + let payload = reader + .read(&mut msg[..], router::SIZE_MESSAGE_PREFIX) + .unwrap(); debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu); // truncate padding |