aboutsummaryrefslogtreecommitdiffstats
path: root/src/crypto_pool.rs
blob: e0756e9eaef72709684369201cff7accae65b130 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
use consts::{PADDING_MULTIPLE, TRANSPORT_OVERHEAD, TRANSPORT_HEADER_SIZE};
use crossbeam;
use crossbeam_channel::{unbounded, Receiver, Sender};
use futures::sync::mpsc;
use futures::executor;
use futures::Sink;
use num_cpus;
use snow::AsyncTransportState;
use std::thread;
use udp::Endpoint;
use message;
use peer::SessionType;
use ip_packet::IpPacket;
use byteorder::{ByteOrder, LittleEndian};

pub enum Work {
    Decrypt((mpsc::UnboundedSender<DecryptResult>, DecryptWork)),
    Encrypt((mpsc::UnboundedSender<EncryptResult>, EncryptWork)),
}

pub struct EncryptWork {
    pub transport: AsyncTransportState,
    pub nonce: u64,
    pub our_index: u32,
    pub their_index: u32,
    pub endpoint: Endpoint,
    pub in_packet: Vec<u8>,
}

pub struct EncryptResult {
    pub endpoint: Endpoint,
    pub our_index: u32,
    pub out_packet: Vec<u8>,
}

pub struct DecryptWork {
    pub transport: AsyncTransportState,
    pub endpoint: Endpoint,
    pub packet: message::Transport,
    pub session_type: SessionType,
}

pub struct DecryptResult {
    pub endpoint: Endpoint,
    pub orig_packet: message::Transport,
    pub out_packet: Vec<u8>,
    pub session_type: SessionType,
}

/// Spawn a thread pool to efficiently process
/// the CPU-intensive encryption/decryption.
pub fn create() -> Sender<Work> {
    let threads            = num_cpus::get() - 1; // One thread for I/O.
    let (sender, receiver) = unbounded();

    crossbeam::scope(|s| {
        for i in 0..threads {
            let rx = receiver.clone();
            s.spawn(move || worker(rx.clone()));
        }
    });

    sender
}

fn worker(receiver: Receiver<Work>) {
    select_loop! {
        recv(receiver, work) => {
            match work {
                Work::Decrypt((tx, element)) => {
                    let mut raw_packet = vec![0u8; element.packet.len()];
                    let nonce = element.packet.nonce();
                    let len = element.transport.read_transport_message(nonce, element.packet.payload(), &mut raw_packet).unwrap();
                    if len > 0 {
                        let len = IpPacket::new(&raw_packet[..len])
                            .ok_or_else(||format_err!("invalid IP packet (len {})", len)).unwrap()
                            .length();
                        raw_packet.truncate(len as usize);
                    } else {
                        raw_packet.truncate(0);
                    }

                    tx.unbounded_send(DecryptResult {
                        endpoint: element.endpoint,
                        orig_packet: element.packet,
                        out_packet: raw_packet,
                        session_type: element.session_type,
                    }).unwrap();
                },
                Work::Encrypt((tx, mut element)) => {
                    let padding        = if element.in_packet.len() % PADDING_MULTIPLE != 0 {
                        PADDING_MULTIPLE - (element.in_packet.len() % PADDING_MULTIPLE)
                    } else { 0 };
                    let padded_len     = element.in_packet.len() + padding;
                    let mut out_packet = vec![0u8; padded_len + TRANSPORT_OVERHEAD];

                    out_packet[0] = 4;
                    LittleEndian::write_u32(&mut out_packet[4..], element.their_index);
                    LittleEndian::write_u64(&mut out_packet[8..], element.nonce);

                    element.in_packet.resize(padded_len, 0);
                    let len = element.transport.write_transport_message(element.nonce,
                        &element.in_packet,
                        &mut out_packet[16..]).unwrap();
                    out_packet.truncate(TRANSPORT_HEADER_SIZE + len);

                    tx.unbounded_send(EncryptResult {
                        endpoint: element.endpoint,
                        our_index: element.our_index,
                        out_packet,
                    }).unwrap();
                }
            }
        }
    }
}