aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/handshake/device.rs58
-rw-r--r--src/handshake/noise.rs23
-rw-r--r--src/handshake/peer.rs23
-rw-r--r--src/handshake/types.rs10
-rw-r--r--src/router/device.rs5
-rw-r--r--src/wireguard.rs222
6 files changed, 212 insertions, 129 deletions
diff --git a/src/handshake/device.rs b/src/handshake/device.rs
index 638d63f..2a06fa7 100644
--- a/src/handshake/device.rs
+++ b/src/handshake/device.rs
@@ -21,11 +21,11 @@ use super::types::*;
const MAX_PEER_PER_DEVICE: usize = 1 << 20;
-pub struct Device<T> {
+pub struct Device {
pub sk: StaticSecret, // static secret key
pub pk: PublicKey, // static public key
macs: macs::Validator, // validator for the mac fields
- pk_map: HashMap<[u8; 32], Peer<T>>, // public key -> peer state
+ pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state
id_map: RwLock<HashMap<u32, [u8; 32]>>, // receiver ids -> public key
limiter: Mutex<RateLimiter>,
}
@@ -33,16 +33,13 @@ pub struct Device<T> {
/* A mutable reference to the device needs to be held during configuration.
* Wrapping the device in a RwLock enables peer config after "configuration time"
*/
-impl<T> Device<T>
-where
- T: Clone,
-{
+impl Device {
/// Initialize a new handshake state machine
///
/// # Arguments
///
/// * `sk` - x25519 scalar representing the local private key
- pub fn new(sk: StaticSecret) -> Device<T> {
+ pub fn new(sk: StaticSecret) -> Device {
let pk = PublicKey::from(&sk);
Device {
pk,
@@ -54,6 +51,25 @@ where
}
}
+ /// Update the secret key of the device
+ ///
+ /// # Arguments
+ ///
+ /// * `sk` - x25519 scalar representing the local private key
+ pub fn set_sk(&mut self, sk: StaticSecret) {
+ // update secret and public key
+ let pk = PublicKey::from(&sk);
+ self.sk = sk;
+ self.pk = pk;
+ self.macs = macs::Validator::new(pk);
+
+ // recalculate the shared secrets for every peer
+ for &mut peer in self.pk_map.values_mut() {
+ peer.reset_state().map(|id| self.release(id));
+ peer.ss = self.sk.diffie_hellman(&peer.pk)
+ }
+ }
+
/// Add a new public key to the state machine
/// To remove public keys, you must create a new machine instance
///
@@ -61,7 +77,7 @@ where
///
/// * `pk` - The public key to add
/// * `identifier` - Associated identifier which can be used to distinguish the peers
- pub fn add(&mut self, pk: PublicKey, identifier: T) -> Result<(), ConfigError> {
+ pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> {
// check that the pk is not added twice
if let Some(_) = self.pk_map.get(pk.as_bytes()) {
return Err(ConfigError::new("Duplicate public key"));
@@ -80,10 +96,8 @@ where
}
// map the public key to the peer state
- self.pk_map.insert(
- *pk.as_bytes(),
- Peer::new(identifier, pk, self.sk.diffie_hellman(&pk)),
- );
+ self.pk_map
+ .insert(*pk.as_bytes(), Peer::new(pk, self.sk.diffie_hellman(&pk)));
Ok(())
}
@@ -204,7 +218,7 @@ where
rng: &mut R, // rng instance to sample randomness from
msg: &[u8], // message buffer
src: Option<&'a S>, // optional source endpoint, set when "under load"
- ) -> Result<Output<T>, HandshakeError>
+ ) -> Result<Output, HandshakeError>
where
&'a S: Into<&'a SocketAddr>,
{
@@ -269,11 +283,7 @@ where
.generate(resp.noise.as_bytes(), &mut resp.macs);
// return unconfirmed keypair and the response as vector
- Ok((
- Some(peer.identifier.clone()),
- Some(resp.as_bytes().to_owned()),
- Some(keys),
- ))
+ Ok((Some(peer.pk), Some(resp.as_bytes().to_owned()), Some(keys)))
}
TYPE_RESPONSE => {
let msg = Response::parse(msg)?;
@@ -328,7 +338,7 @@ where
// Internal function
//
// Return the peer associated with the public key
- pub(crate) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer<T>, HandshakeError> {
+ pub(crate) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer, HandshakeError> {
self.pk_map
.get(pk.as_bytes())
.ok_or(HandshakeError::UnknownPublicKey)
@@ -337,7 +347,7 @@ where
// Internal function
//
// Return the peer currently associated with the receiver identifier
- pub(crate) fn lookup_id(&self, id: u32) -> Result<&Peer<T>, HandshakeError> {
+ pub(crate) fn lookup_id(&self, id: u32) -> Result<&Peer, HandshakeError> {
let im = self.id_map.read();
let pk = im.get(&id).ok_or(HandshakeError::UnknownReceiverId)?;
match self.pk_map.get(pk) {
@@ -349,7 +359,7 @@ where
// Internal function
//
// Allocated a new receiver identifier for the peer
- fn allocate<R: RngCore + CryptoRng>(&self, rng: &mut R, peer: &Peer<T>) -> u32 {
+ fn allocate<R: RngCore + CryptoRng>(&self, rng: &mut R, peer: &Peer) -> u32 {
loop {
let id = rng.gen();
@@ -380,7 +390,7 @@ mod tests {
fn setup_devices<R: RngCore + CryptoRng>(
rng: &mut R,
- ) -> (PublicKey, Device<usize>, PublicKey, Device<usize>) {
+ ) -> (PublicKey, Device, PublicKey, Device) {
// generate new keypairs
let sk1 = StaticSecret::new(rng);
@@ -399,8 +409,8 @@ mod tests {
let mut dev1 = Device::new(sk1);
let mut dev2 = Device::new(sk2);
- dev1.add(pk2, 1337).unwrap();
- dev2.add(pk1, 2600).unwrap();
+ dev1.add(pk2).unwrap();
+ dev2.add(pk1).unwrap();
dev1.set_psk(pk2, Some(psk)).unwrap();
dev2.set_psk(pk1, Some(psk)).unwrap();
diff --git a/src/handshake/noise.rs b/src/handshake/noise.rs
index eafb9e9..1dc8402 100644
--- a/src/handshake/noise.rs
+++ b/src/handshake/noise.rs
@@ -215,10 +215,10 @@ mod tests {
}
}
-pub fn create_initiation<T: Clone, R: RngCore + CryptoRng>(
+pub fn create_initiation<R: RngCore + CryptoRng>(
rng: &mut R,
- device: &Device<T>,
- peer: &Peer<T>,
+ device: &Device,
+ peer: &Peer,
sender: u32,
msg: &mut NoiseInitiation,
) -> Result<(), HandshakeError> {
@@ -296,10 +296,10 @@ pub fn create_initiation<T: Clone, R: RngCore + CryptoRng>(
})
}
-pub fn consume_initiation<'a, T: Clone>(
- device: &'a Device<T>,
+pub fn consume_initiation<'a>(
+ device: &'a Device,
msg: &NoiseInitiation,
-) -> Result<(&'a Peer<T>, TemporaryState), HandshakeError> {
+) -> Result<(&'a Peer, TemporaryState), HandshakeError> {
clear_stack_on_return(CLEAR_PAGES, || {
// initialize new state
@@ -370,9 +370,9 @@ pub fn consume_initiation<'a, T: Clone>(
})
}
-pub fn create_response<T: Clone, R: RngCore + CryptoRng>(
+pub fn create_response<R: RngCore + CryptoRng>(
rng: &mut R,
- peer: &Peer<T>,
+ peer: &Peer,
sender: u32, // sending identifier
state: TemporaryState, // state from "consume_initiation"
msg: &mut NoiseResponse, // resulting response
@@ -456,10 +456,7 @@ pub fn create_response<T: Clone, R: RngCore + CryptoRng>(
* allow concurrent processing of potential responses to the initiation,
* in order to better mitigate DoS from malformed response messages.
*/
-pub fn consume_response<T: Clone>(
- device: &Device<T>,
- msg: &NoiseResponse,
-) -> Result<Output<T>, HandshakeError> {
+pub fn consume_response(device: &Device, msg: &NoiseResponse) -> Result<Output, HandshakeError> {
clear_stack_on_return(CLEAR_PAGES, || {
// retrieve peer and copy initiation state
let peer = device.lookup_id(msg.f_receiver.get())?;
@@ -530,7 +527,7 @@ pub fn consume_response<T: Clone>(
// return confirmed key-pair
Ok((
- Some(peer.identifier.clone()),
+ Some(peer.pk),
None,
Some(KeyPair {
birth,
diff --git a/src/handshake/peer.rs b/src/handshake/peer.rs
index 4c6f2fd..6a85cee 100644
--- a/src/handshake/peer.rs
+++ b/src/handshake/peer.rs
@@ -1,5 +1,7 @@
use lazy_static::lazy_static;
use spin::Mutex;
+
+use std::mem;
use std::time::{Duration, Instant};
use generic_array::typenum::U32;
@@ -24,10 +26,7 @@ lazy_static! {
*
* This type is only for internal use and not exposed.
*/
-pub struct Peer<T> {
- // external identifier
- pub(crate) identifier: T,
-
+pub struct Peer {
// mutable state
pub(crate) state: Mutex<State>,
pub(crate) timestamp: Mutex<Option<timestamp::TAI64N>>,
@@ -65,18 +64,13 @@ impl Drop for State {
}
}
-impl<T> Peer<T>
-where
- T: Clone,
-{
+impl Peer {
pub fn new(
- identifier: T, // external identifier
pk: PublicKey, // public key of peer
ss: SharedSecret, // precomputed DH(static, static)
) -> Self {
Self {
macs: Mutex::new(macs::Generator::new(pk)),
- identifier: identifier,
state: Mutex::new(State::Reset),
timestamp: Mutex::new(None),
last_initiation_consumption: Mutex::new(None),
@@ -94,6 +88,13 @@ where
*self.state.lock() = state_new;
}
+ pub fn reset_state(&self) -> Option<u32> {
+ match mem::replace(&mut *self.state.lock(), State::Reset) {
+ State::InitiationSent { sender, .. } => Some(sender),
+ _ => None,
+ }
+ }
+
/// Set the mutable state of the peer conditioned on the timestamp being newer
///
/// # Arguments
@@ -102,7 +103,7 @@ where
/// * ts_new - The associated timestamp
pub fn check_replay_flood(
&self,
- device: &Device<T>,
+ device: &Device,
timestamp_new: &timestamp::TAI64N,
) -> Result<(), HandshakeError> {
let mut state = self.state.lock();
diff --git a/src/handshake/types.rs b/src/handshake/types.rs
index 7b190ec..ba71ec4 100644
--- a/src/handshake/types.rs
+++ b/src/handshake/types.rs
@@ -1,6 +1,8 @@
use std::error::Error;
use std::fmt;
+use x25519_dalek::PublicKey;
+
use crate::types::KeyPair;
/* Internal types for the noise IKpsk2 implementation */
@@ -77,10 +79,10 @@ impl Error for HandshakeError {
}
}
-pub type Output<T> = (
- Option<T>, // external identifier associated with peer
- Option<Vec<u8>>, // message to send
- Option<KeyPair>, // resulting key-pair of successful handshake
+pub type Output = (
+ Option<PublicKey>, // external identifier associated with peer
+ Option<Vec<u8>>, // message to send
+ Option<KeyPair>, // resulting key-pair of successful handshake
);
// preshared key
diff --git a/src/router/device.rs b/src/router/device.rs
index e9e0fb3..e8250cb 100644
--- a/src/router/device.rs
+++ b/src/router/device.rs
@@ -121,7 +121,6 @@ fn get_route<C: Callbacks, T: Tun, B: Bind>(
}
impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
-
pub fn new(num_workers: usize, tun: T, bind: B) -> Device<C, T, B> {
// allocate shared device state
let mut inner = DeviceInner {
@@ -149,6 +148,10 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
}
}
+ /// A new secret key has been set for the device.
+ /// According to WireGuard semantics, this should cause all "sending" keys to be discarded.
+ pub fn new_sk(&self) {}
+
/// Adds a new peer to the device
///
/// # Returns
diff --git a/src/wireguard.rs b/src/wireguard.rs
index 2c166b4..f98369f 100644
--- a/src/wireguard.rs
+++ b/src/wireguard.rs
@@ -2,17 +2,20 @@ use crate::handshake;
use crate::router;
use crate::types::{Bind, Endpoint, Tun};
-use std::sync::atomic::{AtomicUsize, Ordering};
+use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant};
+use std::collections::HashMap;
+
use log::debug;
use rand::rngs::OsRng;
+use spin::{Mutex, RwLock};
use byteorder::{ByteOrder, LittleEndian};
-use crossbeam_channel::bounded;
-use x25519_dalek::StaticSecret;
+use crossbeam_channel::{bounded, Sender};
+use x25519_dalek::{PublicKey, StaticSecret};
const SIZE_HANDSHAKE_QUEUE: usize = 128;
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
@@ -22,8 +25,10 @@ const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000);
pub struct Peer<T: Tun, B: Bind>(Arc<PeerInner<T, B>>);
pub struct PeerInner<T: Tun, B: Bind> {
- peer: router::Peer<Events, T, B>,
+ router: router::Peer<Events, T, B>,
timers: Timers,
+ rx: AtomicU64,
+ tx: AtomicU64,
}
pub struct Timers {}
@@ -40,96 +45,96 @@ impl router::Callbacks for Events {
fn need_key(t: &Timers) {}
}
-pub struct Wireguard<T: Tun, B: Bind> {
- router: Arc<router::Device<Events, T, B>>,
- handshake: Option<Arc<handshake::Device<()>>>,
+struct Handshake {
+ device: handshake::Device,
+ active: bool,
}
-impl<T: Tun, B: Bind> Wireguard<T, B> {
- fn start(&self) {}
-
- fn new(tun: T, bind: B, sk: StaticSecret) -> Wireguard<T, B> {
- let router = Arc::new(router::Device::new(
- num_cpus::get(),
- tun.clone(),
- bind.clone(),
- ));
-
- let handshake_staged = Arc::new(AtomicUsize::new(0));
- let handshake_device: Arc<handshake::Device<Peer<T, B>>> =
- Arc::new(handshake::Device::new(sk));
+struct WireguardInner<T: Tun, B: Bind> {
+ // identify and configuration map
+ peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>,
- // start UDP read IO thread
- let (handshake_tx, handshake_rx) = bounded(128);
- {
- let tun = tun.clone();
- let bind = bind.clone();
- thread::spawn(move || {
- let mut under_load =
- Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
+ // cryptkey routing
+ router: router::Device<Events, T, B>,
- loop {
- // read UDP packet into vector
- let size = tun.mtu() + 148; // maximum message size
- let mut msg: Vec<u8> =
- Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
- msg.resize(size, 0);
- let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error
- msg.truncate(size);
+ // handshake related state
+ handshake: RwLock<Handshake>,
+ under_load: AtomicBool,
+ pending: AtomicUsize, // num of pending handshake packets in queue
+ queue: Mutex<Sender<(Vec<u8>, B::Endpoint)>>,
- // message type de-multiplexer
- if msg.len() < std::mem::size_of::<u32>() {
- continue;
- }
+ // IO
+ bind: B,
+}
- match LittleEndian::read_u32(&msg[..]) {
- handshake::TYPE_COOKIE_REPLY
- | handshake::TYPE_INITIATION
- | handshake::TYPE_RESPONSE => {
- // detect if under load
- if handshake_staged.fetch_add(1, Ordering::SeqCst)
- > THRESHOLD_UNDER_LOAD
- {
- under_load = Instant::now()
- }
+pub struct Wireguard<T: Tun, B: Bind> {
+ state: Arc<WireguardInner<T, B>>,
+}
- // pass source address along if under load
- handshake_tx
- .send((msg, src, under_load.elapsed() < DURATION_UNDER_LOAD))
- .unwrap();
- }
- router::TYPE_TRANSPORT => {
- // transport message
- }
- _ => (),
- }
- }
- });
+impl<T: Tun, B: Bind> Wireguard<T, B> {
+ fn set_key(&self, sk: Option<StaticSecret>) {
+ let mut handshake = self.state.handshake.write();
+ match sk {
+ None => {
+ let mut rng = OsRng::new().unwrap();
+ handshake.device.set_sk(StaticSecret::new(&mut rng));
+ handshake.active = false;
+ }
+ Some(sk) => {
+ handshake.device.set_sk(sk);
+ handshake.active = true;
+ }
}
+ }
+
+ fn new(tun: T, bind: B) -> Wireguard<T, B> {
+ // create device state
+ let mut rng = OsRng::new().unwrap();
+ let (tx, rx): (Sender<(Vec<u8>, B::Endpoint)>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
+ let wg = Arc::new(WireguardInner {
+ peers: RwLock::new(HashMap::new()),
+ router: router::Device::new(num_cpus::get(), tun.clone(), bind.clone()),
+ pending: AtomicUsize::new(0),
+ handshake: RwLock::new(Handshake {
+ device: handshake::Device::new(StaticSecret::new(&mut rng)),
+ active: false,
+ }),
+ under_load: AtomicBool::new(false),
+ bind: bind.clone(),
+ queue: Mutex::new(tx),
+ });
// start handshake workers
for _ in 0..num_cpus::get() {
+ let wg = wg.clone();
+ let rx = rx.clone();
let bind = bind.clone();
- let handshake_rx = handshake_rx.clone();
- let handshake_device = handshake_device.clone();
thread::spawn(move || {
// prepare OsRng instance for this thread
let mut rng = OsRng::new().unwrap();
// process elements from the handshake queue
- for (msg, src, under_load) in handshake_rx {
+ for (msg, src) in rx {
+ wg.pending.fetch_sub(1, Ordering::SeqCst);
+
// feed message to handshake device
let src_validate = (&src).into_address(); // TODO avoid
- match handshake_device.process(
+ let state = wg.handshake.read();
+ if !state.active {
+ continue;
+ }
+
+ // process message
+ match state.device.process(
&mut rng,
&msg[..],
- if under_load {
+ if wg.under_load.load(Ordering::Relaxed) {
Some(&src_validate)
} else {
None
},
) {
- Ok((identity, msg, keypair)) => {
+ Ok((pk, msg, keypair)) => {
// send response
if let Some(msg) = msg {
let _ = bind.send(&msg[..], &src).map_err(|e| {
@@ -141,11 +146,13 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
}
// update timers
- if let Some(identity) = identity {
+ if let Some(pk) = pk {
// add keypair to peer and free any unused ids
if let Some(keypair) = keypair {
- for id in identity.0.peer.add_keypair(keypair) {
- handshake_device.release(id);
+ if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
+ for id in peer.0.router.add_keypair(keypair) {
+ state.device.release(id);
+ }
}
}
}
@@ -156,13 +163,76 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
});
}
- // start TUN read IO thread
+ // start UDP read IO thread
+ {
+ let wg = wg.clone();
+ let tun = tun.clone();
+ let bind = bind.clone();
+ thread::spawn(move || {
+ let mut last_under_load =
+ Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
+
+ loop {
+ // read UDP packet into vector
+ let size = tun.mtu() + 148; // maximum message size
+ let mut msg: Vec<u8> = Vec::with_capacity(size);
+ msg.resize(size, 0);
+ let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error
+ msg.truncate(size);
- thread::spawn(move || {});
+ // message type de-multiplexer
+ if msg.len() < std::mem::size_of::<u32>() {
+ continue;
+ }
- Wireguard {
- router,
- handshake: None,
+ match LittleEndian::read_u32(&msg[..]) {
+ handshake::TYPE_COOKIE_REPLY
+ | handshake::TYPE_INITIATION
+ | handshake::TYPE_RESPONSE => {
+ // update under_load flag
+ if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD {
+ last_under_load = Instant::now();
+ wg.under_load.store(true, Ordering::SeqCst);
+ } else if last_under_load.elapsed() > DURATION_UNDER_LOAD {
+ wg.under_load.store(false, Ordering::SeqCst);
+ }
+
+ wg.queue.lock().send((msg, src)).unwrap();
+ }
+ router::TYPE_TRANSPORT => {
+ // transport message
+
+ // pad the message
+
+ let _ = wg.router.recv(src, msg);
+ }
+ _ => (),
+ }
+ }
+ });
+ }
+
+ // start TUN read IO thread
+ {
+ let wg = wg.clone();
+ thread::spawn(move || loop {
+ // read a new IP packet
+ let mtu = tun.mtu();
+ let size = mtu + 148;
+ let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
+ let size = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
+ msg.truncate(size);
+
+ // pad message to multiple of 16
+ while msg.len() < mtu && msg.len() % 16 != 0 {
+ msg.push(0);
+ }
+
+ // crypt-key route
+ let _ = wg.router.send(msg);
+ });
}
+
+ Wireguard { state: wg }
}
}