From 89e80e0c3c7f505944d5730136d210b436285694 Mon Sep 17 00:00:00 2001 From: Jake McGinty Date: Fri, 1 Jun 2018 21:37:09 -0500 Subject: crossbeam crypto pool --- Cargo.toml | 3 +- src/crypto_pool.rs | 113 +++++++++++++++++++++++++++++++++++++++++++ src/interface/mod.rs | 9 ++++ src/interface/peer_server.rs | 93 ++++++++++++++++++----------------- src/lib.rs | 4 +- src/peer.rs | 65 +++++++++---------------- 6 files changed, 200 insertions(+), 87 deletions(-) create mode 100644 src/crypto_pool.rs diff --git a/Cargo.toml b/Cargo.toml index 1e35080..06d7852 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,10 +34,10 @@ blake2-rfc = "0.2" byteorder = "^1.2" bytes = "0.4" chacha20-poly1305-aead = "^0.1" +crossbeam-channel = "0.1" derive_deref = "^1.0" failure = "^0.1" futures = "^0.1" -futures-cpupool = "^0.1" lazy_static = "^1" libc = { git = "https://github.com/rust-lang/libc" } log = "^0.4" @@ -46,6 +46,7 @@ notify = "^4.0" rand = "0.5.0-pre.2" nix = { git = "https://github.com/mcginty/nix", branch = "ipv6-pktinfo" } mio = "^0.6" +num_cpus = "1.0" pnet_packet = "^0.21" snow = { git = "https://github.com/mcginty/snow", branch = "wireguard" } socket2 = "^0.3" diff --git a/src/crypto_pool.rs b/src/crypto_pool.rs new file mode 100644 index 0000000..b77b9fd --- /dev/null +++ b/src/crypto_pool.rs @@ -0,0 +1,113 @@ +use consts::{PADDING_MULTIPLE, TRANSPORT_OVERHEAD, TRANSPORT_HEADER_SIZE}; +use crossbeam_channel::{unbounded, Receiver, Sender}; +use futures::sync::mpsc; +use futures::executor; +use futures::Sink; +use num_cpus; +use snow::AsyncTransportState; +use std::thread; +use udp::Endpoint; +use message; +use peer::SessionType; +use ip_packet::IpPacket; +use byteorder::{ByteOrder, LittleEndian}; + +pub enum Work { + Decrypt((mpsc::UnboundedSender, DecryptWork)), + Encrypt((mpsc::UnboundedSender, EncryptWork)), +} + +pub struct EncryptWork { + pub transport: AsyncTransportState, + pub nonce: u64, + pub our_index: u32, + pub their_index: u32, + pub endpoint: Endpoint, + pub in_packet: Vec, +} + +pub struct EncryptResult { + pub endpoint: Endpoint, + pub our_index: u32, + pub out_packet: Vec, +} + +pub struct DecryptWork { + pub transport: AsyncTransportState, + pub endpoint: Endpoint, + pub packet: message::Transport, + pub session_type: SessionType, +} + +pub struct DecryptResult { + pub endpoint: Endpoint, + pub orig_packet: message::Transport, + pub out_packet: Vec, + pub session_type: SessionType, +} + +/// Spawn a thread pool to efficiently process +/// the CPU-intensive encryption/decryption. +pub fn create() -> Sender { + let threads = num_cpus::get() - 1; // One thread for I/O. + let (sender, receiver) = unbounded(); + + for i in 0..threads { + let rx = receiver.clone(); + thread::Builder::new().name(format!("wireguard-rs-crypto-{}", i)) + .spawn(move || worker(rx.clone())).unwrap(); + } + + sender +} + +fn worker(receiver: Receiver) { + loop { + let work = receiver.recv().expect("channel to crypto worker thread broken."); + match work { + Work::Decrypt((tx, element)) => { + let mut raw_packet = vec![0u8; element.packet.len()]; + let nonce = element.packet.nonce(); + let len = element.transport.read_transport_message(nonce, element.packet.payload(), &mut raw_packet).unwrap(); + if len > 0 { + let len = IpPacket::new(&raw_packet[..len]) + .ok_or_else(||format_err!("invalid IP packet (len {})", len)).unwrap() + .length(); + raw_packet.truncate(len as usize); + } else { + raw_packet.truncate(0); + } + + executor::spawn(tx.send(DecryptResult { + endpoint: element.endpoint, + orig_packet: element.packet, + out_packet: raw_packet, + session_type: element.session_type, + })).wait_future(); + }, + Work::Encrypt((tx, mut element)) => { + let padding = if element.in_packet.len() % PADDING_MULTIPLE != 0 { + PADDING_MULTIPLE - (element.in_packet.len() % PADDING_MULTIPLE) + } else { 0 }; + let padded_len = element.in_packet.len() + padding; + let mut out_packet = vec![0u8; padded_len + TRANSPORT_OVERHEAD]; + + out_packet[0] = 4; + LittleEndian::write_u32(&mut out_packet[4..], element.their_index); + LittleEndian::write_u64(&mut out_packet[8..], element.nonce); + + element.in_packet.resize(padded_len, 0); + let len = element.transport.write_transport_message(element.nonce, + &element.in_packet, + &mut out_packet[16..]).unwrap(); + out_packet.truncate(TRANSPORT_HEADER_SIZE + len); + + executor::spawn(tx.send(EncryptResult { + endpoint: element.endpoint, + our_index: element.our_index, + out_packet, + })).wait_future(); + } + } + } +} \ No newline at end of file diff --git a/src/interface/mod.rs b/src/interface/mod.rs index c45e70f..acfb9be 100644 --- a/src/interface/mod.rs +++ b/src/interface/mod.rs @@ -66,6 +66,15 @@ impl UtunPacket { } } +impl From for Vec { + fn from(packet: UtunPacket) -> Vec { + use self::UtunPacket::*; + match packet { + Inet4(payload) | Inet6(payload) => payload, + } + } +} + impl UtunCodec for VecUtunCodec { type In = UtunPacket; type Out = Vec; diff --git a/src/interface/peer_server.rs b/src/interface/peer_server.rs index aa82fb7..39ba68d 100644 --- a/src/interface/peer_server.rs +++ b/src/interface/peer_server.rs @@ -2,6 +2,7 @@ use consts::{REKEY_TIMEOUT, KEEPALIVE_TIMEOUT, STALE_SESSION_TIMEOUT, MAX_CONTENT_SIZE, WIPE_AFTER_TIME, MAX_HANDSHAKE_ATTEMPTS, UNDER_LOAD_QUEUE_SIZE, UNDER_LOAD_TIME}; use cookie; +use crypto_pool::{self, DecryptResult, EncryptResult}; use interface::{SharedPeer, SharedState, State, UtunPacket}; use message::{Message, Initiation, Response, CookieReply, Transport}; use peer::{Peer, SessionType, SessionTransition}; @@ -10,9 +11,9 @@ use timestamp::Timestamp; use timer::{Timer, TimerMessage}; use byteorder::{ByteOrder, LittleEndian}; +use crossbeam_channel as crossbeam; use failure::{Error, err_msg}; -use futures::{Async, Future, Stream, Poll, unsync::mpsc, task}; -use futures_cpupool::CpuPool; +use futures::{Async, Future, Stream, Poll, sync, unsync::mpsc, task}; use rand::{self, Rng, ThreadRng}; use udp::{Endpoint, UdpSocket, PeerServerMessage, UdpChannel}; use tokio_core::reactor::Handle; @@ -37,6 +38,20 @@ struct Channel { rx: mpsc::UnboundedReceiver, } +struct SyncChannel { + tx: sync::mpsc::UnboundedSender, + rx: sync::mpsc::UnboundedReceiver, +} + +impl From<(sync::mpsc::UnboundedSender, sync::mpsc::UnboundedReceiver)> for SyncChannel { + fn from(pair: (sync::mpsc::UnboundedSender, sync::mpsc::UnboundedReceiver)) -> Self { + Self { + tx: pair.0, + rx: pair.1, + } + } +} + impl From<(mpsc::UnboundedSender, mpsc::UnboundedReceiver)> for Channel { fn from(pair: (mpsc::UnboundedSender, mpsc::UnboundedReceiver)) -> Self { Self { @@ -60,9 +75,9 @@ pub struct PeerServer { rate_limiter : RateLimiter, under_load_until : Instant, rng : ThreadRng, - cpu_pool : CpuPool, - decrypt_channel : Channel<(Endpoint, Transport, Vec, SessionType)>, - encrypt_channel : Channel<(SharedPeer, (Endpoint, Vec))>, + crypto_pool : crossbeam::Sender, + decrypt_channel : SyncChannel, + encrypt_channel : SyncChannel, } impl PeerServer { @@ -80,9 +95,9 @@ impl PeerServer { rate_limiter : RateLimiter::new(&handle)?, under_load_until : Instant::now(), rng : rand::thread_rng(), - cpu_pool : CpuPool::new_num_cpus(), - decrypt_channel : mpsc::unbounded().into(), - encrypt_channel : mpsc::unbounded().into(), + crypto_pool : crypto_pool::create(), + decrypt_channel : sync::mpsc::unbounded().into(), + encrypt_channel : sync::mpsc::unbounded().into(), }) } @@ -273,7 +288,7 @@ impl PeerServer { } } else { debug!("sending empty keepalive"); - self.encrypt_and_send(peer_ref.clone(), &mut peer, UtunPacket::Inet4(vec![]))?; + self.encrypt_and_send(peer_ref.clone(), &mut peer, vec![])?; } } else { error!("peer not ready for transport after processing handshake response. this shouldn't happen."); @@ -299,38 +314,27 @@ impl PeerServer { let mut peer = peer_ref.borrow_mut(); let tx = self.decrypt_channel.tx.clone(); - let f = self.cpu_pool.spawn(peer.handle_incoming_transport(addr, packet)?) - .and_then(move |result| { - tx.unbounded_send(result).expect("broken decrypt channel"); - Ok(()) - }) - .map_err(|e| warn!("{:?}", e)); - self.handle.spawn(f); + let work = crypto_pool::Work::Decrypt((tx, peer.handle_incoming_transport(addr, packet)?)); + self.crypto_pool.send(work)?; Ok(()) } - fn encrypt_and_send(&mut self, peer_ref: SharedPeer, peer: &mut Peer, packet: UtunPacket) -> Result<(), Error> { + fn encrypt_and_send(&mut self, peer_ref: SharedPeer, peer: &mut Peer, packet: Vec) -> Result<(), Error> { let tx = self.encrypt_channel.tx.clone(); - let f = self.cpu_pool.spawn(peer.handle_outgoing_transport(packet)?) - .and_then(move |result| { - tx.unbounded_send((peer_ref, result)).expect("broken decrypt channel"); - Ok(()) - }) - .map_err(|e| warn!("{:?}", e)); - self.handle.spawn(f); + let work = crypto_pool::Work::Encrypt((tx, peer.handle_outgoing_transport(packet)?)); + self.crypto_pool.send(work)?; Ok(()) } - fn handle_ingress_decrypted_transport(&mut self, addr: Endpoint, orig_packet: Transport, raw_packet: Vec, session_type: SessionType) - -> Result<(), Error> + fn handle_ingress_decrypted_transport(&mut self, result: DecryptResult) -> Result<(), Error> { - let peer_ref = self.shared_state.borrow().index_map.get(&orig_packet.our_index()) + let peer_ref = self.shared_state.borrow().index_map.get(&result.orig_packet.our_index()) .ok_or_else(|| err_msg("unknown our_index"))?.clone(); let needs_handshake = { let mut peer = peer_ref.borrow_mut(); - let transition = peer.handle_incoming_decrypted_transport(addr, &raw_packet, session_type)?; + let transition = peer.handle_incoming_decrypted_transport(result.endpoint, &result.out_packet, result.session_type)?; let shared_state = self.shared_state.clone(); let mut state = shared_state.borrow_mut(); if let SessionTransition::Transition(possible_dead_index) = transition { @@ -338,7 +342,7 @@ impl PeerServer { let _ = state.index_map.remove(&index); } - let outgoing: Vec = peer.outgoing_queue.drain(..).collect(); + let outgoing: Vec> = peer.outgoing_queue.drain(..).collect(); for packet in outgoing { self.encrypt_and_send(peer_ref.clone(), &mut peer, packet)?; @@ -354,14 +358,14 @@ impl PeerServer { self.send_handshake_init(&peer_ref)?; } - if raw_packet.is_empty() { + if result.out_packet.is_empty() { debug!("received keepalive."); return Ok(()) // short-circuit on keep-alives } - self.shared_state.borrow_mut().router.validate_source(&raw_packet, &peer_ref)?; + self.shared_state.borrow_mut().router.validate_source(&result.out_packet, &peer_ref)?; trace!("received transport packet"); - self.send_to_tunnel(raw_packet)?; + self.send_to_tunnel(result.out_packet)?; Ok(()) } @@ -374,7 +378,7 @@ impl PeerServer { let needs_handshake = { let mut peer = peer_ref.borrow_mut(); let needs_handshake = peer.needs_new_handshake(true); - peer.queue_egress(packet); + peer.queue_egress(packet.into()); if peer.ready_for_transport() { if peer.outgoing_queue.len() > 1 { @@ -396,11 +400,13 @@ impl PeerServer { Ok(()) } - fn handle_egress_encrypted_packet(&mut self, peer_ref: SharedPeer, endpoint: Endpoint, packet: Vec) -> Result<(), Error> { + fn handle_egress_encrypted_packet(&mut self, result: EncryptResult) -> Result<(), Error> { + let peer_ref = self.shared_state.borrow().index_map.get(&result.our_index) + .ok_or_else(|| err_msg("unknown our_index"))?.clone(); let mut peer = peer_ref.borrow_mut(); - peer.handle_outgoing_encrypted_transport(&packet); + peer.handle_outgoing_encrypted_transport(&result.out_packet); - self.send_to_peer((endpoint, packet)) + self.send_to_peer((result.endpoint, result.out_packet)) } fn send_cookie_reply(&mut self, addr: Endpoint, mac1: &[u8], index: u32) -> Result<(), Error> { @@ -508,7 +514,7 @@ impl PeerServer { } } - self.encrypt_and_send(upgraded_peer_ref.clone(), &mut peer, UtunPacket::Inet4(vec![]))?; + self.encrypt_and_send(upgraded_peer_ref.clone(), &mut peer, vec![])?; debug!("sent passive keepalive packet"); self.timer.send_after(*KEEPALIVE_TIMEOUT, PassiveKeepAlive(peer_ref.clone())); @@ -526,7 +532,7 @@ impl PeerServer { bail!("persistent keepalive tick (waiting ~{}s due to last authenticated packet time)", wait.as_secs()); } - self.encrypt_and_send(upgraded_peer_ref.clone(), &mut peer, UtunPacket::Inet4(vec![]))?; + self.encrypt_and_send(upgraded_peer_ref.clone(), &mut peer, vec![])?; let handle = self.timer.send_after(persistent_keepalive, PersistentKeepAlive(peer_ref.clone())); peer.timers.persistent_timer = Some(handle); debug!("sent persistent keepalive packet"); @@ -584,7 +590,7 @@ impl PeerServer { if let Some(keepalive) = peer.info.persistent_keepalive() { let handle = self.timer.send_after(keepalive, TimerMessage::PersistentKeepAlive(Rc::downgrade(&peer_ref))); peer.timers.persistent_timer = Some(handle); - self.encrypt_and_send(peer_ref.clone(), &mut peer, UtunPacket::Inet4(vec![]))?; + self.encrypt_and_send(peer_ref.clone(), &mut peer, vec![])?; debug!("set new keepalive timer and immediately sent new keepalive packet."); } } @@ -646,10 +652,9 @@ impl Future for PeerServer { } loop { - // Handle UDP packets from the outside world match self.decrypt_channel.rx.poll() { - Ok(Async::Ready(Some((addr, orig_packet, decrypted, session_type)))) => { - let _ = self.handle_ingress_decrypted_transport(addr, orig_packet, decrypted, session_type).map_err(|e| warn!("UDP ERR: {:?}", e)); + Ok(Async::Ready(Some(result))) => { + let _ = self.handle_ingress_decrypted_transport(result).map_err(|e| warn!("UDP ERR: {:?}", e)); }, Ok(Async::NotReady) => { break; }, Ok(Async::Ready(None)) => bail!("incoming udp stream ended unexpectedly"), @@ -660,8 +665,8 @@ impl Future for PeerServer { loop { // Handle UDP packets from the outside world match self.encrypt_channel.rx.poll() { - Ok(Async::Ready(Some((peer_ref, (endpoint, packet))))) => { - let _ = self.handle_egress_encrypted_packet(peer_ref, endpoint, packet).map_err(|e| warn!("UDP ERR: {:?}", e)); + Ok(Async::Ready(Some(result))) => { + let _ = self.handle_egress_encrypted_packet(result).map_err(|e| warn!("UDP ERR: {:?}", e)); }, Ok(Async::NotReady) => { break; }, Ok(Async::Ready(None)) => bail!("incoming udp stream ended unexpectedly"), diff --git a/src/lib.rs b/src/lib.rs index 9031322..0fb68f3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,12 +21,13 @@ extern crate blake2_rfc; extern crate byteorder; extern crate bytes; extern crate chacha20_poly1305_aead; -extern crate futures_cpupool; +extern crate crossbeam_channel; extern crate hex; extern crate libc; extern crate mio; extern crate nix; extern crate notify; +extern crate num_cpus; extern crate pnet_packet; extern crate rand; extern crate snow; @@ -50,6 +51,7 @@ pub mod types; mod anti_replay; mod consts; mod cookie; +mod crypto_pool; mod error; mod ip_packet; mod message; diff --git a/src/peer.rs b/src/peer.rs index 3a853ae..aace50d 100644 --- a/src/peer.rs +++ b/src/peer.rs @@ -1,13 +1,13 @@ use anti_replay::AntiReplay; use byteorder::{ByteOrder, LittleEndian}; -use consts::{TRANSPORT_OVERHEAD, TRANSPORT_HEADER_SIZE, REKEY_AFTER_MESSAGES, REKEY_AFTER_TIME, - REKEY_AFTER_TIME_RECV, REJECT_AFTER_TIME, REJECT_AFTER_MESSAGES, PADDING_MULTIPLE, +use consts::{REKEY_AFTER_MESSAGES, REKEY_AFTER_TIME, + REKEY_AFTER_TIME_RECV, REJECT_AFTER_TIME, REJECT_AFTER_MESSAGES, MAX_QUEUED_PACKETS, MAX_HANDSHAKE_ATTEMPTS}; +use crypto_pool::{DecryptWork, EncryptWork}; use cookie; use failure::{Error, err_msg}; use futures::{Future, future}; use interface::UtunPacket; -use ip_packet::IpPacket; use noise; use message::{Initiation, Response, CookieReply, Transport}; use std::{self, mem}; @@ -28,7 +28,7 @@ pub struct Peer { pub tx_bytes : u64, pub rx_bytes : u64, pub last_handshake_tai64n : Option, - pub outgoing_queue : VecDeque, + pub outgoing_queue : VecDeque>, pub cookie : cookie::Generator, } @@ -176,7 +176,7 @@ impl Peer { } } - pub fn queue_egress(&mut self, packet: UtunPacket) { + pub fn queue_egress(&mut self, packet: Vec) { if self.outgoing_queue.len() < MAX_QUEUED_PACKETS { self.outgoing_queue.push_back(packet); self.timers.handshake_attempts = 0; @@ -339,10 +339,7 @@ impl Peer { Ok(dead.map(|session| session.our_index)) } - pub fn handle_incoming_transport(&mut self, addr: Endpoint, packet: Transport) - -> Result, SessionType), Error = Error> + 'static + Send>, Error> - { - let mut raw_packet = vec![0u8; packet.len()]; + pub fn handle_incoming_transport(&mut self, endpoint: Endpoint, packet: Transport) -> Result { let nonce = packet.nonce(); let (session, session_type) = self.find_session(packet.our_index()).ok_or_else(|| err_msg("no session with index"))?; @@ -352,18 +349,12 @@ impl Peer { session.anti_replay.update(nonce)?; let mut transport = session.noise.get_async_transport_state()?.clone(); - Ok(Box::new(future::lazy(move || { - let len = transport.read_transport_message(nonce, packet.payload(), &mut raw_packet).unwrap(); - if len > 0 { - let len = IpPacket::new(&raw_packet[..len]) - .ok_or_else(||format_err!("invalid IP packet (len {})", len)).unwrap() - .length(); - raw_packet.truncate(len as usize); - } else { - raw_packet.truncate(0); - } - Ok((addr, packet, raw_packet, session_type)) - }))) + Ok(DecryptWork { + transport, + endpoint, + packet, + session_type + }) } pub fn handle_incoming_decrypted_transport(&mut self, addr: Endpoint, raw_packet: &[u8], session_type: SessionType) @@ -395,34 +386,26 @@ impl Peer { Ok(transition) } - pub fn handle_outgoing_transport(&mut self, packet: UtunPacket) - -> Result), Error = Error> + 'static + Send>, Error> - { - let session = self.sessions.current.as_mut().ok_or_else(|| err_msg("no current noise session"))?; - let endpoint = self.info.endpoint.ok_or_else(|| err_msg("no known peer endpoint"))?; - let padding = if packet.payload().len() % PADDING_MULTIPLE != 0 { - PADDING_MULTIPLE - (packet.payload().len() % PADDING_MULTIPLE) - } else { 0 }; - let padded_len = packet.payload().len() + padding; - let mut out_packet = vec![0u8; padded_len + TRANSPORT_OVERHEAD]; + pub fn handle_outgoing_transport(&mut self, packet: Vec) -> Result { + let session = self.sessions.current.as_mut().ok_or_else(|| err_msg("no current noise session"))?; + let endpoint = self.info.endpoint.ok_or_else(|| err_msg("no known peer endpoint"))?; ensure!(session.nonce < REJECT_AFTER_MESSAGES, "exceeded REJECT-AFTER-MESSAGES"); ensure!(session.birthday.elapsed() < *REJECT_AFTER_TIME, "exceeded REJECT-AFTER-TIME"); - let mut transport = session.noise.get_async_transport_state()?.clone(); session.nonce += 1; let nonce = session.nonce - 1; - out_packet[0] = 4; - LittleEndian::write_u32(&mut out_packet[4..], session.their_index); - LittleEndian::write_u64(&mut out_packet[8..], nonce); + let mut transport = session.noise.get_async_transport_state()?.clone(); - Ok(Box::new(future::lazy(move || { - let padded_packet = &[packet.payload(), &vec![0u8; padding]].concat(); - let len = transport.write_transport_message(nonce, padded_packet, &mut out_packet[16..])?; - out_packet.truncate(TRANSPORT_HEADER_SIZE + len); - Ok((endpoint, out_packet)) - }))) + Ok(EncryptWork { + transport, + nonce, + endpoint, + our_index: session.our_index, + their_index: session.their_index, + in_packet: packet + }) } pub fn handle_outgoing_encrypted_transport(&mut self, packet: &[u8]) { -- cgit v1.2.3-59-g8ed1b