aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/handshake/device.rs10
-rw-r--r--src/main.rs1
-rw-r--r--src/router/peer.rs58
-rw-r--r--src/timers.rs65
-rw-r--r--src/types/keys.rs6
-rw-r--r--src/wireguard.rs116
6 files changed, 190 insertions, 66 deletions
diff --git a/src/handshake/device.rs b/src/handshake/device.rs
index 2a06fa7..6178831 100644
--- a/src/handshake/device.rs
+++ b/src/handshake/device.rs
@@ -64,10 +64,16 @@ impl Device {
self.macs = macs::Validator::new(pk);
// recalculate the shared secrets for every peer
- for &mut peer in self.pk_map.values_mut() {
- peer.reset_state().map(|id| self.release(id));
+ let mut ids = vec![];
+ for mut peer in self.pk_map.values_mut() {
+ peer.reset_state().map(|id| ids.push(id));
peer.ss = self.sk.diffie_hellman(&peer.pk)
}
+
+ // release ids from aborted handshakes
+ for id in ids {
+ self.release(id)
+ }
}
/// Add a new public key to the state machine
diff --git a/src/main.rs b/src/main.rs
index 103bc65..a52eecc 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -8,6 +8,7 @@ static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;
mod constants;
mod handshake;
mod router;
+mod timers;
mod types;
mod wireguard;
diff --git a/src/router/peer.rs b/src/router/peer.rs
index 952e439..7a3ede8 100644
--- a/src/router/peer.rs
+++ b/src/router/peer.rs
@@ -36,7 +36,7 @@ pub struct KeyWheel {
next: Option<Arc<KeyPair>>, // next key state (unconfirmed)
current: Option<Arc<KeyPair>>, // current key state (used for encryption)
previous: Option<Arc<KeyPair>>, // old key state (used for decryption)
- retired: Option<u32>, // retired id (previous id, after confirming key-pair)
+ retired: Vec<u32>, // retired ids
}
pub struct PeerInner<C: Callbacks, T: Tun, B: Bind> {
@@ -188,7 +188,7 @@ pub fn new_peer<C: Callbacks, T: Tun, B: Bind>(
next: None,
current: None,
previous: None,
- retired: None,
+ retired: vec![],
}),
staged_packets: spin::Mutex::new(ArrayDeque::new()),
})
@@ -375,6 +375,11 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
*self.state.endpoint.lock() = Some(B::Endpoint::from_address(address));
}
+ /// Returns the current endpoint of the peer (for configuration)
+ ///
+ /// # Note
+ ///
+ /// Does not convey potential "sticky socket" information
pub fn get_endpoint(&self) -> Option<SocketAddr> {
self.state
.endpoint
@@ -383,6 +388,30 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
.map(|e| e.into_address())
}
+ /// Zero all key-material related to the peer
+ pub fn zero_keys(&self) {
+ let mut release: Vec<u32> = Vec::with_capacity(3);
+ let mut keys = self.state.keys.lock();
+
+ // update key-wheel
+
+ mem::replace(&mut keys.next, None).map(|k| release.push(k.local_id()));
+ mem::replace(&mut keys.current, None).map(|k| release.push(k.local_id()));
+ mem::replace(&mut keys.previous, None).map(|k| release.push(k.local_id()));
+ keys.retired.extend(&release[..]);
+
+ // update inbound "recv" map
+ {
+ let mut recv = self.state.device.recv.write();
+ for id in release {
+ recv.remove(&id);
+ }
+ }
+
+ // clear encryption state
+ *self.state.ekey.lock() = None;
+ }
+
/// Add a new keypair
///
/// # Arguments
@@ -393,14 +422,16 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
///
/// A vector of ids which has been released.
/// These should be released in the handshake module.
+ ///
+ /// # Note
+ ///
+ /// The number of ids to be released can be at most 3,
+ /// since the only way to add additional keys to the peer is by using this method
+ /// and a peer can have at most 3 keys allocated in the router at any time.
pub fn add_keypair(&self, new: KeyPair) -> Vec<u32> {
- let mut keys = self.state.keys.lock();
- let mut release = Vec::with_capacity(2);
let new = Arc::new(new);
-
- // collect ids to be released
- keys.retired.map(|v| release.push(v));
- keys.previous.as_ref().map(|k| release.push(k.recv.id));
+ let mut keys = self.state.keys.lock();
+ let mut release = mem::replace(&mut keys.retired, vec![]);
// update key-wheel
if new.initiator {
@@ -420,10 +451,11 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
{
let mut recv = self.state.device.recv.write();
- // purge recv map of released ids
- for id in &release {
- recv.remove(&id);
- }
+ // purge recv map of previous id
+ keys.previous.as_ref().map(|k| {
+ recv.remove(&k.local_id());
+ release.push(k.local_id());
+ });
// map new id to decryption state
debug_assert!(!recv.contains_key(&new.recv.id));
@@ -442,7 +474,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
}
}
- // return the released id (for handshake state machine)
+ debug_assert!(release.len() <= 3);
release
}
diff --git a/src/timers.rs b/src/timers.rs
new file mode 100644
index 0000000..0d69c3f
--- /dev/null
+++ b/src/timers.rs
@@ -0,0 +1,65 @@
+use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
+use std::sync::Arc;
+use std::time::Duration;
+
+use hjul::{Runner, Timer};
+
+use crate::router::Callbacks;
+
+const ZERO_DURATION: Duration = Duration::from_micros(0);
+
+pub struct TimersInner {
+ handshake_pending: AtomicBool,
+ handshake_attempts: AtomicUsize,
+
+ retransmit_handshake: Timer,
+ send_keepalive: Timer,
+ zero_key_material: Timer,
+ new_handshake: Timer,
+
+ // stats
+ rx_bytes: AtomicU64,
+ tx_bytes: AtomicU64,
+}
+
+impl TimersInner {
+ pub fn new(runner: &Runner) -> Timers {
+ Arc::new(TimersInner {
+ handshake_pending: AtomicBool::new(false),
+ handshake_attempts: AtomicUsize::new(0),
+ retransmit_handshake: runner.timer(|| {}),
+ new_handshake: runner.timer(|| {}),
+ send_keepalive: runner.timer(|| {}),
+ zero_key_material: runner.timer(|| {}),
+ rx_bytes: AtomicU64::new(0),
+ tx_bytes: AtomicU64::new(0),
+ })
+ }
+
+ pub fn handshake_sent(&self) {
+ self.send_keepalive.stop();
+ }
+}
+
+pub type Timers = Arc<TimersInner>;
+
+pub struct Events();
+
+impl Callbacks for Events {
+ type Opaque = Timers;
+
+ fn send(t: &Timers, size: usize, data: bool, sent: bool) {
+ t.tx_bytes.fetch_add(size as u64, Ordering::Relaxed);
+ }
+
+ fn recv(t: &Timers, size: usize, data: bool, sent: bool) {
+ t.rx_bytes.fetch_add(size as u64, Ordering::Relaxed);
+ }
+
+ fn need_key(t: &Timers) {
+ if !t.handshake_pending.swap(true, Ordering::SeqCst) {
+ t.handshake_attempts.store(0, Ordering::SeqCst);
+ t.new_handshake.reset(ZERO_DURATION);
+ }
+ }
+}
diff --git a/src/types/keys.rs b/src/types/keys.rs
index 89cacf9..282c4ae 100644
--- a/src/types/keys.rs
+++ b/src/types/keys.rs
@@ -28,3 +28,9 @@ pub struct KeyPair {
pub send: Key, // key for outbound messages
pub recv: Key, // key for inbound messages
}
+
+impl KeyPair {
+ pub fn local_id(&self) -> u32 {
+ self.recv.id
+ }
+}
diff --git a/src/wireguard.rs b/src/wireguard.rs
index f98369f..3b4724e 100644
--- a/src/wireguard.rs
+++ b/src/wireguard.rs
@@ -1,5 +1,6 @@
use crate::handshake;
use crate::router;
+use crate::timers::{Events, Timers};
use crate::types::{Bind, Endpoint, Tun};
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
@@ -21,28 +22,19 @@ const SIZE_HANDSHAKE_QUEUE: usize = 128;
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000);
-#[derive(Clone)]
-pub struct Peer<T: Tun, B: Bind>(Arc<PeerInner<T, B>>);
+type Peer<T: Tun, B: Bind> = Arc<PeerInner<T, B>>;
pub struct PeerInner<T: Tun, B: Bind> {
- router: router::Peer<Events, T, B>,
- timers: Timers,
- rx: AtomicU64,
- tx: AtomicU64,
+ queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, // handshake queue
+ router: router::Peer<Events, T, B>, // router peer
+ timers: Option<Timers>, //
}
-pub struct Timers {}
-
-pub struct Events();
-
-impl router::Callbacks for Events {
- type Opaque = Timers;
-
- fn send(t: &Timers, size: usize, data: bool, sent: bool) {}
-
- fn recv(t: &Timers, size: usize, data: bool, sent: bool) {}
-
- fn need_key(t: &Timers) {}
+impl<T: Tun, B: Bind> PeerInner<T, B> {
+ #[inline(always)]
+ fn timers(&self) -> &Timers {
+ self.timers.as_ref().unwrap()
+ }
}
struct Handshake {
@@ -50,6 +42,11 @@ struct Handshake {
active: bool,
}
+enum HandshakeJob<E> {
+ Message(Vec<u8>, E),
+ New(PublicKey),
+}
+
struct WireguardInner<T: Tun, B: Bind> {
// identify and configuration map
peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>,
@@ -61,7 +58,7 @@ struct WireguardInner<T: Tun, B: Bind> {
handshake: RwLock<Handshake>,
under_load: AtomicBool,
pending: AtomicUsize, // num of pending handshake packets in queue
- queue: Mutex<Sender<(Vec<u8>, B::Endpoint)>>,
+ queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>,
// IO
bind: B,
@@ -90,7 +87,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
fn new(tun: T, bind: B) -> Wireguard<T, B> {
// create device state
let mut rng = OsRng::new().unwrap();
- let (tx, rx): (Sender<(Vec<u8>, B::Endpoint)>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
+ let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
let wg = Arc::new(WireguardInner {
peers: RwLock::new(HashMap::new()),
router: router::Device::new(num_cpus::get(), tun.clone(), bind.clone()),
@@ -114,50 +111,64 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
let mut rng = OsRng::new().unwrap();
// process elements from the handshake queue
- for (msg, src) in rx {
+ for job in rx {
wg.pending.fetch_sub(1, Ordering::SeqCst);
-
- // feed message to handshake device
- let src_validate = (&src).into_address(); // TODO avoid
let state = wg.handshake.read();
if !state.active {
continue;
}
- // process message
- match state.device.process(
- &mut rng,
- &msg[..],
- if wg.under_load.load(Ordering::Relaxed) {
- Some(&src_validate)
- } else {
- None
- },
- ) {
- Ok((pk, msg, keypair)) => {
- // send response
- if let Some(msg) = msg {
- let _ = bind.send(&msg[..], &src).map_err(|e| {
- debug!(
+ match job {
+ HandshakeJob::Message(msg, src) => {
+ // feed message to handshake device
+ let src_validate = (&src).into_address(); // TODO avoid
+
+ // process message
+ match state.device.process(
+ &mut rng,
+ &msg[..],
+ if wg.under_load.load(Ordering::Relaxed) {
+ Some(&src_validate)
+ } else {
+ None
+ },
+ ) {
+ Ok((pk, msg, keypair)) => {
+ // send response
+ if let Some(msg) = msg {
+ let _ = bind.send(&msg[..], &src).map_err(|e| {
+ debug!(
"handshake worker, failed to send response, error = {:?}",
e
)
- });
- }
+ });
+ }
- // update timers
- if let Some(pk) = pk {
- // add keypair to peer and free any unused ids
- if let Some(keypair) = keypair {
- if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
- for id in peer.0.router.add_keypair(keypair) {
- state.device.release(id);
+ // update timers
+ if let Some(pk) = pk {
+ if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
+ // update endpoint (DISCUSS: right semantics?)
+ peer.router.set_endpoint(src_validate);
+
+ // add keypair to peer and free any unused ids
+ if let Some(keypair) = keypair {
+ for id in peer.router.add_keypair(keypair) {
+ state.device.release(id);
+ }
+ }
}
}
}
+ Err(e) => debug!("handshake worker, error = {:?}", e),
+ }
+ }
+ HandshakeJob::New(pk) => {
+ let msg = state.device.begin(&mut rng, &pk).unwrap(); // TODO handle
+ if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
+ peer.router.send(&msg[..]);
+ peer.timers().handshake_sent();
}
}
- Err(e) => debug!("handshake worker, error = {:?}", e),
}
}
});
@@ -197,7 +208,10 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
wg.under_load.store(false, Ordering::SeqCst);
}
- wg.queue.lock().send((msg, src)).unwrap();
+ wg.queue
+ .lock()
+ .send(HandshakeJob::Message(msg, src))
+ .unwrap();
}
router::TYPE_TRANSPORT => {
// transport message
@@ -223,7 +237,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
let size = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
msg.truncate(size);
- // pad message to multiple of 16
+ // pad message to multiple of 16 bytes
while msg.len() < mtu && msg.len() % 16 != 0 {
msg.push(0);
}