diff options
Diffstat (limited to 'src/platform')
-rw-r--r-- | src/platform/dummy/tun.rs | 3 | ||||
-rw-r--r-- | src/platform/dummy/udp.rs | 13 | ||||
-rw-r--r-- | src/platform/linux/tun.rs | 2 | ||||
-rw-r--r-- | src/platform/linux/udp.rs | 667 | ||||
-rw-r--r-- | src/platform/udp.rs | 4 |
5 files changed, 640 insertions, 49 deletions
diff --git a/src/platform/dummy/tun.rs b/src/platform/dummy/tun.rs index 9836b48..1955884 100644 --- a/src/platform/dummy/tun.rs +++ b/src/platform/dummy/tun.rs @@ -165,8 +165,7 @@ impl TunTest { sync_channel(1) }; - let mut rng = OsRng::new().unwrap(); - let id: u32 = rng.gen(); + let id: u32 = OsRng.gen(); let fake = TunFakeIO { id, diff --git a/src/platform/dummy/udp.rs b/src/platform/dummy/udp.rs index 35c905d..88630af 100644 --- a/src/platform/dummy/udp.rs +++ b/src/platform/dummy/udp.rs @@ -54,7 +54,7 @@ impl Reader<UnitEndpoint> for VoidBind { impl Writer<UnitEndpoint> for VoidBind { type Error = BindError; - fn write(&self, _buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> { + fn write(&self, _buf: &[u8], _dst: &mut UnitEndpoint) -> Result<(), Self::Error> { Ok(()) } } @@ -105,7 +105,7 @@ impl Reader<UnitEndpoint> for PairReader<UnitEndpoint> { impl Writer<UnitEndpoint> for PairWriter<UnitEndpoint> { type Error = BindError; - fn write(&self, buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> { + fn write(&self, buf: &[u8], _dst: &mut UnitEndpoint) -> Result<(), Self::Error> { debug!( "dummy({}): write ({}, {})", self.id, @@ -135,9 +135,8 @@ impl PairBind { (PairReader<E>, PairWriter<E>), (PairReader<E>, PairWriter<E>), ) { - let mut rng = OsRng::new().unwrap(); - let id1: u32 = rng.gen(); - let id2: u32 = rng.gen(); + let id1: u32 = OsRng.gen(); + let id2: u32 = OsRng.gen(); let (tx1, rx1) = sync_channel(128); let (tx2, rx2) = sync_channel(128); @@ -187,10 +186,6 @@ impl Owner for VoidOwner { fn get_port(&self) -> u16 { 0 } - - fn get_fwmark(&self) -> Option<u32> { - None - } } impl PlatformUDP for PairBind { diff --git a/src/platform/linux/tun.rs b/src/platform/linux/tun.rs index c282a4b..15ca1ec 100644 --- a/src/platform/linux/tun.rs +++ b/src/platform/linux/tun.rs @@ -199,7 +199,7 @@ impl Status for LinuxTunStatus { // cut buffer to size let size: usize = size as usize; let mut remain = &buf[..size]; - log::debug!("netlink, recieved message ({} bytes)", size); + log::debug!("netlink, received message ({} bytes)", size); // handle messages while remain.len() >= HDR_SIZE { diff --git a/src/platform/linux/udp.rs b/src/platform/linux/udp.rs index f871bce..9815ab1 100644 --- a/src/platform/linux/udp.rs +++ b/src/platform/linux/udp.rs @@ -1,84 +1,683 @@ use super::super::udp::*; use super::super::Endpoint; +use log; + +use std::convert::TryInto; use std::io; -use std::net::{SocketAddr, UdpSocket}; -use std::sync::Arc; +use std::mem; +use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::os::unix::io::RawFd; +use std::ptr; + +fn errno() -> libc::c_int { + unsafe { + let ptr = libc::__errno_location(); + if ptr.is_null() { + 0 + } else { + *ptr + } + } +} + +#[repr(C, align(1))] +struct ControlHeaderV4 { + hdr: libc::cmsghdr, + info: libc::in_pktinfo, +} + +#[repr(C, align(1))] +struct ControlHeaderV6 { + hdr: libc::cmsghdr, + info: libc::in6_pktinfo, +} + +pub struct EndpointV4 { + dst: libc::sockaddr_in, // destination IP + info: libc::in_pktinfo, // src & ifindex +} + +pub struct EndpointV6 { + dst: libc::sockaddr_in6, // destination IP + info: libc::in6_pktinfo, // src & zone id +} + +pub struct LinuxUDP(); + +pub struct LinuxOwner { + port: u16, + sock4: Option<RawFd>, + sock6: Option<RawFd>, +} + +pub enum LinuxUDPReader { + V4(RawFd), + V6(RawFd), +} #[derive(Clone)] -pub struct LinuxUDP(Arc<UdpSocket>); +pub struct LinuxUDPWriter { + sock4: RawFd, + sock6: RawFd, +} -pub struct LinuxOwner(Arc<UdpSocket>); +pub enum LinuxEndpoint { + V4(EndpointV4), + V6(EndpointV6), +} -impl Endpoint for SocketAddr { - fn clear_src(&mut self) {} +impl Endpoint for LinuxEndpoint { + fn clear_src(&mut self) { + match self { + LinuxEndpoint::V4(EndpointV4 { ref mut info, .. }) => { + info.ipi_ifindex = 0; + info.ipi_spec_dst = libc::in_addr { s_addr: 0 }; + } + LinuxEndpoint::V6(EndpointV6 { ref mut info, .. }) => { + info.ipi6_addr = libc::in6_addr { s6_addr: [0; 16] }; + info.ipi6_ifindex = 0; + } + }; + } fn from_address(addr: SocketAddr) -> Self { - addr + match addr { + SocketAddr::V4(addr) => LinuxEndpoint::V4(EndpointV4 { + dst: libc::sockaddr_in { + sin_family: libc::AF_INET as libc::sa_family_t, + sin_port: addr.port().to_be(), + sin_addr: libc::in_addr { + s_addr: u32::from(*addr.ip()).to_be(), + }, + sin_zero: [0; 8], + }, + info: libc::in_pktinfo { + ipi_ifindex: 0, // interface (0 is via routing table) + ipi_spec_dst: libc::in_addr { s_addr: 0 }, // src IP (dst of incoming packet) + ipi_addr: libc::in_addr { s_addr: 0 }, + }, + }), + SocketAddr::V6(addr) => LinuxEndpoint::V6(EndpointV6 { + dst: libc::sockaddr_in6 { + sin6_family: libc::AF_INET6 as libc::sa_family_t, + sin6_port: addr.port().to_be(), + sin6_flowinfo: addr.flowinfo(), + sin6_addr: libc::in6_addr { + s6_addr: addr.ip().octets(), + }, + sin6_scope_id: addr.scope_id(), + }, + info: libc::in6_pktinfo { + ipi6_addr: libc::in6_addr { s6_addr: [0; 16] }, // src IP + ipi6_ifindex: 0, // zone id + }, + }), + } } fn into_address(&self) -> SocketAddr { - *self + match self { + LinuxEndpoint::V4(EndpointV4 { ref dst, .. }) => { + SocketAddr::V4(SocketAddrV4::new( + u32::from_be(dst.sin_addr.s_addr).into(), // IPv4 addr + u16::from_be(dst.sin_port), // convert back to native byte-order + )) + } + LinuxEndpoint::V6(EndpointV6 { ref dst, .. }) => SocketAddr::V6(SocketAddrV6::new( + u128::from_ne_bytes(dst.sin6_addr.s6_addr).into(), // IPv6 addr + u16::from_be(dst.sin6_port), // convert back to native byte-order + dst.sin6_flowinfo, + dst.sin6_scope_id, + )), + } } } -impl Reader<SocketAddr> for LinuxUDP { - type Error = io::Error; +fn setsockopt<V: Sized>( + fd: RawFd, + level: libc::c_int, + name: libc::c_int, + value: &V, +) -> Result<(), io::Error> { + let res = unsafe { + libc::setsockopt( + fd, + level, + name, + mem::transmute(value), + mem::size_of_val(value).try_into().unwrap(), + ) + }; + if res == 0 { + Ok(()) + } else { + Err(io::Error::new( + io::ErrorKind::Other, + format!("Failed to set sockopt (res = {}, errno = {})", res, errno()), + )) + } +} + +#[inline(always)] +fn setsockopt_int( + fd: RawFd, + level: libc::c_int, + name: libc::c_int, + value: libc::c_int, +) -> Result<(), io::Error> { + setsockopt(fd, level, name, &value) +} - fn read(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error> { - self.0.recv_from(buf) +#[allow(non_snake_case)] +const fn CMSG_ALIGN(len: usize) -> usize { + (((len) + mem::size_of::<u32>() - 1) & !(mem::size_of::<u32>() - 1)) +} + +#[allow(non_snake_case)] +const fn CMSG_LEN(len: usize) -> usize { + CMSG_ALIGN(len + mem::size_of::<libc::cmsghdr>()) +} + +#[inline(always)] +fn safe_cast<T, D>(v: &mut T) -> *mut D { + (v as *mut T) as *mut D +} + +impl LinuxUDPReader { + fn read6(fd: RawFd, buf: &mut [u8]) -> Result<(usize, LinuxEndpoint), io::Error> { + log::trace!( + "receive IPv6 packet (block), (fd {}, max-len {})", + fd, + buf.len() + ); + + let mut iovs: [libc::iovec; 1] = [libc::iovec { + iov_base: buf.as_mut_ptr() as *mut core::ffi::c_void, + iov_len: buf.len(), + }]; + let mut src: libc::sockaddr_in6 = unsafe { mem::MaybeUninit::uninit().assume_init() }; + let mut control: ControlHeaderV6 = unsafe { mem::MaybeUninit::uninit().assume_init() }; + let mut hdr = libc::msghdr { + msg_name: safe_cast(&mut src), + msg_namelen: mem::size_of::<libc::sockaddr_in6> as u32, + msg_iov: iovs.as_mut_ptr(), + msg_iovlen: iovs.len(), + msg_control: safe_cast(&mut control), + msg_controllen: mem::size_of::<ControlHeaderV6>(), + msg_flags: 0, + }; + + debug_assert!( + hdr.msg_controllen + >= mem::size_of::<libc::cmsghdr>() + mem::size_of::<libc::in6_pktinfo>(), + ); + + let len = unsafe { libc::recvmsg(fd, &mut hdr as *mut libc::msghdr, 0) }; + + if len < 0 { + return Err(io::Error::new( + io::ErrorKind::NotConnected, + "failed to receive", + )); + } + + Ok(( + len.try_into().unwrap(), + LinuxEndpoint::V6(EndpointV6 { + info: control.info, // save pktinfo (sticky source) + dst: src, // our future destination is the source address + }), + )) + } + + fn read4(fd: RawFd, buf: &mut [u8]) -> Result<(usize, LinuxEndpoint), io::Error> { + log::trace!( + "receive IPv4 packet (block), (fd {}, max-len {})", + fd, + buf.len() + ); + + let mut iovs: [libc::iovec; 1] = [libc::iovec { + iov_base: buf.as_mut_ptr() as *mut core::ffi::c_void, + iov_len: buf.len(), + }]; + let mut src: libc::sockaddr_in = unsafe { mem::MaybeUninit::uninit().assume_init() }; + let mut control: ControlHeaderV4 = unsafe { mem::MaybeUninit::uninit().assume_init() }; + let mut hdr = libc::msghdr { + msg_name: safe_cast(&mut src), + msg_namelen: mem::size_of::<libc::sockaddr_in> as u32, + msg_iov: iovs.as_mut_ptr(), + msg_iovlen: iovs.len(), + msg_control: safe_cast(&mut control), + msg_controllen: mem::size_of::<ControlHeaderV4>(), + msg_flags: 0, + }; + + debug_assert!( + hdr.msg_controllen + >= mem::size_of::<libc::cmsghdr>() + mem::size_of::<libc::in_pktinfo>(), + ); + + let len = unsafe { libc::recvmsg(fd, &mut hdr as *mut libc::msghdr, 0) }; + + if len < 0 { + return Err(io::Error::new( + io::ErrorKind::NotConnected, + "failed to receive", + )); + } + + Ok(( + len.try_into().unwrap(), + LinuxEndpoint::V4(EndpointV4 { + info: control.info, // save pktinfo (sticky source) + dst: src, // our future destination is the source address + }), + )) } } -impl Writer<SocketAddr> for LinuxUDP { +impl Reader<LinuxEndpoint> for LinuxUDPReader { type Error = io::Error; - fn write(&self, buf: &[u8], dst: &SocketAddr) -> Result<(), Self::Error> { - self.0.send_to(buf, dst)?; + fn read(&self, buf: &mut [u8]) -> Result<(usize, LinuxEndpoint), Self::Error> { + match self { + Self::V4(fd) => Self::read4(*fd, buf), + Self::V6(fd) => Self::read6(*fd, buf), + } + } +} + +impl LinuxUDPWriter { + fn write6(fd: RawFd, buf: &[u8], dst: &mut EndpointV6) -> Result<(), io::Error> { + log::debug!("sending IPv6 packet ({} fd, {} bytes)", fd, buf.len()); + + let mut iovs: [libc::iovec; 1] = [libc::iovec { + iov_base: buf.as_ptr() as *mut core::ffi::c_void, + iov_len: buf.len(), + }]; + + let mut control = ControlHeaderV6 { + hdr: libc::cmsghdr { + cmsg_len: CMSG_LEN(mem::size_of::<libc::in6_pktinfo>()), + cmsg_level: libc::IPPROTO_IPV6, + cmsg_type: libc::IPV6_PKTINFO, + }, + info: dst.info, + }; + + debug_assert_eq!( + control.hdr.cmsg_len % mem::size_of::<u32>(), + 0, + "cmsg_len must be aligned to a long" + ); + + debug_assert_eq!( + dst.dst.sin6_family, + libc::AF_INET6 as libc::sa_family_t, + "this method only handles IPv6 destinations" + ); + + let mut hdr = libc::msghdr { + msg_name: safe_cast(&mut dst.dst), + msg_namelen: mem::size_of_val(&dst.dst).try_into().unwrap(), + msg_iov: iovs.as_mut_ptr(), + msg_iovlen: iovs.len(), + msg_control: safe_cast(&mut control), + msg_controllen: mem::size_of_val(&control), + msg_flags: 0, + }; + + let ret = unsafe { libc::sendmsg(fd, &hdr, 0) }; + + if ret < 0 { + if errno() == libc::EINVAL { + log::trace!("clear source and retry"); + hdr.msg_control = ptr::null_mut(); + hdr.msg_controllen = 0; + dst.info = unsafe { mem::zeroed() }; + if unsafe { libc::sendmsg(fd, &hdr, 0) } < 0 { + return Err(io::Error::new( + io::ErrorKind::NotConnected, + "failed to send IPv6 packet", + )); + } else { + return Ok(()); + } + } + return Err(io::Error::new( + io::ErrorKind::NotConnected, + "failed to send IPv6 packet", + )); + } + + Ok(()) + } + + fn write4(fd: RawFd, buf: &[u8], dst: &mut EndpointV4) -> Result<(), io::Error> { + log::debug!("sending IPv4 packet ({} fd, {} bytes)", fd, buf.len()); + + let mut iovs: [libc::iovec; 1] = [libc::iovec { + iov_base: buf.as_ptr() as *mut core::ffi::c_void, + iov_len: buf.len(), + }]; + + let mut control = ControlHeaderV4 { + hdr: libc::cmsghdr { + cmsg_len: CMSG_LEN(mem::size_of::<libc::in_pktinfo>()), + cmsg_level: libc::IPPROTO_IP, + cmsg_type: libc::IP_PKTINFO, + }, + info: dst.info, + }; + + debug_assert_eq!( + control.hdr.cmsg_len % mem::size_of::<u32>(), + 0, + "cmsg_len must be aligned to a long" + ); + + debug_assert_eq!( + dst.dst.sin_family, + libc::AF_INET as libc::sa_family_t, + "this method only handles IPv4 destinations" + ); + + let mut hdr = libc::msghdr { + msg_name: safe_cast(&mut dst.dst), + msg_namelen: mem::size_of_val(&dst.dst).try_into().unwrap(), + msg_iov: iovs.as_mut_ptr(), + msg_iovlen: iovs.len(), + msg_control: safe_cast(&mut control), + msg_controllen: mem::size_of_val(&control), + msg_flags: 0, + }; + + let ret = unsafe { libc::sendmsg(fd, &hdr, 0) }; + + if ret < 0 { + if errno() == libc::EINVAL { + log::trace!("clear source and retry"); + hdr.msg_control = ptr::null_mut(); + hdr.msg_controllen = 0; + dst.info = unsafe { mem::zeroed() }; + if unsafe { libc::sendmsg(fd, &hdr, 0) } < 0 { + return Err(io::Error::new( + io::ErrorKind::NotConnected, + "failed to send IPv4 packet", + )); + } else { + return Ok(()); + } + } + return Err(io::Error::new( + io::ErrorKind::NotConnected, + "failed to send IPv4 packet", + )); + } + Ok(()) } } -impl Owner for LinuxOwner { +impl Writer<LinuxEndpoint> for LinuxUDPWriter { type Error = io::Error; - fn get_port(&self) -> u16 { - self.0.local_addr().unwrap().port() // todo handle + fn write(&self, buf: &[u8], dst: &mut LinuxEndpoint) -> Result<(), Self::Error> { + match dst { + LinuxEndpoint::V4(ref mut end) => Self::write4(self.sock4, buf, end), + LinuxEndpoint::V6(ref mut end) => Self::write6(self.sock6, buf, end), + } } +} - fn get_fwmark(&self) -> Option<u32> { - None +impl Owner for LinuxOwner { + type Error = io::Error; + + fn get_port(&self) -> u16 { + self.port } - fn set_fwmark(&mut self, _value: Option<u32>) -> Result<(), Self::Error> { - Ok(()) + fn set_fwmark(&mut self, value: Option<u32>) -> Result<(), Self::Error> { + fn set_mark(fd: Option<RawFd>, value: u32) -> Result<(), io::Error> { + if let Some(fd) = fd { + setsockopt(fd, libc::SOL_SOCKET, libc::SO_MARK, &value) + } else { + Ok(()) + } + } + let value = value.unwrap_or(0); + set_mark(self.sock6, value)?; + set_mark(self.sock4, value) } } impl Drop for LinuxOwner { fn drop(&mut self) { - // TODO: close udp bind + log::trace!("closing the bind (port {})", self.port); + self.sock4.map(|fd| unsafe { + libc::shutdown(fd, libc::SHUT_RDWR); + libc::close(fd) + }); + self.sock6.map(|fd| unsafe { + libc::shutdown(fd, libc::SHUT_RDWR); + libc::close(fd) + }); } } impl UDP for LinuxUDP { type Error = io::Error; - type Endpoint = SocketAddr; - type Reader = Self; - type Writer = Self; + type Endpoint = LinuxEndpoint; + type Reader = LinuxUDPReader; + type Writer = LinuxUDPWriter; +} + +impl LinuxUDP { + /* Bind on all IPv6 interfaces + * + * Arguments: + * + * - 'port', port to bind to (0 = any) + * + * Returns: + * + * Returns a tuple of the resulting port and socket. + */ + fn bind6(port: u16) -> Result<(u16, RawFd), io::Error> { + log::trace!("attempting to bind on IPv6 (port {})", port); + + // create socket fd + let fd: RawFd = unsafe { libc::socket(libc::AF_INET6, libc::SOCK_DGRAM, 0) }; + if fd < 0 { + log::debug!("failed to create IPv6 socket"); + return Err(io::Error::new( + io::ErrorKind::Other, + "failed to create socket", + )); + } + + setsockopt_int(fd, libc::SOL_SOCKET, libc::SO_REUSEADDR, 1)?; + setsockopt_int(fd, libc::IPPROTO_IPV6, libc::IPV6_RECVPKTINFO, 1)?; + setsockopt_int(fd, libc::IPPROTO_IPV6, libc::IPV6_V6ONLY, 1)?; + + // bind + let mut sockaddr = libc::sockaddr_in6 { + sin6_addr: libc::in6_addr { s6_addr: [0; 16] }, + sin6_family: libc::AF_INET6 as libc::sa_family_t, + sin6_port: port.to_be(), // convert to network (big-endian) byteorder + sin6_scope_id: 0, + sin6_flowinfo: 0, + }; + + let err = unsafe { + libc::bind( + fd, + mem::transmute(&sockaddr as *const libc::sockaddr_in6), + mem::size_of_val(&sockaddr).try_into().unwrap(), + ) + }; + + if err != 0 { + log::debug!("failed to bind IPv6 socket"); + return Err(io::Error::new( + io::ErrorKind::Other, + "failed to create socket", + )); + } + + // get the assigned port + let mut socklen: libc::socklen_t = mem::size_of_val(&sockaddr).try_into().unwrap(); + let err = unsafe { + libc::getsockname( + fd, + mem::transmute(&mut sockaddr as *mut libc::sockaddr_in6), + &mut socklen as *mut libc::socklen_t, + ) + }; + if err != 0 { + log::debug!("failed to get port of IPv6 socket"); + return Err(io::Error::new( + io::ErrorKind::Other, + "failed to create socket", + )); + } + + // basic sanity checks + let new_port = u16::from_be(sockaddr.sin6_port); + debug_assert_eq!(socklen, mem::size_of::<libc::sockaddr_in6>() as u32); + debug_assert_eq!(sockaddr.sin6_family, libc::AF_INET6 as libc::sa_family_t); + debug_assert_eq!(new_port, if port != 0 { port } else { new_port }); + log::trace!("bound IPv6 socket (port {}, fd {})", new_port, fd); + return Ok((new_port, fd)); + } + + /* Bind on all IPv4 interfaces. + * + * Arguments: + * + * - 'port', port to bind to (0 = any) + * + * Returns: + * + * Returns a tuple of the resulting port and socket. + */ + fn bind4(port: u16) -> Result<(u16, RawFd), io::Error> { + log::trace!("attempting to bind on IPv4 (port {})", port); + + // create socket fd + let fd: RawFd = unsafe { libc::socket(libc::AF_INET, libc::SOCK_DGRAM, 0) }; + if fd < 0 { + log::trace!("failed to create IPv4 socket (errno = {})", errno()); + return Err(io::Error::new( + io::ErrorKind::Other, + "failed to create socket", + )); + } + + setsockopt_int(fd, libc::SOL_SOCKET, libc::SO_REUSEADDR, 1)?; + setsockopt_int(fd, libc::IPPROTO_IP, libc::IP_PKTINFO, 1)?; + + const INADDR_ANY: libc::in_addr = libc::in_addr { s_addr: 0 }; + + // bind + let mut sockaddr = libc::sockaddr_in { + sin_addr: INADDR_ANY, + sin_family: libc::AF_INET as libc::sa_family_t, + sin_port: port.to_be(), + sin_zero: [0; 8], + }; + + let err = unsafe { + libc::bind( + fd, + mem::transmute(&sockaddr as *const libc::sockaddr_in), + mem::size_of_val(&sockaddr).try_into().unwrap(), + ) + }; + + if err != 0 { + log::trace!("failed to bind IPv4 socket (errno = {})", errno()); + return Err(io::Error::new( + io::ErrorKind::Other, + "failed to create socket", + )); + } + + // get the assigned port + let mut socklen: libc::socklen_t = mem::size_of_val(&sockaddr).try_into().unwrap(); + let err = unsafe { + libc::getsockname( + fd, + mem::transmute(&mut sockaddr as *mut libc::sockaddr_in), + &mut socklen as *mut libc::socklen_t, + ) + }; + if err != 0 { + log::trace!("failed to get port of IPv4 socket (errno = {})", errno()); + return Err(io::Error::new( + io::ErrorKind::Other, + "failed to create socket", + )); + } + + // basic sanity checks + let new_port = u16::from_be(sockaddr.sin_port); + debug_assert_eq!(socklen, mem::size_of::<libc::sockaddr_in>() as u32); + debug_assert_eq!(sockaddr.sin_family, libc::AF_INET as libc::sa_family_t); + debug_assert_eq!(new_port, if port != 0 { port } else { new_port }); + log::trace!("bound IPv4 socket (port {}, fd {})", new_port, fd); + return Ok((new_port, fd)); + } } impl PlatformUDP for LinuxUDP { type Owner = LinuxOwner; - fn bind(port: u16) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Owner), Self::Error> { - let socket = UdpSocket::bind(format!("0.0.0.0:{}", port))?; - let socket = Arc::new(socket); + fn bind(mut port: u16) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Owner), Self::Error> { + log::debug!("bind to port {}", port); - Ok(( - vec![LinuxUDP(socket.clone())], - LinuxUDP(socket.clone()), - LinuxOwner(socket), - )) + // attempt to bind on ipv6 + let bind6 = Self::bind6(port); + if let Ok((new_port, _)) = bind6 { + port = new_port; + } + + // attempt to bind on ipv4 on the same port + let bind4 = Self::bind4(port); + if let Ok((new_port, _)) = bind4 { + port = new_port; + } + + // check if failed to bind on both + if bind4.is_err() && bind6.is_err() { + log::trace!("failed to bind for either IP version"); + return Err(bind6.unwrap_err()); + } + + let sock6 = bind6.ok().map(|(_, fd)| fd); + let sock4 = bind4.ok().map(|(_, fd)| fd); + + // create owner + let owner = LinuxOwner { + port, + sock6: sock6, + sock4: sock4, + }; + + // create readers + let mut readers: Vec<Self::Reader> = Vec::with_capacity(2); + sock6.map(|sock| readers.push(LinuxUDPReader::V6(sock))); + sock4.map(|sock| readers.push(LinuxUDPReader::V4(sock))); + debug_assert!(readers.len() > 0); + + // create writer + let writer = LinuxUDPWriter { + sock4: sock4.unwrap_or(-1), + sock6: sock6.unwrap_or(-1), + }; + + Ok((readers, writer, owner)) } } diff --git a/src/platform/udp.rs b/src/platform/udp.rs index 3671229..e1180fb 100644 --- a/src/platform/udp.rs +++ b/src/platform/udp.rs @@ -10,7 +10,7 @@ pub trait Reader<E: Endpoint>: Send + Sync { pub trait Writer<E: Endpoint>: Send + Sync + Clone + 'static { type Error: Error; - fn write(&self, buf: &[u8], dst: &E) -> Result<(), Self::Error>; + fn write(&self, buf: &[u8], dst: &mut E) -> Result<(), Self::Error>; } pub trait UDP: Send + Sync + 'static { @@ -30,8 +30,6 @@ pub trait Owner: Send { fn get_port(&self) -> u16; - fn get_fwmark(&self) -> Option<u32>; - fn set_fwmark(&mut self, value: Option<u32>) -> Result<(), Self::Error>; } |