aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/configuration/config.rs2
-rw-r--r--src/main.rs21
-rw-r--r--src/platform/linux/tun.rs155
-rw-r--r--src/wireguard/tests.rs14
-rw-r--r--src/wireguard/timers.rs4
-rw-r--r--src/wireguard/wireguard.rs224
6 files changed, 311 insertions, 109 deletions
diff --git a/src/configuration/config.rs b/src/configuration/config.rs
index e7d1ba5..c045d1e 100644
--- a/src/configuration/config.rs
+++ b/src/configuration/config.rs
@@ -261,7 +261,7 @@ impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireguardConfig<T, B> {
// add readers
while let Some(reader) = readers.pop() {
- cfg.wireguard.add_reader(reader);
+ cfg.wireguard.add_udp_reader(reader);
}
// create new UDP state
diff --git a/src/main.rs b/src/main.rs
index 1a9650b..5ea830f 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -26,7 +26,7 @@ fn main() {
let mut foreground = false;
let mut args = env::args();
- args.next(); // skip path
+ args.next(); // skip path (argv[0])
for arg in args {
match arg.as_str() {
@@ -56,7 +56,7 @@ fn main() {
});
// create TUN device
- let (readers, writer, status) = plt::Tun::create(name.as_str()).unwrap_or_else(|e| {
+ let (mut readers, writer, status) = plt::Tun::create(name.as_str()).unwrap_or_else(|e| {
eprintln!("Failed to create TUN device: {}", e);
exit(-3);
});
@@ -82,7 +82,15 @@ fn main() {
if drop_privileges {}
// create WireGuard device
- let wg: wireguard::Wireguard<plt::Tun, plt::UDP> = wireguard::Wireguard::new(readers, writer);
+ let wg: wireguard::Wireguard<plt::Tun, plt::UDP> = wireguard::Wireguard::new(writer);
+
+ // add all Tun readers
+ while let Some(reader) = readers.pop() {
+ wg.add_tun_reader(reader);
+ }
+
+ // obtain handle for waiting
+ let wait = wg.wait();
// wrap in configuration interface
let cfg = configuration::WireguardConfig::new(wg);
@@ -124,7 +132,7 @@ fn main() {
}
// start UAPI server
- loop {
+ thread::spawn(move || loop {
match uapi.connect() {
Ok(mut stream) => {
let cfg = cfg.clone();
@@ -137,5 +145,8 @@ fn main() {
break;
}
}
- }
+ });
+
+ // block until all tun readers closed
+ wait.wait();
}
diff --git a/src/platform/linux/tun.rs b/src/platform/linux/tun.rs
index 82eb469..442d9bc 100644
--- a/src/platform/linux/tun.rs
+++ b/src/platform/linux/tun.rs
@@ -1,13 +1,12 @@
use super::super::tun::*;
-use libc::*;
+use libc;
use std::error::Error;
use std::fmt;
+use std::mem;
use std::os::raw::c_short;
use std::os::unix::io::RawFd;
-use std::thread;
-use std::time::Duration;
const IFNAMSIZ: usize = 16;
const TUNSETIFF: u64 = 0x4004_54ca;
@@ -30,6 +29,18 @@ struct Ifreq {
_pad: [u8; 64],
}
+// man 7 rtnetlink
+// Layout from: https://elixir.bootlin.com/linux/latest/source/include/uapi/linux/rtnetlink.h#L516
+#[repr(C)]
+struct IfInfomsg {
+ ifi_family: libc::c_uchar,
+ __ifi_pad: libc::c_uchar,
+ ifi_type: libc::c_ushort,
+ ifi_index: libc::c_int,
+ ifi_flags: libc::c_uint,
+ ifi_change: libc::c_uint,
+}
+
pub struct LinuxTun {
events: Vec<TunEvent>,
}
@@ -42,12 +53,9 @@ pub struct LinuxTunWriter {
fd: RawFd,
}
-/* Listens for netlink messages
- * announcing an MTU update for the interface
- */
-#[derive(Clone)]
pub struct LinuxTunStatus {
- first: bool,
+ events: Vec<TunEvent>,
+ fd: RawFd,
}
#[derive(Debug)]
@@ -94,7 +102,7 @@ impl Reader for LinuxTunReader {
);
*/
let n: isize =
- unsafe { read(self.fd, buf[offset..].as_mut_ptr() as _, buf.len() - offset) };
+ unsafe { libc::read(self.fd, buf[offset..].as_mut_ptr() as _, buf.len() - offset) };
if n < 0 {
Err(LinuxTunError::Closed)
} else {
@@ -108,7 +116,7 @@ impl Writer for LinuxTunWriter {
type Error = LinuxTunError;
fn write(&self, src: &[u8]) -> Result<(), Self::Error> {
- match unsafe { write(self.fd, src.as_ptr() as _, src.len() as _) } {
+ match unsafe { libc::write(self.fd, src.as_ptr() as _, src.len() as _) } {
-1 => Err(LinuxTunError::Closed),
_ => Ok(()),
}
@@ -119,13 +127,124 @@ impl Status for LinuxTunStatus {
type Error = LinuxTunError;
fn event(&mut self) -> Result<TunEvent, Self::Error> {
- if self.first {
- self.first = false;
- return Ok(TunEvent::Up(1420));
- }
+ const DONE: u16 = libc::NLMSG_DONE as u16;
+ const ERROR: u16 = libc::NLMSG_ERROR as u16;
+ const INFO_SIZE: usize = mem::size_of::<IfInfomsg>();
+ const HDR_SIZE: usize = mem::size_of::<libc::nlmsghdr>();
+ let mut buf = [0u8; 1 << 12];
+ log::debug!("netlink, fetch event (fd = {})", self.fd);
loop {
- thread::sleep(Duration::from_secs(60 * 60));
+ // attempt to return a buffered event
+ if let Some(event) = self.events.pop() {
+ return Ok(event);
+ }
+
+ // read message
+ let size: libc::ssize_t =
+ unsafe { libc::recv(self.fd, mem::transmute(&mut buf), buf.len(), 0) };
+ if size < 0 {
+ break Err(LinuxTunError::Closed);
+ }
+
+ // cut buffer to size
+ let size: usize = size as usize;
+ let mut remain = &buf[..size];
+ log::debug!("netlink, recieved message ({} bytes)", size);
+
+ // handle messages
+ while remain.len() >= HDR_SIZE {
+ // extract the header
+ assert!(remain.len() > HDR_SIZE);
+ let mut hdr = [0u8; HDR_SIZE];
+ hdr.copy_from_slice(&remain[..HDR_SIZE]);
+ let hdr: libc::nlmsghdr = unsafe { mem::transmute(hdr) };
+
+ // upcast length
+ let body: &[u8] = &remain[HDR_SIZE..];
+ let msg_len: usize = hdr.nlmsg_len as usize;
+ assert!(msg_len <= remain.len(), "malformed netlink message");
+
+ // handle message body
+ match hdr.nlmsg_type {
+ DONE => break,
+ ERROR => break,
+ libc::RTM_NEWLINK => {
+ // extract info struct
+ if body.len() < INFO_SIZE {
+ return Err(LinuxTunError::Closed);
+ }
+
+ let mut info = [0u8; INFO_SIZE];
+ info.copy_from_slice(&body[..INFO_SIZE]);
+ log::debug!("netlink, RTM_NEWLINK {:?}", &info[..]);
+ let info: IfInfomsg = unsafe { mem::transmute(info) };
+
+ // trace log
+ log::trace!(
+ "netlink, IfInfomsg{{ family = {}, type = {}, index = {}, flags = {}, change = {}}}",
+ info.ifi_family,
+ info.ifi_type,
+ info.ifi_index,
+ info.ifi_flags,
+ info.ifi_change,
+ );
+ debug_assert_eq!(info.__ifi_pad, 0);
+
+ // handle up / down
+ if info.ifi_flags & (libc::IFF_UP as u32) != 0 {
+ log::trace!("netlink, up event");
+ self.events.push(TunEvent::Up(1420));
+ } else {
+ log::trace!("netlink, down event");
+ self.events.push(TunEvent::Down);
+ }
+ }
+ _ => (),
+ };
+
+ // go to next message
+ remain = &remain[msg_len..];
+ }
+ }
+ }
+}
+
+impl LinuxTunStatus {
+ const RTNLGRP_LINK: libc::c_uint = 1;
+ const RTNLGRP_IPV4_IFADDR: libc::c_uint = 5;
+ const RTNLGRP_IPV6_IFADDR: libc::c_uint = 9;
+
+ fn new() -> Result<LinuxTunStatus, LinuxTunError> {
+ // create netlink socket
+ let fd = unsafe { libc::socket(libc::AF_NETLINK, libc::SOCK_RAW, libc::NETLINK_ROUTE) };
+ if fd < 0 {
+ return Err(LinuxTunError::Closed);
+ }
+
+ // prepare address (specify groups)
+ let groups = (1 << (Self::RTNLGRP_LINK - 1))
+ | (1 << (Self::RTNLGRP_IPV4_IFADDR - 1))
+ | (1 << (Self::RTNLGRP_IPV6_IFADDR - 1));
+
+ let mut sockaddr: libc::sockaddr_nl = unsafe { mem::zeroed() };
+ sockaddr.nl_family = libc::AF_NETLINK as u16;
+ sockaddr.nl_groups = groups;
+ sockaddr.nl_pid = 0;
+
+ // attempt to bind
+ let res = unsafe {
+ libc::bind(
+ fd,
+ mem::transmute(&mut sockaddr),
+ mem::size_of::<libc::sockaddr_nl>() as u32,
+ )
+ };
+
+ if res != 0 {
+ Err(LinuxTunError::Closed)
+ } else {
+ Ok(LinuxTunStatus { events: vec![], fd })
}
}
}
@@ -155,14 +274,14 @@ impl PlatformTun for LinuxTun {
req.name[..bs.len()].copy_from_slice(bs);
// open clone device
- let fd: RawFd = match unsafe { open(CLONE_DEVICE_PATH.as_ptr() as _, O_RDWR) } {
+ let fd: RawFd = match unsafe { libc::open(CLONE_DEVICE_PATH.as_ptr() as _, libc::O_RDWR) } {
-1 => return Err(LinuxTunError::FailedToOpenCloneDevice),
fd => fd,
};
assert!(fd >= 0);
// create TUN device
- if unsafe { ioctl(fd, TUNSETIFF as _, &req) } < 0 {
+ if unsafe { libc::ioctl(fd, TUNSETIFF as _, &req) } < 0 {
return Err(LinuxTunError::SetIFFIoctlFailed);
}
@@ -170,7 +289,7 @@ impl PlatformTun for LinuxTun {
Ok((
vec![LinuxTunReader { fd }], // TODO: enable multi-queue for Linux
LinuxTunWriter { fd },
- LinuxTunStatus { first: true },
+ LinuxTunStatus::new()?,
))
}
}
diff --git a/src/wireguard/tests.rs b/src/wireguard/tests.rs
index 7a18005..bf1bd5f 100644
--- a/src/wireguard/tests.rs
+++ b/src/wireguard/tests.rs
@@ -85,15 +85,13 @@ fn test_pure_wireguard() {
// create WG instances for dummy TUN devices
let (fake1, tun_reader1, tun_writer1, _) = dummy::TunTest::create(true);
- let wg1: Wireguard<dummy::TunTest, dummy::PairBind> =
- Wireguard::new(vec![tun_reader1], tun_writer1);
-
+ let wg1: Wireguard<dummy::TunTest, dummy::PairBind> = Wireguard::new(tun_writer1);
+ wg1.add_tun_reader(tun_reader1);
wg1.up(1500);
let (fake2, tun_reader2, tun_writer2, _) = dummy::TunTest::create(true);
- let wg2: Wireguard<dummy::TunTest, dummy::PairBind> =
- Wireguard::new(vec![tun_reader2], tun_writer2);
-
+ let wg2: Wireguard<dummy::TunTest, dummy::PairBind> = Wireguard::new(tun_writer2);
+ wg2.add_tun_reader(tun_reader2);
wg2.up(1500);
// create pair bind to connect the interfaces "over the internet"
@@ -103,8 +101,8 @@ fn test_pure_wireguard() {
wg1.set_writer(bind_writer1);
wg2.set_writer(bind_writer2);
- wg1.add_reader(bind_reader1);
- wg2.add_reader(bind_reader2);
+ wg1.add_udp_reader(bind_reader1);
+ wg2.add_udp_reader(bind_reader2);
// generate (public, pivate) key pairs
diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs
index 18f49bf..0ce4210 100644
--- a/src/wireguard/timers.rs
+++ b/src/wireguard/timers.rs
@@ -221,14 +221,14 @@ impl<T: tun::Tun, B: udp::UDP> PeerInner<T, B> {
impl Timers {
- pub fn new<T, B>(runner: &Runner, peer: Peer<T, B>) -> Timers
+ pub fn new<T, B>(runner: &Runner, running: bool, peer: Peer<T, B>) -> Timers
where
T: tun::Tun,
B: udp::UDP,
{
// create a timer instance for the provided peer
Timers {
- enabled: true,
+ enabled: running,
keepalive_interval: 0, // disabled
need_another_keepalive: AtomicBool::new(false),
sent_lastminute_handshake: AtomicBool::new(false),
diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs
index 61f6428..2b0e779 100644
--- a/src/wireguard/wireguard.rs
+++ b/src/wireguard/wireguard.rs
@@ -22,6 +22,10 @@ use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant};
+// TODO: avoid
+use std::sync::Condvar;
+use std::sync::Mutex as StdMutex;
+
use std::collections::hash_map::Entry;
use std::collections::HashMap;
@@ -38,15 +42,51 @@ const SIZE_HANDSHAKE_QUEUE: usize = 128;
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000);
+#[derive(Clone)]
+pub struct WaitHandle(Arc<(StdMutex<usize>, Condvar)>);
+
+impl WaitHandle {
+ pub fn wait(&self) {
+ let (lock, cvar) = &*self.0;
+ let mut nread = lock.lock().unwrap();
+ while *nread > 0 {
+ nread = cvar.wait(nread).unwrap();
+ }
+ }
+
+ fn new() -> Self {
+ Self(Arc::new((StdMutex::new(0), Condvar::new())))
+ }
+
+ fn decrease(&self) {
+ let (lock, cvar) = &*self.0;
+ let mut nread = lock.lock().unwrap();
+ assert!(*nread > 0);
+ *nread -= 1;
+ cvar.notify_all();
+ }
+
+ fn increase(&self) {
+ let (lock, _) = &*self.0;
+ let mut nread = lock.lock().unwrap();
+ *nread += 1;
+ }
+}
+
pub struct WireguardInner<T: tun::Tun, B: udp::UDP> {
// identifier (for logging)
id: u32,
- start: Instant,
+
+ // device enabled
+ enabled: RwLock<bool>,
+
+ // enables waiting for all readers to finish
+ tun_readers: WaitHandle,
// current MTU
mtu: AtomicUsize,
- // provides access to the MTU value of the tun device
+ // outbound writer
send: RwLock<Option<B::Writer>>,
// identity and configuration map
@@ -145,7 +185,12 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
/// on both ends of the device.
pub fn down(&self) {
// ensure exclusive access (to avoid race with "up" call)
- let peers = self.peers.write();
+ let mut enabled = self.enabled.write();
+
+ // check if already down
+ if *enabled == false {
+ return;
+ }
// set mtu
self.state.mtu.store(0, Ordering::Relaxed);
@@ -154,27 +199,36 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
self.router.down();
// set all peers down (stops timers)
- for peer in peers.values() {
+ for peer in self.peers.write().values() {
peer.down();
}
+
+ *enabled = false;
}
/// Brings the WireGuard device up.
/// Usually called when the associated interface is brought up.
pub fn up(&self, mtu: usize) {
- // ensure exclusive access (to avoid race with "down" call)
- let peers = self.peers.write();
+ // ensure exclusive access (to avoid race with "up" call)
+ let mut enabled = self.enabled.write();
// set mtu
self.state.mtu.store(mtu, Ordering::Relaxed);
+ // check if already up
+ if *enabled {
+ return;
+ }
+
// enable tranmission from router
self.router.up();
// set all peers up (restarts timers)
- for peer in peers.values() {
+ for peer in self.peers.write().values() {
peer.up();
}
+
+ *enabled = true;
}
pub fn clear_peers(&self) {
@@ -232,7 +286,7 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
pk,
wg: self.state.clone(),
walltime_last_handshake: Mutex::new(None),
- last_handshake_sent: Mutex::new(self.state.start - TIME_HORIZON),
+ last_handshake_sent: Mutex::new(Instant::now() - TIME_HORIZON),
handshake_queued: AtomicBool::new(false),
queue: Mutex::new(self.state.queue.lock().clone()),
rx_bytes: AtomicU64::new(0),
@@ -246,24 +300,31 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
// form WireGuard peer
let peer = Peer { router, state };
- /* The need for dummy timers arises from the chicken-egg
- * problem of the timer callbacks being able to set timers themselves.
- *
- * This is in fact the only place where the write lock is ever taken.
- * TODO: Consider the ease of using atomic pointers instead.
- */
- *peer.timers.write() = Timers::new(&self.runner, peer.clone());
-
// finally, add the peer to the wireguard device
let mut peers = self.state.peers.write();
match peers.entry(*pk.as_bytes()) {
Entry::Occupied(_) => false,
Entry::Vacant(vacancy) => {
+ // check that the public key does not cause conflict with the private key of the device
let ok_pk = self.state.handshake.write().add(pk).is_ok();
- if ok_pk {
- vacancy.insert(peer);
+ if !ok_pk {
+ return false;
}
- ok_pk
+
+ // prevent up/down while inserting
+ let enabled = self.enabled.read();
+
+ /* The need for dummy timers arises from the chicken-egg
+ * problem of the timer callbacks being able to set timers themselves.
+ *
+ * This is in fact the only place where the write lock is ever taken.
+ * TODO: Consider the ease of using atomic pointers instead.
+ */
+ *peer.timers.write() = Timers::new(&self.runner, *enabled, peer.clone());
+
+ // insert into peer map (takes ownership and ensures that the peer is not dropped)
+ vacancy.insert(peer);
+ true
}
}
}
@@ -273,7 +334,7 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
///
/// Any previous reader thread is stopped by closing the previous reader,
/// which unblocks the thread and causes an error on reader.read
- pub fn add_reader(&self, reader: B::Reader) {
+ pub fn add_udp_reader(&self, reader: B::Reader) {
let wg = self.state.clone();
thread::spawn(move || {
let mut last_under_load =
@@ -350,7 +411,72 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
self.state.router.set_outbound_writer(writer);
}
- pub fn new(mut readers: Vec<T::Reader>, writer: T::Writer) -> Wireguard<T, B> {
+ pub fn add_tun_reader(&self, reader: T::Reader) {
+ fn worker<T: tun::Tun, B: udp::UDP>(wg: &Arc<WireguardInner<T, B>>, reader: T::Reader) {
+ loop {
+ // create vector big enough for any transport message (based on MTU)
+ let mtu = wg.mtu.load(Ordering::Relaxed);
+ let size = mtu + router::SIZE_MESSAGE_PREFIX + 1;
+ let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
+ msg.resize(size, 0);
+
+ // read a new IP packet
+ let payload = match reader.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX) {
+ Ok(payload) => payload,
+ Err(e) => {
+ debug!("TUN worker, failed to read from tun device: {}", e);
+ break;
+ }
+ };
+ debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu);
+
+ // TODO: start device down
+ if mtu == 0 {
+ continue;
+ }
+
+ // truncate padding
+ let padded = padding(payload, mtu);
+ log::trace!(
+ "TUN worker, payload length = {}, padded length = {}",
+ payload,
+ padded
+ );
+ msg.truncate(router::SIZE_MESSAGE_PREFIX + padded);
+ debug_assert!(padded <= mtu);
+ debug_assert_eq!(
+ if padded < mtu {
+ (msg.len() - router::SIZE_MESSAGE_PREFIX) % MESSAGE_PADDING_MULTIPLE
+ } else {
+ 0
+ },
+ 0
+ );
+
+ // crypt-key route
+ let e = wg.router.send(msg);
+ debug!("TUN worker, router returned {:?}", e);
+ }
+ }
+
+ // start a thread for every reader
+ let wg = self.state.clone();
+
+ // increment reader count
+ wg.tun_readers.increase();
+
+ // start worker
+ thread::spawn(move || {
+ worker(&wg, reader);
+ wg.tun_readers.decrease();
+ });
+ }
+
+ pub fn wait(&self) -> WaitHandle {
+ self.state.tun_readers.clone()
+ }
+
+ pub fn new(writer: T::Writer) -> Wireguard<T, B> {
// create device state
let mut rng = OsRng::new().unwrap();
@@ -358,7 +484,8 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
let wg = Arc::new(WireguardInner {
- start: Instant::now(),
+ enabled: RwLock::new(false),
+ tun_readers: WaitHandle::new(),
id: rng.gen(),
mtu: AtomicUsize::new(0),
peers: RwLock::new(HashMap::new()),
@@ -486,59 +613,6 @@ impl<T: tun::Tun, B: udp::UDP> Wireguard<T, B> {
});
}
- // start TUN read IO threads (multiple threads to support multi-queue interfaces)
- debug_assert!(
- readers.len() > 0,
- "attempted to create WG device without TUN readers"
- );
- while let Some(reader) = readers.pop() {
- let wg = wg.clone();
- thread::spawn(move || loop {
- // create vector big enough for any transport message (based on MTU)
- let mtu = wg.mtu.load(Ordering::Relaxed);
- let size = mtu + router::SIZE_MESSAGE_PREFIX;
- let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
- msg.resize(size, 0);
-
- // read a new IP packet
- let payload = match reader.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX) {
- Ok(payload) => payload,
- Err(e) => {
- debug!("TUN worker, failed to read from tun device: {}", e);
- return;
- }
- };
- debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu);
-
- // TODO: start device down
- if mtu == 0 {
- continue;
- }
-
- // truncate padding
- let padded = padding(payload, mtu);
- log::trace!(
- "TUN worker, payload length = {}, padded length = {}",
- payload,
- padded
- );
- msg.truncate(router::SIZE_MESSAGE_PREFIX + padded);
- debug_assert!(padded <= mtu);
- debug_assert_eq!(
- if padded < mtu {
- (msg.len() - router::SIZE_MESSAGE_PREFIX) % MESSAGE_PADDING_MULTIPLE
- } else {
- 0
- },
- 0
- );
-
- // crypt-key route
- let e = wg.router.send(msg);
- debug!("TUN worker, router returned {:?}", e);
- });
- }
-
Wireguard {
state: wg,
runner: Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY),