diff options
Diffstat (limited to 'src/crypto_pool.rs')
-rw-r--r-- | src/crypto_pool.rs | 113 |
1 files changed, 113 insertions, 0 deletions
diff --git a/src/crypto_pool.rs b/src/crypto_pool.rs new file mode 100644 index 0000000..b77b9fd --- /dev/null +++ b/src/crypto_pool.rs @@ -0,0 +1,113 @@ +use consts::{PADDING_MULTIPLE, TRANSPORT_OVERHEAD, TRANSPORT_HEADER_SIZE}; +use crossbeam_channel::{unbounded, Receiver, Sender}; +use futures::sync::mpsc; +use futures::executor; +use futures::Sink; +use num_cpus; +use snow::AsyncTransportState; +use std::thread; +use udp::Endpoint; +use message; +use peer::SessionType; +use ip_packet::IpPacket; +use byteorder::{ByteOrder, LittleEndian}; + +pub enum Work { + Decrypt((mpsc::UnboundedSender<DecryptResult>, DecryptWork)), + Encrypt((mpsc::UnboundedSender<EncryptResult>, EncryptWork)), +} + +pub struct EncryptWork { + pub transport: AsyncTransportState, + pub nonce: u64, + pub our_index: u32, + pub their_index: u32, + pub endpoint: Endpoint, + pub in_packet: Vec<u8>, +} + +pub struct EncryptResult { + pub endpoint: Endpoint, + pub our_index: u32, + pub out_packet: Vec<u8>, +} + +pub struct DecryptWork { + pub transport: AsyncTransportState, + pub endpoint: Endpoint, + pub packet: message::Transport, + pub session_type: SessionType, +} + +pub struct DecryptResult { + pub endpoint: Endpoint, + pub orig_packet: message::Transport, + pub out_packet: Vec<u8>, + pub session_type: SessionType, +} + +/// Spawn a thread pool to efficiently process +/// the CPU-intensive encryption/decryption. +pub fn create() -> Sender<Work> { + let threads = num_cpus::get() - 1; // One thread for I/O. + let (sender, receiver) = unbounded(); + + for i in 0..threads { + let rx = receiver.clone(); + thread::Builder::new().name(format!("wireguard-rs-crypto-{}", i)) + .spawn(move || worker(rx.clone())).unwrap(); + } + + sender +} + +fn worker(receiver: Receiver<Work>) { + loop { + let work = receiver.recv().expect("channel to crypto worker thread broken."); + match work { + Work::Decrypt((tx, element)) => { + let mut raw_packet = vec![0u8; element.packet.len()]; + let nonce = element.packet.nonce(); + let len = element.transport.read_transport_message(nonce, element.packet.payload(), &mut raw_packet).unwrap(); + if len > 0 { + let len = IpPacket::new(&raw_packet[..len]) + .ok_or_else(||format_err!("invalid IP packet (len {})", len)).unwrap() + .length(); + raw_packet.truncate(len as usize); + } else { + raw_packet.truncate(0); + } + + executor::spawn(tx.send(DecryptResult { + endpoint: element.endpoint, + orig_packet: element.packet, + out_packet: raw_packet, + session_type: element.session_type, + })).wait_future(); + }, + Work::Encrypt((tx, mut element)) => { + let padding = if element.in_packet.len() % PADDING_MULTIPLE != 0 { + PADDING_MULTIPLE - (element.in_packet.len() % PADDING_MULTIPLE) + } else { 0 }; + let padded_len = element.in_packet.len() + padding; + let mut out_packet = vec![0u8; padded_len + TRANSPORT_OVERHEAD]; + + out_packet[0] = 4; + LittleEndian::write_u32(&mut out_packet[4..], element.their_index); + LittleEndian::write_u64(&mut out_packet[8..], element.nonce); + + element.in_packet.resize(padded_len, 0); + let len = element.transport.write_transport_message(element.nonce, + &element.in_packet, + &mut out_packet[16..]).unwrap(); + out_packet.truncate(TRANSPORT_HEADER_SIZE + len); + + executor::spawn(tx.send(EncryptResult { + endpoint: element.endpoint, + our_index: element.our_index, + out_packet, + })).wait_future(); + } + } + } +}
\ No newline at end of file |