aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-10-09 20:22:16 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-10-09 20:22:16 +0200
commit7ce5415169097839cf711b02ff4188f9a585b7a2 (patch)
tree7e0158b52fb700c1d53d3c0bf2fc9a5ac3592493 /src
parentRestructure IO traits. (diff)
downloadwireguard-rs-7ce5415169097839cf711b02ff4188f9a585b7a2.tar.xz
wireguard-rs-7ce5415169097839cf711b02ff4188f9a585b7a2.zip
Start porting kernel timer semantics
Diffstat (limited to 'src')
-rw-r--r--src/main.rs1
-rw-r--r--src/router/peer.rs34
-rw-r--r--src/timers.rs79
-rw-r--r--src/wireguard.rs51
4 files changed, 115 insertions, 50 deletions
diff --git a/src/main.rs b/src/main.rs
index 7a31119..9b69f54 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,4 +1,5 @@
#![feature(test)]
+#![allow(dead_code)]
extern crate jemallocator;
diff --git a/src/router/peer.rs b/src/router/peer.rs
index 189904c..13e5af4 100644
--- a/src/router/peer.rs
+++ b/src/router/peer.rs
@@ -14,7 +14,7 @@ use treebitmap::IpLookupTable;
use zerocopy::LayoutVerified;
use super::super::constants::*;
-use super::super::types::{Endpoint, KeyPair, bind, tun};
+use super::super::types::{bind, tun, Endpoint, KeyPair};
use super::anti_replay::AntiReplay;
use super::device::DecryptionState;
@@ -39,7 +39,7 @@ pub struct KeyWheel {
retired: Vec<u32>, // retired ids
}
-pub struct PeerInner<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
+pub struct PeerInner<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
pub device: Arc<DeviceInner<E, C, T, B>>,
pub opaque: C::Opaque,
pub outbound: Mutex<SyncSender<JobOutbound>>,
@@ -50,13 +50,13 @@ pub struct PeerInner<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer
pub endpoint: Mutex<Option<E>>,
}
-pub struct Peer<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
+pub struct Peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
state: Arc<PeerInner<E, C, T, B>>,
thread_outbound: Option<thread::JoinHandle<()>>,
thread_inbound: Option<thread::JoinHandle<()>>,
}
-fn treebit_list<A, R, E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+fn treebit_list<A, R, E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
peer: &Arc<PeerInner<E, C, T, B>>,
table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<E, C, T, B>>>>,
callback: Box<dyn Fn(A, u32) -> R>,
@@ -74,7 +74,7 @@ where
res
}
-fn treebit_remove<E : Endpoint, A: Address, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+fn treebit_remove<E: Endpoint, A: Address, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
peer: &Peer<E, C, T, B>,
table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<E, C, T, B>>>>,
) {
@@ -107,8 +107,11 @@ impl EncryptionState {
}
}
-impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> DecryptionState<E, C, T, B> {
- fn new(peer: &Arc<PeerInner<E, C, T, B>>, keypair: &Arc<KeyPair>) -> DecryptionState<E, C, T, B> {
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> DecryptionState<E, C, T, B> {
+ fn new(
+ peer: &Arc<PeerInner<E, C, T, B>>,
+ keypair: &Arc<KeyPair>,
+ ) -> DecryptionState<E, C, T, B> {
DecryptionState {
confirmed: AtomicBool::new(keypair.initiator),
keypair: keypair.clone(),
@@ -119,7 +122,7 @@ impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> DecryptionS
}
}
-impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Peer<E, C, T, B> {
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Peer<E, C, T, B> {
fn drop(&mut self) {
let peer = &self.state;
@@ -167,7 +170,7 @@ impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Pe
}
}
-pub fn new_peer<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+pub fn new_peer<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
device: Arc<DeviceInner<E, C, T, B>>,
opaque: C::Opaque,
) -> Peer<E, C, T, B> {
@@ -215,7 +218,7 @@ pub fn new_peer<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
}
}
-impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> PeerInner<E, C, T, B> {
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> PeerInner<E, C, T, B> {
fn send_staged(&self) -> bool {
debug!("peer.send_staged");
let mut sent = false;
@@ -370,7 +373,7 @@ impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> PeerInner<E
}
}
-impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T, B> {
+impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T, B> {
/// Set the endpoint of the peer
///
/// # Arguments
@@ -591,13 +594,18 @@ impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C,
debug!("peer.send");
let inner = &self.state;
match inner.endpoint.lock().as_ref() {
- Some(endpoint) => inner.device
+ Some(endpoint) => inner
+ .device
.outbound
.read()
.as_ref()
.ok_or(RouterError::SendError)
- .and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError) ),
+ .and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError)),
None => Err(RouterError::NoEndpoint),
}
}
+
+ pub fn purge_staged_packets(&self) {
+ self.state.staged_packets.lock().clear();
+ }
}
diff --git a/src/timers.rs b/src/timers.rs
index 67ece06..23cbb87 100644
--- a/src/timers.rs
+++ b/src/timers.rs
@@ -3,11 +3,13 @@ use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
+use log::info;
+
use hjul::{Runner, Timer};
use crate::constants::*;
use crate::router::Callbacks;
-use crate::types::{tun, bind};
+use crate::types::{bind, tun};
use crate::wireguard::{Peer, PeerInner};
pub struct Timers {
@@ -16,8 +18,17 @@ pub struct Timers {
retransmit_handshake: Timer,
send_keepalive: Timer,
+ send_persistent_keepalive: Timer,
zero_key_material: Timer,
new_handshake: Timer,
+ need_another_keepalive: AtomicBool,
+}
+
+impl Timers {
+ #[inline(always)]
+ fn need_another_keepalive(&self) -> bool {
+ self.need_another_keepalive.swap(false, Ordering::SeqCst)
+ }
}
impl Timers {
@@ -28,34 +39,42 @@ impl Timers {
{
// create a timer instance for the provided peer
Timers {
+ need_another_keepalive: AtomicBool::new(false),
handshake_pending: AtomicBool::new(false),
handshake_attempts: AtomicUsize::new(0),
retransmit_handshake: {
let peer = peer.clone();
runner.timer(move || {
- if peer.timers.read().handshake_retry() {
+ if peer.timers().handshake_retry() {
+ info!("Retransmit handshake for {}", peer);
peer.new_handshake();
+ } else {
+ info!("Failed to complete handshake for {}", peer);
+ peer.router.purge_staged_packets();
+ peer.timers().send_keepalive.stop();
+ peer.timers().zero_key_material.start(REJECT_AFTER_TIME * 3);
}
})
},
- new_handshake: {
+ send_keepalive: {
let peer = peer.clone();
runner.timer(move || {
- peer.new_handshake();
- peer.timers.read().handshake_begun();
+ peer.router.send_keepalive();
+ if peer.timers().need_another_keepalive() {
+ peer.timers().send_keepalive.start(KEEPALIVE_TIMEOUT);
+ }
})
},
- send_keepalive: {
+ new_handshake: {
let peer = peer.clone();
runner.timer(move || {
- peer.router.send_keepalive();
- let keepalive = peer.keepalive.load(Ordering::Acquire);
- if keepalive > 0 {
- peer.timers
- .read()
- .send_keepalive
- .reset(Duration::from_secs(keepalive as u64))
- }
+ info!(
+ "Retrying handshake with {}, because we stopped hearing back after {} seconds",
+ peer,
+ (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT).as_secs()
+ );
+ peer.new_handshake();
+ peer.timers.read().handshake_begun();
})
},
zero_key_material: {
@@ -64,6 +83,17 @@ impl Timers {
peer.router.zero_keys();
})
},
+ send_persistent_keepalive: {
+ let peer = peer.clone();
+ runner.timer(move || {
+ let keepalive = peer.state.keepalive.load(Ordering::Acquire);
+ if keepalive > 0 {
+ peer.router.send_keepalive();
+ peer.timers().send_keepalive.stop();
+ peer.timers().send_persistent_keepalive.start(Duration::from_secs(keepalive as u64));
+ }
+ })
+ }
}
}
@@ -83,6 +113,12 @@ impl Timers {
}
}
+ pub fn updated_persistent_keepalive(&self, keepalive: usize) {
+ if keepalive > 0 {
+ self.send_persistent_keepalive.reset(Duration::from_secs(keepalive as u64));
+ }
+ }
+
pub fn dummy(runner: &Runner) -> Timers {
Timers {
handshake_pending: AtomicBool::new(false),
@@ -90,13 +126,28 @@ impl Timers {
retransmit_handshake: runner.timer(|| {}),
new_handshake: runner.timer(|| {}),
send_keepalive: runner.timer(|| {}),
+ send_persistent_keepalive: runner.timer(|| {}),
zero_key_material: runner.timer(|| {}),
+ need_another_keepalive: AtomicBool::new(false),
}
}
pub fn handshake_sent(&self) {
self.send_keepalive.stop();
}
+
+
+ pub fn any_authenticatec_packet_recieved(&self) {
+
+ }
+
+ pub fn handshake_initiated(&self) {
+
+ }
+
+ pub fn handhsake_complete(&self) {
+
+ }
}
/* Instance of the router callbacks */
diff --git a/src/wireguard.rs b/src/wireguard.rs
index ba81f47..bcb8592 100644
--- a/src/wireguard.rs
+++ b/src/wireguard.rs
@@ -3,12 +3,13 @@ use crate::handshake;
use crate::router;
use crate::timers::{Events, Timers};
-use crate::types::Endpoint;
-use crate::types::tun::{Tun, Reader, MTU};
use crate::types::bind::{Bind, Writer};
+use crate::types::tun::{Reader, Tun, MTU};
+use crate::types::Endpoint;
use hjul::Runner;
+use std::fmt;
use std::ops::Deref;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
@@ -19,7 +20,7 @@ use std::collections::HashMap;
use log::debug;
use rand::rngs::OsRng;
-use spin::{Mutex, RwLock};
+use spin::{Mutex, RwLock, RwLockReadGuard};
use byteorder::{ByteOrder, LittleEndian};
use crossbeam_channel::{bounded, Sender};
@@ -34,11 +35,11 @@ pub struct Peer<T: Tun, B: Bind> {
pub state: Arc<PeerInner<B>>,
}
-impl <T : Tun, B : Bind> Clone for Peer<T, B > {
+impl<T: Tun, B: Bind> Clone for Peer<T, B> {
fn clone(&self) -> Peer<T, B> {
- Peer{
+ Peer {
router: self.router.clone(),
- state: self.state.clone()
+ state: self.state.clone(),
}
}
}
@@ -52,6 +53,19 @@ pub struct PeerInner<B: Bind> {
pub timers: RwLock<Timers>, //
}
+impl <B:Bind > PeerInner<B> {
+ #[inline(always)]
+ pub fn timers(&self) -> RwLockReadGuard<Timers> {
+ self.timers.read()
+ }
+}
+
+impl<T: Tun, B: Bind> fmt::Display for Peer<T, B> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "peer()")
+ }
+}
+
impl<T: Tun, B: Bind> Deref for Peer<T, B> {
type Target = PeerInner<B>;
fn deref(&self) -> &Self::Target {
@@ -61,6 +75,7 @@ impl<T: Tun, B: Bind> Deref for Peer<T, B> {
impl<B: Bind> PeerInner<B> {
pub fn new_handshake(&self) {
+ // TODO: clear endpoint source address ("unsticky")
self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap();
}
}
@@ -76,7 +91,7 @@ pub enum HandshakeJob<E> {
}
struct WireguardInner<T: Tun, B: Bind> {
- // provides access to the MTU value of the tun device
+ // provides access to the MTU value of the tun device
// (otherwise owned solely by the router and a dedicated read IO thread)
mtu: T::MTU,
send: RwLock<Option<B::Writer>>,
@@ -169,18 +184,12 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
peer
}
- pub fn new_bind(
- reader: B::Reader,
- writer: B::Writer,
- closer: B::Closer
- ) {
+ pub fn new_bind(reader: B::Reader, writer: B::Writer, closer: B::Closer) {
// drop existing closer
-
// swap IO thread for new reader
-
// start UDP read IO thread
/*
@@ -232,15 +241,9 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
});
}
*/
-
-
}
- pub fn new(
- reader: T::Reader,
- writer: T::Writer,
- mtu: T::MTU,
- ) -> Wireguard<T, B> {
+ pub fn new(reader: T::Reader, writer: T::Writer, mtu: T::MTU) -> Wireguard<T, B> {
// create device state
let mut rng = OsRng::new().unwrap();
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
@@ -292,7 +295,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
Ok((pk, msg, keypair)) => {
// send response
if let Some(msg) = msg {
- let send : &Option<B::Writer> = &*wg.send.read();
+ let send: &Option<B::Writer> = &*wg.send.read();
if let Some(writer) = send.as_ref() {
let _ = writer.write(&msg[..], &src).map_err(|e| {
debug!(
@@ -344,7 +347,9 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
msg.resize(size, 0);
// read a new IP packet
- let payload = reader.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
+ let payload = reader
+ .read(&mut msg[..], router::SIZE_MESSAGE_PREFIX)
+ .unwrap();
debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu);
// truncate padding