From 605cc656ad235d09ba6cd12d03dee2c5e0a9a80a Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Thu, 30 Jan 2020 14:57:00 +0100 Subject: Clear src when sendmsg fails with EINVAL --- src/configuration/config.rs | 17 ++++--- src/platform/dummy/udp.rs | 4 +- src/platform/linux/udp.rs | 103 ++++++++++++++++++++++++++----------------- src/platform/udp.rs | 2 +- src/wireguard/router/peer.rs | 2 +- src/wireguard/workers.rs | 4 +- 6 files changed, 76 insertions(+), 56 deletions(-) diff --git a/src/configuration/config.rs b/src/configuration/config.rs index d61cda5..59cef4a 100644 --- a/src/configuration/config.rs +++ b/src/configuration/config.rs @@ -1,3 +1,4 @@ +use std::mem; use std::net::{IpAddr, SocketAddr}; use std::sync::atomic::Ordering; use std::sync::{Arc, Mutex, MutexGuard}; @@ -266,24 +267,22 @@ impl Configuration for WireGuardConfig { fn set_listen_port(&self, port: u16) -> Result<(), ConfigError> { log::trace!("Config, Set listen port: {:?}", port); - // update port - let listen: bool = { + // update port and take old bind + let old: Option = { let mut cfg = self.lock(); + let old = mem::replace(&mut cfg.bind, None); cfg.port = port; - if cfg.bind.is_some() { - cfg.bind = None; - true - } else { - false - } + old }; // restart listener if bound - if listen { + if old.is_some() { self.start_listener() } else { Ok(()) } + + // old bind is dropped, causing the file-descriptors to be released } fn set_fwmark(&self, mark: Option) -> Result<(), ConfigError> { diff --git a/src/platform/dummy/udp.rs b/src/platform/dummy/udp.rs index 6c126a9..88630af 100644 --- a/src/platform/dummy/udp.rs +++ b/src/platform/dummy/udp.rs @@ -54,7 +54,7 @@ impl Reader for VoidBind { impl Writer for VoidBind { type Error = BindError; - fn write(&self, _buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> { + fn write(&self, _buf: &[u8], _dst: &mut UnitEndpoint) -> Result<(), Self::Error> { Ok(()) } } @@ -105,7 +105,7 @@ impl Reader for PairReader { impl Writer for PairWriter { type Error = BindError; - fn write(&self, buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> { + fn write(&self, buf: &[u8], _dst: &mut UnitEndpoint) -> Result<(), Self::Error> { debug!( "dummy({}): write ({}, {})", self.id, diff --git a/src/platform/linux/udp.rs b/src/platform/linux/udp.rs index 1552f69..ab5b53b 100644 --- a/src/platform/linux/udp.rs +++ b/src/platform/linux/udp.rs @@ -8,12 +8,13 @@ use std::io; use std::mem; use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; use std::os::unix::io::RawFd; +use std::ptr; fn errno() -> libc::c_int { unsafe { let ptr = libc::__errno_location(); if ptr.is_null() { - -1 + 0 } else { *ptr } @@ -23,7 +24,7 @@ fn errno() -> libc::c_int { #[repr(C)] struct ControlHeaderV4 { hdr: libc::cmsghdr, - body: libc::in_pktinfo, + info: libc::in_pktinfo, } #[repr(C)] @@ -120,12 +121,12 @@ impl Endpoint for LinuxEndpoint { LinuxEndpoint::V4(EndpointV4 { ref dst, .. }) => { SocketAddr::V4(SocketAddrV4::new( u32::from_be(dst.sin_addr.s_addr).into(), // IPv4 addr - dst.sin_port, + u16::from_be(dst.sin_port), // convert back to native byte-order )) } LinuxEndpoint::V6(EndpointV6 { ref dst, .. }) => SocketAddr::V6(SocketAddrV6::new( u128::from_ne_bytes(dst.sin6_addr.s6_addr).into(), // IPv6 addr - dst.sin6_port, + u16::from_be(dst.sin6_port), // convert back to native byte-order dst.sin6_flowinfo, dst.sin6_scope_id, )), @@ -178,12 +179,12 @@ impl LinuxUDPReader { // this memory is mutated by the recvmsg call #[allow(unused_mut)] let mut control: ControlHeaderV6 = unsafe { mem::MaybeUninit::uninit().assume_init() }; - let iovs: [libc::iovec; 1] = [unsafe { - libc::iovec { - iov_base: mem::transmute(&buf[0] as *const u8), - iov_len: buf.len(), - } + + let iovs: [libc::iovec; 1] = [libc::iovec { + iov_base: buf.as_mut_ptr() as *mut core::ffi::c_void, + iov_len: buf.len(), }]; + let src: libc::sockaddr_in6 = unsafe { mem::MaybeUninit::uninit().assume_init() }; let mut hdr = unsafe { libc::msghdr { @@ -223,29 +224,31 @@ impl LinuxUDPReader { buf.len() ); + let iovs: [libc::iovec; 1] = [libc::iovec { + iov_base: buf.as_mut_ptr() as *mut core::ffi::c_void, + iov_len: buf.len(), + }]; + + let src: libc::sockaddr_in = unsafe { mem::MaybeUninit::uninit().assume_init() }; + // this memory is mutated by the recvmsg call #[allow(unused_mut)] let mut control: ControlHeaderV4 = unsafe { mem::MaybeUninit::uninit().assume_init() }; - let iovs: [libc::iovec; 1] = [unsafe { - libc::iovec { - iov_base: mem::transmute(&buf[0] as *const u8), - iov_len: buf.len(), - } - }]; - let src: libc::sockaddr_in = unsafe { mem::MaybeUninit::uninit().assume_init() }; + let mut hdr = unsafe { libc::msghdr { msg_name: mem::transmute(&src), - msg_namelen: mem::size_of_val(&src).try_into().unwrap(), + msg_namelen: mem::size_of_val(&src).try_into().unwrap(), // constant msg_iov: mem::transmute(&iovs[0]), - msg_iovlen: iovs.len(), + msg_iovlen: iovs.len(), // constant msg_control: mem::transmute(&control), - msg_controllen: mem::size_of_val(&control), - msg_flags: 0, // ignored + msg_controllen: mem::size_of_val(&control), // constant + msg_flags: 0, // ignored } }; let len = unsafe { libc::recvmsg(fd, &mut hdr as *mut libc::msghdr, 0) }; + if len < 0 { log::trace!("failed to receive IPv4 packet (errno = {})", errno()); return Err(io::Error::new( @@ -254,12 +257,20 @@ impl LinuxUDPReader { )); } + log::trace!("read4, len: {}", len); + log::trace!( + "control: {{ hdr : {{ cmsg_level: {}, cmsg_type: {}, cmsg_len: {} }} }}", + control.hdr.cmsg_level, + control.hdr.cmsg_type, + control.hdr.cmsg_len + ); + log::trace!("received IPv4 packet ({} fd, {} bytes)", fd, len); Ok(( len.try_into().unwrap(), LinuxEndpoint::V4(EndpointV4 { - info: control.body, - dst: src, + info: control.info, // save pkinfo (sticky source) + dst: src, // our future destination is the source address }), )) } @@ -283,16 +294,21 @@ impl LinuxUDPWriter { unimplemented!() } - fn write4(fd: RawFd, buf: &[u8], dst: &EndpointV4) -> Result<(), io::Error> { + fn write4(fd: RawFd, buf: &[u8], dst: &mut EndpointV4) -> Result<(), io::Error> { log::debug!("sending IPv4 packet ({} fd, {} bytes)", fd, buf.len()); - let control = ControlHeaderV4 { + let iovs: [libc::iovec; 1] = [libc::iovec { + iov_base: buf.as_ptr() as *mut core::ffi::c_void, + iov_len: buf.len(), + }]; + + let mut control = ControlHeaderV4 { hdr: libc::cmsghdr { cmsg_len: mem::size_of::(), cmsg_level: libc::IPPROTO_IP, cmsg_type: libc::IP_PKTINFO, }, - body: dst.info, + info: dst.info, }; debug_assert_eq!( @@ -302,12 +318,7 @@ impl LinuxUDPWriter { ); debug_assert_eq!(dst.dst.sin_family, libc::AF_INET as libc::sa_family_t); - let iovs: [libc::iovec; 1] = [libc::iovec { - iov_base: buf.as_ptr() as *mut core::ffi::c_void, - iov_len: buf.len(), - }]; - - let hdr = libc::msghdr { + let mut hdr = libc::msghdr { msg_name: unsafe { mem::transmute(&dst.dst as *const libc::sockaddr_in) }, msg_namelen: mem::size_of_val(&dst.dst).try_into().unwrap(), msg_iov: iovs.as_ptr() as *mut libc::iovec, @@ -317,19 +328,29 @@ impl LinuxUDPWriter { msg_flags: 0, }; - println!( - "name : {}, controllen: {}", - hdr.msg_namelen, hdr.msg_controllen - ); - - let err = unsafe { libc::sendmsg(fd, &hdr, 0) }; - if err < 0 { - log::trace!("failed to send IPv4: (errno = {})", errno()); + let ret = unsafe { libc::sendmsg(fd, &hdr, 0) }; + + if ret < 0 { + if errno() == libc::EINVAL { + log::trace!("clear source and retry"); + hdr.msg_control = ptr::null_mut(); + hdr.msg_controllen = 0; + dst.info = unsafe { mem::zeroed() }; + if unsafe { libc::sendmsg(fd, &hdr, 0) } < 0 { + return Err(io::Error::new( + io::ErrorKind::NotConnected, + "failed to send IPv4 packet", + )); + } else { + return Ok(()); + } + } return Err(io::Error::new( io::ErrorKind::NotConnected, "failed to send IPv4 packet", )); } + Ok(()) } } @@ -337,9 +358,9 @@ impl LinuxUDPWriter { impl Writer for LinuxUDPWriter { type Error = io::Error; - fn write(&self, buf: &[u8], dst: &LinuxEndpoint) -> Result<(), Self::Error> { + fn write(&self, buf: &[u8], dst: &mut LinuxEndpoint) -> Result<(), Self::Error> { match dst { - LinuxEndpoint::V4(ref end) => Self::write4(self.sock4, buf, end), + LinuxEndpoint::V4(ref mut end) => Self::write4(self.sock4, buf, end), LinuxEndpoint::V6(ref end) => Self::write6(self.sock6, buf, end), } } diff --git a/src/platform/udp.rs b/src/platform/udp.rs index 4685a1e..e1180fb 100644 --- a/src/platform/udp.rs +++ b/src/platform/udp.rs @@ -10,7 +10,7 @@ pub trait Reader: Send + Sync { pub trait Writer: Send + Sync + Clone + 'static { type Error: Error; - fn write(&self, buf: &[u8], dst: &E) -> Result<(), Self::Error>; + fn write(&self, buf: &[u8], dst: &mut E) -> Result<(), Self::Error>; } pub trait UDP: Send + Sync + 'static { diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs index b8110f0..8fe2e1c 100644 --- a/src/wireguard/router/peer.rs +++ b/src/wireguard/router/peer.rs @@ -204,7 +204,7 @@ impl> PeerInner { let outbound = self.device.outbound.read(); if outbound.0 { diff --git a/src/wireguard/workers.rs b/src/wireguard/workers.rs index 9802232..c1a2af7 100644 --- a/src/wireguard/workers.rs +++ b/src/wireguard/workers.rs @@ -178,7 +178,7 @@ pub fn handshake_worker( // de-multiplex staged handshake jobs and handshake messages match job { - HandshakeJob::Message(msg, src) => { + HandshakeJob::Message(msg, mut src) => { // process message let device = wg.peers.read(); match device.process( @@ -201,7 +201,7 @@ pub fn handshake_worker( "{} : handshake worker, send response ({} bytes)", wg, resp_len ); - let _ = writer.write(&msg[..], &src).map_err(|e| { + let _ = writer.write(&msg[..], &mut src).map_err(|e| { debug!( "{} : handshake worker, failed to send response, error = {}", wg, -- cgit v1.2.3-59-g8ed1b