diff options
author | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2020-01-09 11:24:13 +0100 |
---|---|---|
committer | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2020-01-09 11:24:13 +0100 |
commit | acbca236b70598c20c24de474690bcad883241d4 (patch) | |
tree | 10a33cc8c6ae5b4a56bac191b3f8072e6bd4b717 | |
parent | Fixed typo in under load code (diff) | |
download | wireguard-rs-acbca236b70598c20c24de474690bcad883241d4.tar.xz wireguard-rs-acbca236b70598c20c24de474690bcad883241d4.zip |
Work on sticky sockets
-rw-r--r-- | Cargo.toml | 1 | ||||
-rw-r--r-- | src/configuration/config.rs | 2 | ||||
-rw-r--r-- | src/configuration/uapi/get.rs | 3 | ||||
-rw-r--r-- | src/platform/dummy/udp.rs | 4 | ||||
-rw-r--r-- | src/platform/linux/udp.rs | 391 | ||||
-rw-r--r-- | src/platform/udp.rs | 2 |
6 files changed, 359 insertions, 44 deletions
@@ -3,7 +3,6 @@ name = "wireguard-rs" version = "0.1.0" authors = ["Mathias Hall-Andersen <mathias@hall-andersen.dk>"] edition = "2018" -license = "MIT" [dependencies] hex = "0.3" diff --git a/src/configuration/config.rs b/src/configuration/config.rs index aec943f..d61cda5 100644 --- a/src/configuration/config.rs +++ b/src/configuration/config.rs @@ -205,7 +205,7 @@ impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireGuardConfig<T, B> { } fn get_fwmark(&self) -> Option<u32> { - self.lock().bind.as_ref().and_then(|own| own.get_fwmark()) + self.lock().fwmark } fn set_private_key(&self, sk: Option<StaticSecret>) { diff --git a/src/configuration/uapi/get.rs b/src/configuration/uapi/get.rs index 9e6ab36..00048cd 100644 --- a/src/configuration/uapi/get.rs +++ b/src/configuration/uapi/get.rs @@ -2,7 +2,6 @@ use log; use std::io; use super::Configuration; -use super::Endpoint; pub fn serialize<C: Configuration, W: io::Write>(writer: &mut W, config: &C) -> io::Result<()> { let mut write = |key: &'static str, value: String| { @@ -46,7 +45,7 @@ pub fn serialize<C: Configuration, W: io::Write>(writer: &mut W, config: &C) -> } if let Some(endpoint) = p.endpoint { - write("endpoint", endpoint.into_address().to_string())?; + write("endpoint", endpoint.to_string())?; } for (ip, cidr) in p.allowed_ips { diff --git a/src/platform/dummy/udp.rs b/src/platform/dummy/udp.rs index 35c905d..d521851 100644 --- a/src/platform/dummy/udp.rs +++ b/src/platform/dummy/udp.rs @@ -187,10 +187,6 @@ impl Owner for VoidOwner { fn get_port(&self) -> u16 { 0 } - - fn get_fwmark(&self) -> Option<u32> { - None - } } impl PlatformUDP for PairBind { diff --git a/src/platform/linux/udp.rs b/src/platform/linux/udp.rs index f871bce..91be56e 100644 --- a/src/platform/linux/udp.rs +++ b/src/platform/linux/udp.rs @@ -1,41 +1,157 @@ use super::super::udp::*; use super::super::Endpoint; +use log; + +use std::convert::TryInto; use std::io; -use std::net::{SocketAddr, UdpSocket}; -use std::sync::Arc; +use std::mem; +use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::os::unix::io::RawFd; + +pub struct EndpointV4 { + dst: libc::sockaddr_in, // destination IP + info: libc::in_pktinfo, // src & ifindex +} + +pub struct EndpointV6 { + dst: libc::sockaddr_in6, // destination IP + info: libc::in6_pktinfo, // src & zone id +} + +pub struct LinuxUDP(); + +pub struct LinuxOwner { + port: u16, + sock4: Option<RawFd>, + sock6: Option<RawFd>, +} + +pub enum LinuxUDPReader { + V4(RawFd), + V6(RawFd), +} #[derive(Clone)] -pub struct LinuxUDP(Arc<UdpSocket>); +pub struct LinuxUDPWriter { + sock4: RawFd, + sock6: RawFd, +} -pub struct LinuxOwner(Arc<UdpSocket>); +pub enum LinuxEndpoint { + V4(EndpointV4), + V6(EndpointV6), +} -impl Endpoint for SocketAddr { - fn clear_src(&mut self) {} +impl Endpoint for LinuxEndpoint { + fn clear_src(&mut self) { + match self { + LinuxEndpoint::V4(EndpointV4 { ref mut info, .. }) => { + info.ipi_ifindex = 0; + info.ipi_spec_dst = libc::in_addr { s_addr: 0 }; + } + LinuxEndpoint::V6(EndpointV6 { ref mut info, .. }) => { + info.ipi6_addr = libc::in6_addr { s6_addr: [0; 16] }; + info.ipi6_ifindex = 0; + } + }; + } fn from_address(addr: SocketAddr) -> Self { - addr + match addr { + SocketAddr::V4(addr) => LinuxEndpoint::V4(EndpointV4 { + dst: libc::sockaddr_in { + sin_family: libc::AF_INET as libc::sa_family_t, + sin_port: addr.port(), + sin_addr: libc::in_addr { + s_addr: u32::from(*addr.ip()), + }, + sin_zero: [0; 8], + }, + info: libc::in_pktinfo { + ipi_ifindex: 0, // interface (0 is via routing table) + ipi_spec_dst: libc::in_addr { s_addr: 0 }, // src IP (dst of incoming packet) + ipi_addr: libc::in_addr { + // dst IP + s_addr: u32::from(*addr.ip()), + }, + }, + }), + SocketAddr::V6(addr) => LinuxEndpoint::V6(EndpointV6 { + dst: libc::sockaddr_in6 { + sin6_family: libc::AF_INET6 as libc::sa_family_t, + sin6_port: addr.port(), + sin6_flowinfo: addr.flowinfo(), + sin6_addr: libc::in6_addr { + s6_addr: addr.ip().octets(), + }, + sin6_scope_id: addr.scope_id(), + }, + info: libc::in6_pktinfo { + ipi6_addr: libc::in6_addr { s6_addr: [0; 16] }, // src IP + ipi6_ifindex: 0, // zone id + }, + }), + } } fn into_address(&self) -> SocketAddr { - *self + match self { + LinuxEndpoint::V4(EndpointV4 { ref dst, .. }) => { + SocketAddr::V4(SocketAddrV4::new( + dst.sin_addr.s_addr.into(), // IPv4 addr + dst.sin_port, + )) + } + LinuxEndpoint::V6(EndpointV6 { ref dst, .. }) => SocketAddr::V6(SocketAddrV6::new( + u128::from_ne_bytes(dst.sin6_addr.s6_addr).into(), // IPv6 addr + dst.sin6_port, + dst.sin6_flowinfo, + dst.sin6_scope_id, + )), + } + } +} + +impl LinuxUDPReader { + fn read6(fd: RawFd, buf: &mut [u8]) -> Result<(usize, LinuxEndpoint), io::Error> { + unimplemented!() + } + + fn read4(fd: RawFd, buf: &mut [u8]) -> Result<(usize, LinuxEndpoint), io::Error> { + unimplemented!() } } -impl Reader<SocketAddr> for LinuxUDP { +impl Reader<LinuxEndpoint> for LinuxUDPReader { type Error = io::Error; - fn read(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error> { - self.0.recv_from(buf) + fn read(&self, buf: &mut [u8]) -> Result<(usize, LinuxEndpoint), Self::Error> { + match self { + Self::V4(fd) => Self::read4(*fd, buf), + Self::V6(fd) => Self::read6(*fd, buf), + } } } -impl Writer<SocketAddr> for LinuxUDP { +impl LinuxUDPWriter { + fn write6(fd: RawFd, buf: &[u8], dst: &EndpointV6) -> Result<(), io::Error> { + unimplemented!() + } + + fn write4(fd: RawFd, buf: &[u8], dst: &EndpointV4) -> Result<(), io::Error> { + unimplemented!() + } +} + +impl Writer<LinuxEndpoint> for LinuxUDPWriter { type Error = io::Error; - fn write(&self, buf: &[u8], dst: &SocketAddr) -> Result<(), Self::Error> { - self.0.send_to(buf, dst)?; - Ok(()) + fn write(&self, buf: &[u8], dst: &LinuxEndpoint) -> Result<(), Self::Error> { + match dst { + LinuxEndpoint::V4(ref end) => Self::write4(self.sock4, buf, end), + LinuxEndpoint::V6(ref end) => Self::write6(self.sock6, buf, end), + } } } @@ -43,42 +159,249 @@ impl Owner for LinuxOwner { type Error = io::Error; fn get_port(&self) -> u16 { - self.0.local_addr().unwrap().port() // todo handle - } - - fn get_fwmark(&self) -> Option<u32> { - None + self.port } - fn set_fwmark(&mut self, _value: Option<u32>) -> Result<(), Self::Error> { - Ok(()) + fn set_fwmark(&mut self, value: Option<u32>) -> Result<(), Self::Error> { + fn set_mark(fd: Option<RawFd>, value: u32) -> Result<(), io::Error> { + if let Some(fd) = fd { + let err = unsafe { + libc::setsockopt( + fd, + libc::SOL_SOCKET, + libc::SO_MARK, + mem::transmute(&value as *const u32), + mem::size_of_val(&value).try_into().unwrap(), + ) + }; + if err != 0 { + log::debug!("Failed to set fwmark: {}", err); + return Err(io::Error::new( + io::ErrorKind::PermissionDenied, + "failed to set fwmark", + )); + } + } + Ok(()) + } + let value = value.unwrap_or(0); + set_mark(self.sock6, value)?; + set_mark(self.sock4, value) } } impl Drop for LinuxOwner { fn drop(&mut self) { - // TODO: close udp bind + self.sock4.map(|fd| unsafe { libc::close(fd) }); + self.sock6.map(|fd| unsafe { libc::close(fd) }); } } impl UDP for LinuxUDP { type Error = io::Error; - type Endpoint = SocketAddr; - type Reader = Self; - type Writer = Self; + type Endpoint = LinuxEndpoint; + type Reader = LinuxUDPReader; + type Writer = LinuxUDPWriter; +} + +impl LinuxUDP { + /* Bind on all interfaces with IPv6. + * + * Arguments: + * + * - 'port', port to bind to (0 = any) + * + * Returns: + * + * Returns a tuple of the resulting port and socket. + */ + fn bind6(port: u16) -> Result<(u16, RawFd), io::Error> { + // create socket fd + let fd: RawFd = unsafe { libc::socket(libc::AF_INET6, libc::SOCK_DGRAM, 0) }; + if fd < 0 { + return Err(io::Error::new( + io::ErrorKind::Other, + "failed to create socket", + )); + } + + // bind + let mut sockaddr = libc::sockaddr_in6 { + sin6_addr: libc::in6_addr { s6_addr: [0; 16] }, + sin6_family: libc::AF_INET6.try_into().unwrap(), + sin6_port: port.to_be(), // convert to network (big-endian) byteorder + sin6_scope_id: 0, + sin6_flowinfo: 0, + }; + + let err = unsafe { + libc::bind( + fd, + mem::transmute(&sockaddr as *const libc::sockaddr_in6), + mem::size_of_val(&sockaddr).try_into().unwrap(), + ) + }; + + if err != 0 { + return Err(io::Error::new( + io::ErrorKind::Other, + "failed to create socket", + )); + } + + // listen + let err = unsafe { libc::listen(fd, 0) }; + if err != 0 { + return Err(io::Error::new( + io::ErrorKind::Other, + "failed to create socket", + )); + }; + + // get the assigned port + let mut socklen: libc::socklen_t = 0; + let err = unsafe { + libc::getsockname( + fd, + mem::transmute(&mut sockaddr as *mut libc::sockaddr_in6), + &mut socklen as *mut libc::socklen_t, + ) + }; + if err != 0 { + return Err(io::Error::new( + io::ErrorKind::Other, + "failed to create socket", + )); + } + + // basic sanity checks + let new_port = u16::from_be(sockaddr.sin6_port); + debug_assert_eq!(socklen, mem::size_of::<libc::sockaddr_in6>() as u32); + debug_assert_eq!(sockaddr.sin6_family, libc::AF_INET6 as libc::sa_family_t); + debug_assert_eq!(new_port, if port != 0 { port } else { new_port }); + return Ok((new_port, fd)); + } + + /* Bind on all interfaces with IPv4. + * + * Arguments: + * + * - 'port', port to bind to (0 = any) + * + * Returns: + * + * Returns a tuple of the resulting port and socket. + */ + fn bind4(port: u16) -> Result<(u16, RawFd), io::Error> { + // create socket fd + let fd: RawFd = unsafe { libc::socket(libc::AF_INET, libc::SOCK_DGRAM, 0) }; + if fd < 0 { + return Err(io::Error::new( + io::ErrorKind::Other, + "failed to create socket", + )); + } + + // bind + let mut sockaddr = libc::sockaddr_in { + sin_addr: libc::in_addr { s_addr: 0 }, + sin_family: libc::AF_INET as libc::sa_family_t, + sin_port: port.to_be(), // convert to network (big-endian) byteorder + sin_zero: [0; 8], + }; + + let err = unsafe { + libc::bind( + fd, + mem::transmute(&sockaddr as *const libc::sockaddr_in), + mem::size_of_val(&sockaddr).try_into().unwrap(), + ) + }; + + if err != 0 { + return Err(io::Error::new( + io::ErrorKind::Other, + "failed to create socket", + )); + } + + // listen + let err = unsafe { libc::listen(fd, 0) }; + if err != 0 { + return Err(io::Error::new( + io::ErrorKind::Other, + "failed to create socket", + )); + }; + + // get the assigned port + let mut socklen: libc::socklen_t = 0; + let err = unsafe { + libc::getsockname( + fd, + mem::transmute(&mut sockaddr as *mut libc::sockaddr_in), + &mut socklen as *mut libc::socklen_t, + ) + }; + if err != 0 { + return Err(io::Error::new( + io::ErrorKind::Other, + "failed to create socket", + )); + } + + // basic sanity checks + let new_port = u16::from_be(sockaddr.sin_port); + debug_assert_eq!(socklen, mem::size_of::<libc::sockaddr_in>() as u32); + debug_assert_eq!(sockaddr.sin_family, libc::AF_INET as libc::sa_family_t); + debug_assert_eq!(new_port, if port != 0 { port } else { new_port }); + return Ok((new_port, fd)); + } } impl PlatformUDP for LinuxUDP { type Owner = LinuxOwner; - fn bind(port: u16) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Owner), Self::Error> { - let socket = UdpSocket::bind(format!("0.0.0.0:{}", port))?; - let socket = Arc::new(socket); + fn bind(mut port: u16) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Owner), Self::Error> { + // attempt to bind on ipv6 + let bind6 = Self::bind6(port); + if let Ok((new_port, _)) = bind6 { + port = new_port; + } + + // attempt to bind on ipv4 on the same port + let bind4 = Self::bind4(port); + if let Ok((new_port, _)) = bind4 { + port = new_port; + } + + // check if failed to bind on both + if bind4.is_err() && bind6.is_err() { + return Err(bind6.unwrap_err()); + } + + let sock6 = bind6.ok().map(|(_, fd)| fd); + let sock4 = bind4.ok().map(|(_, fd)| fd); + + // create owner + let owner = LinuxOwner { + port, + sock6: sock6, + sock4: sock4, + }; + + // create readers + let mut readers: Vec<Self::Reader> = Vec::with_capacity(2); + sock6.map(|sock| readers.push(LinuxUDPReader::V6(sock))); + sock4.map(|sock| readers.push(LinuxUDPReader::V4(sock))); + debug_assert!(readers.len() > 0); + + // create writer + let writer = LinuxUDPWriter { + sock4: sock4.unwrap_or(-1), + sock6: sock6.unwrap_or(-1), + }; - Ok(( - vec![LinuxUDP(socket.clone())], - LinuxUDP(socket.clone()), - LinuxOwner(socket), - )) + Ok((readers, writer, owner)) } } diff --git a/src/platform/udp.rs b/src/platform/udp.rs index 3671229..4685a1e 100644 --- a/src/platform/udp.rs +++ b/src/platform/udp.rs @@ -30,8 +30,6 @@ pub trait Owner: Send { fn get_port(&self) -> u16; - fn get_fwmark(&self) -> Option<u32>; - fn set_fwmark(&mut self, value: Option<u32>) -> Result<(), Self::Error>; } |