aboutsummaryrefslogtreecommitdiffstats
path: root/src/crypto_pool.rs
blob: c99002a3201146a8e507defd5340dae3fb01afb7 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
use consts::{PADDING_MULTIPLE, TRANSPORT_OVERHEAD, TRANSPORT_HEADER_SIZE};
use crossbeam_channel::{bounded, Receiver, Sender};
use futures::task::Task;
use num_cpus;
use snow::AsyncTransportState;
use std::thread;
use udp::Endpoint;
use message;
use peer::SessionType;
use ip_packet::IpPacket;
use byteorder::{ByteOrder, LittleEndian};

pub enum Work {
    Decrypt((Sender<DecryptResult>, Task, DecryptWork)),
    Encrypt((Sender<EncryptResult>, Task, EncryptWork)),
}

pub struct EncryptWork {
    pub transport: AsyncTransportState,
    pub nonce: u64,
    pub our_index: u32,
    pub their_index: u32,
    pub endpoint: Endpoint,
    pub in_packet: Vec<u8>,
}

pub struct EncryptResult {
    pub endpoint: Endpoint,
    pub our_index: u32,
    pub out_packet: Vec<u8>,
}

pub struct DecryptWork {
    pub transport: AsyncTransportState,
    pub endpoint: Endpoint,
    pub packet: message::Transport,
    pub session_type: SessionType,
}

pub struct DecryptResult {
    pub endpoint: Endpoint,
    pub orig_packet: message::Transport,
    pub out_packet: Vec<u8>,
    pub session_type: SessionType,
}

/// Spawn a thread pool to efficiently process
/// the CPU-intensive encryption/decryption.
pub fn create() -> Sender<Work> {
    let threads            = num_cpus::get() - 2; // One thread for I/O.
    let (sender, receiver) = bounded(4096);

    debug!("spinning up a crypto pool with {} threads", threads);
    for _ in 0..threads {
        let rx = receiver.clone();
        thread::spawn(move || worker(rx.clone()));
    }

    sender
}

fn worker(receiver: Receiver<Work>) {
    while let Some(work) = receiver.recv() {
        match work {
            Work::Decrypt((tx, task, element)) => {
                let mut raw_packet = vec![0u8; element.packet.len()];
                let nonce = element.packet.nonce();
                let len = element.transport.read_transport_message(nonce, element.packet.payload(), &mut raw_packet).unwrap();
                if len > 0 {
                    let len = IpPacket::new(&raw_packet[..len])
                        .ok_or_else(||format_err!("invalid IP packet (len {})", len)).unwrap()
                        .length();
                    raw_packet.truncate(len as usize);
                } else {
                    raw_packet.truncate(0);
                }

                tx.send(DecryptResult {
                    endpoint: element.endpoint,
                    orig_packet: element.packet,
                    out_packet: raw_packet,
                    session_type: element.session_type,
                });

                task.notify();
            },
            Work::Encrypt((tx, task, mut element)) => {
                let padding        = if element.in_packet.len() % PADDING_MULTIPLE != 0 {
                    PADDING_MULTIPLE - (element.in_packet.len() % PADDING_MULTIPLE)
                } else { 0 };
                let padded_len     = element.in_packet.len() + padding;
                let mut out_packet = vec![0u8; padded_len + TRANSPORT_OVERHEAD];

                out_packet[0] = 4;
                LittleEndian::write_u32(&mut out_packet[4..], element.their_index);
                LittleEndian::write_u64(&mut out_packet[8..], element.nonce);

                element.in_packet.resize(padded_len, 0);
                let len = element.transport.write_transport_message(element.nonce,
                    &element.in_packet,
                    &mut out_packet[16..]).unwrap();
                out_packet.truncate(TRANSPORT_HEADER_SIZE + len);

                tx.send(EncryptResult {
                    endpoint: element.endpoint,
                    our_index: element.our_index,
                    out_packet,
                });

                task.notify();
            }
        }
    }
}