From 04f507556baf2336a613cb684ec98f2cdf519163 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Wed, 27 Nov 2019 16:59:54 +0100 Subject: Work on netlink IF event code for Linux --- src/configuration/config.rs | 2 +- src/main.rs | 21 ++++- src/platform/linux/tun.rs | 155 ++++++++++++++++++++++++++---- src/wireguard/tests.rs | 14 ++- src/wireguard/timers.rs | 4 +- src/wireguard/wireguard.rs | 224 +++++++++++++++++++++++++++++--------------- 6 files changed, 311 insertions(+), 109 deletions(-) (limited to 'src') diff --git a/src/configuration/config.rs b/src/configuration/config.rs index e7d1ba5..c045d1e 100644 --- a/src/configuration/config.rs +++ b/src/configuration/config.rs @@ -261,7 +261,7 @@ impl Configuration for WireguardConfig { // add readers while let Some(reader) = readers.pop() { - cfg.wireguard.add_reader(reader); + cfg.wireguard.add_udp_reader(reader); } // create new UDP state diff --git a/src/main.rs b/src/main.rs index 1a9650b..5ea830f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -26,7 +26,7 @@ fn main() { let mut foreground = false; let mut args = env::args(); - args.next(); // skip path + args.next(); // skip path (argv[0]) for arg in args { match arg.as_str() { @@ -56,7 +56,7 @@ fn main() { }); // create TUN device - let (readers, writer, status) = plt::Tun::create(name.as_str()).unwrap_or_else(|e| { + let (mut readers, writer, status) = plt::Tun::create(name.as_str()).unwrap_or_else(|e| { eprintln!("Failed to create TUN device: {}", e); exit(-3); }); @@ -82,7 +82,15 @@ fn main() { if drop_privileges {} // create WireGuard device - let wg: wireguard::Wireguard = wireguard::Wireguard::new(readers, writer); + let wg: wireguard::Wireguard = wireguard::Wireguard::new(writer); + + // add all Tun readers + while let Some(reader) = readers.pop() { + wg.add_tun_reader(reader); + } + + // obtain handle for waiting + let wait = wg.wait(); // wrap in configuration interface let cfg = configuration::WireguardConfig::new(wg); @@ -124,7 +132,7 @@ fn main() { } // start UAPI server - loop { + thread::spawn(move || loop { match uapi.connect() { Ok(mut stream) => { let cfg = cfg.clone(); @@ -137,5 +145,8 @@ fn main() { break; } } - } + }); + + // block until all tun readers closed + wait.wait(); } diff --git a/src/platform/linux/tun.rs b/src/platform/linux/tun.rs index 82eb469..442d9bc 100644 --- a/src/platform/linux/tun.rs +++ b/src/platform/linux/tun.rs @@ -1,13 +1,12 @@ use super::super::tun::*; -use libc::*; +use libc; use std::error::Error; use std::fmt; +use std::mem; use std::os::raw::c_short; use std::os::unix::io::RawFd; -use std::thread; -use std::time::Duration; const IFNAMSIZ: usize = 16; const TUNSETIFF: u64 = 0x4004_54ca; @@ -30,6 +29,18 @@ struct Ifreq { _pad: [u8; 64], } +// man 7 rtnetlink +// Layout from: https://elixir.bootlin.com/linux/latest/source/include/uapi/linux/rtnetlink.h#L516 +#[repr(C)] +struct IfInfomsg { + ifi_family: libc::c_uchar, + __ifi_pad: libc::c_uchar, + ifi_type: libc::c_ushort, + ifi_index: libc::c_int, + ifi_flags: libc::c_uint, + ifi_change: libc::c_uint, +} + pub struct LinuxTun { events: Vec, } @@ -42,12 +53,9 @@ pub struct LinuxTunWriter { fd: RawFd, } -/* Listens for netlink messages - * announcing an MTU update for the interface - */ -#[derive(Clone)] pub struct LinuxTunStatus { - first: bool, + events: Vec, + fd: RawFd, } #[derive(Debug)] @@ -94,7 +102,7 @@ impl Reader for LinuxTunReader { ); */ let n: isize = - unsafe { read(self.fd, buf[offset..].as_mut_ptr() as _, buf.len() - offset) }; + unsafe { libc::read(self.fd, buf[offset..].as_mut_ptr() as _, buf.len() - offset) }; if n < 0 { Err(LinuxTunError::Closed) } else { @@ -108,7 +116,7 @@ impl Writer for LinuxTunWriter { type Error = LinuxTunError; fn write(&self, src: &[u8]) -> Result<(), Self::Error> { - match unsafe { write(self.fd, src.as_ptr() as _, src.len() as _) } { + match unsafe { libc::write(self.fd, src.as_ptr() as _, src.len() as _) } { -1 => Err(LinuxTunError::Closed), _ => Ok(()), } @@ -119,13 +127,124 @@ impl Status for LinuxTunStatus { type Error = LinuxTunError; fn event(&mut self) -> Result { - if self.first { - self.first = false; - return Ok(TunEvent::Up(1420)); - } + const DONE: u16 = libc::NLMSG_DONE as u16; + const ERROR: u16 = libc::NLMSG_ERROR as u16; + const INFO_SIZE: usize = mem::size_of::(); + const HDR_SIZE: usize = mem::size_of::(); + let mut buf = [0u8; 1 << 12]; + log::debug!("netlink, fetch event (fd = {})", self.fd); loop { - thread::sleep(Duration::from_secs(60 * 60)); + // attempt to return a buffered event + if let Some(event) = self.events.pop() { + return Ok(event); + } + + // read message + let size: libc::ssize_t = + unsafe { libc::recv(self.fd, mem::transmute(&mut buf), buf.len(), 0) }; + if size < 0 { + break Err(LinuxTunError::Closed); + } + + // cut buffer to size + let size: usize = size as usize; + let mut remain = &buf[..size]; + log::debug!("netlink, recieved message ({} bytes)", size); + + // handle messages + while remain.len() >= HDR_SIZE { + // extract the header + assert!(remain.len() > HDR_SIZE); + let mut hdr = [0u8; HDR_SIZE]; + hdr.copy_from_slice(&remain[..HDR_SIZE]); + let hdr: libc::nlmsghdr = unsafe { mem::transmute(hdr) }; + + // upcast length + let body: &[u8] = &remain[HDR_SIZE..]; + let msg_len: usize = hdr.nlmsg_len as usize; + assert!(msg_len <= remain.len(), "malformed netlink message"); + + // handle message body + match hdr.nlmsg_type { + DONE => break, + ERROR => break, + libc::RTM_NEWLINK => { + // extract info struct + if body.len() < INFO_SIZE { + return Err(LinuxTunError::Closed); + } + + let mut info = [0u8; INFO_SIZE]; + info.copy_from_slice(&body[..INFO_SIZE]); + log::debug!("netlink, RTM_NEWLINK {:?}", &info[..]); + let info: IfInfomsg = unsafe { mem::transmute(info) }; + + // trace log + log::trace!( + "netlink, IfInfomsg{{ family = {}, type = {}, index = {}, flags = {}, change = {}}}", + info.ifi_family, + info.ifi_type, + info.ifi_index, + info.ifi_flags, + info.ifi_change, + ); + debug_assert_eq!(info.__ifi_pad, 0); + + // handle up / down + if info.ifi_flags & (libc::IFF_UP as u32) != 0 { + log::trace!("netlink, up event"); + self.events.push(TunEvent::Up(1420)); + } else { + log::trace!("netlink, down event"); + self.events.push(TunEvent::Down); + } + } + _ => (), + }; + + // go to next message + remain = &remain[msg_len..]; + } + } + } +} + +impl LinuxTunStatus { + const RTNLGRP_LINK: libc::c_uint = 1; + const RTNLGRP_IPV4_IFADDR: libc::c_uint = 5; + const RTNLGRP_IPV6_IFADDR: libc::c_uint = 9; + + fn new() -> Result { + // create netlink socket + let fd = unsafe { libc::socket(libc::AF_NETLINK, libc::SOCK_RAW, libc::NETLINK_ROUTE) }; + if fd < 0 { + return Err(LinuxTunError::Closed); + } + + // prepare address (specify groups) + let groups = (1 << (Self::RTNLGRP_LINK - 1)) + | (1 << (Self::RTNLGRP_IPV4_IFADDR - 1)) + | (1 << (Self::RTNLGRP_IPV6_IFADDR - 1)); + + let mut sockaddr: libc::sockaddr_nl = unsafe { mem::zeroed() }; + sockaddr.nl_family = libc::AF_NETLINK as u16; + sockaddr.nl_groups = groups; + sockaddr.nl_pid = 0; + + // attempt to bind + let res = unsafe { + libc::bind( + fd, + mem::transmute(&mut sockaddr), + mem::size_of::() as u32, + ) + }; + + if res != 0 { + Err(LinuxTunError::Closed) + } else { + Ok(LinuxTunStatus { events: vec![], fd }) } } } @@ -155,14 +274,14 @@ impl PlatformTun for LinuxTun { req.name[..bs.len()].copy_from_slice(bs); // open clone device - let fd: RawFd = match unsafe { open(CLONE_DEVICE_PATH.as_ptr() as _, O_RDWR) } { + let fd: RawFd = match unsafe { libc::open(CLONE_DEVICE_PATH.as_ptr() as _, libc::O_RDWR) } { -1 => return Err(LinuxTunError::FailedToOpenCloneDevice), fd => fd, }; assert!(fd >= 0); // create TUN device - if unsafe { ioctl(fd, TUNSETIFF as _, &req) } < 0 { + if unsafe { libc::ioctl(fd, TUNSETIFF as _, &req) } < 0 { return Err(LinuxTunError::SetIFFIoctlFailed); } @@ -170,7 +289,7 @@ impl PlatformTun for LinuxTun { Ok(( vec![LinuxTunReader { fd }], // TODO: enable multi-queue for Linux LinuxTunWriter { fd }, - LinuxTunStatus { first: true }, + LinuxTunStatus::new()?, )) } } diff --git a/src/wireguard/tests.rs b/src/wireguard/tests.rs index 7a18005..bf1bd5f 100644 --- a/src/wireguard/tests.rs +++ b/src/wireguard/tests.rs @@ -85,15 +85,13 @@ fn test_pure_wireguard() { // create WG instances for dummy TUN devices let (fake1, tun_reader1, tun_writer1, _) = dummy::TunTest::create(true); - let wg1: Wireguard = - Wireguard::new(vec![tun_reader1], tun_writer1); - + let wg1: Wireguard = Wireguard::new(tun_writer1); + wg1.add_tun_reader(tun_reader1); wg1.up(1500); let (fake2, tun_reader2, tun_writer2, _) = dummy::TunTest::create(true); - let wg2: Wireguard = - Wireguard::new(vec![tun_reader2], tun_writer2); - + let wg2: Wireguard = Wireguard::new(tun_writer2); + wg2.add_tun_reader(tun_reader2); wg2.up(1500); // create pair bind to connect the interfaces "over the internet" @@ -103,8 +101,8 @@ fn test_pure_wireguard() { wg1.set_writer(bind_writer1); wg2.set_writer(bind_writer2); - wg1.add_reader(bind_reader1); - wg2.add_reader(bind_reader2); + wg1.add_udp_reader(bind_reader1); + wg2.add_udp_reader(bind_reader2); // generate (public, pivate) key pairs diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs index 18f49bf..0ce4210 100644 --- a/src/wireguard/timers.rs +++ b/src/wireguard/timers.rs @@ -221,14 +221,14 @@ impl PeerInner { impl Timers { - pub fn new(runner: &Runner, peer: Peer) -> Timers + pub fn new(runner: &Runner, running: bool, peer: Peer) -> Timers where T: tun::Tun, B: udp::UDP, { // create a timer instance for the provided peer Timers { - enabled: true, + enabled: running, keepalive_interval: 0, // disabled need_another_keepalive: AtomicBool::new(false), sent_lastminute_handshake: AtomicBool::new(false), diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index 61f6428..2b0e779 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -22,6 +22,10 @@ use std::sync::Arc; use std::thread; use std::time::{Duration, Instant}; +// TODO: avoid +use std::sync::Condvar; +use std::sync::Mutex as StdMutex; + use std::collections::hash_map::Entry; use std::collections::HashMap; @@ -38,15 +42,51 @@ const SIZE_HANDSHAKE_QUEUE: usize = 128; const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4; const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000); +#[derive(Clone)] +pub struct WaitHandle(Arc<(StdMutex, Condvar)>); + +impl WaitHandle { + pub fn wait(&self) { + let (lock, cvar) = &*self.0; + let mut nread = lock.lock().unwrap(); + while *nread > 0 { + nread = cvar.wait(nread).unwrap(); + } + } + + fn new() -> Self { + Self(Arc::new((StdMutex::new(0), Condvar::new()))) + } + + fn decrease(&self) { + let (lock, cvar) = &*self.0; + let mut nread = lock.lock().unwrap(); + assert!(*nread > 0); + *nread -= 1; + cvar.notify_all(); + } + + fn increase(&self) { + let (lock, _) = &*self.0; + let mut nread = lock.lock().unwrap(); + *nread += 1; + } +} + pub struct WireguardInner { // identifier (for logging) id: u32, - start: Instant, + + // device enabled + enabled: RwLock, + + // enables waiting for all readers to finish + tun_readers: WaitHandle, // current MTU mtu: AtomicUsize, - // provides access to the MTU value of the tun device + // outbound writer send: RwLock>, // identity and configuration map @@ -145,7 +185,12 @@ impl Wireguard { /// on both ends of the device. pub fn down(&self) { // ensure exclusive access (to avoid race with "up" call) - let peers = self.peers.write(); + let mut enabled = self.enabled.write(); + + // check if already down + if *enabled == false { + return; + } // set mtu self.state.mtu.store(0, Ordering::Relaxed); @@ -154,27 +199,36 @@ impl Wireguard { self.router.down(); // set all peers down (stops timers) - for peer in peers.values() { + for peer in self.peers.write().values() { peer.down(); } + + *enabled = false; } /// Brings the WireGuard device up. /// Usually called when the associated interface is brought up. pub fn up(&self, mtu: usize) { - // ensure exclusive access (to avoid race with "down" call) - let peers = self.peers.write(); + // ensure exclusive access (to avoid race with "up" call) + let mut enabled = self.enabled.write(); // set mtu self.state.mtu.store(mtu, Ordering::Relaxed); + // check if already up + if *enabled { + return; + } + // enable tranmission from router self.router.up(); // set all peers up (restarts timers) - for peer in peers.values() { + for peer in self.peers.write().values() { peer.up(); } + + *enabled = true; } pub fn clear_peers(&self) { @@ -232,7 +286,7 @@ impl Wireguard { pk, wg: self.state.clone(), walltime_last_handshake: Mutex::new(None), - last_handshake_sent: Mutex::new(self.state.start - TIME_HORIZON), + last_handshake_sent: Mutex::new(Instant::now() - TIME_HORIZON), handshake_queued: AtomicBool::new(false), queue: Mutex::new(self.state.queue.lock().clone()), rx_bytes: AtomicU64::new(0), @@ -246,24 +300,31 @@ impl Wireguard { // form WireGuard peer let peer = Peer { router, state }; - /* The need for dummy timers arises from the chicken-egg - * problem of the timer callbacks being able to set timers themselves. - * - * This is in fact the only place where the write lock is ever taken. - * TODO: Consider the ease of using atomic pointers instead. - */ - *peer.timers.write() = Timers::new(&self.runner, peer.clone()); - // finally, add the peer to the wireguard device let mut peers = self.state.peers.write(); match peers.entry(*pk.as_bytes()) { Entry::Occupied(_) => false, Entry::Vacant(vacancy) => { + // check that the public key does not cause conflict with the private key of the device let ok_pk = self.state.handshake.write().add(pk).is_ok(); - if ok_pk { - vacancy.insert(peer); + if !ok_pk { + return false; } - ok_pk + + // prevent up/down while inserting + let enabled = self.enabled.read(); + + /* The need for dummy timers arises from the chicken-egg + * problem of the timer callbacks being able to set timers themselves. + * + * This is in fact the only place where the write lock is ever taken. + * TODO: Consider the ease of using atomic pointers instead. + */ + *peer.timers.write() = Timers::new(&self.runner, *enabled, peer.clone()); + + // insert into peer map (takes ownership and ensures that the peer is not dropped) + vacancy.insert(peer); + true } } } @@ -273,7 +334,7 @@ impl Wireguard { /// /// Any previous reader thread is stopped by closing the previous reader, /// which unblocks the thread and causes an error on reader.read - pub fn add_reader(&self, reader: B::Reader) { + pub fn add_udp_reader(&self, reader: B::Reader) { let wg = self.state.clone(); thread::spawn(move || { let mut last_under_load = @@ -350,7 +411,72 @@ impl Wireguard { self.state.router.set_outbound_writer(writer); } - pub fn new(mut readers: Vec, writer: T::Writer) -> Wireguard { + pub fn add_tun_reader(&self, reader: T::Reader) { + fn worker(wg: &Arc>, reader: T::Reader) { + loop { + // create vector big enough for any transport message (based on MTU) + let mtu = wg.mtu.load(Ordering::Relaxed); + let size = mtu + router::SIZE_MESSAGE_PREFIX + 1; + let mut msg: Vec = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX); + msg.resize(size, 0); + + // read a new IP packet + let payload = match reader.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX) { + Ok(payload) => payload, + Err(e) => { + debug!("TUN worker, failed to read from tun device: {}", e); + break; + } + }; + debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu); + + // TODO: start device down + if mtu == 0 { + continue; + } + + // truncate padding + let padded = padding(payload, mtu); + log::trace!( + "TUN worker, payload length = {}, padded length = {}", + payload, + padded + ); + msg.truncate(router::SIZE_MESSAGE_PREFIX + padded); + debug_assert!(padded <= mtu); + debug_assert_eq!( + if padded < mtu { + (msg.len() - router::SIZE_MESSAGE_PREFIX) % MESSAGE_PADDING_MULTIPLE + } else { + 0 + }, + 0 + ); + + // crypt-key route + let e = wg.router.send(msg); + debug!("TUN worker, router returned {:?}", e); + } + } + + // start a thread for every reader + let wg = self.state.clone(); + + // increment reader count + wg.tun_readers.increase(); + + // start worker + thread::spawn(move || { + worker(&wg, reader); + wg.tun_readers.decrease(); + }); + } + + pub fn wait(&self) -> WaitHandle { + self.state.tun_readers.clone() + } + + pub fn new(writer: T::Writer) -> Wireguard { // create device state let mut rng = OsRng::new().unwrap(); @@ -358,7 +484,8 @@ impl Wireguard { let (tx, rx): (Sender>, _) = bounded(SIZE_HANDSHAKE_QUEUE); let wg = Arc::new(WireguardInner { - start: Instant::now(), + enabled: RwLock::new(false), + tun_readers: WaitHandle::new(), id: rng.gen(), mtu: AtomicUsize::new(0), peers: RwLock::new(HashMap::new()), @@ -486,59 +613,6 @@ impl Wireguard { }); } - // start TUN read IO threads (multiple threads to support multi-queue interfaces) - debug_assert!( - readers.len() > 0, - "attempted to create WG device without TUN readers" - ); - while let Some(reader) = readers.pop() { - let wg = wg.clone(); - thread::spawn(move || loop { - // create vector big enough for any transport message (based on MTU) - let mtu = wg.mtu.load(Ordering::Relaxed); - let size = mtu + router::SIZE_MESSAGE_PREFIX; - let mut msg: Vec = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX); - msg.resize(size, 0); - - // read a new IP packet - let payload = match reader.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX) { - Ok(payload) => payload, - Err(e) => { - debug!("TUN worker, failed to read from tun device: {}", e); - return; - } - }; - debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu); - - // TODO: start device down - if mtu == 0 { - continue; - } - - // truncate padding - let padded = padding(payload, mtu); - log::trace!( - "TUN worker, payload length = {}, padded length = {}", - payload, - padded - ); - msg.truncate(router::SIZE_MESSAGE_PREFIX + padded); - debug_assert!(padded <= mtu); - debug_assert_eq!( - if padded < mtu { - (msg.len() - router::SIZE_MESSAGE_PREFIX) % MESSAGE_PADDING_MULTIPLE - } else { - 0 - }, - 0 - ); - - // crypt-key route - let e = wg.router.send(msg); - debug!("TUN worker, router returned {:?}", e); - }); - } - Wireguard { state: wg, runner: Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY), -- cgit v1.2.3-59-g8ed1b