summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/configuration/config.rs17
-rw-r--r--src/platform/dummy/udp.rs4
-rw-r--r--src/platform/linux/udp.rs103
-rw-r--r--src/platform/udp.rs2
-rw-r--r--src/wireguard/router/peer.rs2
-rw-r--r--src/wireguard/workers.rs4
6 files changed, 76 insertions, 56 deletions
diff --git a/src/configuration/config.rs b/src/configuration/config.rs
index d61cda5..59cef4a 100644
--- a/src/configuration/config.rs
+++ b/src/configuration/config.rs
@@ -1,3 +1,4 @@
+use std::mem;
use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::Ordering;
use std::sync::{Arc, Mutex, MutexGuard};
@@ -266,24 +267,22 @@ impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireGuardConfig<T, B> {
fn set_listen_port(&self, port: u16) -> Result<(), ConfigError> {
log::trace!("Config, Set listen port: {:?}", port);
- // update port
- let listen: bool = {
+ // update port and take old bind
+ let old: Option<B::Owner> = {
let mut cfg = self.lock();
+ let old = mem::replace(&mut cfg.bind, None);
cfg.port = port;
- if cfg.bind.is_some() {
- cfg.bind = None;
- true
- } else {
- false
- }
+ old
};
// restart listener if bound
- if listen {
+ if old.is_some() {
self.start_listener()
} else {
Ok(())
}
+
+ // old bind is dropped, causing the file-descriptors to be released
}
fn set_fwmark(&self, mark: Option<u32>) -> Result<(), ConfigError> {
diff --git a/src/platform/dummy/udp.rs b/src/platform/dummy/udp.rs
index 6c126a9..88630af 100644
--- a/src/platform/dummy/udp.rs
+++ b/src/platform/dummy/udp.rs
@@ -54,7 +54,7 @@ impl Reader<UnitEndpoint> for VoidBind {
impl Writer<UnitEndpoint> for VoidBind {
type Error = BindError;
- fn write(&self, _buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> {
+ fn write(&self, _buf: &[u8], _dst: &mut UnitEndpoint) -> Result<(), Self::Error> {
Ok(())
}
}
@@ -105,7 +105,7 @@ impl Reader<UnitEndpoint> for PairReader<UnitEndpoint> {
impl Writer<UnitEndpoint> for PairWriter<UnitEndpoint> {
type Error = BindError;
- fn write(&self, buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> {
+ fn write(&self, buf: &[u8], _dst: &mut UnitEndpoint) -> Result<(), Self::Error> {
debug!(
"dummy({}): write ({}, {})",
self.id,
diff --git a/src/platform/linux/udp.rs b/src/platform/linux/udp.rs
index 1552f69..ab5b53b 100644
--- a/src/platform/linux/udp.rs
+++ b/src/platform/linux/udp.rs
@@ -8,12 +8,13 @@ use std::io;
use std::mem;
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
use std::os::unix::io::RawFd;
+use std::ptr;
fn errno() -> libc::c_int {
unsafe {
let ptr = libc::__errno_location();
if ptr.is_null() {
- -1
+ 0
} else {
*ptr
}
@@ -23,7 +24,7 @@ fn errno() -> libc::c_int {
#[repr(C)]
struct ControlHeaderV4 {
hdr: libc::cmsghdr,
- body: libc::in_pktinfo,
+ info: libc::in_pktinfo,
}
#[repr(C)]
@@ -120,12 +121,12 @@ impl Endpoint for LinuxEndpoint {
LinuxEndpoint::V4(EndpointV4 { ref dst, .. }) => {
SocketAddr::V4(SocketAddrV4::new(
u32::from_be(dst.sin_addr.s_addr).into(), // IPv4 addr
- dst.sin_port,
+ u16::from_be(dst.sin_port), // convert back to native byte-order
))
}
LinuxEndpoint::V6(EndpointV6 { ref dst, .. }) => SocketAddr::V6(SocketAddrV6::new(
u128::from_ne_bytes(dst.sin6_addr.s6_addr).into(), // IPv6 addr
- dst.sin6_port,
+ u16::from_be(dst.sin6_port), // convert back to native byte-order
dst.sin6_flowinfo,
dst.sin6_scope_id,
)),
@@ -178,12 +179,12 @@ impl LinuxUDPReader {
// this memory is mutated by the recvmsg call
#[allow(unused_mut)]
let mut control: ControlHeaderV6 = unsafe { mem::MaybeUninit::uninit().assume_init() };
- let iovs: [libc::iovec; 1] = [unsafe {
- libc::iovec {
- iov_base: mem::transmute(&buf[0] as *const u8),
- iov_len: buf.len(),
- }
+
+ let iovs: [libc::iovec; 1] = [libc::iovec {
+ iov_base: buf.as_mut_ptr() as *mut core::ffi::c_void,
+ iov_len: buf.len(),
}];
+
let src: libc::sockaddr_in6 = unsafe { mem::MaybeUninit::uninit().assume_init() };
let mut hdr = unsafe {
libc::msghdr {
@@ -223,29 +224,31 @@ impl LinuxUDPReader {
buf.len()
);
+ let iovs: [libc::iovec; 1] = [libc::iovec {
+ iov_base: buf.as_mut_ptr() as *mut core::ffi::c_void,
+ iov_len: buf.len(),
+ }];
+
+ let src: libc::sockaddr_in = unsafe { mem::MaybeUninit::uninit().assume_init() };
+
// this memory is mutated by the recvmsg call
#[allow(unused_mut)]
let mut control: ControlHeaderV4 = unsafe { mem::MaybeUninit::uninit().assume_init() };
- let iovs: [libc::iovec; 1] = [unsafe {
- libc::iovec {
- iov_base: mem::transmute(&buf[0] as *const u8),
- iov_len: buf.len(),
- }
- }];
- let src: libc::sockaddr_in = unsafe { mem::MaybeUninit::uninit().assume_init() };
+
let mut hdr = unsafe {
libc::msghdr {
msg_name: mem::transmute(&src),
- msg_namelen: mem::size_of_val(&src).try_into().unwrap(),
+ msg_namelen: mem::size_of_val(&src).try_into().unwrap(), // constant
msg_iov: mem::transmute(&iovs[0]),
- msg_iovlen: iovs.len(),
+ msg_iovlen: iovs.len(), // constant
msg_control: mem::transmute(&control),
- msg_controllen: mem::size_of_val(&control),
- msg_flags: 0, // ignored
+ msg_controllen: mem::size_of_val(&control), // constant
+ msg_flags: 0, // ignored
}
};
let len = unsafe { libc::recvmsg(fd, &mut hdr as *mut libc::msghdr, 0) };
+
if len < 0 {
log::trace!("failed to receive IPv4 packet (errno = {})", errno());
return Err(io::Error::new(
@@ -254,12 +257,20 @@ impl LinuxUDPReader {
));
}
+ log::trace!("read4, len: {}", len);
+ log::trace!(
+ "control: {{ hdr : {{ cmsg_level: {}, cmsg_type: {}, cmsg_len: {} }} }}",
+ control.hdr.cmsg_level,
+ control.hdr.cmsg_type,
+ control.hdr.cmsg_len
+ );
+
log::trace!("received IPv4 packet ({} fd, {} bytes)", fd, len);
Ok((
len.try_into().unwrap(),
LinuxEndpoint::V4(EndpointV4 {
- info: control.body,
- dst: src,
+ info: control.info, // save pkinfo (sticky source)
+ dst: src, // our future destination is the source address
}),
))
}
@@ -283,16 +294,21 @@ impl LinuxUDPWriter {
unimplemented!()
}
- fn write4(fd: RawFd, buf: &[u8], dst: &EndpointV4) -> Result<(), io::Error> {
+ fn write4(fd: RawFd, buf: &[u8], dst: &mut EndpointV4) -> Result<(), io::Error> {
log::debug!("sending IPv4 packet ({} fd, {} bytes)", fd, buf.len());
- let control = ControlHeaderV4 {
+ let iovs: [libc::iovec; 1] = [libc::iovec {
+ iov_base: buf.as_ptr() as *mut core::ffi::c_void,
+ iov_len: buf.len(),
+ }];
+
+ let mut control = ControlHeaderV4 {
hdr: libc::cmsghdr {
cmsg_len: mem::size_of::<ControlHeaderV4>(),
cmsg_level: libc::IPPROTO_IP,
cmsg_type: libc::IP_PKTINFO,
},
- body: dst.info,
+ info: dst.info,
};
debug_assert_eq!(
@@ -302,12 +318,7 @@ impl LinuxUDPWriter {
);
debug_assert_eq!(dst.dst.sin_family, libc::AF_INET as libc::sa_family_t);
- let iovs: [libc::iovec; 1] = [libc::iovec {
- iov_base: buf.as_ptr() as *mut core::ffi::c_void,
- iov_len: buf.len(),
- }];
-
- let hdr = libc::msghdr {
+ let mut hdr = libc::msghdr {
msg_name: unsafe { mem::transmute(&dst.dst as *const libc::sockaddr_in) },
msg_namelen: mem::size_of_val(&dst.dst).try_into().unwrap(),
msg_iov: iovs.as_ptr() as *mut libc::iovec,
@@ -317,19 +328,29 @@ impl LinuxUDPWriter {
msg_flags: 0,
};
- println!(
- "name : {}, controllen: {}",
- hdr.msg_namelen, hdr.msg_controllen
- );
-
- let err = unsafe { libc::sendmsg(fd, &hdr, 0) };
- if err < 0 {
- log::trace!("failed to send IPv4: (errno = {})", errno());
+ let ret = unsafe { libc::sendmsg(fd, &hdr, 0) };
+
+ if ret < 0 {
+ if errno() == libc::EINVAL {
+ log::trace!("clear source and retry");
+ hdr.msg_control = ptr::null_mut();
+ hdr.msg_controllen = 0;
+ dst.info = unsafe { mem::zeroed() };
+ if unsafe { libc::sendmsg(fd, &hdr, 0) } < 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::NotConnected,
+ "failed to send IPv4 packet",
+ ));
+ } else {
+ return Ok(());
+ }
+ }
return Err(io::Error::new(
io::ErrorKind::NotConnected,
"failed to send IPv4 packet",
));
}
+
Ok(())
}
}
@@ -337,9 +358,9 @@ impl LinuxUDPWriter {
impl Writer<LinuxEndpoint> for LinuxUDPWriter {
type Error = io::Error;
- fn write(&self, buf: &[u8], dst: &LinuxEndpoint) -> Result<(), Self::Error> {
+ fn write(&self, buf: &[u8], dst: &mut LinuxEndpoint) -> Result<(), Self::Error> {
match dst {
- LinuxEndpoint::V4(ref end) => Self::write4(self.sock4, buf, end),
+ LinuxEndpoint::V4(ref mut end) => Self::write4(self.sock4, buf, end),
LinuxEndpoint::V6(ref end) => Self::write6(self.sock6, buf, end),
}
}
diff --git a/src/platform/udp.rs b/src/platform/udp.rs
index 4685a1e..e1180fb 100644
--- a/src/platform/udp.rs
+++ b/src/platform/udp.rs
@@ -10,7 +10,7 @@ pub trait Reader<E: Endpoint>: Send + Sync {
pub trait Writer<E: Endpoint>: Send + Sync + Clone + 'static {
type Error: Error;
- fn write(&self, buf: &[u8], dst: &E) -> Result<(), Self::Error>;
+ fn write(&self, buf: &[u8], dst: &mut E) -> Result<(), Self::Error>;
}
pub trait UDP: Send + Sync + 'static {
diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs
index b8110f0..8fe2e1c 100644
--- a/src/wireguard/router/peer.rs
+++ b/src/wireguard/router/peer.rs
@@ -204,7 +204,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E,
debug!("peer.send");
// send to endpoint (if known)
- match self.endpoint.lock().as_ref() {
+ match self.endpoint.lock().as_mut() {
Some(endpoint) => {
let outbound = self.device.outbound.read();
if outbound.0 {
diff --git a/src/wireguard/workers.rs b/src/wireguard/workers.rs
index 9802232..c1a2af7 100644
--- a/src/wireguard/workers.rs
+++ b/src/wireguard/workers.rs
@@ -178,7 +178,7 @@ pub fn handshake_worker<T: Tun, B: UDP>(
// de-multiplex staged handshake jobs and handshake messages
match job {
- HandshakeJob::Message(msg, src) => {
+ HandshakeJob::Message(msg, mut src) => {
// process message
let device = wg.peers.read();
match device.process(
@@ -201,7 +201,7 @@ pub fn handshake_worker<T: Tun, B: UDP>(
"{} : handshake worker, send response ({} bytes)",
wg, resp_len
);
- let _ = writer.write(&msg[..], &src).map_err(|e| {
+ let _ = writer.write(&msg[..], &mut src).map_err(|e| {
debug!(
"{} : handshake worker, failed to send response, error = {}",
wg,