From c82d3e554ba34305fa7ef759c830a74f4ba9559b Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sun, 6 Oct 2019 13:33:15 +0200 Subject: Restructure dummy implementations --- src/constants.rs | 2 + src/handshake/messages.rs | 13 +++ src/handshake/mod.rs | 2 +- src/main.rs | 21 ++++- src/router/tests.rs | 222 +++------------------------------------------- src/types/dummy.rs | 217 ++++++++++++++++++++++++++++++++++++++++++++ src/types/mod.rs | 3 + src/wireguard.rs | 70 +++++++++++---- 8 files changed, 320 insertions(+), 230 deletions(-) create mode 100644 src/types/dummy.rs (limited to 'src') diff --git a/src/constants.rs b/src/constants.rs index c4e3ae7..72de8d9 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -16,3 +16,5 @@ pub const TIMER_MAX_DURATION: Duration = Duration::from_secs(200); pub const TIMERS_TICK: Duration = Duration::from_millis(100); pub const TIMERS_SLOTS: usize = (TIMER_MAX_DURATION.as_micros() / TIMERS_TICK.as_micros()) as usize; pub const TIMERS_CAPACITY: usize = 1024; + +pub const MESSAGE_PADDING_MULTIPLE: usize = 16; diff --git a/src/handshake/messages.rs b/src/handshake/messages.rs index 8611609..796e3c0 100644 --- a/src/handshake/messages.rs +++ b/src/handshake/messages.rs @@ -4,6 +4,9 @@ use hex; #[cfg(test)] use std::fmt; +use std::cmp; +use std::mem; + use byteorder::LittleEndian; use zerocopy::byteorder::U32; use zerocopy::{AsBytes, ByteSlice, FromBytes, LayoutVerified}; @@ -21,6 +24,16 @@ pub const TYPE_INITIATION: u32 = 1; pub const TYPE_RESPONSE: u32 = 2; pub const TYPE_COOKIE_REPLY: u32 = 3; +const fn max(a: usize, b: usize) -> usize { + let m: usize = (a > b) as usize; + m * a + (1 - m) * b +} + +pub const MAX_HANDSHAKE_MSG_SIZE: usize = max( + max(mem::size_of::(), mem::size_of::()), + mem::size_of::(), +); + /* Handshake messsages */ #[repr(packed)] diff --git a/src/handshake/mod.rs b/src/handshake/mod.rs index 6d017cc..071a41f 100644 --- a/src/handshake/mod.rs +++ b/src/handshake/mod.rs @@ -18,4 +18,4 @@ mod types; // publicly exposed interface pub use device::Device; -pub use messages::{TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE}; +pub use messages::{MAX_HANDSHAKE_MSG_SIZE, TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE}; diff --git a/src/main.rs b/src/main.rs index 26b39a2..6133884 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,7 +12,24 @@ mod timers; mod types; mod wireguard; -#[test] -fn test_pure_wireguard() {} +#[cfg(test)] +mod tests { + use crate::types::{dummy, Bind}; + use crate::wireguard::Wireguard; + + use std::thread; + use std::time::Duration; + + fn init() { + let _ = env_logger::builder().is_test(true).try_init(); + } + + #[test] + fn test_pure_wireguard() { + init(); + let wg = Wireguard::new(dummy::TunTest::new(), dummy::VoidBind::new()); + thread::sleep(Duration::from_millis(500)); + } +} fn main() {} diff --git a/src/router/tests.rs b/src/router/tests.rs index 07afa5d..f42e1f6 100644 --- a/src/router/tests.rs +++ b/src/router/tests.rs @@ -12,209 +12,13 @@ use num_cpus; use pnet::packet::ipv4::MutableIpv4Packet; use pnet::packet::ipv6::MutableIpv6Packet; -use super::super::types::{Bind, Endpoint, Key, KeyPair, Tun}; +use super::super::types::{dummy, Bind, Endpoint, Key, KeyPair, Tun}; use super::{Callbacks, Device, SIZE_MESSAGE_PREFIX}; extern crate test; const SIZE_KEEPALIVE: usize = 32; -/* Error implementation */ - -#[derive(Debug)] -enum BindError { - Disconnected, -} - -impl Error for BindError { - fn description(&self) -> &str { - "Generic Bind Error" - } - - fn source(&self) -> Option<&(dyn Error + 'static)> { - None - } -} - -impl fmt::Display for BindError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - BindError::Disconnected => write!(f, "PairBind disconnected"), - } - } -} - -/* TUN implementation */ - -#[derive(Debug)] -enum TunError {} - -impl Error for TunError { - fn description(&self) -> &str { - "Generic Tun Error" - } - - fn source(&self) -> Option<&(dyn Error + 'static)> { - None - } -} - -impl fmt::Display for TunError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Not Possible") - } -} - -/* Endpoint implementation */ - -#[derive(Clone, Copy)] -struct UnitEndpoint {} - -impl Endpoint for UnitEndpoint { - fn from_address(_: SocketAddr) -> UnitEndpoint { - UnitEndpoint {} - } - fn into_address(&self) -> SocketAddr { - "127.0.0.1:8080".parse().unwrap() - } -} - -#[derive(Clone, Copy)] -struct TunTest {} - -impl Tun for TunTest { - type Error = TunError; - - fn mtu(&self) -> usize { - 1500 - } - - fn read(&self, _buf: &mut [u8], _offset: usize) -> Result { - Ok(0) - } - - fn write(&self, _src: &[u8]) -> Result<(), Self::Error> { - Ok(()) - } -} - -/* Bind implemenentations */ - -#[derive(Clone, Copy)] -struct VoidBind {} - -impl Bind for VoidBind { - type Error = BindError; - type Endpoint = UnitEndpoint; - - fn new() -> VoidBind { - VoidBind {} - } - - fn set_port(&self, _port: u16) -> Result<(), Self::Error> { - Ok(()) - } - - fn get_port(&self) -> Option { - None - } - - fn recv(&self, _buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> { - Ok((0, UnitEndpoint {})) - } - - fn send(&self, _buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> { - Ok(()) - } -} - -#[derive(Clone)] -struct PairBind { - send: Arc>>>, - recv: Arc>>>, -} - -impl Bind for PairBind { - type Error = BindError; - type Endpoint = UnitEndpoint; - - fn new() -> PairBind { - PairBind { - send: Arc::new(Mutex::new(sync_channel(0).0)), - recv: Arc::new(Mutex::new(sync_channel(0).1)), - } - } - - fn set_port(&self, _port: u16) -> Result<(), Self::Error> { - Ok(()) - } - - fn get_port(&self) -> Option { - None - } - - fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> { - let vec = self - .recv - .lock() - .unwrap() - .recv() - .map_err(|_| BindError::Disconnected)?; - let len = vec.len(); - buf[..len].copy_from_slice(&vec[..]); - Ok((vec.len(), UnitEndpoint {})) - } - - fn send(&self, buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> { - let owned = buf.to_owned(); - match self.send.lock().unwrap().send(owned) { - Err(_) => Err(BindError::Disconnected), - Ok(_) => Ok(()), - } - } -} - -fn bind_pair() -> (PairBind, PairBind) { - let (tx1, rx1) = sync_channel(128); - let (tx2, rx2) = sync_channel(128); - ( - PairBind { - send: Arc::new(Mutex::new(tx1)), - recv: Arc::new(Mutex::new(rx2)), - }, - PairBind { - send: Arc::new(Mutex::new(tx2)), - recv: Arc::new(Mutex::new(rx1)), - }, - ) -} - -fn dummy_keypair(initiator: bool) -> KeyPair { - let k1 = Key { - key: [0x53u8; 32], - id: 0x646e6573, - }; - let k2 = Key { - key: [0x52u8; 32], - id: 0x76636572, - }; - if initiator { - KeyPair { - birth: Instant::now(), - initiator: true, - send: k1, - recv: k2, - } - } else { - KeyPair { - birth: Instant::now(), - initiator: false, - send: k2, - recv: k1, - } - } -} - #[cfg(test)] mod tests { use super::*; @@ -341,13 +145,13 @@ mod tests { } // create device - let router: Device = - Device::new(num_cpus::get(), TunTest {}, VoidBind::new()); + let router: Device = + Device::new(num_cpus::get(), dummy::TunTest {}, dummy::VoidBind::new()); // add new peer let opaque = Arc::new(AtomicUsize::new(0)); let peer = router.new_peer(opaque.clone()); - peer.add_keypair(dummy_keypair(true)); + peer.add_keypair(dummy::keypair(true)); // add subnet to peer let (mask, len, ip) = ("192.168.1.0", 24, "192.168.1.20"); @@ -370,7 +174,8 @@ mod tests { init(); // create device - let router: Device = Device::new(1, TunTest {}, VoidBind::new()); + let router: Device = + Device::new(1, dummy::TunTest::new(), dummy::VoidBind::new()); let tests = vec![ ("192.168.1.0", 24, "192.168.1.20", true), @@ -404,9 +209,8 @@ mod tests { let opaque = Opaque::new(); let peer = router.new_peer(opaque.clone()); let mask: IpAddr = mask.parse().unwrap(); - if set_key { - peer.add_keypair(dummy_keypair(true)); + peer.add_keypair(dummy::keypair(true)); } // map subnet to peer @@ -512,9 +316,11 @@ mod tests { for (stage, p1, p2) in tests.iter() { // create matching devices - let (bind1, bind2) = bind_pair(); - let router1: Device = Device::new(1, TunTest {}, bind1.clone()); - let router2: Device = Device::new(1, TunTest {}, bind2.clone()); + let (bind1, bind2) = dummy::PairBind::pair(); + let router1: Device = + Device::new(1, dummy::TunTest::new(), bind1.clone()); + let router2: Device = + Device::new(1, dummy::TunTest::new(), bind2.clone()); // prepare opaque values for tracing callbacks @@ -527,7 +333,7 @@ mod tests { let peer1 = router1.new_peer(opaq1.clone()); let mask: IpAddr = mask.parse().unwrap(); peer1.add_subnet(mask, *len); - peer1.add_keypair(dummy_keypair(false)); + peer1.add_keypair(dummy::keypair(false)); let (mask, len, _ip, _okay) = p2; let peer2 = router2.new_peer(opaq2.clone()); @@ -557,7 +363,7 @@ mod tests { // this should cause a key-confirmation packet (keepalive or staged packet) // this also causes peer1 to learn the "endpoint" for peer2 assert!(peer1.get_endpoint().is_none()); - peer2.add_keypair(dummy_keypair(true)); + peer2.add_keypair(dummy::keypair(true)); wait(); assert!(opaq2.send().is_some()); diff --git a/src/types/dummy.rs b/src/types/dummy.rs new file mode 100644 index 0000000..e15abb0 --- /dev/null +++ b/src/types/dummy.rs @@ -0,0 +1,217 @@ +use std::error::Error; +use std::fmt; +use std::net::SocketAddr; +use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; +use std::sync::Arc; +use std::sync::Mutex; +use std::time::Instant; + +use super::{Bind, Endpoint, Key, KeyPair, Tun}; + +/* This submodule provides pure/dummy implementations of the IO interfaces + * for use in unit tests thoughout the project. + */ + +/* Error implementation */ + +#[derive(Debug)] +pub enum BindError { + Disconnected, +} + +impl Error for BindError { + fn description(&self) -> &str { + "Generic Bind Error" + } + + fn source(&self) -> Option<&(dyn Error + 'static)> { + None + } +} + +impl fmt::Display for BindError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + BindError::Disconnected => write!(f, "PairBind disconnected"), + } + } +} + +/* TUN implementation */ + +#[derive(Debug)] +pub enum TunError {} + +impl Error for TunError { + fn description(&self) -> &str { + "Generic Tun Error" + } + + fn source(&self) -> Option<&(dyn Error + 'static)> { + None + } +} + +impl fmt::Display for TunError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Not Possible") + } +} + +/* Endpoint implementation */ + +#[derive(Clone, Copy)] +pub struct UnitEndpoint {} + +impl Endpoint for UnitEndpoint { + fn from_address(_: SocketAddr) -> UnitEndpoint { + UnitEndpoint {} + } + fn into_address(&self) -> SocketAddr { + "127.0.0.1:8080".parse().unwrap() + } +} + +#[derive(Clone, Copy)] +pub struct TunTest {} + +impl Tun for TunTest { + type Error = TunError; + + fn mtu(&self) -> usize { + 1500 + } + + fn read(&self, _buf: &mut [u8], _offset: usize) -> Result { + Ok(0) + } + + fn write(&self, _src: &[u8]) -> Result<(), Self::Error> { + Ok(()) + } +} + +impl TunTest { + pub fn new() -> TunTest { + TunTest {} + } +} + +/* Bind implemenentations */ + +#[derive(Clone, Copy)] +pub struct VoidBind {} + +impl Bind for VoidBind { + type Error = BindError; + type Endpoint = UnitEndpoint; + + fn new() -> VoidBind { + VoidBind {} + } + + fn set_port(&self, _port: u16) -> Result<(), Self::Error> { + Ok(()) + } + + fn get_port(&self) -> Option { + None + } + + fn recv(&self, _buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> { + Ok((0, UnitEndpoint {})) + } + + fn send(&self, _buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> { + Ok(()) + } +} + +#[derive(Clone)] +pub struct PairBind { + send: Arc>>>, + recv: Arc>>>, +} + +impl PairBind { + pub fn pair() -> (PairBind, PairBind) { + let (tx1, rx1) = sync_channel(128); + let (tx2, rx2) = sync_channel(128); + ( + PairBind { + send: Arc::new(Mutex::new(tx1)), + recv: Arc::new(Mutex::new(rx2)), + }, + PairBind { + send: Arc::new(Mutex::new(tx2)), + recv: Arc::new(Mutex::new(rx1)), + }, + ) + } +} + +impl Bind for PairBind { + type Error = BindError; + type Endpoint = UnitEndpoint; + + fn new() -> PairBind { + PairBind { + send: Arc::new(Mutex::new(sync_channel(0).0)), + recv: Arc::new(Mutex::new(sync_channel(0).1)), + } + } + + fn set_port(&self, _port: u16) -> Result<(), Self::Error> { + Ok(()) + } + + fn get_port(&self) -> Option { + None + } + + fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> { + let vec = self + .recv + .lock() + .unwrap() + .recv() + .map_err(|_| BindError::Disconnected)?; + let len = vec.len(); + buf[..len].copy_from_slice(&vec[..]); + Ok((vec.len(), UnitEndpoint {})) + } + + fn send(&self, buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> { + let owned = buf.to_owned(); + match self.send.lock().unwrap().send(owned) { + Err(_) => Err(BindError::Disconnected), + Ok(_) => Ok(()), + } + } +} + +pub fn keypair(initiator: bool) -> KeyPair { + let k1 = Key { + key: [0x53u8; 32], + id: 0x646e6573, + }; + let k2 = Key { + key: [0x52u8; 32], + id: 0x76636572, + }; + if initiator { + KeyPair { + birth: Instant::now(), + initiator: true, + send: k1, + recv: k2, + } + } else { + KeyPair { + birth: Instant::now(), + initiator: false, + send: k2, + recv: k1, + } + } +} diff --git a/src/types/mod.rs b/src/types/mod.rs index 8da6d45..07ca44d 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -3,6 +3,9 @@ mod keys; mod tun; mod udp; +#[cfg(test)] +pub mod dummy; + pub use endpoint::Endpoint; pub use keys::{Key, KeyPair}; pub use tun::Tun; diff --git a/src/wireguard.rs b/src/wireguard.rs index 182cec2..ea600d0 100644 --- a/src/wireguard.rs +++ b/src/wireguard.rs @@ -6,6 +6,7 @@ use crate::types::{Bind, Endpoint, Tun}; use hjul::Runner; +use std::cmp; use std::ops::Deref; use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; @@ -86,8 +87,19 @@ pub struct Wireguard { state: Arc>, } +#[inline(always)] +const fn padding(size: usize, mtu: usize) -> usize { + #[inline(always)] + const fn min(a: usize, b: usize) -> usize { + let m = (a > b) as usize; + a * m + (1 - m) * b + } + let pad = MESSAGE_PADDING_MULTIPLE; + min(mtu, size + (pad - size % pad) % pad) +} + impl Wireguard { - fn set_key(&self, sk: Option) { + pub fn set_key(&self, sk: Option) { let mut handshake = self.state.handshake.write(); match sk { None => { @@ -102,7 +114,7 @@ impl Wireguard { } } - fn new_peer(&self, pk: PublicKey) -> Peer { + pub fn new_peer(&self, pk: PublicKey) -> Peer { let state = Arc::new(PeerInner { pk, queue: Mutex::new(self.state.queue.lock().clone()), @@ -111,11 +123,21 @@ impl Wireguard { tx_bytes: AtomicU64::new(0), timers: RwLock::new(Timers::dummy(&self.runner)), }); + let router = Arc::new(self.state.router.new_peer(state.clone())); - Peer { router, state } + + let peer = Peer { router, state }; + + /* The need for dummy timers arises from the chicken-egg + * problem of the timer callbacks being able to set timers themselves. + * + * This is in fact the only place where the write lock is ever taken. + */ + *peer.timers.write() = Timers::new(&self.runner, peer.clone()); + peer } - fn new(tun: T, bind: B) -> Wireguard { + pub fn new(tun: T, bind: B) -> Wireguard { // create device state let mut rng = OsRng::new().unwrap(); let (tx, rx): (Sender>, _) = bounded(SIZE_HANDSHAKE_QUEUE); @@ -215,10 +237,12 @@ impl Wireguard { Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000); loop { - // read UDP packet into vector - let size = tun.mtu() + 148; // maximum message size + // create vector big enough for any message given current MTU + let size = tun.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE; let mut msg: Vec = Vec::with_capacity(size); msg.resize(size, 0); + + // read UDP packet into vector let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error msg.truncate(size); @@ -226,7 +250,6 @@ impl Wireguard { if msg.len() < std::mem::size_of::() { continue; } - match LittleEndian::read_u32(&msg[..]) { handshake::TYPE_COOKIE_REPLY | handshake::TYPE_INITIATION @@ -246,9 +269,6 @@ impl Wireguard { } router::TYPE_TRANSPORT => { // transport message - - // pad the message - let _ = wg.router.recv(src, msg); } _ => (), @@ -261,20 +281,32 @@ impl Wireguard { { let wg = wg.clone(); thread::spawn(move || loop { - // read a new IP packet + // create vector big enough for any transport message (based on MTU) let mtu = tun.mtu(); - let size = mtu + 148; + let size = mtu + router::SIZE_MESSAGE_PREFIX; let mut msg: Vec = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX); - let size = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap(); - msg.truncate(size); + msg.resize(size, 0); - // pad message to multiple of 16 bytes - while msg.len() < mtu && msg.len() % 16 != 0 { - msg.push(0); - } + // read a new IP packet + let payload = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap(); + debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu); + + // truncate padding + let payload = padding(payload, mtu); + msg.truncate(router::SIZE_MESSAGE_PREFIX + payload); + debug_assert!(payload <= mtu); + debug_assert_eq!( + if payload < mtu { + (msg.len() - router::SIZE_MESSAGE_PREFIX) % MESSAGE_PADDING_MULTIPLE + } else { + 0 + }, + 0 + ); // crypt-key route - let _ = wg.router.send(msg); + let e = wg.router.send(msg); + debug!("TUN worker, router returned {:?}", e); }); } -- cgit v1.2.3-59-g8ed1b