diff options
author | Jake McGinty <me@jake.su> | 2018-04-13 20:23:53 -0700 |
---|---|---|
committer | Jake McGinty <me@jake.su> | 2018-04-22 14:08:41 -0700 |
commit | 930f4effb5abc7cb27b178657f2ec99b29da9e34 (patch) | |
tree | 23c22810350ccd28fc105361b59e0516dd3ab9fb /src | |
parent | udp: remove the unused Connected/Unconnected UDP enum (diff) | |
download | wireguard-rs-930f4effb5abc7cb27b178657f2ec99b29da9e34.tar.xz wireguard-rs-930f4effb5abc7cb27b178657f2ec99b29da9e34.zip |
udp: dual-stack single socket -> dual socket
Diffstat (limited to '')
-rw-r--r-- | src/interface/peer_server.rs | 10 | ||||
-rw-r--r-- | src/udp/frame.rs | 12 | ||||
-rw-r--r-- | src/udp/mod.rs | 192 |
3 files changed, 64 insertions, 150 deletions
diff --git a/src/interface/peer_server.rs b/src/interface/peer_server.rs index 509d8c2..2b05b94 100644 --- a/src/interface/peer_server.rs +++ b/src/interface/peer_server.rs @@ -8,7 +8,7 @@ use time::Timestamp; use timer::{Timer, TimerMessage}; use std::convert::TryInto; -use std::net::{Ipv6Addr, SocketAddr}; +use std::net::SocketAddr; use std::time::Duration; use byteorder::{ByteOrder, LittleEndian}; @@ -68,8 +68,8 @@ impl PeerServer { return Ok(()) } - let socket = UdpSocket::bind((Ipv6Addr::unspecified(), port).into(), self.handle.clone())?; - info!("listening on {:?}", socket.local_addr()?); + let socket = UdpSocket::bind(port, self.handle.clone())?; + info!("listening on {:?}", socket.local_addrs()?); let udp: UdpChannel = socket.framed().into(); @@ -126,7 +126,7 @@ impl PeerServer { let mut state = self.shared_state.borrow_mut(); { let (mac_in, mac_out) = packet.split_at(116); - self.cookie.verify_mac1(mac_in, &mac_out[..16])?; + self.cookie.verify_mac1(&mac_in[..], &mac_out[..16])?; } debug!("got handshake initiation request (0x01)"); @@ -157,7 +157,7 @@ impl PeerServer { let mut state = self.shared_state.borrow_mut(); { let (mac_in, mac_out) = packet.split_at(60); - self.cookie.verify_mac1(mac_in, &mac_out[..16])?; + self.cookie.verify_mac1(&mac_in[..], &mac_out[..16])?; } debug!("got handshake response (0x02)"); diff --git a/src/udp/frame.rs b/src/udp/frame.rs index 860d421..f04b004 100644 --- a/src/udp/frame.rs +++ b/src/udp/frame.rs @@ -168,13 +168,15 @@ impl VecUdpCodec { pub struct UdpChannel { pub ingress : stream::SplitStream<UdpFramed>, pub egress : mpsc::Sender<PeerServerMessage>, - pub fd : RawFd, + pub fd4 : RawFd, + pub fd6 : RawFd, handle : Handle, } impl From<UdpFramed> for UdpChannel { fn from(framed: UdpFramed) -> Self { - let fd = framed.socket.as_raw_fd(); + let fd4 = framed.socket.as_raw_fd_v4(); + let fd6 = framed.socket.as_raw_fd_v6(); let handle = framed.socket.handle.clone(); let (udp_sink, ingress) = framed.split(); let (egress, egress_rx) = mpsc::channel(1024); @@ -189,7 +191,7 @@ impl From<UdpFramed> for UdpChannel { handle.spawn(udp_writethrough); - UdpChannel { egress, ingress, fd, handle } + UdpChannel { egress, ingress, fd4, fd6, handle } } } @@ -200,7 +202,9 @@ impl UdpChannel { #[cfg(target_os = "linux")] pub fn set_mark(&self, mark: u32) -> Result<(), Error> { - Ok(setsockopt(self.fd, sockopt::Mark, &mark)?) + setsockopt(self.fd4, sockopt::Mark, &mark)?; + setsockopt(self.fd6, sockopt::Mark, &mark)?; + Ok(()) } #[cfg(not(target_os = "linux"))] diff --git a/src/udp/mod.rs b/src/udp/mod.rs index 61851e2..30a0e0d 100644 --- a/src/udp/mod.rs +++ b/src/udp/mod.rs @@ -1,8 +1,8 @@ #![allow(unused)] use std::{fmt, io, mem}; -use std::net::{self, SocketAddr, Ipv4Addr, Ipv6Addr}; -use std::os::unix::io::AsRawFd; +use std::net::{self, SocketAddr, SocketAddrV4, SocketAddrV6, Ipv4Addr, Ipv6Addr}; +use std::os::unix::io::{AsRawFd, RawFd}; use futures::{Async, Future, Poll}; use libc; @@ -15,7 +15,8 @@ use tokio_core::reactor::{Handle, PollEvented}; /// An I/O object representing a UDP socket. pub struct UdpSocket { - io: PollEvented<mio::net::UdpSocket>, + io4: PollEvented<mio::net::UdpSocket>, + io6: PollEvented<mio::net::UdpSocket>, handle: Handle, } @@ -41,12 +42,9 @@ mod frame; pub use self::frame::{UdpChannel, UdpFramed, VecUdpCodec, PeerServerMessage}; impl UdpSocket { - /// Create a new UDP socket bound to the specified address. - /// - /// This function will create a new UDP socket and attempt to bind it to the - /// `addr` provided. If the result is `Ok`, the socket has successfully bound. - pub fn bind(addr: SocketAddr, handle: Handle) -> io::Result<UdpSocket> { - let socket = Socket::new(Domain::ipv6(), Type::dgram(), Some(Protocol::udp()))?; + pub fn bind(port: u16, handle: Handle) -> io::Result<UdpSocket> { + let socket4 = Socket::new(Domain::ipv4(), Type::dgram(), Some(Protocol::udp()))?; + let socket6 = Socket::new(Domain::ipv6(), Type::dgram(), Some(Protocol::udp()))?; let off: libc::c_int = 0; let on: libc::c_int = 1; @@ -64,7 +62,7 @@ impl UdpSocket { // } unsafe { - let ret = libc::setsockopt(socket.as_raw_fd(), + let ret = libc::setsockopt(socket6.as_raw_fd(), libc::IPPROTO_IPV6, IPV6_RECVPKTINFO, &on as *const _ as *const libc::c_void, @@ -77,107 +75,35 @@ impl UdpSocket { debug!("set IPV6_PKTINFO"); } - socket.set_only_v6(false)?; - socket.set_nonblocking(true)?; - socket.set_reuse_port(true)?; - socket.set_reuse_address(true)?; + socket6.set_only_v6(true)?; + socket6.set_nonblocking(true)?; + socket6.set_reuse_port(true)?; + socket6.set_reuse_address(true)?; - socket.bind(&addr.into())?; - Self::from_socket(socket.into_udp_socket(), handle) - } + socket4.bind(&SocketAddrV4::new(Ipv4Addr::unspecified(), port).into())?; + socket6.bind(&SocketAddrV6::new(Ipv6Addr::unspecified(), port, 0, 0).into())?; - fn new(socket: mio::net::UdpSocket, handle: Handle) -> io::Result<UdpSocket> { - let io = PollEvented::new(socket, &handle)?; - Ok(UdpSocket { io, handle }) - } + let socket4 = mio::net::UdpSocket::from_socket(socket4.into_udp_socket())?; + let socket6 = mio::net::UdpSocket::from_socket(socket6.into_udp_socket())?; - /// Creates a new `UdpSocket` from the previously bound socket provided. - /// - /// The socket given will be registered with the event loop that `handle` is - /// associated with. This function requires that `socket` has previously - /// been bound to an address to work correctly. - /// - /// This can be used in conjunction with net2's `UdpBuilder` interface to - /// configure a socket before it's handed off, such as setting options like - /// `reuse_address` or binding to multiple addresses. - pub fn from_socket(socket: net::UdpSocket, handle: Handle) -> io::Result<UdpSocket> { - let udp = mio::net::UdpSocket::from_socket(socket)?; - UdpSocket::new(udp, handle) + let io4 = PollEvented::new(socket4, &handle)?; + let io6 = PollEvented::new(socket6, &handle)?; + Ok(UdpSocket { io4, io6, handle }) } - /// Provides a `Stream` and `Sink` interface for reading and writing to this - /// `UdpSocket` object, using the provided `UdpCodec` to read and write the - /// raw data. - /// - /// Raw UDP sockets work with datagrams, but higher-level code usually - /// wants to batch these into meaningful chunks, called "frames". This - /// method layers framing on top of this socket by using the `UdpCodec` - /// trait to handle encoding and decoding of messages frames. Note that - /// the incoming and outgoing frame types may be distinct. - /// - /// This function returns a *single* object that is both `Stream` and - /// `Sink`; grouping this into a single object is often useful for layering - /// things which require both read and write access to the underlying - /// object. - /// - /// If you want to work more directly with the streams and sink, consider - /// calling `split` on the `UdpFramed` returned by this method, which will - /// break them into separate objects, allowing them to interact more - /// easily. pub fn framed(self) -> UdpFramed { frame::new(self) } /// Returns the local address that this stream is bound to. - pub fn local_addr(&self) -> io::Result<SocketAddr> { - self.io.get_ref().local_addr() + pub fn local_addrs(&self) -> io::Result<(SocketAddr, SocketAddr)> { + Ok((self.io4.get_ref().local_addr()?, self.io6.get_ref().local_addr()?)) } - /// Sends data on the socket to the address previously bound via connect(). - /// On success, returns the number of bytes written. - pub fn send(&self, buf: &[u8]) -> io::Result<usize> { - if let Async::NotReady = self.io.poll_write() { - return Err(io::ErrorKind::WouldBlock.into()) - } - match self.io.get_ref().send(buf) { - Ok(n) => Ok(n), - Err(e) => { - if e.kind() == io::ErrorKind::WouldBlock { - self.io.need_write(); - } - Err(e) - } - } - } - - /// Receives data from the socket previously bound with connect(). - /// On success, returns the number of bytes read. - pub fn recv(&self, buf: &mut [u8]) -> io::Result<usize> { - if let Async::NotReady = self.io.poll_read() { - return Err(io::ErrorKind::WouldBlock.into()) - } - let mut cmsgs = CmsgSpace::<in6_pktinfo>::new(); - let res = recvmsg(self.io.get_ref().as_raw_fd(), - &[IoVec::from_mut_slice(buf)], - Some(&mut cmsgs), - MsgFlags::empty()); - - match res { - Ok(msg) => { - debug!("address: {:?}", msg.address); - Ok(msg.bytes) - }, - Err(nix::Error::Sys(Errno::EAGAIN)) => { - debug!("EAGAIN"); - self.io.need_read(); - Err(io::ErrorKind::WouldBlock.into()) - }, - Err(nix::Error::Sys(errno)) => { - Err(io::Error::last_os_error()) - }, - Err(e) => { - Err(io::Error::new(io::ErrorKind::Other, e)) - } + fn get_io(&self, addr: &SocketAddr) -> &PollEvented<mio::net::UdpSocket> { + match *addr { + SocketAddr::V4(_) => &self.io4, + SocketAddr::V6(_) => &self.io6, } } @@ -188,7 +114,10 @@ impl UdpSocket { /// is only suitable for calling in a `Future::poll` method and will /// automatically handle ensuring a retry once the socket is readable again. pub fn poll_read(&self) -> Async<()> { - self.io.poll_read() + match self.io4.poll_read() { + Async::NotReady => self.io6.poll_read(), + res => res + } } /// Test whether this socket is ready to be written to or not. @@ -198,7 +127,10 @@ impl UdpSocket { /// is only suitable for calling in a `Future::poll` method and will /// automatically handle ensuring a retry once the socket is writable again. pub fn poll_write(&self) -> Async<()> { - self.io.poll_write() + match self.io4.poll_write() { + Async::NotReady => self.io6.poll_write(), + res => res + } } /// Sends data on the socket to the given address. On success, returns the @@ -207,14 +139,15 @@ impl UdpSocket { /// Address type can be any implementer of `ToSocketAddrs` trait. See its /// documentation for concrete examples. pub fn send_to(&self, buf: &[u8], target: &SocketAddr) -> io::Result<usize> { - if let Async::NotReady = self.io.poll_write() { + let io = self.get_io(target); + if let Async::NotReady = io.poll_write() { return Err(io::ErrorKind::WouldBlock.into()) } - match self.io.get_ref().send_to(buf, target) { + match io.get_ref().send_to(buf, target) { Ok(n) => Ok(n), Err(e) => { if e.kind() == io::ErrorKind::WouldBlock { - self.io.need_write(); + io.need_write(); } Err(e) } @@ -224,16 +157,16 @@ impl UdpSocket { /// Receives data from the socket. On success, returns the number of bytes /// read and the address from whence the data came. pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { - if let Async::NotReady = self.io.poll_read() { - return Err(io::ErrorKind::WouldBlock.into()) - } - if let Async::NotReady = self.io.poll_read() { - return Err(io::ErrorKind::WouldBlock.into()) - } - let mut cmsgs = CmsgSpace::<[u8; 1024]>::new(); - let res = recvmsg(self.io.get_ref().as_raw_fd(), + let io = match (self.io4.poll_read(), self.io6.poll_read()) { + (Async::Ready(_), _) => &self.io4, + (_, Async::Ready(_)) => &self.io6, + _ => return Err(io::ErrorKind::WouldBlock.into()), + }; + + let mut cmsgspace = CmsgSpace::<[u8; 1024]>::new(); + let res = recvmsg(io.get_ref().as_raw_fd(), &[IoVec::from_mut_slice(buf)], - Some(&mut cmsgs), + Some(&mut cmsgspace), MsgFlags::empty()); match res { @@ -253,7 +186,7 @@ impl UdpSocket { } }, Err(nix::Error::Sys(Errno::EAGAIN)) => { - self.io.need_read(); + io.need_read(); Err(io::ErrorKind::WouldBlock.into()) }, Err(nix::Error::Sys(errno)) => { @@ -264,36 +197,13 @@ impl UdpSocket { } } } -} -impl fmt::Debug for UdpSocket { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - self.io.get_ref().fmt(f) + fn as_raw_fd_v4(&self) -> RawFd { + self.io4.get_ref().as_raw_fd() } -} -#[cfg(all(unix, not(target_os = "fuchsia")))] -mod sys { - use std::os::unix::prelude::*; - use super::UdpSocket; - - impl AsRawFd for UdpSocket { - fn as_raw_fd(&self) -> RawFd { - self.io.get_ref().as_raw_fd() - } + fn as_raw_fd_v6(&self) -> RawFd { + self.io6.get_ref().as_raw_fd() } } -#[cfg(windows)] -mod sys { - // TODO: let's land these upstream with mio and then we can add them here. - // - // use std::os::windows::prelude::*; - // use super::UdpSocket; - // - // impl AsRawHandle for UdpSocket { - // fn as_raw_handle(&self) -> RawHandle { - // self.io.get_ref().as_raw_handle() - // } - // } -}
\ No newline at end of file |