summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/main.rs21
-rw-r--r--src/router/tests.rs10
-rw-r--r--src/tests.rs46
-rw-r--r--src/types/bind.rs5
-rw-r--r--src/types/dummy.rs144
-rw-r--r--src/types/endpoint.rs1
-rw-r--r--src/wireguard.rs161
7 files changed, 244 insertions, 144 deletions
diff --git a/src/main.rs b/src/main.rs
index 9b69f54..3c59c67 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -15,25 +15,6 @@ mod types;
mod wireguard;
#[cfg(test)]
-mod tests {
- use crate::types::tun::Tun;
- use crate::types::{bind, dummy, tun};
- use crate::wireguard::Wireguard;
-
- use std::thread;
- use std::time::Duration;
-
- fn init() {
- let _ = env_logger::builder().is_test(true).try_init();
- }
-
- #[test]
- fn test_pure_wireguard() {
- init();
- let (reader, writer, mtu) = dummy::TunTest::create("name").unwrap();
- let wg: Wireguard<dummy::TunTest, dummy::PairBind> = Wireguard::new(reader, writer, mtu);
- thread::sleep(Duration::from_millis(500));
- }
-}
+mod tests;
fn main() {}
diff --git a/src/router/tests.rs b/src/router/tests.rs
index 3b6b941..6c385a8 100644
--- a/src/router/tests.rs
+++ b/src/router/tests.rs
@@ -145,8 +145,8 @@ mod tests {
}
// create device
- let (_reader, tun_writer, _mtu) = dummy::TunTest::create("name").unwrap();
- let router: Device<_, BencherCallbacks, dummy::TunTest, dummy::VoidBind> =
+ let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false);
+ let router: Device< _, BencherCallbacks, dummy::TunWriter, dummy::VoidBind> =
Device::new(num_cpus::get(), tun_writer);
// add new peer
@@ -175,7 +175,7 @@ mod tests {
init();
// create device
- let (_reader, tun_writer, _mtu) = dummy::TunTest::create("name").unwrap();
+ let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false);
let router: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer);
router.set_outbound_writer(dummy::VoidBind::new());
@@ -321,8 +321,8 @@ mod tests {
dummy::PairBind::pair();
// create matching device
- let (tun_writer1, _, _) = dummy::TunTest::create("tun1").unwrap();
- let (tun_writer2, _, _) = dummy::TunTest::create("tun1").unwrap();
+ let (_fake, _, tun_writer1, _) = dummy::TunTest::create(1500, false);
+ let (_fake, _, tun_writer2, _) = dummy::TunTest::create(1500, false);
let router1: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer1);
router1.set_outbound_writer(bind_writer1);
diff --git a/src/tests.rs b/src/tests.rs
new file mode 100644
index 0000000..8e15037
--- /dev/null
+++ b/src/tests.rs
@@ -0,0 +1,46 @@
+use crate::types::tun::Tun;
+use crate::types::{bind, dummy, tun};
+use crate::wireguard::Wireguard;
+
+use std::thread;
+use std::time::Duration;
+
+fn init() {
+ let _ = env_logger::builder().is_test(true).try_init();
+}
+
+/* Create and configure two matching pure instances of WireGuard
+ *
+ */
+#[test]
+fn test_pure_wireguard() {
+ init();
+
+ // create WG instances for fake TUN devices
+
+ let (fake1, tun_reader1, tun_writer1, mtu1) = dummy::TunTest::create(1500, true);
+ let wg1: Wireguard<dummy::TunTest, dummy::PairBind> =
+ Wireguard::new(vec![tun_reader1], tun_writer1, mtu1);
+
+ let (fake2, tun_reader2, tun_writer2, mtu2) = dummy::TunTest::create(1500, true);
+ let wg2: Wireguard<dummy::TunTest, dummy::PairBind> =
+ Wireguard::new(vec![tun_reader2], tun_writer2, mtu2);
+
+ // create pair bind to connect the interfaces "over the internet"
+
+ let ((bind_reader1, bind_writer1), (bind_reader2, bind_writer2)) = dummy::PairBind::pair();
+
+ wg1.set_writer(bind_writer1);
+ wg2.set_writer(bind_writer2);
+
+ wg1.add_reader(bind_reader1);
+ wg2.add_reader(bind_reader2);
+
+ // generate (public, pivate) key pairs
+
+ // configure cryptkey router
+
+ // create IP packets
+
+ thread::sleep(Duration::from_millis(500));
+}
diff --git a/src/types/bind.rs b/src/types/bind.rs
index fcc38c8..3d3f187 100644
--- a/src/types/bind.rs
+++ b/src/types/bind.rs
@@ -20,9 +20,4 @@ pub trait Bind: Send + Sync + 'static {
/* Until Rust gets type equality constraints these have to be generic */
type Writer: Writer<Self::Endpoint>;
type Reader: Reader<Self::Endpoint>;
-
- /* Used to close the reader/writer when binding to a new port */
- type Closer;
-
- fn bind(port: u16) -> Result<(Self::Reader, Self::Writer, Self::Closer, u16), Self::Error>;
}
diff --git a/src/types/dummy.rs b/src/types/dummy.rs
index 40a3bdd..2403c9b 100644
--- a/src/types/dummy.rs
+++ b/src/types/dummy.rs
@@ -1,11 +1,12 @@
use std::error::Error;
use std::fmt;
+use std::marker;
use std::net::SocketAddr;
use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Instant;
-use std::marker;
+use std::sync::atomic::{Ordering, AtomicUsize};
use super::*;
@@ -41,7 +42,9 @@ impl fmt::Display for BindError {
/* TUN implementation */
#[derive(Debug)]
-pub enum TunError {}
+pub enum TunError {
+ Disconnected
+}
impl Error for TunError {
fn description(&self) -> &str {
@@ -68,54 +71,111 @@ impl Endpoint for UnitEndpoint {
fn from_address(_: SocketAddr) -> UnitEndpoint {
UnitEndpoint {}
}
+
fn into_address(&self) -> SocketAddr {
"127.0.0.1:8080".parse().unwrap()
}
+
+ fn clear_src(&self) {}
}
impl UnitEndpoint {
pub fn new() -> UnitEndpoint {
- UnitEndpoint{}
+ UnitEndpoint {}
}
}
/* */
-#[derive(Clone, Copy)]
pub struct TunTest {}
-impl tun::Reader for TunTest {
- type Error = TunError;
+pub struct TunFakeIO {
+ store: bool,
+ tx: SyncSender<Vec<u8>>,
+ rx: Receiver<Vec<u8>>
+}
- fn read(&self, _buf: &mut [u8], _offset: usize) -> Result<usize, Self::Error> {
- Ok(0)
- }
+pub struct TunReader {
+ rx: Receiver<Vec<u8>>
}
-impl tun::MTU for TunTest {
- fn mtu(&self) -> usize {
- 1500
+pub struct TunWriter {
+ store: bool,
+ tx: Mutex<SyncSender<Vec<u8>>>
+}
+
+#[derive(Clone)]
+pub struct TunMTU {
+ mtu: Arc<AtomicUsize>
+}
+
+impl tun::Reader for TunReader {
+ type Error = TunError;
+
+ fn read(&self, buf: &mut [u8], offset: usize) -> Result<usize, Self::Error> {
+ match self.rx.recv() {
+ Ok(m) => {
+ buf[offset..].copy_from_slice(&m[..]);
+ Ok(m.len())
+ }
+ Err(_) => Err(TunError::Disconnected)
+ }
}
}
-impl tun::Writer for TunTest {
+impl tun::Writer for TunWriter {
type Error = TunError;
- fn write(&self, _src: &[u8]) -> Result<(), Self::Error> {
- Ok(())
+ fn write(&self, src: &[u8]) -> Result<(), Self::Error> {
+ if self.store {
+ let m = src.to_owned();
+ match self.tx.lock().unwrap().send(m) {
+ Ok(_) => Ok(()),
+ Err(_) => Err(TunError::Disconnected)
+ }
+ } else {
+ Ok(())
+ }
+ }
+}
+
+impl tun::MTU for TunMTU {
+ fn mtu(&self) -> usize {
+ self.mtu.load(Ordering::Acquire)
}
}
impl tun::Tun for TunTest {
- type Writer = TunTest;
- type Reader = TunTest;
- type MTU = TunTest;
+ type Writer = TunWriter;
+ type Reader = TunReader;
+ type MTU = TunMTU;
type Error = TunError;
}
+impl TunFakeIO {
+ pub fn write(&self, msg : Vec<u8>) {
+ if self.store {
+ self.tx.send(msg).unwrap();
+ }
+ }
+
+ pub fn read(&self) -> Vec<u8> {
+ self.rx.recv().unwrap()
+ }
+}
+
impl TunTest {
- pub fn create(_name: &str) -> Result<(TunTest, TunTest, TunTest), TunError> {
- Ok((TunTest {},TunTest {}, TunTest{}))
+ pub fn create(mtu : usize, store: bool) -> (TunFakeIO, TunReader, TunWriter, TunMTU) {
+
+ let (tx1, rx1) = if store { sync_channel(32) } else { sync_channel(1) };
+ let (tx2, rx2) = if store { sync_channel(32) } else { sync_channel(1) };
+
+ let fake = TunFakeIO{tx: tx1, rx: rx2, store};
+ let reader = TunReader{rx : rx1};
+ let writer = TunWriter{tx : Mutex::new(tx2), store};
+ let mtu = TunMTU{mtu : Arc::new(AtomicUsize::new(mtu))};
+
+ (fake, reader, writer, mtu)
}
}
@@ -146,16 +206,11 @@ impl bind::Bind for VoidBind {
type Reader = VoidBind;
type Writer = VoidBind;
- type Closer = ();
-
- fn bind(_ : u16) -> Result<(Self::Reader, Self::Writer, Self::Closer, u16), Self::Error> {
- Ok((VoidBind{}, VoidBind{}, (), 2600))
- }
}
impl VoidBind {
pub fn new() -> VoidBind {
- VoidBind{}
+ VoidBind {}
}
}
@@ -203,45 +258,42 @@ pub struct PairWriter<E> {
pub struct PairBind {}
impl PairBind {
- pub fn pair<E>() -> ((PairReader<E>, PairWriter<E>), (PairReader<E>, PairWriter<E>)) {
+ pub fn pair<E>() -> (
+ (PairReader<E>, PairWriter<E>),
+ (PairReader<E>, PairWriter<E>),
+ ) {
let (tx1, rx1) = sync_channel(128);
let (tx2, rx2) = sync_channel(128);
(
(
- PairReader{
-
- recv: Arc::new(Mutex::new(rx1)),
- _marker: marker::PhantomData
- },
- PairWriter{
+ PairReader {
+ recv: Arc::new(Mutex::new(rx1)),
+ _marker: marker::PhantomData,
+ },
+ PairWriter {
send: Arc::new(Mutex::new(tx2)),
- _marker: marker::PhantomData
- }
+ _marker: marker::PhantomData,
+ },
),
(
- PairReader{
+ PairReader {
recv: Arc::new(Mutex::new(rx2)),
- _marker: marker::PhantomData
- },
- PairWriter{
+ _marker: marker::PhantomData,
+ },
+ PairWriter {
send: Arc::new(Mutex::new(tx1)),
- _marker: marker::PhantomData
- }
+ _marker: marker::PhantomData,
+ },
),
)
}
}
impl bind::Bind for PairBind {
- type Closer = ();
type Error = BindError;
type Endpoint = UnitEndpoint;
type Reader = PairReader<Self::Endpoint>;
type Writer = PairWriter<Self::Endpoint>;
-
- fn bind(_port: u16) -> Result<(Self::Reader, Self::Writer, Self::Closer, u16), Self::Error> {
- Err(BindError::Disconnected)
- }
}
pub fn keypair(initiator: bool) -> KeyPair {
diff --git a/src/types/endpoint.rs b/src/types/endpoint.rs
index 74796aa..f4f93da 100644
--- a/src/types/endpoint.rs
+++ b/src/types/endpoint.rs
@@ -3,4 +3,5 @@ use std::net::SocketAddr;
pub trait Endpoint: Send + 'static {
fn from_address(addr: SocketAddr) -> Self;
fn into_address(&self) -> SocketAddr;
+ fn clear_src(&self);
}
diff --git a/src/wireguard.rs b/src/wireguard.rs
index bcb8592..f14a053 100644
--- a/src/wireguard.rs
+++ b/src/wireguard.rs
@@ -3,8 +3,10 @@ use crate::handshake;
use crate::router;
use crate::timers::{Events, Timers};
+use crate::types::bind::Reader as BindReader;
use crate::types::bind::{Bind, Writer};
use crate::types::tun::{Reader, Tun, MTU};
+
use crate::types::Endpoint;
use hjul::Runner;
@@ -53,7 +55,7 @@ pub struct PeerInner<B: Bind> {
pub timers: RwLock<Timers>, //
}
-impl <B:Bind > PeerInner<B> {
+impl<B: Bind> PeerInner<B> {
#[inline(always)]
pub fn timers(&self) -> RwLockReadGuard<Timers> {
self.timers.read()
@@ -153,7 +155,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
}
pub fn get_sk(&self) -> Option<StaticSecret> {
- let mut handshake = self.state.handshake.read();
+ let handshake = self.state.handshake.read();
if handshake.active {
Some(handshake.device.get_sk())
} else {
@@ -184,66 +186,73 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
peer
}
- pub fn new_bind(reader: B::Reader, writer: B::Writer, closer: B::Closer) {
-
- // drop existing closer
-
- // swap IO thread for new reader
-
- // start UDP read IO thread
-
- /*
- {
- let wg = wg.clone();
- let mtu = mtu.clone();
- thread::spawn(move || {
- let mut last_under_load =
- Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
-
- loop {
- // create vector big enough for any message given current MTU
- let size = mtu.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE;
- let mut msg: Vec<u8> = Vec::with_capacity(size);
- msg.resize(size, 0);
-
- // read UDP packet into vector
- let (size, src) = reader.read(&mut msg).unwrap(); // TODO handle error
- msg.truncate(size);
+ /* Begin consuming messages from the reader.
+ *
+ * Any previous reader thread is stopped by closing the previous reader,
+ * which unblocks the thread and causes an error on reader.read
+ */
+ pub fn add_reader(&self, reader: B::Reader) {
+ let wg = self.state.clone();
+ thread::spawn(move || {
+ let mut last_under_load =
+ Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
+
+ loop {
+ // create vector big enough for any message given current MTU
+ let size = wg.mtu.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE;
+ let mut msg: Vec<u8> = Vec::with_capacity(size);
+ msg.resize(size, 0);
- // message type de-multiplexer
- if msg.len() < std::mem::size_of::<u32>() {
- continue;
+ // read UDP packet into vector
+ let (size, src) = match reader.read(&mut msg) {
+ Err(e) => {
+ debug!("Bind reader closed with {}", e);
+ return;
}
- match LittleEndian::read_u32(&msg[..]) {
- handshake::TYPE_COOKIE_REPLY
- | handshake::TYPE_INITIATION
- | handshake::TYPE_RESPONSE => {
- // update under_load flag
- if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD {
- last_under_load = Instant::now();
- wg.under_load.store(true, Ordering::SeqCst);
- } else if last_under_load.elapsed() > DURATION_UNDER_LOAD {
- wg.under_load.store(false, Ordering::SeqCst);
- }
+ Ok(v) => v,
+ };
+ msg.truncate(size);
- wg.queue
- .lock()
- .send(HandshakeJob::Message(msg, src))
- .unwrap();
- }
- router::TYPE_TRANSPORT => {
- // transport message
- let _ = wg.router.recv(src, msg);
+ // message type de-multiplexer
+ if msg.len() < std::mem::size_of::<u32>() {
+ continue;
+ }
+ match LittleEndian::read_u32(&msg[..]) {
+ handshake::TYPE_COOKIE_REPLY
+ | handshake::TYPE_INITIATION
+ | handshake::TYPE_RESPONSE => {
+ // update under_load flag
+ if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD {
+ last_under_load = Instant::now();
+ wg.under_load.store(true, Ordering::SeqCst);
+ } else if last_under_load.elapsed() > DURATION_UNDER_LOAD {
+ wg.under_load.store(false, Ordering::SeqCst);
}
- _ => (),
+
+ wg.queue
+ .lock()
+ .send(HandshakeJob::Message(msg, src))
+ .unwrap();
}
+ router::TYPE_TRANSPORT => {
+ // transport message
+ let _ = wg.router.recv(src, msg).map_err(|e| {
+ debug!("Failed to handle incoming transport message: {}", e);
+ });
+ }
+ _ => (),
}
- });
- }
- */
+ }
+ });
}
- pub fn new(reader: T::Reader, writer: T::Writer, mtu: T::MTU) -> Wireguard<T, B> {
+ pub fn set_writer(&self, writer: B::Writer) {
+ // TODO: Consider unifying these and avoid Clone requirement on writer
+ *self.state.send.write() = Some(writer.clone());
+ self.state.router.set_outbound_writer(writer);
+ }
+
+ pub fn new(mut readers: Vec<T::Reader>, writer: T::Writer, mtu: T::MTU) -> Wireguard<T, B> {
// create device state
let mut rng = OsRng::new().unwrap();
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
@@ -292,14 +301,16 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
None
},
) {
- Ok((pk, msg, keypair)) => {
+ Ok((pk, resp, keypair)) => {
// send response
- if let Some(msg) = msg {
+ let mut resp_len: u64 = 0;
+ if let Some(msg) = resp {
+ resp_len = msg.len() as u64;
let send: &Option<B::Writer> = &*wg.send.read();
if let Some(writer) = send.as_ref() {
let _ = writer.write(&msg[..], &src).map_err(|e| {
debug!(
- "handshake worker, failed to send response, error = {:?}",
+ "handshake worker, failed to send response, error = {}",
e
)
});
@@ -308,16 +319,23 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
// update timers
if let Some(pk) = pk {
+ // authenticated handshake packet received
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
+ // add to rx_bytes and tx_bytes
+ let req_len = msg.len() as u64;
+ peer.rx_bytes.fetch_add(req_len, Ordering::Relaxed);
+ peer.tx_bytes.fetch_add(resp_len, Ordering::Relaxed);
+
// update endpoint
peer.router.set_endpoint(src);
- // add keypair to peer and free any unused ids
- if let Some(keypair) = keypair {
- for id in peer.router.add_keypair(keypair) {
+ // add keypair to peer
+ keypair.map(|kp| {
+ // free any unused ids
+ for id in peer.router.add_keypair(kp) {
state.device.release(id);
}
- }
+ });
}
}
}
@@ -325,20 +343,27 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
}
}
HandshakeJob::New(pk) => {
- let msg = state.device.begin(&mut rng, &pk).unwrap(); // TODO handle
- if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
- peer.router.send(&msg[..]);
- peer.timers.read().handshake_sent();
- }
+ let _ = state.device.begin(&mut rng, &pk).map(|msg| {
+ if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
+ let _ = peer.router.send(&msg[..]).map_err(|e| {
+ debug!("handshake worker, failed to send handshake initiation, error = {}", e)
+ });
+ }
+ });
}
}
}
});
}
- // start TUN read IO thread
- {
+ // start TUN read IO threads (multiple threads to support multi-queue interfaces)
+ debug_assert!(
+ readers.len() > 0,
+ "attempted to create WG device without TUN readers"
+ );
+ while let Some(reader) = readers.pop() {
let wg = wg.clone();
+ let mtu = mtu.clone();
thread::spawn(move || loop {
// create vector big enough for any transport message (based on MTU)
let mtu = mtu.mtu();