summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/platform/dummy/bind.rs29
-rw-r--r--src/platform/dummy/tun.rs27
-rw-r--r--src/wireguard/handshake/noise.rs8
-rw-r--r--src/wireguard/router/device.rs10
-rw-r--r--src/wireguard/router/peer.rs12
-rw-r--r--src/wireguard/router/tests.rs8
-rw-r--r--src/wireguard/router/types.rs4
-rw-r--r--src/wireguard/tests.rs116
-rw-r--r--src/wireguard/timers.rs39
-rw-r--r--src/wireguard/wireguard.rs47
10 files changed, 238 insertions, 62 deletions
diff --git a/src/platform/dummy/bind.rs b/src/platform/dummy/bind.rs
index 984b886..3497656 100644
--- a/src/platform/dummy/bind.rs
+++ b/src/platform/dummy/bind.rs
@@ -1,7 +1,12 @@
+use hex;
use std::error::Error;
use std::fmt;
use std::marker;
+use log::debug;
+use rand::rngs::OsRng;
+use rand::Rng;
+
use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
use std::sync::Arc;
use std::sync::Mutex;
@@ -95,6 +100,7 @@ impl VoidBind {
#[derive(Clone)]
pub struct PairReader<E> {
+ id: u32,
recv: Arc<Mutex<Receiver<Vec<u8>>>>,
_marker: marker::PhantomData<E>,
}
@@ -110,13 +116,25 @@ impl Reader<UnitEndpoint> for PairReader<UnitEndpoint> {
.map_err(|_| BindError::Disconnected)?;
let len = vec.len();
buf[..len].copy_from_slice(&vec[..]);
- Ok((vec.len(), UnitEndpoint {}))
+ debug!(
+ "dummy({}): read ({}, {})",
+ self.id,
+ len,
+ hex::encode(&buf[..len])
+ );
+ Ok((len, UnitEndpoint {}))
}
}
impl Writer<UnitEndpoint> for PairWriter<UnitEndpoint> {
type Error = BindError;
fn write(&self, buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> {
+ debug!(
+ "dummy({}): write ({}, {})",
+ self.id,
+ buf.len(),
+ hex::encode(buf)
+ );
let owned = buf.to_owned();
match self.send.lock().unwrap().send(owned) {
Err(_) => Err(BindError::Disconnected),
@@ -127,6 +145,7 @@ impl Writer<UnitEndpoint> for PairWriter<UnitEndpoint> {
#[derive(Clone)]
pub struct PairWriter<E> {
+ id: u32,
send: Arc<Mutex<SyncSender<Vec<u8>>>>,
_marker: marker::PhantomData<E>,
}
@@ -139,25 +158,33 @@ impl PairBind {
(PairReader<E>, PairWriter<E>),
(PairReader<E>, PairWriter<E>),
) {
+ let mut rng = OsRng::new().unwrap();
+ let id1: u32 = rng.gen();
+ let id2: u32 = rng.gen();
+
let (tx1, rx1) = sync_channel(128);
let (tx2, rx2) = sync_channel(128);
(
(
PairReader {
+ id: id1,
recv: Arc::new(Mutex::new(rx1)),
_marker: marker::PhantomData,
},
PairWriter {
+ id: id1,
send: Arc::new(Mutex::new(tx2)),
_marker: marker::PhantomData,
},
),
(
PairReader {
+ id: id2,
recv: Arc::new(Mutex::new(rx2)),
_marker: marker::PhantomData,
},
PairWriter {
+ id: id2,
send: Arc::new(Mutex::new(tx1)),
_marker: marker::PhantomData,
},
diff --git a/src/platform/dummy/tun.rs b/src/platform/dummy/tun.rs
index fb87d2f..185b328 100644
--- a/src/platform/dummy/tun.rs
+++ b/src/platform/dummy/tun.rs
@@ -1,3 +1,8 @@
+use hex;
+use log::debug;
+use rand::rngs::OsRng;
+use rand::Rng;
+
use std::cmp::min;
use std::error::Error;
use std::fmt;
@@ -61,16 +66,19 @@ impl fmt::Display for TunError {
pub struct TunTest {}
pub struct TunFakeIO {
+ id: u32,
store: bool,
tx: SyncSender<Vec<u8>>,
rx: Receiver<Vec<u8>>,
}
pub struct TunReader {
+ id: u32,
rx: Receiver<Vec<u8>>,
}
pub struct TunWriter {
+ id: u32,
store: bool,
tx: Mutex<SyncSender<Vec<u8>>>,
}
@@ -88,6 +96,12 @@ impl Reader for TunReader {
Ok(msg) => {
let n = min(buf.len() - offset, msg.len());
buf[offset..offset + n].copy_from_slice(&msg[..n]);
+ debug!(
+ "dummy::TUN({}) : read ({}, {})",
+ self.id,
+ n,
+ hex::encode(&buf[offset..offset + n])
+ );
Ok(n)
}
Err(_) => Err(TunError::Disconnected),
@@ -99,6 +113,12 @@ impl Writer for TunWriter {
type Error = TunError;
fn write(&self, src: &[u8]) -> Result<(), Self::Error> {
+ debug!(
+ "dummy::TUN({}) : write ({}, {})",
+ self.id,
+ src.len(),
+ hex::encode(src)
+ );
if self.store {
let m = src.to_owned();
match self.tx.lock().unwrap().send(m) {
@@ -149,13 +169,18 @@ impl TunTest {
sync_channel(1)
};
+ let mut rng = OsRng::new().unwrap();
+ let id: u32 = rng.gen();
+
let fake = TunFakeIO {
+ id,
tx: tx1,
rx: rx2,
store,
};
- let reader = TunReader { rx: rx1 };
+ let reader = TunReader { id, rx: rx1 };
let writer = TunWriter {
+ id,
tx: Mutex::new(tx2),
store,
};
diff --git a/src/wireguard/handshake/noise.rs b/src/wireguard/handshake/noise.rs
index a2a84b0..68e738d 100644
--- a/src/wireguard/handshake/noise.rs
+++ b/src/wireguard/handshake/noise.rs
@@ -12,6 +12,8 @@ use chacha20poly1305::ChaCha20Poly1305;
use rand::{CryptoRng, RngCore};
+use log::debug;
+
use generic_array::typenum::*;
use generic_array::*;
@@ -27,7 +29,7 @@ use super::peer::{Peer, State};
use super::timestamp;
use super::types::*;
-use super::super::types::{KeyPair, Key};
+use super::super::types::{Key, KeyPair};
use std::time::Instant;
@@ -222,6 +224,7 @@ pub fn create_initiation<R: RngCore + CryptoRng>(
sender: u32,
msg: &mut NoiseInitiation,
) -> Result<(), HandshakeError> {
+ debug!("create initation");
clear_stack_on_return(CLEAR_PAGES, || {
// initialize state
@@ -300,6 +303,7 @@ pub fn consume_initiation<'a>(
device: &'a Device,
msg: &NoiseInitiation,
) -> Result<(&'a Peer, TemporaryState), HandshakeError> {
+ debug!("consume initation");
clear_stack_on_return(CLEAR_PAGES, || {
// initialize new state
@@ -377,6 +381,7 @@ pub fn create_response<R: RngCore + CryptoRng>(
state: TemporaryState, // state from "consume_initiation"
msg: &mut NoiseResponse, // resulting response
) -> Result<KeyPair, HandshakeError> {
+ debug!("create response");
clear_stack_on_return(CLEAR_PAGES, || {
// unpack state
@@ -457,6 +462,7 @@ pub fn create_response<R: RngCore + CryptoRng>(
* in order to better mitigate DoS from malformed response messages.
*/
pub fn consume_response(device: &Device, msg: &NoiseResponse) -> Result<Output, HandshakeError> {
+ debug!("consume response");
clear_stack_on_return(CLEAR_PAGES, || {
// retrieve peer and copy initiation state
let peer = device.lookup_id(msg.f_receiver.get())?;
diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs
index b122bf4..254b3de 100644
--- a/src/wireguard/router/device.rs
+++ b/src/wireguard/router/device.rs
@@ -89,13 +89,7 @@ fn get_route<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
device: &Arc<DeviceInner<E, C, T, B>>,
packet: &[u8],
) -> Option<Arc<PeerInner<E, C, T, B>>> {
- // ensure version access within bounds
- if packet.len() < 1 {
- return None;
- };
-
- // cast to correct IP header
- match packet[0] >> 4 {
+ match packet.get(0)? >> 4 {
VERSION_IP4 => {
// check length and cast to IPv4 header
let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) =
@@ -176,7 +170,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C,
let packet = &msg[SIZE_MESSAGE_PREFIX..];
// lookup peer based on IP packet destination address
- let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptKeyRoute)?;
+ let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptoKeyRoute)?;
// schedule for encryption and transmission to peer
if let Some(job) = peer.send_job(msg, true) {
diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs
index 0b193a4..66a6e9f 100644
--- a/src/wireguard/router/peer.rs
+++ b/src/wireguard/router/peer.rs
@@ -531,8 +531,8 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T
///
/// If an identical value already exists as part of a prior peer,
/// the allowed IP entry will be removed from that peer and added to this peer.
- pub fn add_subnet(&self, ip: IpAddr, masklen: u32) {
- debug!("peer.add_subnet");
+ pub fn add_allowed_ips(&self, ip: IpAddr, masklen: u32) {
+ debug!("peer.add_allowed_ips");
match ip {
IpAddr::V4(v4) => {
self.state
@@ -556,8 +556,8 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T
/// # Returns
///
/// A vector of subnets, represented by as mask/size
- pub fn list_subnets(&self) -> Vec<(IpAddr, u32)> {
- debug!("peer.list_subnets");
+ pub fn list_allowed_ips(&self) -> Vec<(IpAddr, u32)> {
+ debug!("peer.list_allowed_ips");
let mut res = Vec::new();
res.append(&mut treebit_list(
&self.state,
@@ -575,8 +575,8 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T
/// Clear subnets mapped to the peer.
/// After the call, no subnets will be cryptkey routed to the peer.
/// Used for the UAPI command "replace_allowed_ips=true"
- pub fn remove_subnets(&self) {
- debug!("peer.remove_subnets");
+ pub fn remove_allowed_ips(&self) {
+ debug!("peer.remove_allowed_ips");
treebit_remove(self, &self.state.device.ipv4);
treebit_remove(self, &self.state.device.ipv6);
}
diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs
index d44a612..6184993 100644
--- a/src/wireguard/router/tests.rs
+++ b/src/wireguard/router/tests.rs
@@ -157,7 +157,7 @@ mod tests {
let (mask, len, ip) = ("192.168.1.0", 24, "192.168.1.20");
let mask: IpAddr = mask.parse().unwrap();
let ip1: IpAddr = ip.parse().unwrap();
- peer.add_subnet(mask, len);
+ peer.add_allowed_ips(mask, len);
// every iteration sends 10 GB
b.iter(|| {
@@ -215,7 +215,7 @@ mod tests {
}
// map subnet to peer
- peer.add_subnet(mask, *len);
+ peer.add_allowed_ips(mask, *len);
// create "IP packet"
let msg = make_packet(1024, ip.parse().unwrap());
@@ -339,13 +339,13 @@ mod tests {
let (mask, len, _ip, _okay) = p1;
let peer1 = router1.new_peer(opaq1.clone());
let mask: IpAddr = mask.parse().unwrap();
- peer1.add_subnet(mask, *len);
+ peer1.add_allowed_ips(mask, *len);
peer1.add_keypair(dummy_keypair(false));
let (mask, len, _ip, _okay) = p2;
let peer2 = router2.new_peer(opaq2.clone());
let mask: IpAddr = mask.parse().unwrap();
- peer2.add_subnet(mask, *len);
+ peer2.add_allowed_ips(mask, *len);
peer2.set_endpoint(dummy::UnitEndpoint::new());
if *stage {
diff --git a/src/wireguard/router/types.rs b/src/wireguard/router/types.rs
index 52ee4f1..9f769fe 100644
--- a/src/wireguard/router/types.rs
+++ b/src/wireguard/router/types.rs
@@ -31,7 +31,7 @@ pub trait Callbacks: Send + Sync + 'static {
#[derive(Debug)]
pub enum RouterError {
- NoCryptKeyRoute,
+ NoCryptoKeyRoute,
MalformedIPHeader,
MalformedTransportMessage,
UnknownReceiverId,
@@ -42,7 +42,7 @@ pub enum RouterError {
impl fmt::Display for RouterError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
- RouterError::NoCryptKeyRoute => write!(f, "No cryptkey route configured for subnet"),
+ RouterError::NoCryptoKeyRoute => write!(f, "No cryptokey route configured for subnet"),
RouterError::MalformedIPHeader => write!(f, "IP header is malformed"),
RouterError::MalformedTransportMessage => write!(f, "IP header is malformed"),
RouterError::UnknownReceiverId => {
diff --git a/src/wireguard/tests.rs b/src/wireguard/tests.rs
index 7c87056..28dedec 100644
--- a/src/wireguard/tests.rs
+++ b/src/wireguard/tests.rs
@@ -5,13 +5,23 @@ use std::net::IpAddr;
use std::thread;
use std::time::Duration;
-use rand::rngs::OsRng;
+use hex;
+
+use rand_chacha::ChaCha8Rng;
+use rand_core::{RngCore, SeedableRng};
use x25519_dalek::{PublicKey, StaticSecret};
use pnet::packet::ipv4::MutableIpv4Packet;
use pnet::packet::ipv6::MutableIpv6Packet;
-fn make_packet(size: usize, src: IpAddr, dst: IpAddr) -> Vec<u8> {
+fn make_packet(size: usize, src: IpAddr, dst: IpAddr, id: u64) -> Vec<u8> {
+ // expand pseudo random payload
+ let mut rng: _ = ChaCha8Rng::seed_from_u64(id);
+ let mut p: Vec<u8> = vec![];
+ for _ in 0..size {
+ p.push(rng.next_u32() as u8);
+ }
+
// create "IP packet"
let mut msg = Vec::with_capacity(size);
msg.resize(size, 0);
@@ -19,21 +29,25 @@ fn make_packet(size: usize, src: IpAddr, dst: IpAddr) -> Vec<u8> {
IpAddr::V4(dst) => {
let mut packet = MutableIpv4Packet::new(&mut msg[..]).unwrap();
packet.set_destination(dst);
+ packet.set_total_length(size as u16);
packet.set_source(if let IpAddr::V4(src) = src {
src
} else {
panic!("src.version != dst.version")
});
+ packet.set_payload(&p[..]);
packet.set_version(4);
}
IpAddr::V6(dst) => {
let mut packet = MutableIpv6Packet::new(&mut msg[..]).unwrap();
packet.set_destination(dst);
+ packet.set_payload_length((size - MutableIpv6Packet::minimum_packet_size()) as u16);
packet.set_source(if let IpAddr::V6(src) = src {
src
} else {
panic!("src.version != dst.version")
});
+ packet.set_payload(&p[..]);
packet.set_version(6);
}
}
@@ -55,7 +69,7 @@ fn wait() {
fn test_pure_wireguard() {
init();
- // create WG instances for fake TUN devices
+ // create WG instances for dummy TUN devices
let (fake1, tun_reader1, tun_writer1, mtu1) = dummy::TunTest::create(1500, true);
let wg1: Wireguard<dummy::TunTest, dummy::PairBind> =
@@ -77,10 +91,20 @@ fn test_pure_wireguard() {
// generate (public, pivate) key pairs
- let mut rng = OsRng::new().unwrap();
- let sk1 = StaticSecret::new(&mut rng);
- let sk2 = StaticSecret::new(&mut rng);
+ let sk1 = StaticSecret::from([
+ 0x3f, 0x69, 0x86, 0xd1, 0xc0, 0xec, 0x25, 0xa0, 0x9c, 0x8e, 0x56, 0xb5, 0x1d, 0xb7, 0x3c,
+ 0xed, 0x56, 0x8e, 0x59, 0x9d, 0xd9, 0xc3, 0x98, 0x67, 0x74, 0x69, 0x90, 0xc3, 0x43, 0x36,
+ 0x78, 0x89,
+ ]);
+
+ let sk2 = StaticSecret::from([
+ 0xfb, 0xd1, 0xd6, 0xe4, 0x65, 0x06, 0xd2, 0xe5, 0xc5, 0xdf, 0x6e, 0xab, 0x51, 0x71, 0xd8,
+ 0x70, 0xb5, 0xb7, 0x77, 0x51, 0xb4, 0xbe, 0xfb, 0xbc, 0x88, 0x62, 0x40, 0xca, 0x2c, 0xc2,
+ 0x66, 0xe2,
+ ]);
+
let pk1 = PublicKey::from(&sk1);
+
let pk2 = PublicKey::from(&sk2);
wg1.new_peer(pk2);
@@ -94,21 +118,79 @@ fn test_pure_wireguard() {
let peer2 = wg1.lookup_peer(&pk2).unwrap();
let peer1 = wg2.lookup_peer(&pk1).unwrap();
- peer1.router.add_subnet("192.168.2.0".parse().unwrap(), 24);
- peer2.router.add_subnet("192.168.1.0".parse().unwrap(), 24);
+ peer1
+ .router
+ .add_allowed_ips("192.168.1.0".parse().unwrap(), 24);
+
+ peer2
+ .router
+ .add_allowed_ips("192.168.2.0".parse().unwrap(), 24);
- // set endpoints
+ // set endpoint (the other should be learned dynamically)
- peer1.router.set_endpoint(dummy::UnitEndpoint::new());
peer2.router.set_endpoint(dummy::UnitEndpoint::new());
- // create IP packets (causing a new handshake)
+ let num_packets = 20;
+
+ // send IP packets (causing a new handshake)
+
+ {
+ let mut packets: Vec<Vec<u8>> = Vec::with_capacity(num_packets);
+
+ for id in 0..num_packets {
+ packets.push(make_packet(
+ 50 + 50 * id as usize, // size
+ "192.168.1.20".parse().unwrap(), // src
+ "192.168.2.10".parse().unwrap(), // dst
+ id as u64, // prng seed
+ ));
+ }
+
+ let mut backup = packets.clone();
+
+ while let Some(p) = packets.pop() {
+ fake1.write(p);
+ }
- let packet_p1_to_p2 = make_packet(
- 1000,
- "192.168.2.20".parse().unwrap(), // src
- "192.168.1.10".parse().unwrap(), // dst
- );
+ wait();
- fake1.write(packet_p1_to_p2);
+ while let Some(p) = backup.pop() {
+ assert_eq!(
+ hex::encode(fake2.read()),
+ hex::encode(p),
+ "Failed to receive valid IPv4 packet unmodified and in-order"
+ );
+ }
+ }
+
+ // send IP packets (other direction)
+
+ {
+ let mut packets: Vec<Vec<u8>> = Vec::with_capacity(num_packets);
+
+ for id in 0..num_packets {
+ packets.push(make_packet(
+ 50 + 50 * id as usize, // size
+ "192.168.2.10".parse().unwrap(), // src
+ "192.168.1.20".parse().unwrap(), // dst
+ (id + 100) as u64, // prng seed
+ ));
+ }
+
+ let mut backup = packets.clone();
+
+ while let Some(p) = packets.pop() {
+ fake2.write(p);
+ }
+
+ wait();
+
+ while let Some(p) = backup.pop() {
+ assert_eq!(
+ hex::encode(fake1.read()),
+ hex::encode(p),
+ "Failed to receive valid IPv4 packet unmodified and in-order"
+ );
+ }
+ }
}
diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs
index 5ebc746..3b16bf6 100644
--- a/src/wireguard/timers.rs
+++ b/src/wireguard/timers.rs
@@ -7,10 +7,10 @@ use log::info;
use hjul::{Runner, Timer};
-use super::{bind, tun};
use super::constants::*;
-use super::router::{Callbacks, message_data_len};
+use super::router::{message_data_len, Callbacks};
use super::wireguard::{Peer, PeerInner};
+use super::{bind, tun};
pub struct Timers {
handshake_pending: AtomicBool,
@@ -32,16 +32,20 @@ impl Timers {
}
}
-impl <B: bind::Bind>PeerInner<B> {
+impl<B: bind::Bind> PeerInner<B> {
/* should be called after an authenticated data packet is sent */
pub fn timers_data_sent(&self) {
- self.timers().new_handshake.start(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT);
+ self.timers()
+ .new_handshake
+ .start(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT);
}
/* should be called after an authenticated data packet is received */
pub fn timers_data_received(&self) {
if !self.timers().send_keepalive.start(KEEPALIVE_TIMEOUT) {
- self.timers().need_another_keepalive.store(true, Ordering::SeqCst)
+ self.timers()
+ .need_another_keepalive
+ .store(true, Ordering::SeqCst)
}
}
@@ -74,7 +78,9 @@ impl <B: bind::Bind>PeerInner<B> {
*/
pub fn timers_handshake_complete(&self) {
self.timers().handshake_attempts.store(0, Ordering::SeqCst);
- self.timers().sent_lastminute_handshake.store(false, Ordering::SeqCst);
+ self.timers()
+ .sent_lastminute_handshake
+ .store(false, Ordering::SeqCst);
// TODO: Store time in peer for config
// self.walltime_last_handshake
}
@@ -92,7 +98,9 @@ impl <B: bind::Bind>PeerInner<B> {
pub fn timers_any_authenticated_packet_traversal(&self) {
let keepalive = self.keepalive.load(Ordering::Acquire);
if keepalive > 0 {
- self.timers().send_persistent_keepalive.reset(Duration::from_secs(keepalive as u64));
+ self.timers()
+ .send_persistent_keepalive
+ .reset(Duration::from_secs(keepalive as u64));
}
}
@@ -149,11 +157,7 @@ impl Timers {
new_handshake: {
let peer = peer.clone();
runner.timer(move || {
- info!(
- "Retrying handshake with {}, because we stopped hearing back after {} seconds",
- peer,
- (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT).as_secs()
- );
+ info!("Initiate new handshake with {}", peer);
peer.new_handshake();
peer.timers.read().handshake_begun();
})
@@ -171,10 +175,12 @@ impl Timers {
if keepalive > 0 {
peer.router.send_keepalive();
peer.timers().send_keepalive.stop();
- peer.timers().send_persistent_keepalive.start(Duration::from_secs(keepalive as u64));
+ peer.timers()
+ .send_persistent_keepalive
+ .start(Duration::from_secs(keepalive as u64));
}
})
- }
+ },
}
}
@@ -196,7 +202,8 @@ impl Timers {
pub fn updated_persistent_keepalive(&self, keepalive: usize) {
if keepalive > 0 {
- self.send_persistent_keepalive.reset(Duration::from_secs(keepalive as u64));
+ self.send_persistent_keepalive
+ .reset(Duration::from_secs(keepalive as u64));
}
}
@@ -210,7 +217,7 @@ impl Timers {
new_handshake: runner.timer(|| {}),
send_keepalive: runner.timer(|| {}),
send_persistent_keepalive: runner.timer(|| {}),
- zero_key_material: runner.timer(|| {})
+ zero_key_material: runner.timer(|| {}),
}
}
diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs
index 25544d9..233559e 100644
--- a/src/wireguard/wireguard.rs
+++ b/src/wireguard/wireguard.rs
@@ -21,6 +21,7 @@ use std::collections::HashMap;
use log::debug;
use rand::rngs::OsRng;
+use rand::Rng;
use spin::{Mutex, RwLock, RwLockReadGuard};
use byteorder::{ByteOrder, LittleEndian};
@@ -37,6 +38,8 @@ pub struct Peer<T: Tun, B: Bind> {
}
pub struct PeerInner<B: Bind> {
+ pub id: u64,
+
pub keepalive: AtomicUsize, // keepalive interval
pub rx_bytes: AtomicU64,
pub tx_bytes: AtomicU64,
@@ -50,6 +53,9 @@ pub struct PeerInner<B: Bind> {
}
pub struct WireguardInner<T: Tun, B: Bind> {
+ // identifier (for logging)
+ id: u32,
+
// provides access to the MTU value of the tun device
// (otherwise owned solely by the router and a dedicated read IO thread)
mtu: T::MTU,
@@ -96,7 +102,13 @@ impl<B: Bind> PeerInner<B> {
impl<T: Tun, B: Bind> fmt::Display for Peer<T, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(f, "peer()")
+ write!(f, "peer(id = {})", self.id)
+ }
+}
+
+impl<T: Tun, B: Bind> fmt::Display for WireguardInner<T, B> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "wireguard({:x})", self.id)
}
}
@@ -209,7 +221,9 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
}
pub fn new_peer(&self, pk: PublicKey) {
+ let mut rng = OsRng::new().unwrap();
let state = Arc::new(PeerInner {
+ id: rng.gen(),
pk,
last_handshake: Mutex::new(SystemTime::UNIX_EPOCH),
handshake_queued: AtomicBool::new(false),
@@ -277,11 +291,17 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
handshake::TYPE_COOKIE_REPLY
| handshake::TYPE_INITIATION
| handshake::TYPE_RESPONSE => {
+ debug!("{} : reader, received handshake message", wg);
+
+ let pending = wg.pending.fetch_add(1, Ordering::SeqCst);
+
// update under_load flag
- if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD {
+ if pending > THRESHOLD_UNDER_LOAD {
+ debug!("{} : reader, set under load (pending = {})", wg, pending);
last_under_load = Instant::now();
wg.under_load.store(true, Ordering::SeqCst);
} else if last_under_load.elapsed() > DURATION_UNDER_LOAD {
+ debug!("{} : reader, clear under load", wg);
wg.under_load.store(false, Ordering::SeqCst);
}
@@ -291,6 +311,8 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
.unwrap();
}
router::TYPE_TRANSPORT => {
+ debug!("{} : reader, received transport message", wg);
+
// transport message
let _ = wg.router.recv(src, msg).map_err(|e| {
debug!("Failed to handle incoming transport message: {}", e);
@@ -313,6 +335,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
let mut rng = OsRng::new().unwrap();
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
let wg = Arc::new(WireguardInner {
+ id: rng.gen(),
mtu: mtu.clone(),
peers: RwLock::new(HashMap::new()),
send: RwLock::new(None),
@@ -331,12 +354,13 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
let wg = wg.clone();
let rx = rx.clone();
thread::spawn(move || {
+ debug!("{} : handshake worker, started", wg);
+
// prepare OsRng instance for this thread
let mut rng = OsRng::new().unwrap();
// process elements from the handshake queue
for job in rx {
- wg.pending.fetch_sub(1, Ordering::SeqCst);
let state = wg.handshake.read();
if !state.active {
continue;
@@ -344,6 +368,8 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
match job {
HandshakeJob::Message(msg, src) => {
+ wg.pending.fetch_sub(1, Ordering::SeqCst);
+
// feed message to handshake device
let src_validate = (&src).into_address(); // TODO avoid
@@ -352,6 +378,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
&mut rng,
&msg[..],
if wg.under_load.load(Ordering::Relaxed) {
+ debug!("{} : handshake worker, under load", wg);
Some(&src_validate)
} else {
None
@@ -364,9 +391,14 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
resp_len = msg.len() as u64;
let send: &Option<B::Writer> = &*wg.send.read();
if let Some(writer) = send.as_ref() {
+ debug!(
+ "{} : handshake worker, send response ({} bytes)",
+ wg, resp_len
+ );
let _ = writer.write(&msg[..], &src).map_err(|e| {
debug!(
- "handshake worker, failed to send response, error = {}",
+ "{} : handshake worker, failed to send response, error = {}",
+ wg,
e
)
});
@@ -387,11 +419,13 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
// update timers after sending handshake response
if resp_len > 0 {
+ debug!("{} : handshake worker, handshake response sent", wg);
peer.state.sent_handshake_response();
}
// add resulting keypair to peer
keypair.map(|kp| {
+ debug!("{} : handshake worker, new keypair", wg);
// free any unused ids
for id in peer.router.add_keypair(kp) {
state.device.release(id);
@@ -400,14 +434,15 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
}
}
}
- Err(e) => debug!("handshake worker, error = {:?}", e),
+ Err(e) => debug!("{} : handshake worker, error = {:?}", wg, e),
}
}
HandshakeJob::New(pk) => {
+ debug!("{} : handshake worker, new handshake requested", wg);
let _ = state.device.begin(&mut rng, &pk).map(|msg| {
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
let _ = peer.router.send(&msg[..]).map_err(|e| {
- debug!("handshake worker, failed to send handshake initiation, error = {}", e)
+ debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e)
});
peer.state.sent_handshake_initiation();
}