From 3d6e8f08a7408a3f68b7917ae4ff4ea804c36d00 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Fri, 11 Oct 2019 12:57:24 +0200 Subject: Enable adding TUN reader to WG interface --- src/main.rs | 21 +------ src/router/tests.rs | 10 ++-- src/tests.rs | 46 +++++++++++++++ src/types/bind.rs | 5 -- src/types/dummy.rs | 144 +++++++++++++++++++++++++++++--------------- src/types/endpoint.rs | 1 + src/wireguard.rs | 161 +++++++++++++++++++++++++++++--------------------- 7 files changed, 244 insertions(+), 144 deletions(-) create mode 100644 src/tests.rs diff --git a/src/main.rs b/src/main.rs index 9b69f54..3c59c67 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,25 +15,6 @@ mod types; mod wireguard; #[cfg(test)] -mod tests { - use crate::types::tun::Tun; - use crate::types::{bind, dummy, tun}; - use crate::wireguard::Wireguard; - - use std::thread; - use std::time::Duration; - - fn init() { - let _ = env_logger::builder().is_test(true).try_init(); - } - - #[test] - fn test_pure_wireguard() { - init(); - let (reader, writer, mtu) = dummy::TunTest::create("name").unwrap(); - let wg: Wireguard = Wireguard::new(reader, writer, mtu); - thread::sleep(Duration::from_millis(500)); - } -} +mod tests; fn main() {} diff --git a/src/router/tests.rs b/src/router/tests.rs index 3b6b941..6c385a8 100644 --- a/src/router/tests.rs +++ b/src/router/tests.rs @@ -145,8 +145,8 @@ mod tests { } // create device - let (_reader, tun_writer, _mtu) = dummy::TunTest::create("name").unwrap(); - let router: Device<_, BencherCallbacks, dummy::TunTest, dummy::VoidBind> = + let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false); + let router: Device< _, BencherCallbacks, dummy::TunWriter, dummy::VoidBind> = Device::new(num_cpus::get(), tun_writer); // add new peer @@ -175,7 +175,7 @@ mod tests { init(); // create device - let (_reader, tun_writer, _mtu) = dummy::TunTest::create("name").unwrap(); + let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(1500, false); let router: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer); router.set_outbound_writer(dummy::VoidBind::new()); @@ -321,8 +321,8 @@ mod tests { dummy::PairBind::pair(); // create matching device - let (tun_writer1, _, _) = dummy::TunTest::create("tun1").unwrap(); - let (tun_writer2, _, _) = dummy::TunTest::create("tun1").unwrap(); + let (_fake, _, tun_writer1, _) = dummy::TunTest::create(1500, false); + let (_fake, _, tun_writer2, _) = dummy::TunTest::create(1500, false); let router1: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer1); router1.set_outbound_writer(bind_writer1); diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 0000000..8e15037 --- /dev/null +++ b/src/tests.rs @@ -0,0 +1,46 @@ +use crate::types::tun::Tun; +use crate::types::{bind, dummy, tun}; +use crate::wireguard::Wireguard; + +use std::thread; +use std::time::Duration; + +fn init() { + let _ = env_logger::builder().is_test(true).try_init(); +} + +/* Create and configure two matching pure instances of WireGuard + * + */ +#[test] +fn test_pure_wireguard() { + init(); + + // create WG instances for fake TUN devices + + let (fake1, tun_reader1, tun_writer1, mtu1) = dummy::TunTest::create(1500, true); + let wg1: Wireguard = + Wireguard::new(vec![tun_reader1], tun_writer1, mtu1); + + let (fake2, tun_reader2, tun_writer2, mtu2) = dummy::TunTest::create(1500, true); + let wg2: Wireguard = + Wireguard::new(vec![tun_reader2], tun_writer2, mtu2); + + // create pair bind to connect the interfaces "over the internet" + + let ((bind_reader1, bind_writer1), (bind_reader2, bind_writer2)) = dummy::PairBind::pair(); + + wg1.set_writer(bind_writer1); + wg2.set_writer(bind_writer2); + + wg1.add_reader(bind_reader1); + wg2.add_reader(bind_reader2); + + // generate (public, pivate) key pairs + + // configure cryptkey router + + // create IP packets + + thread::sleep(Duration::from_millis(500)); +} diff --git a/src/types/bind.rs b/src/types/bind.rs index fcc38c8..3d3f187 100644 --- a/src/types/bind.rs +++ b/src/types/bind.rs @@ -20,9 +20,4 @@ pub trait Bind: Send + Sync + 'static { /* Until Rust gets type equality constraints these have to be generic */ type Writer: Writer; type Reader: Reader; - - /* Used to close the reader/writer when binding to a new port */ - type Closer; - - fn bind(port: u16) -> Result<(Self::Reader, Self::Writer, Self::Closer, u16), Self::Error>; } diff --git a/src/types/dummy.rs b/src/types/dummy.rs index 40a3bdd..2403c9b 100644 --- a/src/types/dummy.rs +++ b/src/types/dummy.rs @@ -1,11 +1,12 @@ use std::error::Error; use std::fmt; +use std::marker; use std::net::SocketAddr; use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; use std::sync::Arc; use std::sync::Mutex; use std::time::Instant; -use std::marker; +use std::sync::atomic::{Ordering, AtomicUsize}; use super::*; @@ -41,7 +42,9 @@ impl fmt::Display for BindError { /* TUN implementation */ #[derive(Debug)] -pub enum TunError {} +pub enum TunError { + Disconnected +} impl Error for TunError { fn description(&self) -> &str { @@ -68,54 +71,111 @@ impl Endpoint for UnitEndpoint { fn from_address(_: SocketAddr) -> UnitEndpoint { UnitEndpoint {} } + fn into_address(&self) -> SocketAddr { "127.0.0.1:8080".parse().unwrap() } + + fn clear_src(&self) {} } impl UnitEndpoint { pub fn new() -> UnitEndpoint { - UnitEndpoint{} + UnitEndpoint {} } } /* */ -#[derive(Clone, Copy)] pub struct TunTest {} -impl tun::Reader for TunTest { - type Error = TunError; +pub struct TunFakeIO { + store: bool, + tx: SyncSender>, + rx: Receiver> +} - fn read(&self, _buf: &mut [u8], _offset: usize) -> Result { - Ok(0) - } +pub struct TunReader { + rx: Receiver> } -impl tun::MTU for TunTest { - fn mtu(&self) -> usize { - 1500 +pub struct TunWriter { + store: bool, + tx: Mutex>> +} + +#[derive(Clone)] +pub struct TunMTU { + mtu: Arc +} + +impl tun::Reader for TunReader { + type Error = TunError; + + fn read(&self, buf: &mut [u8], offset: usize) -> Result { + match self.rx.recv() { + Ok(m) => { + buf[offset..].copy_from_slice(&m[..]); + Ok(m.len()) + } + Err(_) => Err(TunError::Disconnected) + } } } -impl tun::Writer for TunTest { +impl tun::Writer for TunWriter { type Error = TunError; - fn write(&self, _src: &[u8]) -> Result<(), Self::Error> { - Ok(()) + fn write(&self, src: &[u8]) -> Result<(), Self::Error> { + if self.store { + let m = src.to_owned(); + match self.tx.lock().unwrap().send(m) { + Ok(_) => Ok(()), + Err(_) => Err(TunError::Disconnected) + } + } else { + Ok(()) + } + } +} + +impl tun::MTU for TunMTU { + fn mtu(&self) -> usize { + self.mtu.load(Ordering::Acquire) } } impl tun::Tun for TunTest { - type Writer = TunTest; - type Reader = TunTest; - type MTU = TunTest; + type Writer = TunWriter; + type Reader = TunReader; + type MTU = TunMTU; type Error = TunError; } +impl TunFakeIO { + pub fn write(&self, msg : Vec) { + if self.store { + self.tx.send(msg).unwrap(); + } + } + + pub fn read(&self) -> Vec { + self.rx.recv().unwrap() + } +} + impl TunTest { - pub fn create(_name: &str) -> Result<(TunTest, TunTest, TunTest), TunError> { - Ok((TunTest {},TunTest {}, TunTest{})) + pub fn create(mtu : usize, store: bool) -> (TunFakeIO, TunReader, TunWriter, TunMTU) { + + let (tx1, rx1) = if store { sync_channel(32) } else { sync_channel(1) }; + let (tx2, rx2) = if store { sync_channel(32) } else { sync_channel(1) }; + + let fake = TunFakeIO{tx: tx1, rx: rx2, store}; + let reader = TunReader{rx : rx1}; + let writer = TunWriter{tx : Mutex::new(tx2), store}; + let mtu = TunMTU{mtu : Arc::new(AtomicUsize::new(mtu))}; + + (fake, reader, writer, mtu) } } @@ -146,16 +206,11 @@ impl bind::Bind for VoidBind { type Reader = VoidBind; type Writer = VoidBind; - type Closer = (); - - fn bind(_ : u16) -> Result<(Self::Reader, Self::Writer, Self::Closer, u16), Self::Error> { - Ok((VoidBind{}, VoidBind{}, (), 2600)) - } } impl VoidBind { pub fn new() -> VoidBind { - VoidBind{} + VoidBind {} } } @@ -203,45 +258,42 @@ pub struct PairWriter { pub struct PairBind {} impl PairBind { - pub fn pair() -> ((PairReader, PairWriter), (PairReader, PairWriter)) { + pub fn pair() -> ( + (PairReader, PairWriter), + (PairReader, PairWriter), + ) { let (tx1, rx1) = sync_channel(128); let (tx2, rx2) = sync_channel(128); ( ( - PairReader{ - - recv: Arc::new(Mutex::new(rx1)), - _marker: marker::PhantomData - }, - PairWriter{ + PairReader { + recv: Arc::new(Mutex::new(rx1)), + _marker: marker::PhantomData, + }, + PairWriter { send: Arc::new(Mutex::new(tx2)), - _marker: marker::PhantomData - } + _marker: marker::PhantomData, + }, ), ( - PairReader{ + PairReader { recv: Arc::new(Mutex::new(rx2)), - _marker: marker::PhantomData - }, - PairWriter{ + _marker: marker::PhantomData, + }, + PairWriter { send: Arc::new(Mutex::new(tx1)), - _marker: marker::PhantomData - } + _marker: marker::PhantomData, + }, ), ) } } impl bind::Bind for PairBind { - type Closer = (); type Error = BindError; type Endpoint = UnitEndpoint; type Reader = PairReader; type Writer = PairWriter; - - fn bind(_port: u16) -> Result<(Self::Reader, Self::Writer, Self::Closer, u16), Self::Error> { - Err(BindError::Disconnected) - } } pub fn keypair(initiator: bool) -> KeyPair { diff --git a/src/types/endpoint.rs b/src/types/endpoint.rs index 74796aa..f4f93da 100644 --- a/src/types/endpoint.rs +++ b/src/types/endpoint.rs @@ -3,4 +3,5 @@ use std::net::SocketAddr; pub trait Endpoint: Send + 'static { fn from_address(addr: SocketAddr) -> Self; fn into_address(&self) -> SocketAddr; + fn clear_src(&self); } diff --git a/src/wireguard.rs b/src/wireguard.rs index bcb8592..f14a053 100644 --- a/src/wireguard.rs +++ b/src/wireguard.rs @@ -3,8 +3,10 @@ use crate::handshake; use crate::router; use crate::timers::{Events, Timers}; +use crate::types::bind::Reader as BindReader; use crate::types::bind::{Bind, Writer}; use crate::types::tun::{Reader, Tun, MTU}; + use crate::types::Endpoint; use hjul::Runner; @@ -53,7 +55,7 @@ pub struct PeerInner { pub timers: RwLock, // } -impl PeerInner { +impl PeerInner { #[inline(always)] pub fn timers(&self) -> RwLockReadGuard { self.timers.read() @@ -153,7 +155,7 @@ impl Wireguard { } pub fn get_sk(&self) -> Option { - let mut handshake = self.state.handshake.read(); + let handshake = self.state.handshake.read(); if handshake.active { Some(handshake.device.get_sk()) } else { @@ -184,66 +186,73 @@ impl Wireguard { peer } - pub fn new_bind(reader: B::Reader, writer: B::Writer, closer: B::Closer) { - - // drop existing closer - - // swap IO thread for new reader - - // start UDP read IO thread - - /* - { - let wg = wg.clone(); - let mtu = mtu.clone(); - thread::spawn(move || { - let mut last_under_load = - Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000); - - loop { - // create vector big enough for any message given current MTU - let size = mtu.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE; - let mut msg: Vec = Vec::with_capacity(size); - msg.resize(size, 0); - - // read UDP packet into vector - let (size, src) = reader.read(&mut msg).unwrap(); // TODO handle error - msg.truncate(size); + /* Begin consuming messages from the reader. + * + * Any previous reader thread is stopped by closing the previous reader, + * which unblocks the thread and causes an error on reader.read + */ + pub fn add_reader(&self, reader: B::Reader) { + let wg = self.state.clone(); + thread::spawn(move || { + let mut last_under_load = + Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000); + + loop { + // create vector big enough for any message given current MTU + let size = wg.mtu.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE; + let mut msg: Vec = Vec::with_capacity(size); + msg.resize(size, 0); - // message type de-multiplexer - if msg.len() < std::mem::size_of::() { - continue; + // read UDP packet into vector + let (size, src) = match reader.read(&mut msg) { + Err(e) => { + debug!("Bind reader closed with {}", e); + return; } - match LittleEndian::read_u32(&msg[..]) { - handshake::TYPE_COOKIE_REPLY - | handshake::TYPE_INITIATION - | handshake::TYPE_RESPONSE => { - // update under_load flag - if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD { - last_under_load = Instant::now(); - wg.under_load.store(true, Ordering::SeqCst); - } else if last_under_load.elapsed() > DURATION_UNDER_LOAD { - wg.under_load.store(false, Ordering::SeqCst); - } + Ok(v) => v, + }; + msg.truncate(size); - wg.queue - .lock() - .send(HandshakeJob::Message(msg, src)) - .unwrap(); - } - router::TYPE_TRANSPORT => { - // transport message - let _ = wg.router.recv(src, msg); + // message type de-multiplexer + if msg.len() < std::mem::size_of::() { + continue; + } + match LittleEndian::read_u32(&msg[..]) { + handshake::TYPE_COOKIE_REPLY + | handshake::TYPE_INITIATION + | handshake::TYPE_RESPONSE => { + // update under_load flag + if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD { + last_under_load = Instant::now(); + wg.under_load.store(true, Ordering::SeqCst); + } else if last_under_load.elapsed() > DURATION_UNDER_LOAD { + wg.under_load.store(false, Ordering::SeqCst); } - _ => (), + + wg.queue + .lock() + .send(HandshakeJob::Message(msg, src)) + .unwrap(); } + router::TYPE_TRANSPORT => { + // transport message + let _ = wg.router.recv(src, msg).map_err(|e| { + debug!("Failed to handle incoming transport message: {}", e); + }); + } + _ => (), } - }); - } - */ + } + }); } - pub fn new(reader: T::Reader, writer: T::Writer, mtu: T::MTU) -> Wireguard { + pub fn set_writer(&self, writer: B::Writer) { + // TODO: Consider unifying these and avoid Clone requirement on writer + *self.state.send.write() = Some(writer.clone()); + self.state.router.set_outbound_writer(writer); + } + + pub fn new(mut readers: Vec, writer: T::Writer, mtu: T::MTU) -> Wireguard { // create device state let mut rng = OsRng::new().unwrap(); let (tx, rx): (Sender>, _) = bounded(SIZE_HANDSHAKE_QUEUE); @@ -292,14 +301,16 @@ impl Wireguard { None }, ) { - Ok((pk, msg, keypair)) => { + Ok((pk, resp, keypair)) => { // send response - if let Some(msg) = msg { + let mut resp_len: u64 = 0; + if let Some(msg) = resp { + resp_len = msg.len() as u64; let send: &Option = &*wg.send.read(); if let Some(writer) = send.as_ref() { let _ = writer.write(&msg[..], &src).map_err(|e| { debug!( - "handshake worker, failed to send response, error = {:?}", + "handshake worker, failed to send response, error = {}", e ) }); @@ -308,16 +319,23 @@ impl Wireguard { // update timers if let Some(pk) = pk { + // authenticated handshake packet received if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { + // add to rx_bytes and tx_bytes + let req_len = msg.len() as u64; + peer.rx_bytes.fetch_add(req_len, Ordering::Relaxed); + peer.tx_bytes.fetch_add(resp_len, Ordering::Relaxed); + // update endpoint peer.router.set_endpoint(src); - // add keypair to peer and free any unused ids - if let Some(keypair) = keypair { - for id in peer.router.add_keypair(keypair) { + // add keypair to peer + keypair.map(|kp| { + // free any unused ids + for id in peer.router.add_keypair(kp) { state.device.release(id); } - } + }); } } } @@ -325,20 +343,27 @@ impl Wireguard { } } HandshakeJob::New(pk) => { - let msg = state.device.begin(&mut rng, &pk).unwrap(); // TODO handle - if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { - peer.router.send(&msg[..]); - peer.timers.read().handshake_sent(); - } + let _ = state.device.begin(&mut rng, &pk).map(|msg| { + if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { + let _ = peer.router.send(&msg[..]).map_err(|e| { + debug!("handshake worker, failed to send handshake initiation, error = {}", e) + }); + } + }); } } } }); } - // start TUN read IO thread - { + // start TUN read IO threads (multiple threads to support multi-queue interfaces) + debug_assert!( + readers.len() > 0, + "attempted to create WG device without TUN readers" + ); + while let Some(reader) = readers.pop() { let wg = wg.clone(); + let mtu = mtu.clone(); thread::spawn(move || loop { // create vector big enough for any transport message (based on MTU) let mtu = mtu.mtu(); -- cgit v1.2.3-59-g8ed1b