aboutsummaryrefslogtreecommitdiffstats
path: root/src/wireguard.rs
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-10-09 15:08:26 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-10-09 15:08:26 +0200
commit761c46064d7510303f08cde27c9e13b07293f3af (patch)
tree7b914169725952e557223972b3f0b611c54e6829 /src/wireguard.rs
parentRestructure dummy implementations (diff)
downloadwireguard-rs-761c46064d7510303f08cde27c9e13b07293f3af.tar.xz
wireguard-rs-761c46064d7510303f08cde27c9e13b07293f3af.zip
Restructure IO traits.
Diffstat (limited to 'src/wireguard.rs')
-rw-r--r--src/wireguard.rs200
1 files changed, 128 insertions, 72 deletions
diff --git a/src/wireguard.rs b/src/wireguard.rs
index ea600d0..ba81f47 100644
--- a/src/wireguard.rs
+++ b/src/wireguard.rs
@@ -2,11 +2,13 @@ use crate::constants::*;
use crate::handshake;
use crate::router;
use crate::timers::{Events, Timers};
-use crate::types::{Bind, Endpoint, Tun};
+
+use crate::types::Endpoint;
+use crate::types::tun::{Tun, Reader, MTU};
+use crate::types::bind::{Bind, Writer};
use hjul::Runner;
-use std::cmp;
use std::ops::Deref;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
@@ -27,12 +29,20 @@ const SIZE_HANDSHAKE_QUEUE: usize = 128;
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000);
-#[derive(Clone)]
pub struct Peer<T: Tun, B: Bind> {
- pub router: Arc<router::Peer<Events<T, B>, T, B>>,
+ pub router: Arc<router::Peer<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>,
pub state: Arc<PeerInner<B>>,
}
+impl <T : Tun, B : Bind> Clone for Peer<T, B > {
+ fn clone(&self) -> Peer<T, B> {
+ Peer{
+ router: self.router.clone(),
+ state: self.state.clone()
+ }
+ }
+}
+
pub struct PeerInner<B: Bind> {
pub keepalive: AtomicUsize, // keepalive interval
pub rx_bytes: AtomicU64,
@@ -66,20 +76,22 @@ pub enum HandshakeJob<E> {
}
struct WireguardInner<T: Tun, B: Bind> {
+ // provides access to the MTU value of the tun device
+ // (otherwise owned solely by the router and a dedicated read IO thread)
+ mtu: T::MTU,
+ send: RwLock<Option<B::Writer>>,
+
// identify and configuration map
peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>,
// cryptkey router
- router: router::Device<Events<T, B>, T, B>,
+ router: router::Device<B::Endpoint, Events<T, B>, T::Writer, B::Writer>,
// handshake related state
handshake: RwLock<Handshake>,
under_load: AtomicBool,
pending: AtomicUsize, // num of pending handshake packets in queue
queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>,
-
- // IO
- bind: B,
}
pub struct Wireguard<T: Tun, B: Bind> {
@@ -87,6 +99,17 @@ pub struct Wireguard<T: Tun, B: Bind> {
state: Arc<WireguardInner<T, B>>,
}
+/* Returns the padded length of a message:
+ *
+ * # Arguments
+ *
+ * - `size` : Size of unpadded message
+ * - `mtu` : Maximum transmission unit of the device
+ *
+ * # Returns
+ *
+ * The padded length (always less than or equal to the MTU)
+ */
#[inline(always)]
const fn padding(size: usize, mtu: usize) -> usize {
#[inline(always)]
@@ -114,6 +137,15 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
}
}
+ pub fn get_sk(&self) -> Option<StaticSecret> {
+ let mut handshake = self.state.handshake.read();
+ if handshake.active {
+ Some(handshake.device.get_sk())
+ } else {
+ None
+ }
+ }
+
pub fn new_peer(&self, pk: PublicKey) -> Peer<T, B> {
let state = Arc::new(PeerInner {
pk,
@@ -137,20 +169,92 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
peer
}
- pub fn new(tun: T, bind: B) -> Wireguard<T, B> {
+ pub fn new_bind(
+ reader: B::Reader,
+ writer: B::Writer,
+ closer: B::Closer
+ ) {
+
+ // drop existing closer
+
+
+ // swap IO thread for new reader
+
+
+ // start UDP read IO thread
+
+ /*
+ {
+ let wg = wg.clone();
+ let mtu = mtu.clone();
+ thread::spawn(move || {
+ let mut last_under_load =
+ Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
+
+ loop {
+ // create vector big enough for any message given current MTU
+ let size = mtu.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE;
+ let mut msg: Vec<u8> = Vec::with_capacity(size);
+ msg.resize(size, 0);
+
+ // read UDP packet into vector
+ let (size, src) = reader.read(&mut msg).unwrap(); // TODO handle error
+ msg.truncate(size);
+
+ // message type de-multiplexer
+ if msg.len() < std::mem::size_of::<u32>() {
+ continue;
+ }
+ match LittleEndian::read_u32(&msg[..]) {
+ handshake::TYPE_COOKIE_REPLY
+ | handshake::TYPE_INITIATION
+ | handshake::TYPE_RESPONSE => {
+ // update under_load flag
+ if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD {
+ last_under_load = Instant::now();
+ wg.under_load.store(true, Ordering::SeqCst);
+ } else if last_under_load.elapsed() > DURATION_UNDER_LOAD {
+ wg.under_load.store(false, Ordering::SeqCst);
+ }
+
+ wg.queue
+ .lock()
+ .send(HandshakeJob::Message(msg, src))
+ .unwrap();
+ }
+ router::TYPE_TRANSPORT => {
+ // transport message
+ let _ = wg.router.recv(src, msg);
+ }
+ _ => (),
+ }
+ }
+ });
+ }
+ */
+
+
+ }
+
+ pub fn new(
+ reader: T::Reader,
+ writer: T::Writer,
+ mtu: T::MTU,
+ ) -> Wireguard<T, B> {
// create device state
let mut rng = OsRng::new().unwrap();
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
let wg = Arc::new(WireguardInner {
+ mtu: mtu.clone(),
peers: RwLock::new(HashMap::new()),
- router: router::Device::new(num_cpus::get(), tun.clone(), bind.clone()),
+ send: RwLock::new(None),
+ router: router::Device::new(num_cpus::get(), writer), // router owns the writing half
pending: AtomicUsize::new(0),
handshake: RwLock::new(Handshake {
device: handshake::Device::new(StaticSecret::new(&mut rng)),
active: false,
}),
under_load: AtomicBool::new(false),
- bind: bind.clone(),
queue: Mutex::new(tx),
});
@@ -158,7 +262,6 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
for _ in 0..num_cpus::get() {
let wg = wg.clone();
let rx = rx.clone();
- let bind = bind.clone();
thread::spawn(move || {
// prepare OsRng instance for this thread
let mut rng = OsRng::new().unwrap();
@@ -189,19 +292,22 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
Ok((pk, msg, keypair)) => {
// send response
if let Some(msg) = msg {
- let _ = bind.send(&msg[..], &src).map_err(|e| {
- debug!(
- "handshake worker, failed to send response, error = {:?}",
- e
- )
- });
+ let send : &Option<B::Writer> = &*wg.send.read();
+ if let Some(writer) = send.as_ref() {
+ let _ = writer.write(&msg[..], &src).map_err(|e| {
+ debug!(
+ "handshake worker, failed to send response, error = {:?}",
+ e
+ )
+ });
+ }
}
// update timers
if let Some(pk) = pk {
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
- // update endpoint (DISCUSS: right semantics?)
- peer.router.set_endpoint(src_validate);
+ // update endpoint
+ peer.router.set_endpoint(src);
// add keypair to peer and free any unused ids
if let Some(keypair) = keypair {
@@ -227,68 +333,18 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
});
}
- // start UDP read IO thread
- {
- let wg = wg.clone();
- let tun = tun.clone();
- let bind = bind.clone();
- thread::spawn(move || {
- let mut last_under_load =
- Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
-
- loop {
- // create vector big enough for any message given current MTU
- let size = tun.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE;
- let mut msg: Vec<u8> = Vec::with_capacity(size);
- msg.resize(size, 0);
-
- // read UDP packet into vector
- let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error
- msg.truncate(size);
-
- // message type de-multiplexer
- if msg.len() < std::mem::size_of::<u32>() {
- continue;
- }
- match LittleEndian::read_u32(&msg[..]) {
- handshake::TYPE_COOKIE_REPLY
- | handshake::TYPE_INITIATION
- | handshake::TYPE_RESPONSE => {
- // update under_load flag
- if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD {
- last_under_load = Instant::now();
- wg.under_load.store(true, Ordering::SeqCst);
- } else if last_under_load.elapsed() > DURATION_UNDER_LOAD {
- wg.under_load.store(false, Ordering::SeqCst);
- }
-
- wg.queue
- .lock()
- .send(HandshakeJob::Message(msg, src))
- .unwrap();
- }
- router::TYPE_TRANSPORT => {
- // transport message
- let _ = wg.router.recv(src, msg);
- }
- _ => (),
- }
- }
- });
- }
-
// start TUN read IO thread
{
let wg = wg.clone();
thread::spawn(move || loop {
// create vector big enough for any transport message (based on MTU)
- let mtu = tun.mtu();
+ let mtu = mtu.mtu();
let size = mtu + router::SIZE_MESSAGE_PREFIX;
let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
msg.resize(size, 0);
// read a new IP packet
- let payload = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
+ let payload = reader.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu);
// truncate padding