summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/constants.rs2
-rw-r--r--src/handshake/messages.rs13
-rw-r--r--src/handshake/mod.rs2
-rw-r--r--src/main.rs21
-rw-r--r--src/router/tests.rs222
-rw-r--r--src/types/dummy.rs217
-rw-r--r--src/types/mod.rs3
-rw-r--r--src/wireguard.rs70
8 files changed, 320 insertions, 230 deletions
diff --git a/src/constants.rs b/src/constants.rs
index c4e3ae7..72de8d9 100644
--- a/src/constants.rs
+++ b/src/constants.rs
@@ -16,3 +16,5 @@ pub const TIMER_MAX_DURATION: Duration = Duration::from_secs(200);
pub const TIMERS_TICK: Duration = Duration::from_millis(100);
pub const TIMERS_SLOTS: usize = (TIMER_MAX_DURATION.as_micros() / TIMERS_TICK.as_micros()) as usize;
pub const TIMERS_CAPACITY: usize = 1024;
+
+pub const MESSAGE_PADDING_MULTIPLE: usize = 16;
diff --git a/src/handshake/messages.rs b/src/handshake/messages.rs
index 8611609..796e3c0 100644
--- a/src/handshake/messages.rs
+++ b/src/handshake/messages.rs
@@ -4,6 +4,9 @@ use hex;
#[cfg(test)]
use std::fmt;
+use std::cmp;
+use std::mem;
+
use byteorder::LittleEndian;
use zerocopy::byteorder::U32;
use zerocopy::{AsBytes, ByteSlice, FromBytes, LayoutVerified};
@@ -21,6 +24,16 @@ pub const TYPE_INITIATION: u32 = 1;
pub const TYPE_RESPONSE: u32 = 2;
pub const TYPE_COOKIE_REPLY: u32 = 3;
+const fn max(a: usize, b: usize) -> usize {
+ let m: usize = (a > b) as usize;
+ m * a + (1 - m) * b
+}
+
+pub const MAX_HANDSHAKE_MSG_SIZE: usize = max(
+ max(mem::size_of::<Response>(), mem::size_of::<Initiation>()),
+ mem::size_of::<CookieReply>(),
+);
+
/* Handshake messsages */
#[repr(packed)]
diff --git a/src/handshake/mod.rs b/src/handshake/mod.rs
index 6d017cc..071a41f 100644
--- a/src/handshake/mod.rs
+++ b/src/handshake/mod.rs
@@ -18,4 +18,4 @@ mod types;
// publicly exposed interface
pub use device::Device;
-pub use messages::{TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE};
+pub use messages::{MAX_HANDSHAKE_MSG_SIZE, TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE};
diff --git a/src/main.rs b/src/main.rs
index 26b39a2..6133884 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -12,7 +12,24 @@ mod timers;
mod types;
mod wireguard;
-#[test]
-fn test_pure_wireguard() {}
+#[cfg(test)]
+mod tests {
+ use crate::types::{dummy, Bind};
+ use crate::wireguard::Wireguard;
+
+ use std::thread;
+ use std::time::Duration;
+
+ fn init() {
+ let _ = env_logger::builder().is_test(true).try_init();
+ }
+
+ #[test]
+ fn test_pure_wireguard() {
+ init();
+ let wg = Wireguard::new(dummy::TunTest::new(), dummy::VoidBind::new());
+ thread::sleep(Duration::from_millis(500));
+ }
+}
fn main() {}
diff --git a/src/router/tests.rs b/src/router/tests.rs
index 07afa5d..f42e1f6 100644
--- a/src/router/tests.rs
+++ b/src/router/tests.rs
@@ -12,209 +12,13 @@ use num_cpus;
use pnet::packet::ipv4::MutableIpv4Packet;
use pnet::packet::ipv6::MutableIpv6Packet;
-use super::super::types::{Bind, Endpoint, Key, KeyPair, Tun};
+use super::super::types::{dummy, Bind, Endpoint, Key, KeyPair, Tun};
use super::{Callbacks, Device, SIZE_MESSAGE_PREFIX};
extern crate test;
const SIZE_KEEPALIVE: usize = 32;
-/* Error implementation */
-
-#[derive(Debug)]
-enum BindError {
- Disconnected,
-}
-
-impl Error for BindError {
- fn description(&self) -> &str {
- "Generic Bind Error"
- }
-
- fn source(&self) -> Option<&(dyn Error + 'static)> {
- None
- }
-}
-
-impl fmt::Display for BindError {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- match self {
- BindError::Disconnected => write!(f, "PairBind disconnected"),
- }
- }
-}
-
-/* TUN implementation */
-
-#[derive(Debug)]
-enum TunError {}
-
-impl Error for TunError {
- fn description(&self) -> &str {
- "Generic Tun Error"
- }
-
- fn source(&self) -> Option<&(dyn Error + 'static)> {
- None
- }
-}
-
-impl fmt::Display for TunError {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(f, "Not Possible")
- }
-}
-
-/* Endpoint implementation */
-
-#[derive(Clone, Copy)]
-struct UnitEndpoint {}
-
-impl Endpoint for UnitEndpoint {
- fn from_address(_: SocketAddr) -> UnitEndpoint {
- UnitEndpoint {}
- }
- fn into_address(&self) -> SocketAddr {
- "127.0.0.1:8080".parse().unwrap()
- }
-}
-
-#[derive(Clone, Copy)]
-struct TunTest {}
-
-impl Tun for TunTest {
- type Error = TunError;
-
- fn mtu(&self) -> usize {
- 1500
- }
-
- fn read(&self, _buf: &mut [u8], _offset: usize) -> Result<usize, Self::Error> {
- Ok(0)
- }
-
- fn write(&self, _src: &[u8]) -> Result<(), Self::Error> {
- Ok(())
- }
-}
-
-/* Bind implemenentations */
-
-#[derive(Clone, Copy)]
-struct VoidBind {}
-
-impl Bind for VoidBind {
- type Error = BindError;
- type Endpoint = UnitEndpoint;
-
- fn new() -> VoidBind {
- VoidBind {}
- }
-
- fn set_port(&self, _port: u16) -> Result<(), Self::Error> {
- Ok(())
- }
-
- fn get_port(&self) -> Option<u16> {
- None
- }
-
- fn recv(&self, _buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> {
- Ok((0, UnitEndpoint {}))
- }
-
- fn send(&self, _buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> {
- Ok(())
- }
-}
-
-#[derive(Clone)]
-struct PairBind {
- send: Arc<Mutex<SyncSender<Vec<u8>>>>,
- recv: Arc<Mutex<Receiver<Vec<u8>>>>,
-}
-
-impl Bind for PairBind {
- type Error = BindError;
- type Endpoint = UnitEndpoint;
-
- fn new() -> PairBind {
- PairBind {
- send: Arc::new(Mutex::new(sync_channel(0).0)),
- recv: Arc::new(Mutex::new(sync_channel(0).1)),
- }
- }
-
- fn set_port(&self, _port: u16) -> Result<(), Self::Error> {
- Ok(())
- }
-
- fn get_port(&self) -> Option<u16> {
- None
- }
-
- fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> {
- let vec = self
- .recv
- .lock()
- .unwrap()
- .recv()
- .map_err(|_| BindError::Disconnected)?;
- let len = vec.len();
- buf[..len].copy_from_slice(&vec[..]);
- Ok((vec.len(), UnitEndpoint {}))
- }
-
- fn send(&self, buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> {
- let owned = buf.to_owned();
- match self.send.lock().unwrap().send(owned) {
- Err(_) => Err(BindError::Disconnected),
- Ok(_) => Ok(()),
- }
- }
-}
-
-fn bind_pair() -> (PairBind, PairBind) {
- let (tx1, rx1) = sync_channel(128);
- let (tx2, rx2) = sync_channel(128);
- (
- PairBind {
- send: Arc::new(Mutex::new(tx1)),
- recv: Arc::new(Mutex::new(rx2)),
- },
- PairBind {
- send: Arc::new(Mutex::new(tx2)),
- recv: Arc::new(Mutex::new(rx1)),
- },
- )
-}
-
-fn dummy_keypair(initiator: bool) -> KeyPair {
- let k1 = Key {
- key: [0x53u8; 32],
- id: 0x646e6573,
- };
- let k2 = Key {
- key: [0x52u8; 32],
- id: 0x76636572,
- };
- if initiator {
- KeyPair {
- birth: Instant::now(),
- initiator: true,
- send: k1,
- recv: k2,
- }
- } else {
- KeyPair {
- birth: Instant::now(),
- initiator: false,
- send: k2,
- recv: k1,
- }
- }
-}
-
#[cfg(test)]
mod tests {
use super::*;
@@ -341,13 +145,13 @@ mod tests {
}
// create device
- let router: Device<BencherCallbacks, TunTest, VoidBind> =
- Device::new(num_cpus::get(), TunTest {}, VoidBind::new());
+ let router: Device<BencherCallbacks, dummy::TunTest, dummy::VoidBind> =
+ Device::new(num_cpus::get(), dummy::TunTest {}, dummy::VoidBind::new());
// add new peer
let opaque = Arc::new(AtomicUsize::new(0));
let peer = router.new_peer(opaque.clone());
- peer.add_keypair(dummy_keypair(true));
+ peer.add_keypair(dummy::keypair(true));
// add subnet to peer
let (mask, len, ip) = ("192.168.1.0", 24, "192.168.1.20");
@@ -370,7 +174,8 @@ mod tests {
init();
// create device
- let router: Device<TestCallbacks, _, _> = Device::new(1, TunTest {}, VoidBind::new());
+ let router: Device<TestCallbacks, _, _> =
+ Device::new(1, dummy::TunTest::new(), dummy::VoidBind::new());
let tests = vec![
("192.168.1.0", 24, "192.168.1.20", true),
@@ -404,9 +209,8 @@ mod tests {
let opaque = Opaque::new();
let peer = router.new_peer(opaque.clone());
let mask: IpAddr = mask.parse().unwrap();
-
if set_key {
- peer.add_keypair(dummy_keypair(true));
+ peer.add_keypair(dummy::keypair(true));
}
// map subnet to peer
@@ -512,9 +316,11 @@ mod tests {
for (stage, p1, p2) in tests.iter() {
// create matching devices
- let (bind1, bind2) = bind_pair();
- let router1: Device<TestCallbacks, _, _> = Device::new(1, TunTest {}, bind1.clone());
- let router2: Device<TestCallbacks, _, _> = Device::new(1, TunTest {}, bind2.clone());
+ let (bind1, bind2) = dummy::PairBind::pair();
+ let router1: Device<TestCallbacks, _, _> =
+ Device::new(1, dummy::TunTest::new(), bind1.clone());
+ let router2: Device<TestCallbacks, _, _> =
+ Device::new(1, dummy::TunTest::new(), bind2.clone());
// prepare opaque values for tracing callbacks
@@ -527,7 +333,7 @@ mod tests {
let peer1 = router1.new_peer(opaq1.clone());
let mask: IpAddr = mask.parse().unwrap();
peer1.add_subnet(mask, *len);
- peer1.add_keypair(dummy_keypair(false));
+ peer1.add_keypair(dummy::keypair(false));
let (mask, len, _ip, _okay) = p2;
let peer2 = router2.new_peer(opaq2.clone());
@@ -557,7 +363,7 @@ mod tests {
// this should cause a key-confirmation packet (keepalive or staged packet)
// this also causes peer1 to learn the "endpoint" for peer2
assert!(peer1.get_endpoint().is_none());
- peer2.add_keypair(dummy_keypair(true));
+ peer2.add_keypair(dummy::keypair(true));
wait();
assert!(opaq2.send().is_some());
diff --git a/src/types/dummy.rs b/src/types/dummy.rs
new file mode 100644
index 0000000..e15abb0
--- /dev/null
+++ b/src/types/dummy.rs
@@ -0,0 +1,217 @@
+use std::error::Error;
+use std::fmt;
+use std::net::SocketAddr;
+use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
+use std::sync::Arc;
+use std::sync::Mutex;
+use std::time::Instant;
+
+use super::{Bind, Endpoint, Key, KeyPair, Tun};
+
+/* This submodule provides pure/dummy implementations of the IO interfaces
+ * for use in unit tests thoughout the project.
+ */
+
+/* Error implementation */
+
+#[derive(Debug)]
+pub enum BindError {
+ Disconnected,
+}
+
+impl Error for BindError {
+ fn description(&self) -> &str {
+ "Generic Bind Error"
+ }
+
+ fn source(&self) -> Option<&(dyn Error + 'static)> {
+ None
+ }
+}
+
+impl fmt::Display for BindError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ BindError::Disconnected => write!(f, "PairBind disconnected"),
+ }
+ }
+}
+
+/* TUN implementation */
+
+#[derive(Debug)]
+pub enum TunError {}
+
+impl Error for TunError {
+ fn description(&self) -> &str {
+ "Generic Tun Error"
+ }
+
+ fn source(&self) -> Option<&(dyn Error + 'static)> {
+ None
+ }
+}
+
+impl fmt::Display for TunError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "Not Possible")
+ }
+}
+
+/* Endpoint implementation */
+
+#[derive(Clone, Copy)]
+pub struct UnitEndpoint {}
+
+impl Endpoint for UnitEndpoint {
+ fn from_address(_: SocketAddr) -> UnitEndpoint {
+ UnitEndpoint {}
+ }
+ fn into_address(&self) -> SocketAddr {
+ "127.0.0.1:8080".parse().unwrap()
+ }
+}
+
+#[derive(Clone, Copy)]
+pub struct TunTest {}
+
+impl Tun for TunTest {
+ type Error = TunError;
+
+ fn mtu(&self) -> usize {
+ 1500
+ }
+
+ fn read(&self, _buf: &mut [u8], _offset: usize) -> Result<usize, Self::Error> {
+ Ok(0)
+ }
+
+ fn write(&self, _src: &[u8]) -> Result<(), Self::Error> {
+ Ok(())
+ }
+}
+
+impl TunTest {
+ pub fn new() -> TunTest {
+ TunTest {}
+ }
+}
+
+/* Bind implemenentations */
+
+#[derive(Clone, Copy)]
+pub struct VoidBind {}
+
+impl Bind for VoidBind {
+ type Error = BindError;
+ type Endpoint = UnitEndpoint;
+
+ fn new() -> VoidBind {
+ VoidBind {}
+ }
+
+ fn set_port(&self, _port: u16) -> Result<(), Self::Error> {
+ Ok(())
+ }
+
+ fn get_port(&self) -> Option<u16> {
+ None
+ }
+
+ fn recv(&self, _buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> {
+ Ok((0, UnitEndpoint {}))
+ }
+
+ fn send(&self, _buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> {
+ Ok(())
+ }
+}
+
+#[derive(Clone)]
+pub struct PairBind {
+ send: Arc<Mutex<SyncSender<Vec<u8>>>>,
+ recv: Arc<Mutex<Receiver<Vec<u8>>>>,
+}
+
+impl PairBind {
+ pub fn pair() -> (PairBind, PairBind) {
+ let (tx1, rx1) = sync_channel(128);
+ let (tx2, rx2) = sync_channel(128);
+ (
+ PairBind {
+ send: Arc::new(Mutex::new(tx1)),
+ recv: Arc::new(Mutex::new(rx2)),
+ },
+ PairBind {
+ send: Arc::new(Mutex::new(tx2)),
+ recv: Arc::new(Mutex::new(rx1)),
+ },
+ )
+ }
+}
+
+impl Bind for PairBind {
+ type Error = BindError;
+ type Endpoint = UnitEndpoint;
+
+ fn new() -> PairBind {
+ PairBind {
+ send: Arc::new(Mutex::new(sync_channel(0).0)),
+ recv: Arc::new(Mutex::new(sync_channel(0).1)),
+ }
+ }
+
+ fn set_port(&self, _port: u16) -> Result<(), Self::Error> {
+ Ok(())
+ }
+
+ fn get_port(&self) -> Option<u16> {
+ None
+ }
+
+ fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> {
+ let vec = self
+ .recv
+ .lock()
+ .unwrap()
+ .recv()
+ .map_err(|_| BindError::Disconnected)?;
+ let len = vec.len();
+ buf[..len].copy_from_slice(&vec[..]);
+ Ok((vec.len(), UnitEndpoint {}))
+ }
+
+ fn send(&self, buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> {
+ let owned = buf.to_owned();
+ match self.send.lock().unwrap().send(owned) {
+ Err(_) => Err(BindError::Disconnected),
+ Ok(_) => Ok(()),
+ }
+ }
+}
+
+pub fn keypair(initiator: bool) -> KeyPair {
+ let k1 = Key {
+ key: [0x53u8; 32],
+ id: 0x646e6573,
+ };
+ let k2 = Key {
+ key: [0x52u8; 32],
+ id: 0x76636572,
+ };
+ if initiator {
+ KeyPair {
+ birth: Instant::now(),
+ initiator: true,
+ send: k1,
+ recv: k2,
+ }
+ } else {
+ KeyPair {
+ birth: Instant::now(),
+ initiator: false,
+ send: k2,
+ recv: k1,
+ }
+ }
+}
diff --git a/src/types/mod.rs b/src/types/mod.rs
index 8da6d45..07ca44d 100644
--- a/src/types/mod.rs
+++ b/src/types/mod.rs
@@ -3,6 +3,9 @@ mod keys;
mod tun;
mod udp;
+#[cfg(test)]
+pub mod dummy;
+
pub use endpoint::Endpoint;
pub use keys::{Key, KeyPair};
pub use tun::Tun;
diff --git a/src/wireguard.rs b/src/wireguard.rs
index 182cec2..ea600d0 100644
--- a/src/wireguard.rs
+++ b/src/wireguard.rs
@@ -6,6 +6,7 @@ use crate::types::{Bind, Endpoint, Tun};
use hjul::Runner;
+use std::cmp;
use std::ops::Deref;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
@@ -86,8 +87,19 @@ pub struct Wireguard<T: Tun, B: Bind> {
state: Arc<WireguardInner<T, B>>,
}
+#[inline(always)]
+const fn padding(size: usize, mtu: usize) -> usize {
+ #[inline(always)]
+ const fn min(a: usize, b: usize) -> usize {
+ let m = (a > b) as usize;
+ a * m + (1 - m) * b
+ }
+ let pad = MESSAGE_PADDING_MULTIPLE;
+ min(mtu, size + (pad - size % pad) % pad)
+}
+
impl<T: Tun, B: Bind> Wireguard<T, B> {
- fn set_key(&self, sk: Option<StaticSecret>) {
+ pub fn set_key(&self, sk: Option<StaticSecret>) {
let mut handshake = self.state.handshake.write();
match sk {
None => {
@@ -102,7 +114,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
}
}
- fn new_peer(&self, pk: PublicKey) -> Peer<T, B> {
+ pub fn new_peer(&self, pk: PublicKey) -> Peer<T, B> {
let state = Arc::new(PeerInner {
pk,
queue: Mutex::new(self.state.queue.lock().clone()),
@@ -111,11 +123,21 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
tx_bytes: AtomicU64::new(0),
timers: RwLock::new(Timers::dummy(&self.runner)),
});
+
let router = Arc::new(self.state.router.new_peer(state.clone()));
- Peer { router, state }
+
+ let peer = Peer { router, state };
+
+ /* The need for dummy timers arises from the chicken-egg
+ * problem of the timer callbacks being able to set timers themselves.
+ *
+ * This is in fact the only place where the write lock is ever taken.
+ */
+ *peer.timers.write() = Timers::new(&self.runner, peer.clone());
+ peer
}
- fn new(tun: T, bind: B) -> Wireguard<T, B> {
+ pub fn new(tun: T, bind: B) -> Wireguard<T, B> {
// create device state
let mut rng = OsRng::new().unwrap();
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
@@ -215,10 +237,12 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
loop {
- // read UDP packet into vector
- let size = tun.mtu() + 148; // maximum message size
+ // create vector big enough for any message given current MTU
+ let size = tun.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE;
let mut msg: Vec<u8> = Vec::with_capacity(size);
msg.resize(size, 0);
+
+ // read UDP packet into vector
let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error
msg.truncate(size);
@@ -226,7 +250,6 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
if msg.len() < std::mem::size_of::<u32>() {
continue;
}
-
match LittleEndian::read_u32(&msg[..]) {
handshake::TYPE_COOKIE_REPLY
| handshake::TYPE_INITIATION
@@ -246,9 +269,6 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
}
router::TYPE_TRANSPORT => {
// transport message
-
- // pad the message
-
let _ = wg.router.recv(src, msg);
}
_ => (),
@@ -261,20 +281,32 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
{
let wg = wg.clone();
thread::spawn(move || loop {
- // read a new IP packet
+ // create vector big enough for any transport message (based on MTU)
let mtu = tun.mtu();
- let size = mtu + 148;
+ let size = mtu + router::SIZE_MESSAGE_PREFIX;
let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
- let size = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
- msg.truncate(size);
+ msg.resize(size, 0);
- // pad message to multiple of 16 bytes
- while msg.len() < mtu && msg.len() % 16 != 0 {
- msg.push(0);
- }
+ // read a new IP packet
+ let payload = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
+ debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu);
+
+ // truncate padding
+ let payload = padding(payload, mtu);
+ msg.truncate(router::SIZE_MESSAGE_PREFIX + payload);
+ debug_assert!(payload <= mtu);
+ debug_assert_eq!(
+ if payload < mtu {
+ (msg.len() - router::SIZE_MESSAGE_PREFIX) % MESSAGE_PADDING_MULTIPLE
+ } else {
+ 0
+ },
+ 0
+ );
// crypt-key route
- let _ = wg.router.send(msg);
+ let e = wg.router.send(msg);
+ debug!("TUN worker, router returned {:?}", e);
});
}