diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/configuration/config.rs | 19 | ||||
-rw-r--r-- | src/configuration/uapi/get.rs | 3 | ||||
-rw-r--r-- | src/configuration/uapi/mod.rs | 1 | ||||
-rw-r--r-- | src/main.rs | 2 | ||||
-rw-r--r-- | src/platform/dummy/tun.rs | 3 | ||||
-rw-r--r-- | src/platform/dummy/udp.rs | 13 | ||||
-rw-r--r-- | src/platform/linux/tun.rs | 2 | ||||
-rw-r--r-- | src/platform/linux/udp.rs | 667 | ||||
-rw-r--r-- | src/platform/udp.rs | 4 | ||||
-rw-r--r-- | src/wireguard/handshake/device.rs | 198 | ||||
-rw-r--r-- | src/wireguard/handshake/macs.rs | 6 | ||||
-rw-r--r-- | src/wireguard/handshake/noise.rs | 46 | ||||
-rw-r--r-- | src/wireguard/handshake/peer.rs | 26 | ||||
-rw-r--r-- | src/wireguard/handshake/tests.rs | 62 | ||||
-rw-r--r-- | src/wireguard/handshake/types.rs | 14 | ||||
-rw-r--r-- | src/wireguard/peer.rs | 2 | ||||
-rw-r--r-- | src/wireguard/router/device.rs | 2 | ||||
-rw-r--r-- | src/wireguard/router/peer.rs | 2 | ||||
-rw-r--r-- | src/wireguard/wireguard.rs | 85 | ||||
-rw-r--r-- | src/wireguard/workers.rs | 80 |
20 files changed, 929 insertions, 308 deletions
diff --git a/src/configuration/config.rs b/src/configuration/config.rs index aec943f..59cef4a 100644 --- a/src/configuration/config.rs +++ b/src/configuration/config.rs @@ -1,3 +1,4 @@ +use std::mem; use std::net::{IpAddr, SocketAddr}; use std::sync::atomic::Ordering; use std::sync::{Arc, Mutex, MutexGuard}; @@ -205,7 +206,7 @@ impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireGuardConfig<T, B> { } fn get_fwmark(&self) -> Option<u32> { - self.lock().bind.as_ref().and_then(|own| own.get_fwmark()) + self.lock().fwmark } fn set_private_key(&self, sk: Option<StaticSecret>) { @@ -266,24 +267,22 @@ impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireGuardConfig<T, B> { fn set_listen_port(&self, port: u16) -> Result<(), ConfigError> { log::trace!("Config, Set listen port: {:?}", port); - // update port - let listen: bool = { + // update port and take old bind + let old: Option<B::Owner> = { let mut cfg = self.lock(); + let old = mem::replace(&mut cfg.bind, None); cfg.port = port; - if cfg.bind.is_some() { - cfg.bind = None; - true - } else { - false - } + old }; // restart listener if bound - if listen { + if old.is_some() { self.start_listener() } else { Ok(()) } + + // old bind is dropped, causing the file-descriptors to be released } fn set_fwmark(&self, mark: Option<u32>) -> Result<(), ConfigError> { diff --git a/src/configuration/uapi/get.rs b/src/configuration/uapi/get.rs index 9e6ab36..00048cd 100644 --- a/src/configuration/uapi/get.rs +++ b/src/configuration/uapi/get.rs @@ -2,7 +2,6 @@ use log; use std::io; use super::Configuration; -use super::Endpoint; pub fn serialize<C: Configuration, W: io::Write>(writer: &mut W, config: &C) -> io::Result<()> { let mut write = |key: &'static str, value: String| { @@ -46,7 +45,7 @@ pub fn serialize<C: Configuration, W: io::Write>(writer: &mut W, config: &C) -> } if let Some(endpoint) = p.endpoint { - write("endpoint", endpoint.into_address().to_string())?; + write("endpoint", endpoint.to_string())?; } for (ip, cidr) in p.allowed_ips { diff --git a/src/configuration/uapi/mod.rs b/src/configuration/uapi/mod.rs index 4f0b741..9f54775 100644 --- a/src/configuration/uapi/mod.rs +++ b/src/configuration/uapi/mod.rs @@ -4,7 +4,6 @@ mod set; use log; use std::io::{Read, Write}; -use super::Endpoint; use super::{ConfigError, Configuration}; use get::serialize; diff --git a/src/main.rs b/src/main.rs index a0f4a23..a8e4ad2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -112,6 +112,8 @@ fn main() { .try_init() .expect("Failed to initialize event logger"); + log::info!("starting {} wireguard device", name); + // drop privileges if drop_privileges {} diff --git a/src/platform/dummy/tun.rs b/src/platform/dummy/tun.rs index 9836b48..1955884 100644 --- a/src/platform/dummy/tun.rs +++ b/src/platform/dummy/tun.rs @@ -165,8 +165,7 @@ impl TunTest { sync_channel(1) }; - let mut rng = OsRng::new().unwrap(); - let id: u32 = rng.gen(); + let id: u32 = OsRng.gen(); let fake = TunFakeIO { id, diff --git a/src/platform/dummy/udp.rs b/src/platform/dummy/udp.rs index 35c905d..88630af 100644 --- a/src/platform/dummy/udp.rs +++ b/src/platform/dummy/udp.rs @@ -54,7 +54,7 @@ impl Reader<UnitEndpoint> for VoidBind { impl Writer<UnitEndpoint> for VoidBind { type Error = BindError; - fn write(&self, _buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> { + fn write(&self, _buf: &[u8], _dst: &mut UnitEndpoint) -> Result<(), Self::Error> { Ok(()) } } @@ -105,7 +105,7 @@ impl Reader<UnitEndpoint> for PairReader<UnitEndpoint> { impl Writer<UnitEndpoint> for PairWriter<UnitEndpoint> { type Error = BindError; - fn write(&self, buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> { + fn write(&self, buf: &[u8], _dst: &mut UnitEndpoint) -> Result<(), Self::Error> { debug!( "dummy({}): write ({}, {})", self.id, @@ -135,9 +135,8 @@ impl PairBind { (PairReader<E>, PairWriter<E>), (PairReader<E>, PairWriter<E>), ) { - let mut rng = OsRng::new().unwrap(); - let id1: u32 = rng.gen(); - let id2: u32 = rng.gen(); + let id1: u32 = OsRng.gen(); + let id2: u32 = OsRng.gen(); let (tx1, rx1) = sync_channel(128); let (tx2, rx2) = sync_channel(128); @@ -187,10 +186,6 @@ impl Owner for VoidOwner { fn get_port(&self) -> u16 { 0 } - - fn get_fwmark(&self) -> Option<u32> { - None - } } impl PlatformUDP for PairBind { diff --git a/src/platform/linux/tun.rs b/src/platform/linux/tun.rs index c282a4b..15ca1ec 100644 --- a/src/platform/linux/tun.rs +++ b/src/platform/linux/tun.rs @@ -199,7 +199,7 @@ impl Status for LinuxTunStatus { // cut buffer to size let size: usize = size as usize; let mut remain = &buf[..size]; - log::debug!("netlink, recieved message ({} bytes)", size); + log::debug!("netlink, received message ({} bytes)", size); // handle messages while remain.len() >= HDR_SIZE { diff --git a/src/platform/linux/udp.rs b/src/platform/linux/udp.rs index f871bce..9815ab1 100644 --- a/src/platform/linux/udp.rs +++ b/src/platform/linux/udp.rs @@ -1,84 +1,683 @@ use super::super::udp::*; use super::super::Endpoint; +use log; + +use std::convert::TryInto; use std::io; -use std::net::{SocketAddr, UdpSocket}; -use std::sync::Arc; +use std::mem; +use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::os::unix::io::RawFd; +use std::ptr; + +fn errno() -> libc::c_int { + unsafe { + let ptr = libc::__errno_location(); + if ptr.is_null() { + 0 + } else { + *ptr + } + } +} + +#[repr(C, align(1))] +struct ControlHeaderV4 { + hdr: libc::cmsghdr, + info: libc::in_pktinfo, +} + +#[repr(C, align(1))] +struct ControlHeaderV6 { + hdr: libc::cmsghdr, + info: libc::in6_pktinfo, +} + +pub struct EndpointV4 { + dst: libc::sockaddr_in, // destination IP + info: libc::in_pktinfo, // src & ifindex +} + +pub struct EndpointV6 { + dst: libc::sockaddr_in6, // destination IP + info: libc::in6_pktinfo, // src & zone id +} + +pub struct LinuxUDP(); + +pub struct LinuxOwner { + port: u16, + sock4: Option<RawFd>, + sock6: Option<RawFd>, +} + +pub enum LinuxUDPReader { + V4(RawFd), + V6(RawFd), +} #[derive(Clone)] -pub struct LinuxUDP(Arc<UdpSocket>); +pub struct LinuxUDPWriter { + sock4: RawFd, + sock6: RawFd, +} -pub struct LinuxOwner(Arc<UdpSocket>); +pub enum LinuxEndpoint { + V4(EndpointV4), + V6(EndpointV6), +} -impl Endpoint for SocketAddr { - fn clear_src(&mut self) {} +impl Endpoint for LinuxEndpoint { + fn clear_src(&mut self) { + match self { + LinuxEndpoint::V4(EndpointV4 { ref mut info, .. }) => { + info.ipi_ifindex = 0; + info.ipi_spec_dst = libc::in_addr { s_addr: 0 }; + } + LinuxEndpoint::V6(EndpointV6 { ref mut info, .. }) => { + info.ipi6_addr = libc::in6_addr { s6_addr: [0; 16] }; + info.ipi6_ifindex = 0; + } + }; + } fn from_address(addr: SocketAddr) -> Self { - addr + match addr { + SocketAddr::V4(addr) => LinuxEndpoint::V4(EndpointV4 { + dst: libc::sockaddr_in { + sin_family: libc::AF_INET as libc::sa_family_t, + sin_port: addr.port().to_be(), + sin_addr: libc::in_addr { + s_addr: u32::from(*addr.ip()).to_be(), + }, + sin_zero: [0; 8], + }, + info: libc::in_pktinfo { + ipi_ifindex: 0, // interface (0 is via routing table) + ipi_spec_dst: libc::in_addr { s_addr: 0 }, // src IP (dst of incoming packet) + ipi_addr: libc::in_addr { s_addr: 0 }, + }, + }), + SocketAddr::V6(addr) => LinuxEndpoint::V6(EndpointV6 { + dst: libc::sockaddr_in6 { + sin6_family: libc::AF_INET6 as libc::sa_family_t, + sin6_port: addr.port().to_be(), + sin6_flowinfo: addr.flowinfo(), + sin6_addr: libc::in6_addr { + s6_addr: addr.ip().octets(), + }, + sin6_scope_id: addr.scope_id(), + }, + info: libc::in6_pktinfo { + ipi6_addr: libc::in6_addr { s6_addr: [0; 16] }, // src IP + ipi6_ifindex: 0, // zone id + }, + }), + } } fn into_address(&self) -> SocketAddr { - *self + match self { + LinuxEndpoint::V4(EndpointV4 { ref dst, .. }) => { + SocketAddr::V4(SocketAddrV4::new( + u32::from_be(dst.sin_addr.s_addr).into(), // IPv4 addr + u16::from_be(dst.sin_port), // convert back to native byte-order + )) + } + LinuxEndpoint::V6(EndpointV6 { ref dst, .. }) => SocketAddr::V6(SocketAddrV6::new( + u128::from_ne_bytes(dst.sin6_addr.s6_addr).into(), // IPv6 addr + u16::from_be(dst.sin6_port), // convert back to native byte-order + dst.sin6_flowinfo, + dst.sin6_scope_id, + )), + } } } -impl Reader<SocketAddr> for LinuxUDP { - type Error = io::Error; +fn setsockopt<V: Sized>( + fd: RawFd, + level: libc::c_int, + name: libc::c_int, + value: &V, +) -> Result<(), io::Error> { + let res = unsafe { + libc::setsockopt( + fd, + level, + name, + mem::transmute(value), + mem::size_of_val(value).try_into().unwrap(), + ) + }; + if res == 0 { + Ok(()) + } else { + Err(io::Error::new( + io::ErrorKind::Other, + format!("Failed to set sockopt (res = {}, errno = {})", res, errno()), + )) + } +} + +#[inline(always)] +fn setsockopt_int( + fd: RawFd, + level: libc::c_int, + name: libc::c_int, + value: libc::c_int, +) -> Result<(), io::Error> { + setsockopt(fd, level, name, &value) +} - fn read(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error> { - self.0.recv_from(buf) +#[allow(non_snake_case)] +const fn CMSG_ALIGN(len: usize) -> usize { + (((len) + mem::size_of::<u32>() - 1) & !(mem::size_of::<u32>() - 1)) +} + +#[allow(non_snake_case)] +const fn CMSG_LEN(len: usize) -> usize { + CMSG_ALIGN(len + mem::size_of::<libc::cmsghdr>()) +} + +#[inline(always)] +fn safe_cast<T, D>(v: &mut T) -> *mut D { + (v as *mut T) as *mut D +} + +impl LinuxUDPReader { + fn read6(fd: RawFd, buf: &mut [u8]) -> Result<(usize, LinuxEndpoint), io::Error> { + log::trace!( + "receive IPv6 packet (block), (fd {}, max-len {})", + fd, + buf.len() + ); + + let mut iovs: [libc::iovec; 1] = [libc::iovec { + iov_base: buf.as_mut_ptr() as *mut core::ffi::c_void, + iov_len: buf.len(), + }]; + let mut src: libc::sockaddr_in6 = unsafe { mem::MaybeUninit::uninit().assume_init() }; + let mut control: ControlHeaderV6 = unsafe { mem::MaybeUninit::uninit().assume_init() }; + let mut hdr = libc::msghdr { + msg_name: safe_cast(&mut src), + msg_namelen: mem::size_of::<libc::sockaddr_in6> as u32, + msg_iov: iovs.as_mut_ptr(), + msg_iovlen: iovs.len(), + msg_control: safe_cast(&mut control), + msg_controllen: mem::size_of::<ControlHeaderV6>(), + msg_flags: 0, + }; + + debug_assert!( + hdr.msg_controllen + >= mem::size_of::<libc::cmsghdr>() + mem::size_of::<libc::in6_pktinfo>(), + ); + + let len = unsafe { libc::recvmsg(fd, &mut hdr as *mut libc::msghdr, 0) }; + + if len < 0 { + return Err(io::Error::new( + io::ErrorKind::NotConnected, + "failed to receive", + )); + } + + Ok(( + len.try_into().unwrap(), + LinuxEndpoint::V6(EndpointV6 { + info: control.info, // save pktinfo (sticky source) + dst: src, // our future destination is the source address + }), + )) + } + + fn read4(fd: RawFd, buf: &mut [u8]) -> Result<(usize, LinuxEndpoint), io::Error> { + log::trace!( + "receive IPv4 packet (block), (fd {}, max-len {})", + fd, + buf.len() + ); + + let mut iovs: [libc::iovec; 1] = [libc::iovec { + iov_base: buf.as_mut_ptr() as *mut core::ffi::c_void, + iov_len: buf.len(), + }]; + let mut src: libc::sockaddr_in = unsafe { mem::MaybeUninit::uninit().assume_init() }; + let mut control: ControlHeaderV4 = unsafe { mem::MaybeUninit::uninit().assume_init() }; + let mut hdr = libc::msghdr { + msg_name: safe_cast(&mut src), + msg_namelen: mem::size_of::<libc::sockaddr_in> as u32, + msg_iov: iovs.as_mut_ptr(), + msg_iovlen: iovs.len(), + msg_control: safe_cast(&mut control), + msg_controllen: mem::size_of::<ControlHeaderV4>(), + msg_flags: 0, + }; + + debug_assert!( + hdr.msg_controllen + >= mem::size_of::<libc::cmsghdr>() + mem::size_of::<libc::in_pktinfo>(), + ); + + let len = unsafe { libc::recvmsg(fd, &mut hdr as *mut libc::msghdr, 0) }; + + if len < 0 { + return Err(io::Error::new( + io::ErrorKind::NotConnected, + "failed to receive", + )); + } + + Ok(( + len.try_into().unwrap(), + LinuxEndpoint::V4(EndpointV4 { + info: control.info, // save pktinfo (sticky source) + dst: src, // our future destination is the source address + }), + )) } } -impl Writer<SocketAddr> for LinuxUDP { +impl Reader<LinuxEndpoint> for LinuxUDPReader { type Error = io::Error; - fn write(&self, buf: &[u8], dst: &SocketAddr) -> Result<(), Self::Error> { - self.0.send_to(buf, dst)?; + fn read(&self, buf: &mut [u8]) -> Result<(usize, LinuxEndpoint), Self::Error> { + match self { + Self::V4(fd) => Self::read4(*fd, buf), + Self::V6(fd) => Self::read6(*fd, buf), + } + } +} + +impl LinuxUDPWriter { + fn write6(fd: RawFd, buf: &[u8], dst: &mut EndpointV6) -> Result<(), io::Error> { + log::debug!("sending IPv6 packet ({} fd, {} bytes)", fd, buf.len()); + + let mut iovs: [libc::iovec; 1] = [libc::iovec { + iov_base: buf.as_ptr() as *mut core::ffi::c_void, + iov_len: buf.len(), + }]; + + let mut control = ControlHeaderV6 { + hdr: libc::cmsghdr { + cmsg_len: CMSG_LEN(mem::size_of::<libc::in6_pktinfo>()), + cmsg_level: libc::IPPROTO_IPV6, + cmsg_type: libc::IPV6_PKTINFO, + }, + info: dst.info, + }; + + debug_assert_eq!( + control.hdr.cmsg_len % mem::size_of::<u32>(), + 0, + "cmsg_len must be aligned to a long" + ); + + debug_assert_eq!( + dst.dst.sin6_family, + libc::AF_INET6 as libc::sa_family_t, + "this method only handles IPv6 destinations" + ); + + let mut hdr = libc::msghdr { + msg_name: safe_cast(&mut dst.dst), + msg_namelen: mem::size_of_val(&dst.dst).try_into().unwrap(), + msg_iov: iovs.as_mut_ptr(), + msg_iovlen: iovs.len(), + msg_control: safe_cast(&mut control), + msg_controllen: mem::size_of_val(&control), + msg_flags: 0, + }; + + let ret = unsafe { libc::sendmsg(fd, &hdr, 0) }; + + if ret < 0 { + if errno() == libc::EINVAL { + log::trace!("clear source and retry"); + hdr.msg_control = ptr::null_mut(); + hdr.msg_controllen = 0; + dst.info = unsafe { mem::zeroed() }; + if unsafe { libc::sendmsg(fd, &hdr, 0) } < 0 { + return Err(io::Error::new( + io::ErrorKind::NotConnected, + "failed to send IPv6 packet", + )); + } else { + return Ok(()); + } + } + return Err(io::Error::new( + io::ErrorKind::NotConnected, + "failed to send IPv6 packet", + )); + } + + Ok(()) + } + + fn write4(fd: RawFd, buf: &[u8], dst: &mut EndpointV4) -> Result<(), io::Error> { + log::debug!("sending IPv4 packet ({} fd, {} bytes)", fd, buf.len()); + + let mut iovs: [libc::iovec; 1] = [libc::iovec { + iov_base: buf.as_ptr() as *mut core::ffi::c_void, + iov_len: buf.len(), + }]; + + let mut control = ControlHeaderV4 { + hdr: libc::cmsghdr { + cmsg_len: CMSG_LEN(mem::size_of::<libc::in_pktinfo>()), + cmsg_level: libc::IPPROTO_IP, + cmsg_type: libc::IP_PKTINFO, + }, + info: dst.info, + }; + + debug_assert_eq!( + control.hdr.cmsg_len % mem::size_of::<u32>(), + 0, + "cmsg_len must be aligned to a long" + ); + + debug_assert_eq!( + dst.dst.sin_family, + libc::AF_INET as libc::sa_family_t, + "this method only handles IPv4 destinations" + ); + + let mut hdr = libc::msghdr { + msg_name: safe_cast(&mut dst.dst), + msg_namelen: mem::size_of_val(&dst.dst).try_into().unwrap(), + msg_iov: iovs.as_mut_ptr(), + msg_iovlen: iovs.len(), + msg_control: safe_cast(&mut control), + msg_controllen: mem::size_of_val(&control), + msg_flags: 0, + }; + + let ret = unsafe { libc::sendmsg(fd, &hdr, 0) }; + + if ret < 0 { + if errno() == libc::EINVAL { + log::trace!("clear source and retry"); + hdr.msg_control = ptr::null_mut(); + hdr.msg_controllen = 0; + dst.info = unsafe { mem::zeroed() }; + if unsafe { libc::sendmsg(fd, &hdr, 0) } < 0 { + return Err(io::Error::new( + io::ErrorKind::NotConnected, + "failed to send IPv4 packet", + )); + } else { + return Ok(()); + } + } + return Err(io::Error::new( + io::ErrorKind::NotConnected, + "failed to send IPv4 packet", + )); + } + Ok(()) } } -impl Owner for LinuxOwner { +impl Writer<LinuxEndpoint> for LinuxUDPWriter { type Error = io::Error; - fn get_port(&self) -> u16 { - self.0.local_addr().unwrap().port() // todo handle + fn write(&self, buf: &[u8], dst: &mut LinuxEndpoint) -> Result<(), Self::Error> { + match dst { + LinuxEndpoint::V4(ref mut end) => Self::write4(self.sock4, buf, end), + LinuxEndpoint::V6(ref mut end) => Self::write6(self.sock6, buf, end), + } } +} - fn get_fwmark(&self) -> Option<u32> { - None +impl Owner for LinuxOwner { + type Error = io::Error; + + fn get_port(&self) -> u16 { + self.port } - fn set_fwmark(&mut self, _value: Option<u32>) -> Result<(), Self::Error> { - Ok(()) + fn set_fwmark(&mut self, value: Option<u32>) -> Result<(), Self::Error> { + fn set_mark(fd: Option<RawFd>, value: u32) -> Result<(), io::Error> { + if let Some(fd) = fd { + setsockopt(fd, libc::SOL_SOCKET, libc::SO_MARK, &value) + } else { + Ok(()) + } + } + let value = value.unwrap_or(0); + set_mark(self.sock6, value)?; + set_mark(self.sock4, value) } } impl Drop for LinuxOwner { fn drop(&mut self) { - // TODO: close udp bind + log::trace!("closing the bind (port {})", self.port); + self.sock4.map(|fd| unsafe { + libc::shutdown(fd, libc::SHUT_RDWR); + libc::close(fd) + }); + self.sock6.map(|fd| unsafe { + libc::shutdown(fd, libc::SHUT_RDWR); + libc::close(fd) + }); } } impl UDP for LinuxUDP { type Error = io::Error; - type Endpoint = SocketAddr; - type Reader = Self; - type Writer = Self; + type Endpoint = LinuxEndpoint; + type Reader = LinuxUDPReader; + type Writer = LinuxUDPWriter; +} + +impl LinuxUDP { + /* Bind on all IPv6 interfaces + * + * Arguments: + * + * - 'port', port to bind to (0 = any) + * + * Returns: + * + * Returns a tuple of the resulting port and socket. + */ + fn bind6(port: u16) -> Result<(u16, RawFd), io::Error> { + log::trace!("attempting to bind on IPv6 (port {})", port); + + // create socket fd + let fd: RawFd = unsafe { libc::socket(libc::AF_INET6, libc::SOCK_DGRAM, 0) }; + if fd < 0 { + log::debug!("failed to create IPv6 socket"); + return Err(io::Error::new( + io::ErrorKind::Other, + "failed to create socket", + )); + } + + setsockopt_int(fd, libc::SOL_SOCKET, libc::SO_REUSEADDR, 1)?; + setsockopt_int(fd, libc::IPPROTO_IPV6, libc::IPV6_RECVPKTINFO, 1)?; + setsockopt_int(fd, libc::IPPROTO_IPV6, libc::IPV6_V6ONLY, 1)?; + + // bind + let mut sockaddr = libc::sockaddr_in6 { + sin6_addr: libc::in6_addr { s6_addr: [0; 16] }, + sin6_family: libc::AF_INET6 as libc::sa_family_t, + sin6_port: port.to_be(), // convert to network (big-endian) byteorder + sin6_scope_id: 0, + sin6_flowinfo: 0, + }; + + let err = unsafe { + libc::bind( + fd, + mem::transmute(&sockaddr as *const libc::sockaddr_in6), + mem::size_of_val(&sockaddr).try_into().unwrap(), + ) + }; + + if err != 0 { + log::debug!("failed to bind IPv6 socket"); + return Err(io::Error::new( + io::ErrorKind::Other, + "failed to create socket", + )); + } + + // get the assigned port + let mut socklen: libc::socklen_t = mem::size_of_val(&sockaddr).try_into().unwrap(); + let err = unsafe { + libc::getsockname( + fd, + mem::transmute(&mut sockaddr as *mut libc::sockaddr_in6), + &mut socklen as *mut libc::socklen_t, + ) + }; + if err != 0 { + log::debug!("failed to get port of IPv6 socket"); + return Err(io::Error::new( + io::ErrorKind::Other, + "failed to create socket", + )); + } + + // basic sanity checks + let new_port = u16::from_be(sockaddr.sin6_port); + debug_assert_eq!(socklen, mem::size_of::<libc::sockaddr_in6>() as u32); + debug_assert_eq!(sockaddr.sin6_family, libc::AF_INET6 as libc::sa_family_t); + debug_assert_eq!(new_port, if port != 0 { port } else { new_port }); + log::trace!("bound IPv6 socket (port {}, fd {})", new_port, fd); + return Ok((new_port, fd)); + } + + /* Bind on all IPv4 interfaces. + * + * Arguments: + * + * - 'port', port to bind to (0 = any) + * + * Returns: + * + * Returns a tuple of the resulting port and socket. + */ + fn bind4(port: u16) -> Result<(u16, RawFd), io::Error> { + log::trace!("attempting to bind on IPv4 (port {})", port); + + // create socket fd + let fd: RawFd = unsafe { libc::socket(libc::AF_INET, libc::SOCK_DGRAM, 0) }; + if fd < 0 { + log::trace!("failed to create IPv4 socket (errno = {})", errno()); + return Err(io::Error::new( + io::ErrorKind::Other, + "failed to create socket", + )); + } + + setsockopt_int(fd, libc::SOL_SOCKET, libc::SO_REUSEADDR, 1)?; + setsockopt_int(fd, libc::IPPROTO_IP, libc::IP_PKTINFO, 1)?; + + const INADDR_ANY: libc::in_addr = libc::in_addr { s_addr: 0 }; + + // bind + let mut sockaddr = libc::sockaddr_in { + sin_addr: INADDR_ANY, + sin_family: libc::AF_INET as libc::sa_family_t, + sin_port: port.to_be(), + sin_zero: [0; 8], + }; + + let err = unsafe { + libc::bind( + fd, + mem::transmute(&sockaddr as *const libc::sockaddr_in), + mem::size_of_val(&sockaddr).try_into().unwrap(), + ) + }; + + if err != 0 { + log::trace!("failed to bind IPv4 socket (errno = {})", errno()); + return Err(io::Error::new( + io::ErrorKind::Other, + "failed to create socket", + )); + } + + // get the assigned port + let mut socklen: libc::socklen_t = mem::size_of_val(&sockaddr).try_into().unwrap(); + let err = unsafe { + libc::getsockname( + fd, + mem::transmute(&mut sockaddr as *mut libc::sockaddr_in), + &mut socklen as *mut libc::socklen_t, + ) + }; + if err != 0 { + log::trace!("failed to get port of IPv4 socket (errno = {})", errno()); + return Err(io::Error::new( + io::ErrorKind::Other, + "failed to create socket", + )); + } + + // basic sanity checks + let new_port = u16::from_be(sockaddr.sin_port); + debug_assert_eq!(socklen, mem::size_of::<libc::sockaddr_in>() as u32); + debug_assert_eq!(sockaddr.sin_family, libc::AF_INET as libc::sa_family_t); + debug_assert_eq!(new_port, if port != 0 { port } else { new_port }); + log::trace!("bound IPv4 socket (port {}, fd {})", new_port, fd); + return Ok((new_port, fd)); + } } impl PlatformUDP for LinuxUDP { type Owner = LinuxOwner; - fn bind(port: u16) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Owner), Self::Error> { - let socket = UdpSocket::bind(format!("0.0.0.0:{}", port))?; - let socket = Arc::new(socket); + fn bind(mut port: u16) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Owner), Self::Error> { + log::debug!("bind to port {}", port); - Ok(( - vec![LinuxUDP(socket.clone())], - LinuxUDP(socket.clone()), - LinuxOwner(socket), - )) + // attempt to bind on ipv6 + let bind6 = Self::bind6(port); + if let Ok((new_port, _)) = bind6 { + port = new_port; + } + + // attempt to bind on ipv4 on the same port + let bind4 = Self::bind4(port); + if let Ok((new_port, _)) = bind4 { + port = new_port; + } + + // check if failed to bind on both + if bind4.is_err() && bind6.is_err() { + log::trace!("failed to bind for either IP version"); + return Err(bind6.unwrap_err()); + } + + let sock6 = bind6.ok().map(|(_, fd)| fd); + let sock4 = bind4.ok().map(|(_, fd)| fd); + + // create owner + let owner = LinuxOwner { + port, + sock6: sock6, + sock4: sock4, + }; + + // create readers + let mut readers: Vec<Self::Reader> = Vec::with_capacity(2); + sock6.map(|sock| readers.push(LinuxUDPReader::V6(sock))); + sock4.map(|sock| readers.push(LinuxUDPReader::V4(sock))); + debug_assert!(readers.len() > 0); + + // create writer + let writer = LinuxUDPWriter { + sock4: sock4.unwrap_or(-1), + sock6: sock6.unwrap_or(-1), + }; + + Ok((readers, writer, owner)) } } diff --git a/src/platform/udp.rs b/src/platform/udp.rs index 3671229..e1180fb 100644 --- a/src/platform/udp.rs +++ b/src/platform/udp.rs @@ -10,7 +10,7 @@ pub trait Reader<E: Endpoint>: Send + Sync { pub trait Writer<E: Endpoint>: Send + Sync + Clone + 'static { type Error: Error; - fn write(&self, buf: &[u8], dst: &E) -> Result<(), Self::Error>; + fn write(&self, buf: &[u8], dst: &mut E) -> Result<(), Self::Error>; } pub trait UDP: Send + Sync + 'static { @@ -30,8 +30,6 @@ pub trait Owner: Send { fn get_port(&self) -> u16; - fn get_fwmark(&self) -> Option<u32>; - fn set_fwmark(&mut self, value: Option<u32>) -> Result<(), Self::Error>; } diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs index edd1a07..4b5d8f6 100644 --- a/src/wireguard/handshake/device.rs +++ b/src/wireguard/handshake/device.rs @@ -1,4 +1,5 @@ use spin::RwLock; +use std::collections::hash_map; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Mutex; @@ -6,7 +7,10 @@ use zerocopy::AsBytes; use byteorder::{ByteOrder, LittleEndian}; -use rand::prelude::*; +use rand::Rng; +use rand_core::{CryptoRng, RngCore}; + +use clear_on_drop::clear::Clear; use x25519_dalek::PublicKey; use x25519_dalek::StaticSecret; @@ -22,42 +26,101 @@ use super::types::*; const MAX_PEER_PER_DEVICE: usize = 1 << 20; pub struct KeyState { - pub sk: StaticSecret, // static secret key - pub pk: PublicKey, // static public key - macs: macs::Validator, // validator for the mac fields + pub(super) sk: StaticSecret, // static secret key + pub(super) pk: PublicKey, // static public key + macs: macs::Validator, // validator for the mac fields } -pub struct Device { - keyst: Option<KeyState>, // secret/public key - pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state - id_map: RwLock<HashMap<u32, [u8; 32]>>, // receiver ids -> public key +/// The device is generic over an "opaque" type +/// which can be used to associate the public key with this value. +/// (the instance is a Peer object in the parent module) +pub struct Device<O> { + keyst: Option<KeyState>, + id_map: RwLock<HashMap<u32, [u8; 32]>>, + pk_map: HashMap<[u8; 32], Peer<O>>, limiter: Mutex<RateLimiter>, } +pub struct Iter<'a, O> { + iter: hash_map::Iter<'a, [u8; 32], Peer<O>>, +} + +impl<'a, O> Iterator for Iter<'a, O> { + type Item = (PublicKey, &'a O); + + fn next(&mut self) -> Option<Self::Item> { + self.iter + .next() + .map(|(pk, peer)| (PublicKey::from(*pk), &peer.opaque)) + } +} + +/* These methods enable the Device to act as a map + * from public keys to the set of contained opaque values. + * + * It also abstracts away the problem of PublicKey not being hashable. + */ +impl<O> Device<O> { + pub fn clear(&mut self) { + self.id_map.write().clear(); + self.pk_map.clear(); + } + + pub fn len(&self) -> usize { + self.pk_map.len() + } + + /// Enables enumeration of (public key, opaque) pairs + /// without exposing internal peer type. + pub fn iter(&self) -> Iter<O> { + Iter { + iter: self.pk_map.iter(), + } + } + + /// Enables lookup by public key without exposing internal peer type. + pub fn get(&self, pk: &PublicKey) -> Option<&O> { + self.pk_map.get(pk.as_bytes()).map(|peer| &peer.opaque) + } + + pub fn contains_key(&self, pk: &PublicKey) -> bool { + self.pk_map.contains_key(pk.as_bytes()) + } +} + /* A mutable reference to the device needs to be held during configuration. * Wrapping the device in a RwLock enables peer config after "configuration time" */ -impl Device { +impl<O> Device<O> { /// Initialize a new handshake state machine - pub fn new() -> Device { + pub fn new() -> Device<O> { Device { keyst: None, - pk_map: HashMap::new(), id_map: RwLock::new(HashMap::new()), + pk_map: HashMap::new(), limiter: Mutex::new(RateLimiter::new()), } } - fn update_ss(&self, peer: &mut Peer) -> Option<PublicKey> { - if let Some(key) = self.keyst.as_ref() { - if *peer.pk.as_bytes() == *key.pk.as_bytes() { - return Some(peer.pk); + fn update_ss(&mut self) -> (Vec<u32>, Option<PublicKey>) { + let mut same = None; + let mut ids = Vec::with_capacity(self.pk_map.len()); + for (pk, peer) in self.pk_map.iter_mut() { + if let Some(key) = self.keyst.as_ref() { + if key.pk.as_bytes() == pk { + same = Some(PublicKey::from(*pk)); + peer.ss.clear() + } else { + let pk = PublicKey::from(*pk); + peer.ss = *key.sk.diffie_hellman(&pk).as_bytes(); + } + } else { + peer.ss.clear(); } - peer.ss = *key.sk.diffie_hellman(&peer.pk).as_bytes(); - } else { - peer.ss = [0u8; 32]; - }; - None + peer.reset_state().map(|id| ids.push(id)); + } + + (ids, same) } /// Update the secret key of the device @@ -74,29 +137,15 @@ impl Device { }); // recalculate / erase the shared secrets for every peer - let mut ids = vec![]; - let mut same = None; - for mut peer in self.pk_map.values_mut() { - // clear any existing handshake state - peer.reset_state().map(|id| ids.push(id)); - - // update precomputed shared secret - if let Some(key) = self.keyst.as_ref() { - peer.ss = *key.sk.diffie_hellman(&peer.pk).as_bytes(); - if *peer.pk.as_bytes() == *key.pk.as_bytes() { - same = Some(peer.pk) - } - } else { - peer.ss = [0u8; 32]; - }; - } + let (ids, same) = self.update_ss(); // release ids from aborted handshakes for id in ids { self.release(id) } - // if we found a peer matching the device public key, remove it. + // if we found a peer matching the device public key + // remove it and return its value to the caller same.map(|pk| { self.pk_map.remove(pk.as_bytes()); pk @@ -119,29 +168,32 @@ impl Device { /// /// * `pk` - The public key to add /// * `identifier` - Associated identifier which can be used to distinguish the peers - pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> { + pub fn add(&mut self, pk: PublicKey, opaque: O) -> Result<(), ConfigError> { // ensure less than 2^20 peers if self.pk_map.len() > MAX_PEER_PER_DEVICE { return Err(ConfigError::new("Too many peers for device")); } - // create peer and precompute static secret - let mut peer = Peer::new( - pk, - self.keyst - .as_ref() - .map(|key| *key.sk.diffie_hellman(&pk).as_bytes()) - .unwrap_or([0u8; 32]), - ); - - // add peer to device - match self.update_ss(&mut peer) { - Some(_) => Err(ConfigError::new("Public key of peer matches the device")), - None => { - self.pk_map.insert(*pk.as_bytes(), peer); - Ok(()) + // error if public key matches device + if let Some(key) = self.keyst.as_ref() { + if pk.as_bytes() == key.pk.as_bytes() { + return Err(ConfigError::new("Public key of peer matches the device")); } } + + // pre-compute shared secret and add to pk_map + self.pk_map.insert( + *pk.as_bytes(), + Peer::new( + pk, + self.keyst + .as_ref() + .map(|key| *key.sk.diffie_hellman(&pk).as_bytes()) + .unwrap_or([0u8; 32]), + opaque, + ), + ); + Ok(()) } /// Remove a peer by public key @@ -163,7 +215,7 @@ impl Device { .remove(pk.as_bytes()) .ok_or(ConfigError::new("Public key not in device"))?; - // pruge the id map (linear scan) + // purge the id map (linear scan) id_map.retain(|_, v| v != pk.as_bytes()); Ok(()) } @@ -231,11 +283,11 @@ impl Device { (_, None) => Err(HandshakeError::UnknownPublicKey), (None, _) => Err(HandshakeError::UnknownPublicKey), (Some(keyst), Some(peer)) => { - let local = self.allocate(rng, peer); + let local = self.allocate(rng, pk); let mut msg = Initiation::default(); // create noise part of initation - noise::create_initiation(rng, keyst, peer, local, &mut msg.noise)?; + noise::create_initiation(rng, keyst, peer, pk, local, &mut msg.noise)?; // add macs to initation peer.macs @@ -253,11 +305,11 @@ impl Device { /// /// * `msg` - Byte slice containing the message (untrusted input) pub fn process<'a, R: RngCore + CryptoRng>( - &self, - rng: &mut R, // rng instance to sample randomness from - msg: &[u8], // message buffer + &'a self, + rng: &mut R, // rng instance to sample randomness from + msg: &[u8], // message buffer src: Option<SocketAddr>, // optional source endpoint, set when "under load" - ) -> Result<Output, HandshakeError> { + ) -> Result<Output<'a, O>, HandshakeError> { // ensure type read in-range if msg.len() < 4 { return Err(HandshakeError::InvalidMessageFormat); @@ -303,17 +355,17 @@ impl Device { } // consume the initiation - let (peer, st) = noise::consume_initiation(self, keyst, &msg.noise)?; + let (peer, pk, st) = noise::consume_initiation(self, keyst, &msg.noise)?; // allocate new index for response - let local = self.allocate(rng, peer); + let local = self.allocate(rng, &pk); // prepare memory for response, TODO: take slice for zero allocation let mut resp = Response::default(); // create response (release id on error) - let keys = - noise::create_response(rng, peer, local, st, &mut resp.noise).map_err(|e| { + let keys = noise::create_response(rng, peer, &pk, local, st, &mut resp.noise) + .map_err(|e| { self.release(local); e })?; @@ -324,7 +376,11 @@ impl Device { .generate(resp.noise.as_bytes(), &mut resp.macs); // return unconfirmed keypair and the response as vector - Ok((Some(peer.pk), Some(resp.as_bytes().to_owned()), Some(keys))) + Ok(( + Some(&peer.opaque), + Some(resp.as_bytes().to_owned()), + Some(keys), + )) } TYPE_RESPONSE => { let msg = Response::parse(msg)?; @@ -363,7 +419,7 @@ impl Device { let msg = CookieReply::parse(msg)?; // lookup peer - let peer = self.lookup_id(msg.f_receiver.get())?; + let (peer, _) = self.lookup_id(msg.f_receiver.get())?; // validate cookie reply peer.macs.lock().process(&msg)?; @@ -379,7 +435,7 @@ impl Device { // Internal function // // Return the peer associated with the public key - pub(crate) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer, HandshakeError> { + pub(super) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer<O>, HandshakeError> { self.pk_map .get(pk.as_bytes()) .ok_or(HandshakeError::UnknownPublicKey) @@ -388,11 +444,11 @@ impl Device { // Internal function // // Return the peer currently associated with the receiver identifier - pub(crate) fn lookup_id(&self, id: u32) -> Result<&Peer, HandshakeError> { + pub(super) fn lookup_id(&self, id: u32) -> Result<(&Peer<O>, PublicKey), HandshakeError> { let im = self.id_map.read(); let pk = im.get(&id).ok_or(HandshakeError::UnknownReceiverId)?; match self.pk_map.get(pk) { - Some(peer) => Ok(peer), + Some(peer) => Ok((peer, PublicKey::from(*pk))), _ => unreachable!(), // if the id-lookup succeeded, the peer should exist } } @@ -400,7 +456,7 @@ impl Device { // Internal function // // Allocated a new receiver identifier for the peer - fn allocate<R: RngCore + CryptoRng>(&self, rng: &mut R, peer: &Peer) -> u32 { + fn allocate<R: RngCore + CryptoRng>(&self, rng: &mut R, pk: &PublicKey) -> u32 { loop { let id = rng.gen(); @@ -412,7 +468,7 @@ impl Device { // take write lock and add index let mut m = self.id_map.write(); if !m.contains_key(&id) { - m.insert(id, *peer.pk.as_bytes()); + m.insert(id, *pk.as_bytes()); return id; } } diff --git a/src/wireguard/handshake/macs.rs b/src/wireguard/handshake/macs.rs index 689826b..cb5d7d4 100644 --- a/src/wireguard/handshake/macs.rs +++ b/src/wireguard/handshake/macs.rs @@ -286,8 +286,7 @@ mod tests { use x25519_dalek::StaticSecret; fn new_validator_generator() -> (Validator, Generator) { - let mut rng = OsRng::new().unwrap(); - let sk = StaticSecret::new(&mut rng); + let sk = StaticSecret::new(&mut OsRng); let pk = PublicKey::from(&sk); (Validator::new(pk), Generator::new(pk)) } @@ -296,7 +295,6 @@ mod tests { #[test] fn test_cookie_reply(inner1 : Vec<u8>, inner2 : Vec<u8>, receiver : u32) { let mut msg = CookieReply::default(); - let mut rng = OsRng::new().expect("failed to create rng"); let mut macs = MacsFooter::default(); let src = "192.0.2.16:8080".parse().unwrap(); let (validator, mut generator) = new_validator_generator(); @@ -309,7 +307,7 @@ mod tests { // check validity of mac1 validator.check_mac1(&inner1[..], &macs).expect("mac1 of inner1 did not validate"); assert_eq!(validator.check_mac2(&inner1[..], &src, &macs), false, "mac2 of inner2 did not validate"); - validator.create_cookie_reply(&mut rng, receiver, &src, &macs, &mut msg); + validator.create_cookie_reply(&mut OsRng, receiver, &src, &macs, &mut msg); // consume cookie reply generator.process(&msg).expect("failed to process CookieReply"); diff --git a/src/wireguard/handshake/noise.rs b/src/wireguard/handshake/noise.rs index 072ac13..9e431cf 100644 --- a/src/wireguard/handshake/noise.rs +++ b/src/wireguard/handshake/noise.rs @@ -10,7 +10,7 @@ use hmac::Hmac; use aead::{Aead, NewAead, Payload}; use chacha20poly1305::ChaCha20Poly1305; -use rand::{CryptoRng, RngCore}; +use rand_core::{CryptoRng, RngCore}; use log::debug; @@ -215,20 +215,21 @@ mod tests { } } -pub fn create_initiation<R: RngCore + CryptoRng>( +pub(super) fn create_initiation<R: RngCore + CryptoRng, O>( rng: &mut R, keyst: &KeyState, - peer: &Peer, + peer: &Peer<O>, + pk: &PublicKey, local: u32, msg: &mut NoiseInitiation, ) -> Result<(), HandshakeError> { - debug!("create initation"); + debug!("create initiation"); clear_stack_on_return(CLEAR_PAGES, || { // initialize state let ck = INITIAL_CK; let hs = INITIAL_HS; - let hs = HASH!(&hs, peer.pk.as_bytes()); + let hs = HASH!(&hs, pk.as_bytes()); msg.f_type.set(TYPE_INITIATION as u32); msg.f_sender.set(local); // from us @@ -252,7 +253,7 @@ pub fn create_initiation<R: RngCore + CryptoRng>( // (C, k) := Kdf2(C, DH(E_priv, S_pub)) - let (ck, key) = KDF2!(&ck, eph_sk.diffie_hellman(&peer.pk).as_bytes()); + let (ck, key) = KDF2!(&ck, eph_sk.diffie_hellman(&pk).as_bytes()); // msg.static := Aead(k, 0, S_pub, H) @@ -297,12 +298,12 @@ pub fn create_initiation<R: RngCore + CryptoRng>( }) } -pub fn consume_initiation<'a>( - device: &'a Device, +pub(super) fn consume_initiation<'a, O>( + device: &'a Device<O>, keyst: &KeyState, msg: &NoiseInitiation, -) -> Result<(&'a Peer, TemporaryState), HandshakeError> { - debug!("consume initation"); +) -> Result<(&'a Peer<O>, PublicKey, TemporaryState), HandshakeError> { + debug!("consume initiation"); clear_stack_on_return(CLEAR_PAGES, || { // initialize new state @@ -369,13 +370,18 @@ pub fn consume_initiation<'a>( // return state (to create response) - Ok((peer, (msg.f_sender.get(), eph_r_pk, hs, ck))) + Ok(( + peer, + PublicKey::from(pk), + (msg.f_sender.get(), eph_r_pk, hs, ck), + )) }) } -pub fn create_response<R: RngCore + CryptoRng>( +pub(super) fn create_response<R: RngCore + CryptoRng, O>( rng: &mut R, - peer: &Peer, + peer: &Peer<O>, + pk: &PublicKey, local: u32, // sending identifier state: TemporaryState, // state from "consume_initiation" msg: &mut NoiseResponse, // resulting response @@ -388,7 +394,7 @@ pub fn create_response<R: RngCore + CryptoRng>( msg.f_type.set(TYPE_RESPONSE as u32); msg.f_sender.set(local); // from us - msg.f_receiver.set(receiver); // to the sender of the initation + msg.f_receiver.set(receiver); // to the sender of the initiation // (E_priv, E_pub) := DH-Generate() @@ -413,7 +419,7 @@ pub fn create_response<R: RngCore + CryptoRng>( // C := Kdf1(C, DH(E_priv, S_pub)) - let ck = KDF1!(&ck, eph_sk.diffie_hellman(&peer.pk).as_bytes()); + let ck = KDF1!(&ck, eph_sk.diffie_hellman(&pk).as_bytes()); // (C, tau, k) := Kdf3(C, Q) @@ -460,15 +466,15 @@ pub fn create_response<R: RngCore + CryptoRng>( * allow concurrent processing of potential responses to the initiation, * in order to better mitigate DoS from malformed response messages. */ -pub fn consume_response( - device: &Device, +pub(super) fn consume_response<'a, O>( + device: &'a Device<O>, keyst: &KeyState, msg: &NoiseResponse, -) -> Result<Output, HandshakeError> { +) -> Result<Output<'a, O>, HandshakeError> { debug!("consume response"); clear_stack_on_return(CLEAR_PAGES, || { // retrieve peer and copy initiation state - let peer = device.lookup_id(msg.f_receiver.get())?; + let (peer, _) = device.lookup_id(msg.f_receiver.get())?; let (hs, ck, local, eph_sk) = match *peer.state.lock() { State::InitiationSent { @@ -537,7 +543,7 @@ pub fn consume_response( // return confirmed key-pair Ok(( - Some(peer.pk), + Some(&peer.opaque), None, Some(KeyPair { birth, diff --git a/src/wireguard/handshake/peer.rs b/src/wireguard/handshake/peer.rs index a4df560..f4d15fc 100644 --- a/src/wireguard/handshake/peer.rs +++ b/src/wireguard/handshake/peer.rs @@ -22,19 +22,21 @@ const TIME_BETWEEN_INITIATIONS: Duration = Duration::from_millis(20); * * This type is only for internal use and not exposed. */ -pub struct Peer { +pub(super) struct Peer<O> { + // opaque type which identifies a peer + pub opaque: O, + // mutable state - pub(crate) state: Mutex<State>, - pub(crate) timestamp: Mutex<Option<timestamp::TAI64N>>, - pub(crate) last_initiation_consumption: Mutex<Option<Instant>>, + pub state: Mutex<State>, + pub timestamp: Mutex<Option<timestamp::TAI64N>>, + pub last_initiation_consumption: Mutex<Option<Instant>>, // state related to DoS mitigation fields - pub(crate) macs: Mutex<macs::Generator>, + pub macs: Mutex<macs::Generator>, // constant state - pub(crate) pk: PublicKey, // public key of peer - pub(crate) ss: [u8; 32], // precomputed DH(static, static) - pub(crate) psk: Psk, // psk of peer + pub ss: [u8; 32], // precomputed DH(static, static) + pub psk: Psk, // psk of peer } pub enum State { @@ -60,14 +62,14 @@ impl Drop for State { } } -impl Peer { - pub fn new(pk: PublicKey, ss: [u8; 32]) -> Self { +impl<O> Peer<O> { + pub fn new(pk: PublicKey, ss: [u8; 32], opaque: O) -> Self { Self { + opaque, macs: Mutex::new(macs::Generator::new(pk)), state: Mutex::new(State::Reset), timestamp: Mutex::new(None), last_initiation_consumption: Mutex::new(None), - pk, ss, psk: [0u8; 32], } @@ -88,7 +90,7 @@ impl Peer { /// * ts_new - The associated timestamp pub fn check_replay_flood( &self, - device: &Device, + device: &Device<O>, timestamp_new: ×tamp::TAI64N, ) -> Result<(), HandshakeError> { let mut state = self.state.lock(); diff --git a/src/wireguard/handshake/tests.rs b/src/wireguard/handshake/tests.rs index ff27b3e..bfdc5ab 100644 --- a/src/wireguard/handshake/tests.rs +++ b/src/wireguard/handshake/tests.rs @@ -12,8 +12,10 @@ use x25519_dalek::StaticSecret; use super::messages::{Initiation, Response}; -fn setup_devices<R: RngCore + CryptoRng>(rng: &mut R) -> (PublicKey, Device, PublicKey, Device) { - // generate new keypairs +fn setup_devices<R: RngCore + CryptoRng, O: Default>( + rng: &mut R, +) -> (PublicKey, Device<O>, PublicKey, Device<O>) { + // generate new key pairs let sk1 = StaticSecret::new(rng); let pk1 = PublicKey::from(&sk1); @@ -26,7 +28,7 @@ fn setup_devices<R: RngCore + CryptoRng>(rng: &mut R) -> (PublicKey, Device, Pub let mut psk = [0u8; 32]; rng.fill_bytes(&mut psk[..]); - // intialize devices on both ends + // initialize devices on both ends let mut dev1 = Device::new(); let mut dev2 = Device::new(); @@ -34,8 +36,8 @@ fn setup_devices<R: RngCore + CryptoRng>(rng: &mut R) -> (PublicKey, Device, Pub dev1.set_sk(Some(sk1)); dev2.set_sk(Some(sk2)); - dev1.add(pk2).unwrap(); - dev2.add(pk1).unwrap(); + dev1.add(pk2, O::default()).unwrap(); + dev2.add(pk1, O::default()).unwrap(); dev1.set_psk(pk2, psk).unwrap(); dev2.set_psk(pk1, psk).unwrap(); @@ -49,45 +51,44 @@ fn wait() { /* Test longest possible handshake interaction (7 messages): * - * 1. I -> R (initation) + * 1. I -> R (initiation) * 2. I <- R (cookie reply) - * 3. I -> R (initation) + * 3. I -> R (initiation) * 4. I <- R (response) * 5. I -> R (cookie reply) - * 6. I -> R (initation) + * 6. I -> R (initiation) * 7. I <- R (response) */ #[test] fn handshake_under_load() { - let mut rng = OsRng::new().unwrap(); - let (_pk1, dev1, pk2, dev2) = setup_devices(&mut rng); + let (_pk1, dev1, pk2, dev2): (_, Device<usize>, _, _) = setup_devices(&mut OsRng); let src1: SocketAddr = "172.16.0.1:8080".parse().unwrap(); let src2: SocketAddr = "172.16.0.2:7070".parse().unwrap(); - // 1. device-1 : create first initation - let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); + // 1. device-1 : create first initiation + let msg_init = dev1.begin(&mut OsRng, &pk2).unwrap(); // 2. device-2 : responds with CookieReply - let msg_cookie = match dev2.process(&mut rng, &msg_init, Some(src1)).unwrap() { + let msg_cookie = match dev2.process(&mut OsRng, &msg_init, Some(src1)).unwrap() { (None, Some(msg), None) => msg, _ => panic!("unexpected response"), }; // device-1 : processes CookieReply (no response) - match dev1.process(&mut rng, &msg_cookie, Some(src2)).unwrap() { + match dev1.process(&mut OsRng, &msg_cookie, Some(src2)).unwrap() { (None, None, None) => (), _ => panic!("unexpected response"), } - // avoid initation flood detection + // avoid initiation flood detection wait(); - // 3. device-1 : create second initation - let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); + // 3. device-1 : create second initiation + let msg_init = dev1.begin(&mut OsRng, &pk2).unwrap(); // 4. device-2 : responds with noise response - let msg_response = match dev2.process(&mut rng, &msg_init, Some(src1)).unwrap() { + let msg_response = match dev2.process(&mut OsRng, &msg_init, Some(src1)).unwrap() { (Some(_), Some(msg), Some(kp)) => { assert_eq!(kp.initiator, false); msg @@ -96,25 +97,25 @@ fn handshake_under_load() { }; // 5. device-1 : responds with CookieReply - let msg_cookie = match dev1.process(&mut rng, &msg_response, Some(src2)).unwrap() { + let msg_cookie = match dev1.process(&mut OsRng, &msg_response, Some(src2)).unwrap() { (None, Some(msg), None) => msg, _ => panic!("unexpected response"), }; // device-2 : processes CookieReply (no response) - match dev2.process(&mut rng, &msg_cookie, Some(src1)).unwrap() { + match dev2.process(&mut OsRng, &msg_cookie, Some(src1)).unwrap() { (None, None, None) => (), _ => panic!("unexpected response"), } - // avoid initation flood detection + // avoid initiation flood detection wait(); - // 6. device-1 : create third initation - let msg_init = dev1.begin(&mut rng, &pk2).unwrap(); + // 6. device-1 : create third initiation + let msg_init = dev1.begin(&mut OsRng, &pk2).unwrap(); // 7. device-2 : responds with noise response - let (msg_response, kp1) = match dev2.process(&mut rng, &msg_init, Some(src1)).unwrap() { + let (msg_response, kp1) = match dev2.process(&mut OsRng, &msg_init, Some(src1)).unwrap() { (Some(_), Some(msg), Some(kp)) => { assert_eq!(kp.initiator, false); (msg, kp) @@ -123,7 +124,7 @@ fn handshake_under_load() { }; // device-1 : process noise response - let kp2 = match dev1.process(&mut rng, &msg_response, Some(src2)).unwrap() { + let kp2 = match dev1.process(&mut OsRng, &msg_response, Some(src2)).unwrap() { (Some(_), None, Some(kp)) => { assert_eq!(kp.initiator, true); kp @@ -137,8 +138,7 @@ fn handshake_under_load() { #[test] fn handshake_no_load() { - let mut rng = OsRng::new().unwrap(); - let (pk1, mut dev1, pk2, mut dev2) = setup_devices(&mut rng); + let (pk1, mut dev1, pk2, mut dev2): (_, Device<usize>, _, _) = setup_devices(&mut OsRng); // do a few handshakes (every handshake should succeed) @@ -147,7 +147,7 @@ fn handshake_no_load() { // create initiation - let msg1 = dev1.begin(&mut rng, &pk2).unwrap(); + let msg1 = dev1.begin(&mut OsRng, &pk2).unwrap(); println!("msg1 = {} : {} bytes", hex::encode(&msg1[..]), msg1.len()); println!( @@ -158,7 +158,7 @@ fn handshake_no_load() { // process initiation and create response let (_, msg2, ks_r) = dev2 - .process(&mut rng, &msg1, None) + .process(&mut OsRng, &msg1, None) .expect("failed to process initiation"); let ks_r = ks_r.unwrap(); @@ -175,7 +175,7 @@ fn handshake_no_load() { // process response and obtain confirmed key-pair let (_, msg3, ks_i) = dev1 - .process(&mut rng, &msg2, None) + .process(&mut OsRng, &msg2, None) .expect("failed to process response"); let ks_i = ks_i.unwrap(); @@ -188,7 +188,7 @@ fn handshake_no_load() { dev1.release(ks_i.local_id()); dev2.release(ks_r.local_id()); - // avoid initation flood detection + // avoid initiation flood detection wait(); } diff --git a/src/wireguard/handshake/types.rs b/src/wireguard/handshake/types.rs index 5f984cc..ed2fcbb 100644 --- a/src/wireguard/handshake/types.rs +++ b/src/wireguard/handshake/types.rs @@ -1,10 +1,8 @@ +use super::super::types::KeyPair; + use std::error::Error; use std::fmt; -use x25519_dalek::PublicKey; - -use super::super::types::KeyPair; - /* Internal types for the noise IKpsk2 implementation */ // config error @@ -79,10 +77,10 @@ impl Error for HandshakeError { } } -pub type Output = ( - Option<PublicKey>, // external identifier associated with peer - Option<Vec<u8>>, // message to send - Option<KeyPair>, // resulting key-pair of successful handshake +pub type Output<'a, O> = ( + Option<&'a O>, // external identifier associated with peer + Option<Vec<u8>>, // message to send + Option<KeyPair>, // resulting key-pair of successful handshake ); // preshared key diff --git a/src/wireguard/peer.rs b/src/wireguard/peer.rs index 1af4df3..b3656fe 100644 --- a/src/wireguard/peer.rs +++ b/src/wireguard/peer.rs @@ -31,7 +31,7 @@ pub struct PeerInner<T: Tun, B: UDP> { pub handshake_queued: AtomicBool, // is a handshake job currently queued for the peer? // stats and configuration - pub pk: PublicKey, // public key, DISCUSS: avoid this. TODO: remove + pub pk: PublicKey, // public key pub rx_bytes: AtomicU64, // received bytes pub tx_bytes: AtomicU64, // transmitted bytes diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs index f903a8e..6c59491 100644 --- a/src/wireguard/router/device.rs +++ b/src/wireguard/router/device.rs @@ -142,7 +142,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle< }; // start worker threads - let mut threads = Vec::with_capacity(num_workers); + let mut threads = Vec::with_capacity(4 * num_workers); // inbound/decryption workers for _ in 0..num_workers { diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs index b8110f0..8fe2e1c 100644 --- a/src/wireguard/router/peer.rs +++ b/src/wireguard/router/peer.rs @@ -204,7 +204,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E, debug!("peer.send"); // send to endpoint (if known) - match self.endpoint.lock().as_ref() { + match self.endpoint.lock().as_mut() { Some(endpoint) => { let outbound = self.device.outbound.read(); if outbound.0 { diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index bf550ef..ecbb9c1 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -21,9 +21,6 @@ use std::sync::Mutex as StdMutex; use std::thread; use std::time::Instant; -use std::collections::hash_map::Entry; -use std::collections::HashMap; - use hjul::Runner; use rand::rngs::OsRng; use rand::Rng; @@ -50,14 +47,13 @@ pub struct WireguardInner<T: Tun, B: UDP> { // outbound writer pub send: RwLock<Option<B::Writer>>, - // identity and configuration map - pub peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>, + // peer map + pub peers: RwLock<handshake::Device<Peer<T, B>>>, // cryptokey router pub router: router::Device<B::Endpoint, Events<T, B>, T::Writer, B::Writer>, // handshake related state - pub handshake: RwLock<handshake::Device>, pub last_under_load: Mutex<Instant>, pub pending: AtomicUsize, // number of pending handshake packets in queue pub queue: ParallelQueue<HandshakeJob<B::Endpoint>>, @@ -142,7 +138,7 @@ impl<T: Tun, B: UDP> WireGuard<T, B> { self.router.down(); // set all peers down (stops timers) - for peer in self.peers.write().values() { + for (_, peer) in self.peers.write().iter() { peer.down(); } @@ -163,11 +159,11 @@ impl<T: Tun, B: UDP> WireGuard<T, B> { return; } - // enable tranmission from router + // enable transmission from router self.router.up(); // set all peers up (restarts timers) - for peer in self.peers.write().values() { + for (_, peer) in self.peers.write().iter() { peer.up(); } @@ -179,54 +175,51 @@ impl<T: Tun, B: UDP> WireGuard<T, B> { } pub fn remove_peer(&self, pk: &PublicKey) { - if self.handshake.write().remove(pk).is_ok() { - self.peers.write().remove(pk.as_bytes()); - } + let _ = self.peers.write().remove(pk); } pub fn lookup_peer(&self, pk: &PublicKey) -> Option<Peer<T, B>> { - self.peers.read().get(pk.as_bytes()).map(|p| p.clone()) + self.peers.read().get(pk).map(|p| p.clone()) } pub fn list_peers(&self) -> Vec<Peer<T, B>> { let peers = self.peers.read(); let mut list = Vec::with_capacity(peers.len()); for (k, v) in peers.iter() { - debug_assert!(k == v.pk.as_bytes()); + debug_assert!(k.as_bytes() == v.pk.as_bytes()); list.push(v.clone()); } list } pub fn set_key(&self, sk: Option<StaticSecret>) { - let mut handshake = self.handshake.write(); - handshake.set_sk(sk); + let mut peers = self.peers.write(); + peers.set_sk(sk); self.router.clear_sending_keys(); - // handshake lock is released and new handshakes can be initated } pub fn get_sk(&self) -> Option<StaticSecret> { - self.handshake + self.peers .read() .get_sk() .map(|sk| StaticSecret::from(sk.to_bytes())) } pub fn set_psk(&self, pk: PublicKey, psk: [u8; 32]) -> bool { - self.handshake.write().set_psk(pk, psk).is_ok() + self.peers.write().set_psk(pk, psk).is_ok() } pub fn get_psk(&self, pk: &PublicKey) -> Option<[u8; 32]> { - self.handshake.read().get_psk(pk).ok() + self.peers.read().get_psk(pk).ok() } pub fn add_peer(&self, pk: PublicKey) -> bool { - if self.peers.read().contains_key(pk.as_bytes()) { + let mut peers = self.peers.write(); + if peers.contains_key(&pk) { return false; } - let mut rng = OsRng::new().unwrap(); let state = Arc::new(PeerInner { - id: rng.gen(), + id: OsRng.gen(), pk, wg: self.clone(), walltime_last_handshake: Mutex::new(None), @@ -243,33 +236,19 @@ impl<T: Tun, B: UDP> WireGuard<T, B> { // form WireGuard peer let peer = Peer { router, state }; + // prevent up/down while inserting + let enabled = self.enabled.read(); + + /* The need for dummy timers arises from the chicken-egg + * problem of the timer callbacks being able to set timers themselves. + * + * This is in fact the only place where the write lock is ever taken. + * TODO: Consider the ease of using atomic pointers instead. + */ + *peer.timers.write() = Timers::new(&*self.runner.lock(), *enabled, peer.clone()); + // finally, add the peer to the wireguard device - let mut peers = self.peers.write(); - match peers.entry(*pk.as_bytes()) { - Entry::Occupied(_) => false, - Entry::Vacant(vacancy) => { - // check that the public key does not cause conflict with the private key of the device - let ok_pk = self.handshake.write().add(pk).is_ok(); - if !ok_pk { - return false; - } - - // prevent up/down while inserting - let enabled = self.enabled.read(); - - /* The need for dummy timers arises from the chicken-egg - * problem of the timer callbacks being able to set timers themselves. - * - * This is in fact the only place where the write lock is ever taken. - * TODO: Consider the ease of using atomic pointers instead. - */ - *peer.timers.write() = Timers::new(&*self.runner.lock(), *enabled, peer.clone()); - - // insert into peer map (takes ownership and ensures that the peer is not dropped) - vacancy.insert(peer); - true - } - } + peers.add(pk, peer).is_ok() } /// Begin consuming messages from the reader. @@ -311,9 +290,6 @@ impl<T: Tun, B: UDP> WireGuard<T, B> { // workers equal to number of physical cores let cpus = num_cpus::get(); - // create device state - let mut rng = OsRng::new().unwrap(); - // create handshake queue let (tx, mut rxs) = ParallelQueue::new(cpus, 128); @@ -322,14 +298,13 @@ impl<T: Tun, B: UDP> WireGuard<T, B> { inner: Arc::new(WireguardInner { enabled: RwLock::new(false), tun_readers: WaitCounter::new(), - id: rng.gen(), + id: OsRng.gen(), mtu: AtomicUsize::new(0), - peers: RwLock::new(HashMap::new()), last_under_load: Mutex::new(Instant::now() - TIME_HORIZON), send: RwLock::new(None), router: router::Device::new(num_cpus::get(), writer), // router owns the writing half pending: AtomicUsize::new(0), - handshake: RwLock::new(handshake::Device::new()), + peers: RwLock::new(handshake::Device::new()), runner: Mutex::new(Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY)), queue: tx, }), diff --git a/src/wireguard/workers.rs b/src/wireguard/workers.rs index e1d3899..c1a2af7 100644 --- a/src/wireguard/workers.rs +++ b/src/wireguard/workers.rs @@ -152,9 +152,6 @@ pub fn handshake_worker<T: Tun, B: UDP>( ) { debug!("{} : handshake worker, started", wg); - // prepare OsRng instance for this thread - let mut rng = OsRng::new().expect("Unable to obtain a CSPRNG"); - // process elements from the handshake queue for job in rx { // check if under load @@ -181,11 +178,11 @@ pub fn handshake_worker<T: Tun, B: UDP>( // de-multiplex staged handshake jobs and handshake messages match job { - HandshakeJob::Message(msg, src) => { + HandshakeJob::Message(msg, mut src) => { // process message - let device = wg.handshake.read(); + let device = wg.peers.read(); match device.process( - &mut rng, + &mut OsRng, &msg[..], if under_load { Some(src.into_address()) @@ -193,7 +190,7 @@ pub fn handshake_worker<T: Tun, B: UDP>( None }, ) { - Ok((pk, resp, keypair)) => { + Ok((peer, resp, keypair)) => { // send response (might be cookie reply or handshake response) let mut resp_len: u64 = 0; if let Some(msg) = resp { @@ -204,7 +201,7 @@ pub fn handshake_worker<T: Tun, B: UDP>( "{} : handshake worker, send response ({} bytes)", wg, resp_len ); - let _ = writer.write(&msg[..], &src).map_err(|e| { + let _ = writer.write(&msg[..], &mut src).map_err(|e| { debug!( "{} : handshake worker, failed to send response, error = {}", wg, @@ -215,56 +212,55 @@ pub fn handshake_worker<T: Tun, B: UDP>( } // update peer state - if let Some(pk) = pk { + if let Some(peer) = peer { // authenticated handshake packet received - if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { - // add to rx_bytes and tx_bytes - let req_len = msg.len() as u64; - peer.rx_bytes.fetch_add(req_len, Ordering::Relaxed); - peer.tx_bytes.fetch_add(resp_len, Ordering::Relaxed); - // update endpoint - peer.router.set_endpoint(src); + // add to rx_bytes and tx_bytes + let req_len = msg.len() as u64; + peer.rx_bytes.fetch_add(req_len, Ordering::Relaxed); + peer.tx_bytes.fetch_add(resp_len, Ordering::Relaxed); - if resp_len > 0 { - // update timers after sending handshake response - debug!("{} : handshake worker, handshake response sent", wg); - peer.state.sent_handshake_response(); - } else { - // update timers after receiving handshake response - debug!( - "{} : handshake worker, handshake response was received", - wg - ); - peer.state.timers_handshake_complete(); - } + // update endpoint + peer.router.set_endpoint(src); + + if resp_len > 0 { + // update timers after sending handshake response + debug!("{} : handshake worker, handshake response sent", wg); + peer.state.sent_handshake_response(); + } else { + // update timers after receiving handshake response + debug!( + "{} : handshake worker, handshake response was received", + wg + ); + peer.state.timers_handshake_complete(); + } - // add any new keypair to peer - keypair.map(|kp| { - debug!("{} : handshake worker, new keypair for {}", wg, peer); + // add any new keypair to peer + keypair.map(|kp| { + debug!("{} : handshake worker, new keypair for {}", wg, peer); - // this means that a handshake response was processed or sent - peer.timers_session_derived(); + // this means that a handshake response was processed or sent + peer.timers_session_derived(); - // free any unused ids - for id in peer.router.add_keypair(kp) { - device.release(id); - } - }); - } + // free any unused ids + for id in peer.router.add_keypair(kp) { + device.release(id); + } + }); } } Err(e) => debug!("{} : handshake worker, error = {:?}", wg, e), } } HandshakeJob::New(pk) => { - if let Some(peer) = wg.peers.read().get(pk.as_bytes()) { + if let Some(peer) = wg.peers.read().get(&pk) { debug!( "{} : handshake worker, new handshake requested for {}", wg, peer ); - let device = wg.handshake.read(); - let _ = device.begin(&mut rng, &peer.pk).map(|msg| { + let device = wg.peers.read(); + let _ = device.begin(&mut OsRng, &peer.pk).map(|msg| { let _ = peer.router.send(&msg[..]).map_err(|e| { debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e) }); |