From 1e26a0bef44e65023a97a16ecf3b123e688d19f7 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sat, 1 Feb 2020 14:36:50 +0100 Subject: Initial version of sticky sockets for Linux --- src/platform/linux/udp.rs | 192 +++++++++++++++++++++++++++++++--------------- 1 file changed, 129 insertions(+), 63 deletions(-) (limited to 'src/platform/linux/udp.rs') diff --git a/src/platform/linux/udp.rs b/src/platform/linux/udp.rs index ab5b53b..9815ab1 100644 --- a/src/platform/linux/udp.rs +++ b/src/platform/linux/udp.rs @@ -21,16 +21,16 @@ fn errno() -> libc::c_int { } } -#[repr(C)] +#[repr(C, align(1))] struct ControlHeaderV4 { hdr: libc::cmsghdr, info: libc::in_pktinfo, } -#[repr(C)] +#[repr(C, align(1))] struct ControlHeaderV6 { hdr: libc::cmsghdr, - body: libc::in6_pktinfo, + info: libc::in6_pktinfo, } pub struct EndpointV4 { @@ -159,6 +159,7 @@ fn setsockopt( } } +#[inline(always)] fn setsockopt_int( fd: RawFd, level: libc::c_int, @@ -168,6 +169,21 @@ fn setsockopt_int( setsockopt(fd, level, name, &value) } +#[allow(non_snake_case)] +const fn CMSG_ALIGN(len: usize) -> usize { + (((len) + mem::size_of::() - 1) & !(mem::size_of::() - 1)) +} + +#[allow(non_snake_case)] +const fn CMSG_LEN(len: usize) -> usize { + CMSG_ALIGN(len + mem::size_of::()) +} + +#[inline(always)] +fn safe_cast(v: &mut T) -> *mut D { + (v as *mut T) as *mut D +} + impl LinuxUDPReader { fn read6(fd: RawFd, buf: &mut [u8]) -> Result<(usize, LinuxEndpoint), io::Error> { log::trace!( @@ -176,43 +192,41 @@ impl LinuxUDPReader { buf.len() ); - // this memory is mutated by the recvmsg call - #[allow(unused_mut)] - let mut control: ControlHeaderV6 = unsafe { mem::MaybeUninit::uninit().assume_init() }; - - let iovs: [libc::iovec; 1] = [libc::iovec { + let mut iovs: [libc::iovec; 1] = [libc::iovec { iov_base: buf.as_mut_ptr() as *mut core::ffi::c_void, iov_len: buf.len(), }]; - - let src: libc::sockaddr_in6 = unsafe { mem::MaybeUninit::uninit().assume_init() }; - let mut hdr = unsafe { - libc::msghdr { - msg_name: mem::transmute(&src), - msg_namelen: mem::size_of_val(&src).try_into().unwrap(), - msg_iov: mem::transmute(&iovs[0]), - msg_iovlen: iovs.len(), - msg_control: mem::transmute(&control), - msg_controllen: mem::size_of_val(&control), - msg_flags: 0, // ignored - } + let mut src: libc::sockaddr_in6 = unsafe { mem::MaybeUninit::uninit().assume_init() }; + let mut control: ControlHeaderV6 = unsafe { mem::MaybeUninit::uninit().assume_init() }; + let mut hdr = libc::msghdr { + msg_name: safe_cast(&mut src), + msg_namelen: mem::size_of:: as u32, + msg_iov: iovs.as_mut_ptr(), + msg_iovlen: iovs.len(), + msg_control: safe_cast(&mut control), + msg_controllen: mem::size_of::(), + msg_flags: 0, }; + debug_assert!( + hdr.msg_controllen + >= mem::size_of::() + mem::size_of::(), + ); + let len = unsafe { libc::recvmsg(fd, &mut hdr as *mut libc::msghdr, 0) }; + if len < 0 { - log::trace!("failed to receive IPv6 packet (errno = {})", errno()); return Err(io::Error::new( io::ErrorKind::NotConnected, "failed to receive", )); } - log::trace!("received IPv6 packet ({} fd, {} bytes)", fd, len); Ok(( len.try_into().unwrap(), LinuxEndpoint::V6(EndpointV6 { - info: control.body, - dst: src, + info: control.info, // save pktinfo (sticky source) + dst: src, // our future destination is the source address }), )) } @@ -224,52 +238,40 @@ impl LinuxUDPReader { buf.len() ); - let iovs: [libc::iovec; 1] = [libc::iovec { + let mut iovs: [libc::iovec; 1] = [libc::iovec { iov_base: buf.as_mut_ptr() as *mut core::ffi::c_void, iov_len: buf.len(), }]; - - let src: libc::sockaddr_in = unsafe { mem::MaybeUninit::uninit().assume_init() }; - - // this memory is mutated by the recvmsg call - #[allow(unused_mut)] + let mut src: libc::sockaddr_in = unsafe { mem::MaybeUninit::uninit().assume_init() }; let mut control: ControlHeaderV4 = unsafe { mem::MaybeUninit::uninit().assume_init() }; - - let mut hdr = unsafe { - libc::msghdr { - msg_name: mem::transmute(&src), - msg_namelen: mem::size_of_val(&src).try_into().unwrap(), // constant - msg_iov: mem::transmute(&iovs[0]), - msg_iovlen: iovs.len(), // constant - msg_control: mem::transmute(&control), - msg_controllen: mem::size_of_val(&control), // constant - msg_flags: 0, // ignored - } + let mut hdr = libc::msghdr { + msg_name: safe_cast(&mut src), + msg_namelen: mem::size_of:: as u32, + msg_iov: iovs.as_mut_ptr(), + msg_iovlen: iovs.len(), + msg_control: safe_cast(&mut control), + msg_controllen: mem::size_of::(), + msg_flags: 0, }; + debug_assert!( + hdr.msg_controllen + >= mem::size_of::() + mem::size_of::(), + ); + let len = unsafe { libc::recvmsg(fd, &mut hdr as *mut libc::msghdr, 0) }; if len < 0 { - log::trace!("failed to receive IPv4 packet (errno = {})", errno()); return Err(io::Error::new( io::ErrorKind::NotConnected, "failed to receive", )); } - log::trace!("read4, len: {}", len); - log::trace!( - "control: {{ hdr : {{ cmsg_level: {}, cmsg_type: {}, cmsg_len: {} }} }}", - control.hdr.cmsg_level, - control.hdr.cmsg_type, - control.hdr.cmsg_len - ); - - log::trace!("received IPv4 packet ({} fd, {} bytes)", fd, len); Ok(( len.try_into().unwrap(), LinuxEndpoint::V4(EndpointV4 { - info: control.info, // save pkinfo (sticky source) + info: control.info, // save pktinfo (sticky source) dst: src, // our future destination is the source address }), )) @@ -288,23 +290,82 @@ impl Reader for LinuxUDPReader { } impl LinuxUDPWriter { - fn write6(fd: RawFd, buf: &[u8], dst: &EndpointV6) -> Result<(), io::Error> { + fn write6(fd: RawFd, buf: &[u8], dst: &mut EndpointV6) -> Result<(), io::Error> { log::debug!("sending IPv6 packet ({} fd, {} bytes)", fd, buf.len()); - unimplemented!() + let mut iovs: [libc::iovec; 1] = [libc::iovec { + iov_base: buf.as_ptr() as *mut core::ffi::c_void, + iov_len: buf.len(), + }]; + + let mut control = ControlHeaderV6 { + hdr: libc::cmsghdr { + cmsg_len: CMSG_LEN(mem::size_of::()), + cmsg_level: libc::IPPROTO_IPV6, + cmsg_type: libc::IPV6_PKTINFO, + }, + info: dst.info, + }; + + debug_assert_eq!( + control.hdr.cmsg_len % mem::size_of::(), + 0, + "cmsg_len must be aligned to a long" + ); + + debug_assert_eq!( + dst.dst.sin6_family, + libc::AF_INET6 as libc::sa_family_t, + "this method only handles IPv6 destinations" + ); + + let mut hdr = libc::msghdr { + msg_name: safe_cast(&mut dst.dst), + msg_namelen: mem::size_of_val(&dst.dst).try_into().unwrap(), + msg_iov: iovs.as_mut_ptr(), + msg_iovlen: iovs.len(), + msg_control: safe_cast(&mut control), + msg_controllen: mem::size_of_val(&control), + msg_flags: 0, + }; + + let ret = unsafe { libc::sendmsg(fd, &hdr, 0) }; + + if ret < 0 { + if errno() == libc::EINVAL { + log::trace!("clear source and retry"); + hdr.msg_control = ptr::null_mut(); + hdr.msg_controllen = 0; + dst.info = unsafe { mem::zeroed() }; + if unsafe { libc::sendmsg(fd, &hdr, 0) } < 0 { + return Err(io::Error::new( + io::ErrorKind::NotConnected, + "failed to send IPv6 packet", + )); + } else { + return Ok(()); + } + } + return Err(io::Error::new( + io::ErrorKind::NotConnected, + "failed to send IPv6 packet", + )); + } + + Ok(()) } fn write4(fd: RawFd, buf: &[u8], dst: &mut EndpointV4) -> Result<(), io::Error> { log::debug!("sending IPv4 packet ({} fd, {} bytes)", fd, buf.len()); - let iovs: [libc::iovec; 1] = [libc::iovec { + let mut iovs: [libc::iovec; 1] = [libc::iovec { iov_base: buf.as_ptr() as *mut core::ffi::c_void, iov_len: buf.len(), }]; let mut control = ControlHeaderV4 { hdr: libc::cmsghdr { - cmsg_len: mem::size_of::(), + cmsg_len: CMSG_LEN(mem::size_of::()), cmsg_level: libc::IPPROTO_IP, cmsg_type: libc::IP_PKTINFO, }, @@ -312,18 +373,23 @@ impl LinuxUDPWriter { }; debug_assert_eq!( - control.hdr.cmsg_len % mem::size_of::(), + control.hdr.cmsg_len % mem::size_of::(), 0, - "cmsg_len must be aligned to a word" + "cmsg_len must be aligned to a long" + ); + + debug_assert_eq!( + dst.dst.sin_family, + libc::AF_INET as libc::sa_family_t, + "this method only handles IPv4 destinations" ); - debug_assert_eq!(dst.dst.sin_family, libc::AF_INET as libc::sa_family_t); let mut hdr = libc::msghdr { - msg_name: unsafe { mem::transmute(&dst.dst as *const libc::sockaddr_in) }, + msg_name: safe_cast(&mut dst.dst), msg_namelen: mem::size_of_val(&dst.dst).try_into().unwrap(), - msg_iov: iovs.as_ptr() as *mut libc::iovec, + msg_iov: iovs.as_mut_ptr(), msg_iovlen: iovs.len(), - msg_control: unsafe { mem::transmute(&control as *const ControlHeaderV4) }, + msg_control: safe_cast(&mut control), msg_controllen: mem::size_of_val(&control), msg_flags: 0, }; @@ -361,7 +427,7 @@ impl Writer for LinuxUDPWriter { fn write(&self, buf: &[u8], dst: &mut LinuxEndpoint) -> Result<(), Self::Error> { match dst { LinuxEndpoint::V4(ref mut end) => Self::write4(self.sock4, buf, end), - LinuxEndpoint::V6(ref end) => Self::write6(self.sock6, buf, end), + LinuxEndpoint::V6(ref mut end) => Self::write6(self.sock6, buf, end), } } } -- cgit v1.2.3-59-g8ed1b