aboutsummaryrefslogtreecommitdiffstats
path: root/src/wireguard.rs
blob: 71b981edd9e5993104297abe1be58f5df269cd2b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
use crate::handshake;
use crate::router;
use crate::types::{Bind, Tun};

use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::mpsc::sync_channel;
use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant};

use byteorder::{ByteOrder, LittleEndian};
use crossbeam_channel::bounded;
use x25519_dalek::StaticSecret;

const SIZE_HANDSHAKE_QUEUE: usize = 128;
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000);

pub struct Timers {}

pub struct Events();

impl router::Callbacks for Events {
    type Opaque = Timers;

    fn send(t: &Timers, size: usize, data: bool, sent: bool) {}

    fn recv(t: &Timers, size: usize, data: bool, sent: bool) {}

    fn need_key(t: &Timers) {}
}

pub struct Wireguard<T: Tun, B: Bind> {
    router: Arc<router::Device<Events, T, B>>,
    handshake: Option<Arc<handshake::Device<()>>>,
}

impl<T: Tun, B: Bind> Wireguard<T, B> {
    fn start(&self) {}

    fn new(tun: T, bind: B) -> Wireguard<T, B> {
        let router = Arc::new(router::Device::new(
            num_cpus::get(),
            tun.clone(),
            bind.clone(),
        ));

        let handshake_staged = Arc::new(AtomicUsize::new(0));

        // start UDP read IO thread
        let (handshake_tx, handshake_rx) = bounded(128);
        {
            let tun = tun.clone();
            thread::spawn(move || {
                let mut under_load =
                    Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);

                loop {
                    // read UDP packet into vector
                    let size = tun.mtu() + 148; // maximum message size
                    let mut msg: Vec<u8> =
                        Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
                    msg.resize(size, 0);
                    let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error
                    msg.truncate(size);

                    // message type de-multiplexer
                    if msg.len() < std::mem::size_of::<u32>() {
                        continue;
                    }

                    match LittleEndian::read_u32(&msg[..]) {
                        handshake::TYPE_COOKIE_REPLY
                        | handshake::TYPE_INITIATION
                        | handshake::TYPE_RESPONSE => {
                            // detect if under load
                            if handshake_staged.fetch_add(1, Ordering::SeqCst)
                                > THRESHOLD_UNDER_LOAD
                            {
                                under_load = Instant::now()
                            }

                            // pass source address along if under load
                            if under_load.elapsed() < DURATION_UNDER_LOAD {
                                handshake_tx.send((msg, Some(src))).unwrap();
                            } else {
                                handshake_tx.send((msg, None)).unwrap();
                            }
                        }
                        router::TYPE_TRANSPORT => {
                            // transport message
                        }
                        _ => (),
                    }
                }
            });
        }

        // start handshake workers
        for _ in 0..num_cpus::get() {
            let handshake_rx = handshake_rx.clone();
            thread::spawn(move || loop {
                let (msg, src) = handshake_rx.recv().unwrap(); // TODO handle error
            });
        }

        // start TUN read IO thread

        thread::spawn(move || {});

        Wireguard {
            router,
            handshake: None,
        }
    }
}