aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2020-02-01 14:36:50 +0100
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2020-02-01 14:36:50 +0100
commit1e26a0bef44e65023a97a16ecf3b123e688d19f7 (patch)
tree2786e76e656739d4f6cdc260a0378751735265ee
parentClear src when sendmsg fails with EINVAL (diff)
downloadwireguard-rs-sticky-sockets.tar.xz
wireguard-rs-sticky-sockets.zip
Initial version of sticky sockets for Linuxsticky-sockets
-rw-r--r--src/platform/linux/udp.rs192
1 files changed, 129 insertions, 63 deletions
diff --git a/src/platform/linux/udp.rs b/src/platform/linux/udp.rs
index ab5b53b..9815ab1 100644
--- a/src/platform/linux/udp.rs
+++ b/src/platform/linux/udp.rs
@@ -21,16 +21,16 @@ fn errno() -> libc::c_int {
}
}
-#[repr(C)]
+#[repr(C, align(1))]
struct ControlHeaderV4 {
hdr: libc::cmsghdr,
info: libc::in_pktinfo,
}
-#[repr(C)]
+#[repr(C, align(1))]
struct ControlHeaderV6 {
hdr: libc::cmsghdr,
- body: libc::in6_pktinfo,
+ info: libc::in6_pktinfo,
}
pub struct EndpointV4 {
@@ -159,6 +159,7 @@ fn setsockopt<V: Sized>(
}
}
+#[inline(always)]
fn setsockopt_int(
fd: RawFd,
level: libc::c_int,
@@ -168,6 +169,21 @@ fn setsockopt_int(
setsockopt(fd, level, name, &value)
}
+#[allow(non_snake_case)]
+const fn CMSG_ALIGN(len: usize) -> usize {
+ (((len) + mem::size_of::<u32>() - 1) & !(mem::size_of::<u32>() - 1))
+}
+
+#[allow(non_snake_case)]
+const fn CMSG_LEN(len: usize) -> usize {
+ CMSG_ALIGN(len + mem::size_of::<libc::cmsghdr>())
+}
+
+#[inline(always)]
+fn safe_cast<T, D>(v: &mut T) -> *mut D {
+ (v as *mut T) as *mut D
+}
+
impl LinuxUDPReader {
fn read6(fd: RawFd, buf: &mut [u8]) -> Result<(usize, LinuxEndpoint), io::Error> {
log::trace!(
@@ -176,43 +192,41 @@ impl LinuxUDPReader {
buf.len()
);
- // this memory is mutated by the recvmsg call
- #[allow(unused_mut)]
- let mut control: ControlHeaderV6 = unsafe { mem::MaybeUninit::uninit().assume_init() };
-
- let iovs: [libc::iovec; 1] = [libc::iovec {
+ let mut iovs: [libc::iovec; 1] = [libc::iovec {
iov_base: buf.as_mut_ptr() as *mut core::ffi::c_void,
iov_len: buf.len(),
}];
-
- let src: libc::sockaddr_in6 = unsafe { mem::MaybeUninit::uninit().assume_init() };
- let mut hdr = unsafe {
- libc::msghdr {
- msg_name: mem::transmute(&src),
- msg_namelen: mem::size_of_val(&src).try_into().unwrap(),
- msg_iov: mem::transmute(&iovs[0]),
- msg_iovlen: iovs.len(),
- msg_control: mem::transmute(&control),
- msg_controllen: mem::size_of_val(&control),
- msg_flags: 0, // ignored
- }
+ let mut src: libc::sockaddr_in6 = unsafe { mem::MaybeUninit::uninit().assume_init() };
+ let mut control: ControlHeaderV6 = unsafe { mem::MaybeUninit::uninit().assume_init() };
+ let mut hdr = libc::msghdr {
+ msg_name: safe_cast(&mut src),
+ msg_namelen: mem::size_of::<libc::sockaddr_in6> as u32,
+ msg_iov: iovs.as_mut_ptr(),
+ msg_iovlen: iovs.len(),
+ msg_control: safe_cast(&mut control),
+ msg_controllen: mem::size_of::<ControlHeaderV6>(),
+ msg_flags: 0,
};
+ debug_assert!(
+ hdr.msg_controllen
+ >= mem::size_of::<libc::cmsghdr>() + mem::size_of::<libc::in6_pktinfo>(),
+ );
+
let len = unsafe { libc::recvmsg(fd, &mut hdr as *mut libc::msghdr, 0) };
+
if len < 0 {
- log::trace!("failed to receive IPv6 packet (errno = {})", errno());
return Err(io::Error::new(
io::ErrorKind::NotConnected,
"failed to receive",
));
}
- log::trace!("received IPv6 packet ({} fd, {} bytes)", fd, len);
Ok((
len.try_into().unwrap(),
LinuxEndpoint::V6(EndpointV6 {
- info: control.body,
- dst: src,
+ info: control.info, // save pktinfo (sticky source)
+ dst: src, // our future destination is the source address
}),
))
}
@@ -224,52 +238,40 @@ impl LinuxUDPReader {
buf.len()
);
- let iovs: [libc::iovec; 1] = [libc::iovec {
+ let mut iovs: [libc::iovec; 1] = [libc::iovec {
iov_base: buf.as_mut_ptr() as *mut core::ffi::c_void,
iov_len: buf.len(),
}];
-
- let src: libc::sockaddr_in = unsafe { mem::MaybeUninit::uninit().assume_init() };
-
- // this memory is mutated by the recvmsg call
- #[allow(unused_mut)]
+ let mut src: libc::sockaddr_in = unsafe { mem::MaybeUninit::uninit().assume_init() };
let mut control: ControlHeaderV4 = unsafe { mem::MaybeUninit::uninit().assume_init() };
-
- let mut hdr = unsafe {
- libc::msghdr {
- msg_name: mem::transmute(&src),
- msg_namelen: mem::size_of_val(&src).try_into().unwrap(), // constant
- msg_iov: mem::transmute(&iovs[0]),
- msg_iovlen: iovs.len(), // constant
- msg_control: mem::transmute(&control),
- msg_controllen: mem::size_of_val(&control), // constant
- msg_flags: 0, // ignored
- }
+ let mut hdr = libc::msghdr {
+ msg_name: safe_cast(&mut src),
+ msg_namelen: mem::size_of::<libc::sockaddr_in> as u32,
+ msg_iov: iovs.as_mut_ptr(),
+ msg_iovlen: iovs.len(),
+ msg_control: safe_cast(&mut control),
+ msg_controllen: mem::size_of::<ControlHeaderV4>(),
+ msg_flags: 0,
};
+ debug_assert!(
+ hdr.msg_controllen
+ >= mem::size_of::<libc::cmsghdr>() + mem::size_of::<libc::in_pktinfo>(),
+ );
+
let len = unsafe { libc::recvmsg(fd, &mut hdr as *mut libc::msghdr, 0) };
if len < 0 {
- log::trace!("failed to receive IPv4 packet (errno = {})", errno());
return Err(io::Error::new(
io::ErrorKind::NotConnected,
"failed to receive",
));
}
- log::trace!("read4, len: {}", len);
- log::trace!(
- "control: {{ hdr : {{ cmsg_level: {}, cmsg_type: {}, cmsg_len: {} }} }}",
- control.hdr.cmsg_level,
- control.hdr.cmsg_type,
- control.hdr.cmsg_len
- );
-
- log::trace!("received IPv4 packet ({} fd, {} bytes)", fd, len);
Ok((
len.try_into().unwrap(),
LinuxEndpoint::V4(EndpointV4 {
- info: control.info, // save pkinfo (sticky source)
+ info: control.info, // save pktinfo (sticky source)
dst: src, // our future destination is the source address
}),
))
@@ -288,23 +290,82 @@ impl Reader<LinuxEndpoint> for LinuxUDPReader {
}
impl LinuxUDPWriter {
- fn write6(fd: RawFd, buf: &[u8], dst: &EndpointV6) -> Result<(), io::Error> {
+ fn write6(fd: RawFd, buf: &[u8], dst: &mut EndpointV6) -> Result<(), io::Error> {
log::debug!("sending IPv6 packet ({} fd, {} bytes)", fd, buf.len());
- unimplemented!()
+ let mut iovs: [libc::iovec; 1] = [libc::iovec {
+ iov_base: buf.as_ptr() as *mut core::ffi::c_void,
+ iov_len: buf.len(),
+ }];
+
+ let mut control = ControlHeaderV6 {
+ hdr: libc::cmsghdr {
+ cmsg_len: CMSG_LEN(mem::size_of::<libc::in6_pktinfo>()),
+ cmsg_level: libc::IPPROTO_IPV6,
+ cmsg_type: libc::IPV6_PKTINFO,
+ },
+ info: dst.info,
+ };
+
+ debug_assert_eq!(
+ control.hdr.cmsg_len % mem::size_of::<u32>(),
+ 0,
+ "cmsg_len must be aligned to a long"
+ );
+
+ debug_assert_eq!(
+ dst.dst.sin6_family,
+ libc::AF_INET6 as libc::sa_family_t,
+ "this method only handles IPv6 destinations"
+ );
+
+ let mut hdr = libc::msghdr {
+ msg_name: safe_cast(&mut dst.dst),
+ msg_namelen: mem::size_of_val(&dst.dst).try_into().unwrap(),
+ msg_iov: iovs.as_mut_ptr(),
+ msg_iovlen: iovs.len(),
+ msg_control: safe_cast(&mut control),
+ msg_controllen: mem::size_of_val(&control),
+ msg_flags: 0,
+ };
+
+ let ret = unsafe { libc::sendmsg(fd, &hdr, 0) };
+
+ if ret < 0 {
+ if errno() == libc::EINVAL {
+ log::trace!("clear source and retry");
+ hdr.msg_control = ptr::null_mut();
+ hdr.msg_controllen = 0;
+ dst.info = unsafe { mem::zeroed() };
+ if unsafe { libc::sendmsg(fd, &hdr, 0) } < 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::NotConnected,
+ "failed to send IPv6 packet",
+ ));
+ } else {
+ return Ok(());
+ }
+ }
+ return Err(io::Error::new(
+ io::ErrorKind::NotConnected,
+ "failed to send IPv6 packet",
+ ));
+ }
+
+ Ok(())
}
fn write4(fd: RawFd, buf: &[u8], dst: &mut EndpointV4) -> Result<(), io::Error> {
log::debug!("sending IPv4 packet ({} fd, {} bytes)", fd, buf.len());
- let iovs: [libc::iovec; 1] = [libc::iovec {
+ let mut iovs: [libc::iovec; 1] = [libc::iovec {
iov_base: buf.as_ptr() as *mut core::ffi::c_void,
iov_len: buf.len(),
}];
let mut control = ControlHeaderV4 {
hdr: libc::cmsghdr {
- cmsg_len: mem::size_of::<ControlHeaderV4>(),
+ cmsg_len: CMSG_LEN(mem::size_of::<libc::in_pktinfo>()),
cmsg_level: libc::IPPROTO_IP,
cmsg_type: libc::IP_PKTINFO,
},
@@ -312,18 +373,23 @@ impl LinuxUDPWriter {
};
debug_assert_eq!(
- control.hdr.cmsg_len % mem::size_of::<usize>(),
+ control.hdr.cmsg_len % mem::size_of::<u32>(),
0,
- "cmsg_len must be aligned to a word"
+ "cmsg_len must be aligned to a long"
+ );
+
+ debug_assert_eq!(
+ dst.dst.sin_family,
+ libc::AF_INET as libc::sa_family_t,
+ "this method only handles IPv4 destinations"
);
- debug_assert_eq!(dst.dst.sin_family, libc::AF_INET as libc::sa_family_t);
let mut hdr = libc::msghdr {
- msg_name: unsafe { mem::transmute(&dst.dst as *const libc::sockaddr_in) },
+ msg_name: safe_cast(&mut dst.dst),
msg_namelen: mem::size_of_val(&dst.dst).try_into().unwrap(),
- msg_iov: iovs.as_ptr() as *mut libc::iovec,
+ msg_iov: iovs.as_mut_ptr(),
msg_iovlen: iovs.len(),
- msg_control: unsafe { mem::transmute(&control as *const ControlHeaderV4) },
+ msg_control: safe_cast(&mut control),
msg_controllen: mem::size_of_val(&control),
msg_flags: 0,
};
@@ -361,7 +427,7 @@ impl Writer<LinuxEndpoint> for LinuxUDPWriter {
fn write(&self, buf: &[u8], dst: &mut LinuxEndpoint) -> Result<(), Self::Error> {
match dst {
LinuxEndpoint::V4(ref mut end) => Self::write4(self.sock4, buf, end),
- LinuxEndpoint::V6(ref end) => Self::write6(self.sock6, buf, end),
+ LinuxEndpoint::V6(ref mut end) => Self::write6(self.sock6, buf, end),
}
}
}