From dd85201c15244fbd380eef8ee359a535335b7250 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Fri, 8 Nov 2019 19:00:12 +0100 Subject: Removal of secret key in the handshake module --- src/wireguard/handshake/device.rs | 147 ++++++++++++++++++++++++-------------- src/wireguard/handshake/noise.rs | 27 ++++--- src/wireguard/handshake/peer.rs | 15 ++-- src/wireguard/peer.rs | 37 +++------- src/wireguard/timers.rs | 5 +- src/wireguard/wireguard.rs | 128 +++++++++++++++++---------------- 6 files changed, 190 insertions(+), 169 deletions(-) (limited to 'src') diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs index c2e3a6e..85c2e45 100644 --- a/src/wireguard/handshake/device.rs +++ b/src/wireguard/handshake/device.rs @@ -21,10 +21,14 @@ use super::types::*; const MAX_PEER_PER_DEVICE: usize = 1 << 20; +pub struct KeyState { + pub sk: StaticSecret, // static secret key + pub pk: PublicKey, // static public key + macs: macs::Validator, // validator for the mac fields +} + pub struct Device { - pub sk: StaticSecret, // static secret key - pub pk: PublicKey, // static public key - macs: macs::Validator, // validator for the mac fields + keyst: Option, // secret/public key pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state id_map: RwLock>, // receiver ids -> public key limiter: Mutex, @@ -35,45 +39,68 @@ pub struct Device { */ impl Device { /// Initialize a new handshake state machine - /// - /// # Arguments - /// - /// * `sk` - x25519 scalar representing the local private key - pub fn new(sk: StaticSecret) -> Device { - let pk = PublicKey::from(&sk); + pub fn new() -> Device { Device { - pk, - sk, - macs: macs::Validator::new(pk), + keyst: None, pk_map: HashMap::new(), id_map: RwLock::new(HashMap::new()), limiter: Mutex::new(RateLimiter::new()), } } + fn update_ss(&self, peer: &mut Peer) -> Option { + if let Some(key) = self.keyst.as_ref() { + if *peer.pk.as_bytes() == *key.pk.as_bytes() { + return Some(peer.pk); + } + peer.ss = *key.sk.diffie_hellman(&peer.pk).as_bytes(); + } else { + peer.ss = [0u8; 32]; + }; + None + } + /// Update the secret key of the device /// /// # Arguments /// /// * `sk` - x25519 scalar representing the local private key - pub fn set_sk(&mut self, sk: StaticSecret) { + pub fn set_sk(&mut self, sk: Option) -> Option { // update secret and public key - let pk = PublicKey::from(&sk); - self.sk = sk; - self.pk = pk; - self.macs = macs::Validator::new(pk); + self.keyst = sk.map(|sk| { + let pk = PublicKey::from(&sk); + let macs = macs::Validator::new(pk); + KeyState { pk, sk, macs } + }); - // recalculate the shared secrets for every peer + // recalculate / erase the shared secrets for every peer let mut ids = vec![]; + let mut same = None; for mut peer in self.pk_map.values_mut() { + // clear any existing handshake state peer.reset_state().map(|id| ids.push(id)); - peer.ss = self.sk.diffie_hellman(&peer.pk) + + // update precomputed shared secret + if let Some(key) = self.keyst.as_ref() { + peer.ss = *key.sk.diffie_hellman(&peer.pk).as_bytes(); + if *peer.pk.as_bytes() == *key.pk.as_bytes() { + same = Some(peer.pk) + } + } else { + peer.ss = [0u8; 32]; + }; } // release ids from aborted handshakes for id in ids { self.release(id) } + + // if we found a peer matching the device public key, remove it. + same.map(|pk| { + self.pk_map.remove(pk.as_bytes()); + pk + }) } /// Return the secret key of the device @@ -81,8 +108,8 @@ impl Device { /// # Returns /// /// A secret key (x25519 scalar) - pub fn get_sk(&self) -> StaticSecret { - StaticSecret::from(self.sk.to_bytes()) + pub fn get_sk(&self) -> Option<&StaticSecret> { + self.keyst.as_ref().map(|key| &key.sk) } /// Add a new public key to the state machine @@ -93,28 +120,28 @@ impl Device { /// * `pk` - The public key to add /// * `identifier` - Associated identifier which can be used to distinguish the peers pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> { - // check that the pk is not added twice - if let Some(_) = self.pk_map.get(pk.as_bytes()) { - return Ok(()); - }; - - // check that the pk is not that of the device - if *self.pk.as_bytes() == *pk.as_bytes() { - return Err(ConfigError::new( - "Public key corresponds to secret key of interface", - )); - } - // ensure less than 2^20 peers if self.pk_map.len() > MAX_PEER_PER_DEVICE { return Err(ConfigError::new("Too many peers for device")); } - // map the public key to the peer state - self.pk_map - .insert(*pk.as_bytes(), Peer::new(pk, self.sk.diffie_hellman(&pk))); - - Ok(()) + // create peer and precompute static secret + let mut peer = Peer::new( + pk, + self.keyst + .as_ref() + .map(|key| *key.sk.diffie_hellman(&pk).as_bytes()) + .unwrap_or([0u8; 32]), + ); + + // add peer to device + match self.update_ss(&mut peer) { + Some(_) => Err(ConfigError::new("Public key of peer matches the device")), + None => { + self.pk_map.insert(*pk.as_bytes(), peer); + Ok(()) + } + } } /// Remove a peer by public key @@ -203,17 +230,17 @@ impl Device { rng: &mut R, pk: &PublicKey, ) -> Result, HandshakeError> { - match self.pk_map.get(pk.as_bytes()) { - None => Err(HandshakeError::UnknownPublicKey), - Some(peer) => { + match (self.keyst.as_ref(), self.pk_map.get(pk.as_bytes())) { + (_, None) => Err(HandshakeError::UnknownPublicKey), + (None, _) => Err(HandshakeError::UnknownPublicKey), + (Some(keyst), Some(peer)) => { let sender = self.allocate(rng, peer); - let mut msg = Initiation::default(); - noise::create_initiation(rng, self, peer, sender, &mut msg.noise)?; + // create noise part of initation + noise::create_initiation(rng, keyst, peer, sender, &mut msg.noise)?; // add macs to initation - peer.macs .lock() .generate(msg.noise.as_bytes(), &mut msg.macs); @@ -242,6 +269,15 @@ impl Device { return Err(HandshakeError::InvalidMessageFormat); } + // obtain reference to key state + // if no key is configured return a noop. + let keyst = match self.keyst.as_ref() { + Some(key) => key, + None => { + return Ok((None, None, None)); + } + }; + // de-multiplex the message type field match LittleEndian::read_u32(msg) { TYPE_INITIATION => { @@ -249,7 +285,7 @@ impl Device { let msg = Initiation::parse(msg)?; // check mac1 field - self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?; + keyst.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?; // address validation & DoS mitigation if let Some(src) = src { @@ -257,9 +293,9 @@ impl Device { let src = src.into(); // check mac2 field - if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) { + if !keyst.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) { let mut reply = Default::default(); - self.macs.create_cookie_reply( + keyst.macs.create_cookie_reply( rng, msg.noise.f_sender.get(), src, @@ -276,7 +312,7 @@ impl Device { } // consume the initiation - let (peer, st) = noise::consume_initiation(self, &msg.noise)?; + let (peer, st) = noise::consume_initiation(self, keyst, &msg.noise)?; // allocate new index for response let sender = self.allocate(rng, peer); @@ -304,7 +340,7 @@ impl Device { let msg = Response::parse(msg)?; // check mac1 field - self.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?; + keyst.macs.check_mac1(msg.noise.as_bytes(), &msg.macs)?; // address validation & DoS mitigation if let Some(src) = src { @@ -312,9 +348,9 @@ impl Device { let src = src.into(); // check mac2 field - if !self.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) { + if !keyst.macs.check_mac2(msg.noise.as_bytes(), src, &msg.macs) { let mut reply = Default::default(); - self.macs.create_cookie_reply( + keyst.macs.create_cookie_reply( rng, msg.noise.f_sender.get(), src, @@ -331,7 +367,7 @@ impl Device { } // consume inner playload - noise::consume_response(self, &msg.noise) + noise::consume_response(self, keyst, &msg.noise) } TYPE_COOKIE_REPLY => { let msg = CookieReply::parse(msg)?; @@ -421,8 +457,11 @@ mod tests { // intialize devices on both ends - let mut dev1 = Device::new(sk1); - let mut dev2 = Device::new(sk2); + let mut dev1 = Device::new(); + let mut dev2 = Device::new(); + + dev1.set_sk(Some(sk1)); + dev2.set_sk(Some(sk2)); dev1.add(pk2).unwrap(); dev2.add(pk1).unwrap(); diff --git a/src/wireguard/handshake/noise.rs b/src/wireguard/handshake/noise.rs index 68e738d..6db300a 100644 --- a/src/wireguard/handshake/noise.rs +++ b/src/wireguard/handshake/noise.rs @@ -22,7 +22,7 @@ use clear_on_drop::clear_stack_on_return; use subtle::ConstantTimeEq; -use super::device::Device; +use super::device::{Device, KeyState}; use super::messages::{NoiseInitiation, NoiseResponse}; use super::messages::{TYPE_INITIATION, TYPE_RESPONSE}; use super::peer::{Peer, State}; @@ -219,7 +219,7 @@ mod tests { pub fn create_initiation( rng: &mut R, - device: &Device, + keyst: &KeyState, peer: &Peer, sender: u32, msg: &mut NoiseInitiation, @@ -260,9 +260,9 @@ pub fn create_initiation( SEAL!( &key, - &hs, // ad - device.pk.as_bytes(), // pt - &mut msg.f_static // ct || tag + &hs, // ad + keyst.pk.as_bytes(), // pt + &mut msg.f_static // ct || tag ); // H := Hash(H || msg.static) @@ -271,7 +271,7 @@ pub fn create_initiation( // (C, k) := Kdf2(C, DH(S_priv, S_pub)) - let (ck, key) = KDF2!(&ck, peer.ss.as_bytes()); + let (ck, key) = KDF2!(&ck, &peer.ss); // msg.timestamp := Aead(k, 0, Timestamp(), H) @@ -301,6 +301,7 @@ pub fn create_initiation( pub fn consume_initiation<'a>( device: &'a Device, + keyst: &KeyState, msg: &NoiseInitiation, ) -> Result<(&'a Peer, TemporaryState), HandshakeError> { debug!("consume initation"); @@ -309,7 +310,7 @@ pub fn consume_initiation<'a>( let ck = INITIAL_CK; let hs = INITIAL_HS; - let hs = HASH!(&hs, device.pk.as_bytes()); + let hs = HASH!(&hs, keyst.pk.as_bytes()); // C := Kdf(C, E_pub) @@ -322,7 +323,7 @@ pub fn consume_initiation<'a>( // (C, k) := Kdf2(C, DH(E_priv, S_pub)) let eph_r_pk = PublicKey::from(msg.f_ephemeral); - let (ck, key) = KDF2!(&ck, device.sk.diffie_hellman(&eph_r_pk).as_bytes()); + let (ck, key) = KDF2!(&ck, keyst.sk.diffie_hellman(&eph_r_pk).as_bytes()); // msg.static := Aead(k, 0, S_pub, H) @@ -347,7 +348,7 @@ pub fn consume_initiation<'a>( // (C, k) := Kdf2(C, DH(S_priv, S_pub)) - let (ck, key) = KDF2!(&ck, peer.ss.as_bytes()); + let (ck, key) = KDF2!(&ck, &peer.ss); // msg.timestamp := Aead(k, 0, Timestamp(), H) @@ -461,7 +462,11 @@ pub fn create_response( * allow concurrent processing of potential responses to the initiation, * in order to better mitigate DoS from malformed response messages. */ -pub fn consume_response(device: &Device, msg: &NoiseResponse) -> Result { +pub fn consume_response( + device: &Device, + keyst: &KeyState, + msg: &NoiseResponse, +) -> Result { debug!("consume response"); clear_stack_on_return(CLEAR_PAGES, || { // retrieve peer and copy initiation state @@ -492,7 +497,7 @@ pub fn consume_response(device: &Device, msg: &NoiseResponse) -> Result, // constant state - pub(crate) pk: PublicKey, // public key of peer - pub(crate) ss: SharedSecret, // precomputed DH(static, static) - pub(crate) psk: Psk, // psk of peer + pub(crate) pk: PublicKey, // public key of peer + pub(crate) ss: [u8; 32], // precomputed DH(static, static) + pub(crate) psk: Psk, // psk of peer } pub enum State { @@ -62,17 +62,14 @@ impl Drop for State { } impl Peer { - pub fn new( - pk: PublicKey, // public key of peer - ss: SharedSecret, // precomputed DH(static, static) - ) -> Self { + pub fn new(pk: PublicKey, ss: [u8; 32]) -> Self { Self { macs: Mutex::new(macs::Generator::new(pk)), state: Mutex::new(State::Reset), timestamp: Mutex::new(None), last_initiation_consumption: Mutex::new(None), - pk: pk, - ss: ss, + pk, + ss, psk: [0u8; 32], } } diff --git a/src/wireguard/peer.rs b/src/wireguard/peer.rs index b77e8c6..4f9d19f 100644 --- a/src/wireguard/peer.rs +++ b/src/wireguard/peer.rs @@ -5,6 +5,7 @@ use super::HandshakeJob; use super::bind::Bind; use super::tun::Tun; +use super::wireguard::WireguardInner; use std::fmt; use std::ops::Deref; @@ -19,13 +20,16 @@ use x25519_dalek::PublicKey; pub struct Peer { pub router: Arc, T::Writer, B::Writer>>, - pub state: Arc>, + pub state: Arc>, } -pub struct PeerInner { +pub struct PeerInner { // internal id (for logging) pub id: u64, + // wireguard device state + pub wg: Arc>, + // handshake state pub walltime_last_handshake: Mutex, pub last_handshake_sent: Mutex, // instant for last handshake @@ -50,7 +54,7 @@ impl Clone for Peer { } } -impl PeerInner { +impl PeerInner { #[inline(always)] pub fn timers(&self) -> RwLockReadGuard { self.timers.read() @@ -69,7 +73,7 @@ impl fmt::Display for Peer { } impl Deref for Peer { - type Target = PeerInner; + type Target = PeerInner; fn deref(&self) -> &Self::Target { &self.state } @@ -91,28 +95,3 @@ impl Peer { self.start_timers(); } } - -impl PeerInner { - /* Queue a handshake request for the parallel workers - * (if one does not already exist) - * - * The function is ratelimited. - */ - pub fn packet_send_handshake_initiation(&self) { - // the function is rate limited - - { - let mut lhs = self.last_handshake_sent.lock(); - if lhs.elapsed() < REKEY_TIMEOUT { - return; - } - *lhs = Instant::now(); - } - - // create a new handshake job for the peer - - if !self.handshake_queued.swap(true, Ordering::SeqCst) { - self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap(); - } - } -} diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs index 038f6c6..33b089f 100644 --- a/src/wireguard/timers.rs +++ b/src/wireguard/timers.rs @@ -35,7 +35,7 @@ impl Timers { } } -impl PeerInner { +impl PeerInner { pub fn stop_timers(&self) { // take a write lock preventing simultaneous timer events or "start_timers" call let mut timers = self.timers_mut(); @@ -180,7 +180,6 @@ impl PeerInner { */ pub fn sent_handshake_initiation(&self) { *self.last_handshake_sent.lock() = Instant::now(); - self.handshake_queued.store(false, Ordering::SeqCst); self.timers_set_retransmit_handshake(); self.timers_any_authenticated_packet_traversal(); self.timers_any_authenticated_packet_sent(); @@ -333,7 +332,7 @@ impl Timers { pub struct Events(PhantomData<(T, B)>); impl Callbacks for Events { - type Opaque = Arc>; + type Opaque = Arc>; /* Called after the router encrypts a transport message destined for the peer. * This method is called, even if the encrypted payload is empty (keepalive) diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index 6da428c..a890d5e 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -42,46 +42,56 @@ pub struct WireguardInner { mtu: T::MTU, send: RwLock>, - // identify and configuration map + // identity and configuration map peers: RwLock>>, // cryptokey router router: router::Device, T::Writer, B::Writer>, // handshake related state - handshake: RwLock, + handshake: RwLock, under_load: AtomicBool, pending: AtomicUsize, // num of pending handshake packets in queue queue: Mutex>>, } +impl PeerInner { + /* Queue a handshake request for the parallel workers + * (if one does not already exist) + * + * The function is ratelimited. + */ + pub fn packet_send_handshake_initiation(&self) { + // the function is rate limited + + { + let mut lhs = self.last_handshake_sent.lock(); + if lhs.elapsed() < REKEY_TIMEOUT { + return; + } + *lhs = Instant::now(); + } + + // create a new handshake job for the peer + + if !self.handshake_queued.swap(true, Ordering::SeqCst) { + self.wg.pending.fetch_add(1, Ordering::SeqCst); + self.queue.lock().send(HandshakeJob::New(self.pk)).unwrap(); + } + } +} + pub enum HandshakeJob { Message(Vec, E), New(PublicKey), } -#[derive(Clone)] -pub struct WireguardHandle { - inner: Arc>, -} - impl fmt::Display for WireguardInner { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "wireguard({:x})", self.id) } } -struct Handshake { - device: handshake::Device, - active: bool, -} - -impl Deref for WireguardHandle { - type Target = Arc>; - fn deref(&self) -> &Self::Target { - &self.inner - } -} impl Deref for Wireguard { type Target = Arc>; fn deref(&self) -> &Self::Target { @@ -91,7 +101,7 @@ impl Deref for Wireguard { pub struct Wireguard { runner: Runner, - state: WireguardHandle, + state: Arc>, } /* Returns the padded length of a message: @@ -181,31 +191,18 @@ impl Wireguard { } pub fn set_key(&self, sk: Option) { - let mut handshake = self.state.handshake.write(); - match sk { - None => { - let mut rng = OsRng::new().unwrap(); - handshake.device.set_sk(StaticSecret::new(&mut rng)); - handshake.active = false; - } - Some(sk) => { - handshake.device.set_sk(sk); - handshake.active = true; - } - } + self.handshake.write().set_sk(sk); } pub fn get_sk(&self) -> Option { - let handshake = self.state.handshake.read(); - if handshake.active { - Some(handshake.device.get_sk()) - } else { - None - } + self.handshake + .read() + .get_sk() + .map(|sk| StaticSecret::from(sk.to_bytes())) } pub fn set_psk(&self, pk: PublicKey, psk: Option<[u8; 32]>) -> bool { - self.state.handshake.write().device.set_psk(pk, psk).is_ok() + self.state.handshake.write().set_psk(pk, psk).is_ok() } pub fn add_peer(&self, pk: PublicKey) { @@ -217,6 +214,7 @@ impl Wireguard { let state = Arc::new(PeerInner { id: rng.gen(), pk, + wg: self.state.clone(), walltime_last_handshake: Mutex::new(SystemTime::UNIX_EPOCH), last_handshake_sent: Mutex::new(self.state.start - TIME_HORIZON), handshake_queued: AtomicBool::new(false), @@ -245,14 +243,14 @@ impl Wireguard { peers.entry(*pk.as_bytes()).or_insert(peer); // add to the handshake device - self.state.handshake.write().device.add(pk).unwrap(); // TODO: handle adding of public key for interface + self.state.handshake.write().add(pk).unwrap(); // TODO: handle adding of public key for interface } - /* Begin consuming messages from the reader. - * - * Any previous reader thread is stopped by closing the previous reader, - * which unblocks the thread and causes an error on reader.read - */ + /// Begin consuming messages from the reader. + /// Multiple readers can be added to support multi-queue and individual Ipv6/Ipv4 sockets interfaces + /// + /// Any previous reader thread is stopped by closing the previous reader, + /// which unblocks the thread and causes an error on reader.read pub fn add_reader(&self, reader: B::Reader) { let wg = self.state.clone(); thread::spawn(move || { @@ -285,6 +283,7 @@ impl Wireguard { | handshake::TYPE_RESPONSE => { debug!("{} : reader, received handshake message", wg); + // add one to pending let pending = wg.pending.fetch_add(1, Ordering::SeqCst); // update under_load flag @@ -297,6 +296,7 @@ impl Wireguard { wg.under_load.store(false, Ordering::SeqCst); } + // add to handshake queue wg.queue .lock() .send(HandshakeJob::Message(msg, src)) @@ -325,7 +325,10 @@ impl Wireguard { pub fn new(mut readers: Vec, writer: T::Writer, mtu: T::MTU) -> Wireguard { // create device state let mut rng = OsRng::new().unwrap(); + + // handshake queue let (tx, rx): (Sender>, _) = bounded(SIZE_HANDSHAKE_QUEUE); + let wg = Arc::new(WireguardInner { start: Instant::now(), id: rng.gen(), @@ -334,10 +337,7 @@ impl Wireguard { send: RwLock::new(None), router: router::Device::new(num_cpus::get(), writer), // router owns the writing half pending: AtomicUsize::new(0), - handshake: RwLock::new(Handshake { - device: handshake::Device::new(StaticSecret::new(&mut rng)), - active: false, - }), + handshake: RwLock::new(handshake::Device::new()), under_load: AtomicBool::new(false), queue: Mutex::new(tx), }); @@ -350,24 +350,22 @@ impl Wireguard { debug!("{} : handshake worker, started", wg); // prepare OsRng instance for this thread - let mut rng = OsRng::new().unwrap(); + let mut rng = OsRng::new().expect("Unable to obtain a CSPRNG"); // process elements from the handshake queue for job in rx { - let state = wg.handshake.read(); - if !state.active { - continue; - } + // decrement pending + wg.pending.fetch_sub(1, Ordering::SeqCst); + + let device = wg.handshake.read(); match job { HandshakeJob::Message(msg, src) => { - wg.pending.fetch_sub(1, Ordering::SeqCst); - // feed message to handshake device let src_validate = (&src).into_address(); // TODO avoid // process message - match state.device.process( + match device.process( &mut rng, &msg[..], if wg.under_load.load(Ordering::Relaxed) { @@ -428,7 +426,7 @@ impl Wireguard { // free any unused ids for id in peer.router.add_keypair(kp) { - state.device.release(id); + device.release(id); } }); } @@ -438,15 +436,19 @@ impl Wireguard { } } HandshakeJob::New(pk) => { - debug!("{} : handshake worker, new handshake requested", wg); - let _ = state.device.begin(&mut rng, &pk).map(|msg| { - if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { + if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { + debug!( + "{} : handshake worker, new handshake requested for {}", + wg, peer + ); + let _ = device.begin(&mut rng, &peer.pk).map(|msg| { let _ = peer.router.send(&msg[..]).map_err(|e| { debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e) }); peer.state.sent_handshake_initiation(); - } - }); + }); + peer.handshake_queued.store(false, Ordering::SeqCst); + } } } } @@ -498,7 +500,7 @@ impl Wireguard { } Wireguard { - state: WireguardHandle { inner: wg }, + state: wg, runner: Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY), } } -- cgit v1.2.3-59-g8ed1b